Spaces:
Running
Running
| import os | |
| from pathlib import Path | |
| from .prediction_helper import ResnetCarDamagePredictor, FusionCarDamagePredictor | |
| CHECKPOINT_DIR = Path(__file__).resolve().parents[1] / "checkpoints" | |
| MODEL_FILES = { | |
| "resnet": "best_resnet_model.pt", | |
| "fusion": "best_fusion_model_fp16.pth", | |
| "yolo": "damage_detector.pt", | |
| } | |
| def get_checkpoint_path(model_key: str) -> Path: | |
| if model_key not in MODEL_FILES: | |
| raise ValueError(f"Unknown model key: {model_key}") | |
| path = CHECKPOINT_DIR / MODEL_FILES[model_key] | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found: {path}") | |
| return path | |
| class ModelLoader: | |
| def __init__(self): | |
| self.base_dir = CHECKPOINT_DIR | |
| def get_model_path(self, model_key: str) -> Path: | |
| return get_checkpoint_path(model_key) | |
| def initialize_models(class_map): | |
| resnet_path = get_checkpoint_path("resnet") | |
| fusion_path = get_checkpoint_path("fusion") | |
| resnet_predictor = ResnetCarDamagePredictor(resnet_path, class_map) | |
| fusion_predictor = FusionCarDamagePredictor(fusion_path, class_map) | |
| return resnet_predictor, fusion_predictor | |