import os from pathlib import Path from .prediction_helper import ResnetCarDamagePredictor, FusionCarDamagePredictor CHECKPOINT_DIR = Path(__file__).resolve().parents[1] / "checkpoints" MODEL_FILES = { "resnet": "best_resnet_model.pt", "fusion": "best_fusion_model_fp16.pth", "yolo": "damage_detector.pt", } def get_checkpoint_path(model_key: str) -> Path: if model_key not in MODEL_FILES: raise ValueError(f"Unknown model key: {model_key}") path = CHECKPOINT_DIR / MODEL_FILES[model_key] if not path.exists(): raise FileNotFoundError(f"Checkpoint not found: {path}") return path class ModelLoader: def __init__(self): self.base_dir = CHECKPOINT_DIR def get_model_path(self, model_key: str) -> Path: return get_checkpoint_path(model_key) def initialize_models(class_map): resnet_path = get_checkpoint_path("resnet") fusion_path = get_checkpoint_path("fusion") resnet_predictor = ResnetCarDamagePredictor(resnet_path, class_map) fusion_predictor = FusionCarDamagePredictor(fusion_path, class_map) return resnet_predictor, fusion_predictor