import torch from typing import Dict, Any, List, Optional, Tuple from pathlib import Path from backend.models.registry import ( build as build_model, get_model_info as get_registry_model_info, choices, ) from backend.config import TARGET_LEN from backend.pydantic_models import ModelInfo class ModelManager: """ Centralized manager for discovering, loading, and caching ML models and their weights. Ensures consistent model loading logic across different services. """ def __init__(self): self._model_cache: Dict[str, Dict[str, Any]] = {} self._weights_cache: Dict[str, torch.nn.Module] = {} self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"✅ ModelManager initialized on {self.device}") def _load_state_dict(self, model_path: Path) -> Optional[Dict]: """Production load: Strict security with weights_only enforcement.""" try: if not model_path.exists(): return None mtime = model_path.stat().st_mtime cache_key = f"{model_path}:{mtime}" if cache_key not in self._weights_cache: # Strictly enforced security load self._weights_cache[cache_key] = torch.load( model_path, map_location=self.device, weights_only=True ) return self._weights_cache[cache_key] except Exception as e: print(f"❌ Security/Load Error for {model_path.name}: {e}") return None def load_model( self, model_name: str, target_len: int = TARGET_LEN ) -> Tuple[torch.nn.Module, bool, Path]: """ Load a trained model for inference, including its weights. Caches the loaded model. Args: model_name (str): Name of the model architecture (from registry). target_len (int): Expected input length for the model. Returns: Tuple[torch.nn.Module, bool, Path]: The loaded model, a boolean indicating if weights were successfully loaded, and the path to the loaded weights. """ # Always use lowercase for filenames model_name_lower = model_name.lower() # Use absolute path for weights directory weights_dir = Path(__file__).parent.parent / "models" / "weights" potential_weight_paths = [ weights_dir / f"{model_name_lower}_model.pth", weights_dir / f"{model_name_lower}.pth", ] if model_name_lower in self._model_cache: model_entry = self._model_cache[model_name_lower] return ( model_entry["model"], model_entry["weights_loaded"], model_entry["weights_path"], ) if model_name_lower not in [m.lower() for m in choices()]: print(f"⚠️ Model '{model_name_lower}' not found in registry.") return None, False, Path("") model = build_model(model_name_lower, target_len) weights_loaded = False loaded_path = Path("") for weight_path in potential_weight_paths: print(f"🔍 Checking for weights at {weight_path}") # Debug log if weight_path.exists(): try: state_dict = self._load_state_dict(weight_path) if state_dict: model.load_state_dict(state_dict, strict=True) model.to(self.device) model.eval() weights_loaded = True loaded_path = weight_path print( f"✅ Loaded weights for {model_name_lower} from {loaded_path}" ) break except (OSError, RuntimeError, KeyError) as e: print( f"❌ Error loading weights for {model_name_lower} from {weight_path}: {e}" ) continue else: print(f"🔍 Weights not found for {model_name_lower} at {weight_path}") if not weights_loaded: print( f"⚠️ No weights loaded for model '{model_name_lower}'. Model will use random initialization." ) model.to(self.device) model.eval() # Ensure model is in eval mode even if no weights loaded self._model_cache[model_name_lower] = { "model": model, "weights_loaded": weights_loaded, "weights_path": loaded_path, "target_len": target_len, "device": self.device, } return model, weights_loaded, loaded_path def get_model_info(self, model_name: str) -> Optional[Dict[str, Any]]: """Get detailed information for a specific model.""" if model_name not in choices(): return None info = get_registry_model_info(model_name) # Add runtime info if model is loaded if model_name in self._model_cache: cached_info = self._model_cache[model_name] info["weights_loaded"] = cached_info["weights_loaded"] info["weights_path"] = str(cached_info["weights_path"]) info["device"] = str(cached_info["device"]) info["available"] = True else: # Check if weights exist even if not loaded yet model_name = model_name.lower() weights_exist = any( (Path("backend/models/weights") / f"{model_name}_model.pth").exists() or (Path("backend/models/weights") / f"{model_name}.pth").exists() for _ in [0] ) # Dummy loop to check both paths info["weights_loaded"] = False info["weights_path"] = None info["device"] = str(self.device) # Mark as available if weights are present info["available"] = weights_exist return info def get_available_models(self) -> List[ModelInfo]: """Get a list of all models with their availability status.""" models_list = [] for model_name in choices(): info = self.get_model_info(model_name) if info: models_list.append( ModelInfo( name=model_name, description=info.get("description", ""), input_length=info.get("input_length", TARGET_LEN), num_classes=info.get("num_classes", 2), supported_modalities=info.get("modalities", ["raman", "ftir"]), performance=info.get("performance", {}), parameters=info.get("parameters"), speed=info.get("speed"), citation=info.get("citation"), # Use the 'available' status from get_model_info available=info.get("available", False), ) ) return models_list # Global instance of the ModelManager model_manager = ModelManager()