new_car / scripts /model_loader.py
junaid17's picture
Initial commit: DamageLens project
c5377b5
import os
import torch
from pathlib import Path
from huggingface_hub import hf_hub_download, HfApi
from dotenv import load_dotenv
import yaml
load_dotenv()
class ModelLoader:
def __init__(self):
self.hf_repo = os.getenv("HF_MODEL_REPO", "junaid17/damagelens-models")
self.hf_token = os.getenv("HF_TOKEN", None)
self.cache_dir = Path(os.getenv("MODEL_CACHE_DIR", "./model_cache"))
self.cache_dir.mkdir(parents=True, exist_ok=True)
# Load model config
with open("model_config.yaml", "r") as f:
self.config = yaml.safe_load(f)
def download_model(self, model_type: str) -> str:
"""
Download a model from HuggingFace Hub and return local path.
Args:
model_type: 'resnet', 'fusion', or 'yolo'
Returns:
Path to cached model file
"""
if model_type not in self.config["models"]:
raise ValueError(f"Unknown model type: {model_type}")
model_info = self.config["models"][model_type]
filename = model_info["filename"]
# Check if already cached locally
local_path = self.cache_dir / filename
if local_path.exists():
print(f"βœ“ Using cached {model_type} model: {local_path}")
return str(local_path)
# Download from HuggingFace Hub
print(f"⏳ Downloading {model_type} model from {self.hf_repo}...")
try:
downloaded_path = hf_hub_download(
repo_id=self.hf_repo,
filename=filename,
cache_dir=str(self.cache_dir),
token=self.hf_token,
resume_download=True
)
print(f"βœ“ Downloaded {model_type} model: {downloaded_path}")
return downloaded_path
except Exception as e:
print(f"❌ Failed to download {model_type} model: {str(e)}")
raise RuntimeError(f"Could not load {model_type} model from HuggingFace Hub: {str(e)}")
def load_checkpoint(self, model_type: str, device: str = None) -> dict:
"""
Download and load a model checkpoint.
Args:
model_type: 'resnet', 'fusion', or 'yolo'
device: torch device (auto-detect if None)
Returns:
Model state dict
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = self.download_model(model_type)
checkpoint = torch.load(model_path, map_location=device)
return checkpoint
def get_model_path(self, model_type: str) -> str:
"""Get local path to model (downloads if needed)."""
return self.download_model(model_type)
def initialize_models(class_map: dict):
"""
Initialize all models from HuggingFace Hub.
Returns predictor instances ready for inference.
"""
from scripts.prediction_helper import ResnetCarDamagePredictor, FusionCarDamagePredictor
loader = ModelLoader()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("\n" + "="*50)
print("πŸš€ Initializing Models from HuggingFace Hub")
print("="*50)
# Load ResNet model
resnet_path = loader.get_model_path("resnet")
resnet_model = ResnetCarDamagePredictor(resnet_path, class_map)
print("βœ“ ResNet model loaded successfully")
# Load Fusion model
fusion_path = loader.get_model_path("fusion")
fusion_model = FusionCarDamagePredictor(fusion_path, class_map)
print("βœ“ Fusion model loaded successfully")
print("="*50 + "\n")
return resnet_model, fusion_model, loader
if __name__ == "__main__":
# Test the loader
loader = ModelLoader()
print(f"Repository: {loader.hf_repo}")
print(f"Cache dir: {loader.cache_dir}")
print(f"Models config: {loader.config['models'].keys()}")