| import torch |
| from models import CNNModel_Small, CNNModel_Medium, CNNModel_Large |
| from config import settings |
| import os |
|
|
| ARCHITECTURE_MAP = { |
| "small": CNNModel_Small, |
| "medium": CNNModel_Medium, |
| "large": CNNModel_Large |
| } |
|
|
|
|
| def _load_model(config: dict, model_type: str, model_name: str): |
| """Generic helper function to load a finetuned model based on its config.""" |
| model_class = ARCHITECTURE_MAP[config["architecture"]] |
|
|
| model_path = os.path.join(settings.MODELS_DIR, f"{model_name}_model_finetuned.pth") |
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Fine-tuned model not found at '{model_path}'. Please run the training script.") |
|
|
| num_classes = config["num_classes"] |
| model = model_class(num_classes=num_classes) |
| model.load_state_dict(torch.load(model_path, map_location=settings.DEVICE)) |
| model.to(settings.DEVICE) |
| model.eval() |
|
|
| print(f"Successfully loaded fine-tuned {model_type} model: '{model_name}' ({config['architecture']})") |
| return model |
|
|
|
|
| def load_all_models() -> dict: |
| """ |
| Loads the finetuned triage model and all three expert models. |
| """ |
| print("Loading all fine-tuned models for the OCR pipeline...") |
|
|
| models = {"triage": _load_model(settings.TRIAGE_CONFIG, "triage", "triage")} |
|
|
| for name, config in settings.EXPERT_CONFIG.items(): |
| models[name] = _load_model(config, f"expert", name) |
|
|
| print("All fine-tuned models loaded and ready.") |
| return models |
|
|
|
|