Digital-Image-Processing-OCR / src /evaluate_models.py
chiruu12
Initial commit of clean OCR application
9543569
Raw
History Blame Contribute Delete
6.59 kB
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
import os
import numpy as np
import h5py
from tqdm import tqdm
from sklearn.metrics import classification_report, accuracy_score
from models import CNNModel_Small, CNNModel_Medium, CNNModel_Large
DATA_FILE = "data/book_dataset.h5"
MODEL_DIR = "models/saved_weights_finetuned/"
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
BATCH_SIZE = 256
NUM_SAMPLES_PER_CLASS = 1000
class HDF5Dataset(Dataset):
def __init__(self, h5_path, transform=None):
self.h5_path = h5_path
self.transform = transform
with h5py.File(self.h5_path, 'r') as hf:
self.images = hf['images'][:]
self.labels = hf['labels'][:]
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
image, label = self.images[idx], self.labels[idx]
if self.transform: image = self.transform(image)
return image, label
def load_finetuned_model(name, model_class, num_classes):
model_path = os.path.join(MODEL_DIR, f"{name}_model_finetuned.pth")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found at '{model_path}'. Please run the fine-tuning script.")
model = model_class(num_classes=num_classes).to(DEVICE)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
return model
def create_balanced_test_set():
print(f"Creating a balanced test set with up to {NUM_SAMPLES_PER_CLASS} samples per class...")
with h5py.File(DATA_FILE, 'r') as hf:
all_labels = hf['labels'][:]
digit_indices = np.where((all_labels >= 48) & (all_labels <= 57))[0]
upper_indices = np.where((all_labels >= 65) & (all_labels <= 90))[0]
lower_indices = np.where((all_labels >= 97) & (all_labels <= 122))[0]
digit_samples = np.random.choice(digit_indices, min(NUM_SAMPLES_PER_CLASS, len(digit_indices)), replace=False)
upper_samples = np.random.choice(upper_indices, min(NUM_SAMPLES_PER_CLASS, len(upper_indices)), replace=False)
lower_samples = np.random.choice(lower_indices, min(NUM_SAMPLES_PER_CLASS, len(lower_indices)), replace=False)
test_indices = np.concatenate([digit_samples, upper_samples, lower_samples])
np.random.shuffle(test_indices)
full_dataset = HDF5Dataset(DATA_FILE, transform=transforms.ToTensor())
test_subset = Subset(full_dataset, test_indices)
print(f"Test set created with {len(test_subset)} total samples.")
return test_subset
def main():
print("--- Starting Full Model Evaluation ---")
print(f"Using device: {DEVICE}")
print("\nLoading all fine-tuned models...")
models = {
'triage': load_finetuned_model('triage', CNNModel_Large, 3),
'digits': load_finetuned_model('digits', CNNModel_Small, 10),
'uppercase': load_finetuned_model('uppercase', CNNModel_Medium, 26),
'lowercase': load_finetuned_model('lowercase', CNNModel_Medium, 26)
}
print("All models loaded.")
test_dataset = create_balanced_test_set()
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
all_true_labels, all_pred_labels = [], []
all_true_triage, all_pred_triage = [], []
expert_true = {'digits': [], 'uppercase': [], 'lowercase': []}
expert_preds = {'digits': [], 'uppercase': [], 'lowercase': []}
triage_map = {0: 'digits', 1: 'uppercase', 2: 'lowercase'}
expert_remaps = {
'digits': {code: i for i, code in enumerate(range(48, 58))},
'uppercase': {code: i for i, code in enumerate(range(65, 91))},
'lowercase': {code: i for i, code in enumerate(range(97, 123))}
}
with torch.no_grad():
for images, labels in tqdm(test_loader, desc="Evaluating Full System"):
images = images.to(DEVICE)
triage_out = models['triage'](images)
_, pred_triage_indices = torch.max(triage_out, 1)
for i in range(len(images)):
true_label_code = labels[i].item()
if 48 <= true_label_code <= 57:
true_triage_class_idx, true_expert_name = 0, 'digits'
elif 65 <= true_label_code <= 90:
true_triage_class_idx, true_expert_name = 1, 'uppercase'
else:
true_triage_class_idx, true_expert_name = 2, 'lowercase'
pred_triage_class_idx = pred_triage_indices[i].item()
expert_to_use = triage_map[pred_triage_class_idx]
expert_out = models[expert_to_use](images[i].unsqueeze(0))
_, pred_expert_idx = torch.max(expert_out, 1)
remap = expert_remaps[expert_to_use]
inv_remap = {v: k for k, v in remap.items()}
pred_label_code = inv_remap.get(pred_expert_idx.item(), -1)
all_true_labels.append(true_label_code)
all_pred_labels.append(pred_label_code)
all_true_triage.append(true_triage_class_idx)
all_pred_triage.append(pred_triage_class_idx)
correct_expert_model = models[true_expert_name]
correct_expert_out = correct_expert_model(images[i].unsqueeze(0))
_, pred_correct_expert_idx = torch.max(correct_expert_out, 1)
true_expert_label = expert_remaps[true_expert_name][true_label_code]
expert_true[true_expert_name].append(true_expert_label)
expert_preds[true_expert_name].append(pred_correct_expert_idx.item())
print("\n\n" + "=" * 50)
print(" OCR EVALUATION REPORT")
print("=" * 50)
overall_accuracy = accuracy_score(all_true_labels, all_pred_labels)
print("\n--- 1. Overall OCR System Accuracy ---")
print(f" End-to-End Accuracy: {overall_accuracy:.4f} ({overall_accuracy * 100:.2f}%)")
print("\n--- 2. Triage Model Performance (Class-wise) ---")
triage_class_names = ['digits', 'uppercase', 'lowercase']
print(classification_report(all_true_triage, all_pred_triage, target_names=triage_class_names))
print("\n--- 3. Individual Expert Accuracy ---")
print("(This measures accuracy assuming the Triage model was correct)")
for name in ['digits', 'uppercase', 'lowercase']:
expert_acc = accuracy_score(expert_true[name], expert_preds[name])
print(f" - {name.capitalize()} Expert Accuracy: {expert_acc:.4f} ({expert_acc * 100:.2f}%)")
print("\n" + "=" * 50)
if __name__ == "__main__":
main()