polymer-aging-with-ml / backend /utils /model_manager.py
devjas1
Initial Release: Polymer Aging With ML [Standalone Appliance]
4a0e21d
Raw
History Blame Contribute Delete
7.39 kB
import torch
from typing import Dict, Any, List, Optional, Tuple
from pathlib import Path
from backend.models.registry import (
build as build_model,
get_model_info as get_registry_model_info,
choices,
)
from backend.config import TARGET_LEN
from backend.pydantic_models import ModelInfo
class ModelManager:
"""
Centralized manager for discovering, loading, and caching ML models and their weights.
Ensures consistent model loading logic across different services.
"""
def __init__(self):
self._model_cache: Dict[str, Dict[str, Any]] = {}
self._weights_cache: Dict[str, torch.nn.Module] = {}
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ ModelManager initialized on {self.device}")
def _load_state_dict(self, model_path: Path) -> Optional[Dict]:
"""Production load: Strict security with weights_only enforcement."""
try:
if not model_path.exists():
return None
mtime = model_path.stat().st_mtime
cache_key = f"{model_path}:{mtime}"
if cache_key not in self._weights_cache:
# Strictly enforced security load
self._weights_cache[cache_key] = torch.load(
model_path, map_location=self.device, weights_only=True
)
return self._weights_cache[cache_key]
except Exception as e:
print(f"❌ Security/Load Error for {model_path.name}: {e}")
return None
def load_model(
self, model_name: str, target_len: int = TARGET_LEN
) -> Tuple[torch.nn.Module, bool, Path]:
"""
Load a trained model for inference, including its weights.
Caches the loaded model.
Args:
model_name (str): Name of the model architecture (from registry).
target_len (int): Expected input length for the model.
Returns:
Tuple[torch.nn.Module, bool, Path]: The loaded model, a boolean indicating
if weights were successfully loaded, and the path to the loaded weights.
"""
# Always use lowercase for filenames
model_name_lower = model_name.lower()
# Use absolute path for weights directory
weights_dir = Path(__file__).parent.parent / "models" / "weights"
potential_weight_paths = [
weights_dir / f"{model_name_lower}_model.pth",
weights_dir / f"{model_name_lower}.pth",
]
if model_name_lower in self._model_cache:
model_entry = self._model_cache[model_name_lower]
return (
model_entry["model"],
model_entry["weights_loaded"],
model_entry["weights_path"],
)
if model_name_lower not in [m.lower() for m in choices()]:
print(f"⚠️ Model '{model_name_lower}' not found in registry.")
return None, False, Path("")
model = build_model(model_name_lower, target_len)
weights_loaded = False
loaded_path = Path("")
for weight_path in potential_weight_paths:
print(f"🔍 Checking for weights at {weight_path}") # Debug log
if weight_path.exists():
try:
state_dict = self._load_state_dict(weight_path)
if state_dict:
model.load_state_dict(state_dict, strict=True)
model.to(self.device)
model.eval()
weights_loaded = True
loaded_path = weight_path
print(
f"✅ Loaded weights for {model_name_lower} from {loaded_path}"
)
break
except (OSError, RuntimeError, KeyError) as e:
print(
f"❌ Error loading weights for {model_name_lower} from {weight_path}: {e}"
)
continue
else:
print(f"🔍 Weights not found for {model_name_lower} at {weight_path}")
if not weights_loaded:
print(
f"⚠️ No weights loaded for model '{model_name_lower}'. Model will use random initialization."
)
model.to(self.device)
model.eval() # Ensure model is in eval mode even if no weights loaded
self._model_cache[model_name_lower] = {
"model": model,
"weights_loaded": weights_loaded,
"weights_path": loaded_path,
"target_len": target_len,
"device": self.device,
}
return model, weights_loaded, loaded_path
def get_model_info(self, model_name: str) -> Optional[Dict[str, Any]]:
"""Get detailed information for a specific model."""
if model_name not in choices():
return None
info = get_registry_model_info(model_name)
# Add runtime info if model is loaded
if model_name in self._model_cache:
cached_info = self._model_cache[model_name]
info["weights_loaded"] = cached_info["weights_loaded"]
info["weights_path"] = str(cached_info["weights_path"])
info["device"] = str(cached_info["device"])
info["available"] = True
else:
# Check if weights exist even if not loaded yet
model_name = model_name.lower()
weights_exist = any(
(Path("backend/models/weights") / f"{model_name}_model.pth").exists()
or (Path("backend/models/weights") / f"{model_name}.pth").exists()
for _ in [0]
) # Dummy loop to check both paths
info["weights_loaded"] = False
info["weights_path"] = None
info["device"] = str(self.device)
# Mark as available if weights are present
info["available"] = weights_exist
return info
def get_available_models(self) -> List[ModelInfo]:
"""Get a list of all models with their availability status."""
models_list = []
for model_name in choices():
info = self.get_model_info(model_name)
if info:
models_list.append(
ModelInfo(
name=model_name,
description=info.get("description", ""),
input_length=info.get("input_length", TARGET_LEN),
num_classes=info.get("num_classes", 2),
supported_modalities=info.get("modalities", ["raman", "ftir"]),
performance=info.get("performance", {}),
parameters=info.get("parameters"),
speed=info.get("speed"),
citation=info.get("citation"),
# Use the 'available' status from get_model_info
available=info.get("available", False),
)
)
return models_list
# Global instance of the ModelManager
model_manager = ModelManager()