Digital-Image-Processing-OCR / src /model_loader.py
chiruu12
Initial commit of clean OCR application
9543569
Raw
History Blame Contribute Delete
1.48 kB
import torch
from models import CNNModel_Small, CNNModel_Medium, CNNModel_Large
from config import settings
import os
ARCHITECTURE_MAP = {
"small": CNNModel_Small,
"medium": CNNModel_Medium,
"large": CNNModel_Large
}
def _load_model(config: dict, model_type: str, model_name: str):
"""Generic helper function to load a finetuned model based on its config."""
model_class = ARCHITECTURE_MAP[config["architecture"]]
model_path = os.path.join(settings.MODELS_DIR, f"{model_name}_model_finetuned.pth")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Fine-tuned model not found at '{model_path}'. Please run the training script.")
num_classes = config["num_classes"]
model = model_class(num_classes=num_classes)
model.load_state_dict(torch.load(model_path, map_location=settings.DEVICE))
model.to(settings.DEVICE)
model.eval()
print(f"Successfully loaded fine-tuned {model_type} model: '{model_name}' ({config['architecture']})")
return model
def load_all_models() -> dict:
"""
Loads the finetuned triage model and all three expert models.
"""
print("Loading all fine-tuned models for the OCR pipeline...")
models = {"triage": _load_model(settings.TRIAGE_CONFIG, "triage", "triage")}
for name, config in settings.EXPERT_CONFIG.items():
models[name] = _load_model(config, f"expert", name)
print("All fine-tuned models loaded and ready.")
return models