File size: 1,175 Bytes
1ae016f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
from pathlib import Path
from .prediction_helper import ResnetCarDamagePredictor, FusionCarDamagePredictor

CHECKPOINT_DIR = Path(__file__).resolve().parents[1] / "checkpoints"
MODEL_FILES = {
    "resnet": "best_resnet_model.pt",
    "fusion": "best_fusion_model_fp16.pth",
    "yolo": "damage_detector.pt",
}


def get_checkpoint_path(model_key: str) -> Path:
    if model_key not in MODEL_FILES:
        raise ValueError(f"Unknown model key: {model_key}")

    path = CHECKPOINT_DIR / MODEL_FILES[model_key]
    if not path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {path}")
    return path


class ModelLoader:
    def __init__(self):
        self.base_dir = CHECKPOINT_DIR

    def get_model_path(self, model_key: str) -> Path:
        return get_checkpoint_path(model_key)


def initialize_models(class_map):
    resnet_path = get_checkpoint_path("resnet")
    fusion_path = get_checkpoint_path("fusion")

    resnet_predictor = ResnetCarDamagePredictor(resnet_path, class_map)
    fusion_predictor = FusionCarDamagePredictor(fusion_path, class_map)

    return resnet_predictor, fusion_predictor