| import os |
| import torch |
| from pathlib import Path |
| from huggingface_hub import hf_hub_download, HfApi |
| from dotenv import load_dotenv |
| import yaml |
|
|
| load_dotenv() |
|
|
| class ModelLoader: |
| def __init__(self): |
| self.hf_repo = os.getenv("HF_MODEL_REPO", "junaid17/damagelens-models") |
| self.hf_token = os.getenv("HF_TOKEN", None) |
| self.cache_dir = Path(os.getenv("MODEL_CACHE_DIR", "./model_cache")) |
| self.cache_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| with open("model_config.yaml", "r") as f: |
| self.config = yaml.safe_load(f) |
| |
| def download_model(self, model_type: str) -> str: |
| """ |
| Download a model from HuggingFace Hub and return local path. |
| |
| Args: |
| model_type: 'resnet', 'fusion', or 'yolo' |
| |
| Returns: |
| Path to cached model file |
| """ |
| if model_type not in self.config["models"]: |
| raise ValueError(f"Unknown model type: {model_type}") |
| |
| model_info = self.config["models"][model_type] |
| filename = model_info["filename"] |
| |
| |
| local_path = self.cache_dir / filename |
| if local_path.exists(): |
| print(f"β Using cached {model_type} model: {local_path}") |
| return str(local_path) |
| |
| |
| print(f"β³ Downloading {model_type} model from {self.hf_repo}...") |
| try: |
| downloaded_path = hf_hub_download( |
| repo_id=self.hf_repo, |
| filename=filename, |
| cache_dir=str(self.cache_dir), |
| token=self.hf_token, |
| resume_download=True |
| ) |
| print(f"β Downloaded {model_type} model: {downloaded_path}") |
| return downloaded_path |
| |
| except Exception as e: |
| print(f"β Failed to download {model_type} model: {str(e)}") |
| raise RuntimeError(f"Could not load {model_type} model from HuggingFace Hub: {str(e)}") |
| |
| def load_checkpoint(self, model_type: str, device: str = None) -> dict: |
| """ |
| Download and load a model checkpoint. |
| |
| Args: |
| model_type: 'resnet', 'fusion', or 'yolo' |
| device: torch device (auto-detect if None) |
| |
| Returns: |
| Model state dict |
| """ |
| if device is None: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| model_path = self.download_model(model_type) |
| checkpoint = torch.load(model_path, map_location=device) |
| return checkpoint |
| |
| def get_model_path(self, model_type: str) -> str: |
| """Get local path to model (downloads if needed).""" |
| return self.download_model(model_type) |
|
|
|
|
| def initialize_models(class_map: dict): |
| """ |
| Initialize all models from HuggingFace Hub. |
| Returns predictor instances ready for inference. |
| """ |
| from scripts.prediction_helper import ResnetCarDamagePredictor, FusionCarDamagePredictor |
| |
| loader = ModelLoader() |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| print("\n" + "="*50) |
| print("π Initializing Models from HuggingFace Hub") |
| print("="*50) |
| |
| |
| resnet_path = loader.get_model_path("resnet") |
| resnet_model = ResnetCarDamagePredictor(resnet_path, class_map) |
| print("β ResNet model loaded successfully") |
| |
| |
| fusion_path = loader.get_model_path("fusion") |
| fusion_model = FusionCarDamagePredictor(fusion_path, class_map) |
| print("β Fusion model loaded successfully") |
| |
| print("="*50 + "\n") |
| |
| return resnet_model, fusion_model, loader |
|
|
|
|
| if __name__ == "__main__": |
| |
| loader = ModelLoader() |
| print(f"Repository: {loader.hf_repo}") |
| print(f"Cache dir: {loader.cache_dir}") |
| print(f"Models config: {loader.config['models'].keys()}") |
|
|