| import json | |
| import torch | |
| from pathlib import Path | |
| from typing import Dict, Any | |
| from src.app.model_utils.factory import ModelFactory | |
| class ModelManager: | |
| """ | |
| Manages model loading and inference operations | |
| Args: | |
| model_dir: Directory containing model artifacts | |
| """ | |
| def __init__(self, model_dir: str = "../pretrained") -> None: | |
| self.model_dir = Path(model_dir) | |
| self.loaded_models: Dict[str, Any] = {} | |
| self._load_model_artifacts() | |
| def _load_model_artifacts(self) -> None: | |
| """ | |
| Load model configuration and vocabulary | |
| """ | |
| with open(self.model_dir / "config.json", "r") as f: | |
| self.config = json.load(f) | |
| with open(self.model_dir / "vocab.json", "r") as f: | |
| self.vocab = json.load(f) | |
| self.idx_to_label = {0: "Negative", 1: "Positive"} | |
| def get_model(self) -> torch.nn.Module: | |
| """ | |
| Get the loaded model (cached for performance) | |
| Returns: | |
| Loaded PyTorch model in evaluation mode | |
| """ | |
| model_type = self.config["model_type"] | |
| if model_type not in self.loaded_models: | |
| model = ModelFactory.create_model( | |
| model_type=model_type, | |
| model_params=self.config["model_params"], | |
| state_dict_path=self.model_dir / "best_model.pth" | |
| ) | |
| self.loaded_models[model_type] = model | |
| return self.loaded_models[model_type] | |
| def get_vocab(self) -> Dict[str, int]: | |
| """ | |
| Get vocabulary mapping | |
| """ | |
| return self.vocab | |
| def get_config(self) -> Dict[str, Any]: | |
| """ | |
| Get model configuration | |
| """ | |
| return self.config | |