| import torch |
| from torch.utils.data import Dataset, DataLoader, Subset |
| from torchvision import transforms |
| import os |
| import numpy as np |
| import h5py |
| from tqdm import tqdm |
| from sklearn.metrics import classification_report, accuracy_score |
|
|
| from models import CNNModel_Small, CNNModel_Medium, CNNModel_Large |
|
|
| DATA_FILE = "data/book_dataset.h5" |
| MODEL_DIR = "models/saved_weights_finetuned/" |
| DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
| BATCH_SIZE = 256 |
| NUM_SAMPLES_PER_CLASS = 1000 |
|
|
| class HDF5Dataset(Dataset): |
| def __init__(self, h5_path, transform=None): |
| self.h5_path = h5_path |
| self.transform = transform |
| with h5py.File(self.h5_path, 'r') as hf: |
| self.images = hf['images'][:] |
| self.labels = hf['labels'][:] |
|
|
| def __len__(self): |
| return len(self.labels) |
|
|
| def __getitem__(self, idx): |
| image, label = self.images[idx], self.labels[idx] |
| if self.transform: image = self.transform(image) |
| return image, label |
|
|
|
|
| def load_finetuned_model(name, model_class, num_classes): |
| model_path = os.path.join(MODEL_DIR, f"{name}_model_finetuned.pth") |
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Model file not found at '{model_path}'. Please run the fine-tuning script.") |
|
|
| model = model_class(num_classes=num_classes).to(DEVICE) |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) |
| model.eval() |
| return model |
|
|
|
|
| def create_balanced_test_set(): |
| print(f"Creating a balanced test set with up to {NUM_SAMPLES_PER_CLASS} samples per class...") |
| with h5py.File(DATA_FILE, 'r') as hf: |
| all_labels = hf['labels'][:] |
|
|
| digit_indices = np.where((all_labels >= 48) & (all_labels <= 57))[0] |
| upper_indices = np.where((all_labels >= 65) & (all_labels <= 90))[0] |
| lower_indices = np.where((all_labels >= 97) & (all_labels <= 122))[0] |
|
|
| digit_samples = np.random.choice(digit_indices, min(NUM_SAMPLES_PER_CLASS, len(digit_indices)), replace=False) |
| upper_samples = np.random.choice(upper_indices, min(NUM_SAMPLES_PER_CLASS, len(upper_indices)), replace=False) |
| lower_samples = np.random.choice(lower_indices, min(NUM_SAMPLES_PER_CLASS, len(lower_indices)), replace=False) |
|
|
| test_indices = np.concatenate([digit_samples, upper_samples, lower_samples]) |
| np.random.shuffle(test_indices) |
|
|
| full_dataset = HDF5Dataset(DATA_FILE, transform=transforms.ToTensor()) |
| test_subset = Subset(full_dataset, test_indices) |
| print(f"Test set created with {len(test_subset)} total samples.") |
| return test_subset |
|
|
|
|
| def main(): |
| print("--- Starting Full Model Evaluation ---") |
| print(f"Using device: {DEVICE}") |
|
|
| print("\nLoading all fine-tuned models...") |
| models = { |
| 'triage': load_finetuned_model('triage', CNNModel_Large, 3), |
| 'digits': load_finetuned_model('digits', CNNModel_Small, 10), |
| 'uppercase': load_finetuned_model('uppercase', CNNModel_Medium, 26), |
| 'lowercase': load_finetuned_model('lowercase', CNNModel_Medium, 26) |
| } |
| print("All models loaded.") |
|
|
| test_dataset = create_balanced_test_set() |
| test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) |
|
|
| all_true_labels, all_pred_labels = [], [] |
| all_true_triage, all_pred_triage = [], [] |
| expert_true = {'digits': [], 'uppercase': [], 'lowercase': []} |
| expert_preds = {'digits': [], 'uppercase': [], 'lowercase': []} |
|
|
| triage_map = {0: 'digits', 1: 'uppercase', 2: 'lowercase'} |
| expert_remaps = { |
| 'digits': {code: i for i, code in enumerate(range(48, 58))}, |
| 'uppercase': {code: i for i, code in enumerate(range(65, 91))}, |
| 'lowercase': {code: i for i, code in enumerate(range(97, 123))} |
| } |
|
|
| with torch.no_grad(): |
| for images, labels in tqdm(test_loader, desc="Evaluating Full System"): |
| images = images.to(DEVICE) |
|
|
| triage_out = models['triage'](images) |
| _, pred_triage_indices = torch.max(triage_out, 1) |
|
|
| for i in range(len(images)): |
| true_label_code = labels[i].item() |
|
|
| if 48 <= true_label_code <= 57: |
| true_triage_class_idx, true_expert_name = 0, 'digits' |
| elif 65 <= true_label_code <= 90: |
| true_triage_class_idx, true_expert_name = 1, 'uppercase' |
| else: |
| true_triage_class_idx, true_expert_name = 2, 'lowercase' |
|
|
| pred_triage_class_idx = pred_triage_indices[i].item() |
| expert_to_use = triage_map[pred_triage_class_idx] |
|
|
| expert_out = models[expert_to_use](images[i].unsqueeze(0)) |
| _, pred_expert_idx = torch.max(expert_out, 1) |
|
|
| remap = expert_remaps[expert_to_use] |
| inv_remap = {v: k for k, v in remap.items()} |
| pred_label_code = inv_remap.get(pred_expert_idx.item(), -1) |
|
|
| all_true_labels.append(true_label_code) |
| all_pred_labels.append(pred_label_code) |
| all_true_triage.append(true_triage_class_idx) |
| all_pred_triage.append(pred_triage_class_idx) |
|
|
| correct_expert_model = models[true_expert_name] |
| correct_expert_out = correct_expert_model(images[i].unsqueeze(0)) |
| _, pred_correct_expert_idx = torch.max(correct_expert_out, 1) |
|
|
| true_expert_label = expert_remaps[true_expert_name][true_label_code] |
| expert_true[true_expert_name].append(true_expert_label) |
| expert_preds[true_expert_name].append(pred_correct_expert_idx.item()) |
|
|
| print("\n\n" + "=" * 50) |
| print(" OCR EVALUATION REPORT") |
| print("=" * 50) |
|
|
| overall_accuracy = accuracy_score(all_true_labels, all_pred_labels) |
| print("\n--- 1. Overall OCR System Accuracy ---") |
| print(f" End-to-End Accuracy: {overall_accuracy:.4f} ({overall_accuracy * 100:.2f}%)") |
|
|
| print("\n--- 2. Triage Model Performance (Class-wise) ---") |
| triage_class_names = ['digits', 'uppercase', 'lowercase'] |
| print(classification_report(all_true_triage, all_pred_triage, target_names=triage_class_names)) |
|
|
| print("\n--- 3. Individual Expert Accuracy ---") |
| print("(This measures accuracy assuming the Triage model was correct)") |
| for name in ['digits', 'uppercase', 'lowercase']: |
| expert_acc = accuracy_score(expert_true[name], expert_preds[name]) |
| print(f" - {name.capitalize()} Expert Accuracy: {expert_acc:.4f} ({expert_acc * 100:.2f}%)") |
|
|
| print("\n" + "=" * 50) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|