DamageLensAI / scripts /model_loader.py
junaid17's picture
Upload 15 files
1ae016f verified
import os
from pathlib import Path
from .prediction_helper import ResnetCarDamagePredictor, FusionCarDamagePredictor
CHECKPOINT_DIR = Path(__file__).resolve().parents[1] / "checkpoints"
MODEL_FILES = {
"resnet": "best_resnet_model.pt",
"fusion": "best_fusion_model_fp16.pth",
"yolo": "damage_detector.pt",
}
def get_checkpoint_path(model_key: str) -> Path:
if model_key not in MODEL_FILES:
raise ValueError(f"Unknown model key: {model_key}")
path = CHECKPOINT_DIR / MODEL_FILES[model_key]
if not path.exists():
raise FileNotFoundError(f"Checkpoint not found: {path}")
return path
class ModelLoader:
def __init__(self):
self.base_dir = CHECKPOINT_DIR
def get_model_path(self, model_key: str) -> Path:
return get_checkpoint_path(model_key)
def initialize_models(class_map):
resnet_path = get_checkpoint_path("resnet")
fusion_path = get_checkpoint_path("fusion")
resnet_predictor = ResnetCarDamagePredictor(resnet_path, class_map)
fusion_predictor = FusionCarDamagePredictor(fusion_path, class_map)
return resnet_predictor, fusion_predictor