Spaces:
Running
Running
| # DEPENDENCIES | |
| import os | |
| import gc | |
| import torch | |
| from typing import Optional | |
| from config.settings import get_settings | |
| from config.logging_config import get_logger | |
| from utils.error_handler import handle_errors | |
| from utils.error_handler import EmbeddingError | |
| from sentence_transformers import SentenceTransformer | |
| # Setup Settings and Logging | |
| settings = get_settings() | |
| logger = get_logger(__name__) | |
| class EmbeddingModelLoader: | |
| """ | |
| Manages loading and caching of embedding models: Supports multiple models with efficient resource management | |
| """ | |
| def __init__(self): | |
| self.logger = logger | |
| self._loaded_model = None | |
| self._model_name = None | |
| self._device = None | |
| # Model cache for multiple models | |
| self._model_cache = dict() | |
| def load_model(self, model_name: Optional[str] = None, device: Optional[str] = None, force_reload: bool = False) -> SentenceTransformer: | |
| """ | |
| Load embedding model with caching and device optimization | |
| Arguments: | |
| ---------- | |
| model_name { str } : Name of model to load (default from settings) | |
| device { str } : Device to load on ('cpu', 'cuda', 'mps', 'auto') | |
| force_reload { bool } : Force reload even if model is cached | |
| Returns: | |
| -------- | |
| { SentenceTransformer } : Loaded model instance | |
| Raises: | |
| ------- | |
| EmbeddingError : If model loading fails | |
| """ | |
| model_name = model_name or settings.EMBEDDING_MODEL | |
| device = self._resolve_device(device) | |
| # Check cache first | |
| cache_key = f"{model_name}_{device}" | |
| if ((not force_reload) and (cache_key in self._model_cache)): | |
| self.logger.debug(f"Using cached model: {cache_key}") | |
| self._loaded_model = self._model_cache[cache_key] | |
| self._model_name = model_name | |
| self._device = device | |
| return self._loaded_model | |
| try: | |
| self.logger.info(f"Loading embedding model: {model_name} on device: {device}") | |
| # Load model with optimized settings | |
| model = SentenceTransformer(model_name, | |
| device = device, | |
| cache_folder = os.path.expanduser("~/.cache/sentence_transformers"), | |
| ) | |
| # Model-specific optimizations | |
| model = self._optimize_model(model = model, | |
| device = device, | |
| ) | |
| # Cache the model | |
| self._model_cache[cache_key] = model | |
| self._loaded_model = model | |
| self._model_name = model_name | |
| self._device = device | |
| # Log model info | |
| self._log_model_info(model = model, | |
| device = device, | |
| ) | |
| self.logger.info(f"Successfully loaded model: {model_name}") | |
| return model | |
| except Exception as e: | |
| self.logger.error(f"Failed to load model {model_name}: {repr(e)}") | |
| raise EmbeddingError(f"Model loading failed: {repr(e)}") | |
| def _resolve_device(self, device: Optional[str] = None) -> str: | |
| """ | |
| Resolve the best available device | |
| Arguments: | |
| ---------- | |
| device { str } : Requested device | |
| Returns: | |
| -------- | |
| { str } : Actual device to use | |
| """ | |
| if (device and (device != "auto")): | |
| return device | |
| # Auto device selection | |
| if (settings.EMBEDDING_DEVICE != "auto"): | |
| return settings.EMBEDDING_DEVICE | |
| # Automatic detection | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif torch.backends.mps.is_available(): | |
| return "mps" | |
| else: | |
| return "cpu" | |
| def _optimize_model(self, model: SentenceTransformer, device: str) -> SentenceTransformer: | |
| """ | |
| Apply optimizations to the model | |
| Arguments: | |
| ---------- | |
| model { SentenceTransformer } : Model to optimize | |
| device { str } : Device model is on | |
| Returns: | |
| -------- | |
| { SentenceTransformer } : Optimized model | |
| """ | |
| # Enable eval mode for inference | |
| model.eval() | |
| # GPU optimizations | |
| if (device == "cuda"): | |
| # Use half precision for GPU if supported | |
| try: | |
| model = model.half() | |
| self.logger.debug("Enabled half precision for GPU") | |
| except Exception as e: | |
| self.logger.warning(f"Could not enable half precision: {repr(e)}") | |
| # Disable gradient computation | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| return model | |
| def _log_model_info(self, model: SentenceTransformer, device: str): | |
| """ | |
| Log detailed model information | |
| Arguments: | |
| ---------- | |
| model { SentenceTransformer } : Model to log info for | |
| device { str } : Device model is on | |
| """ | |
| try: | |
| # Get model architecture info | |
| if hasattr(model, '_modules'): | |
| modules = list(model._modules.keys()) | |
| else: | |
| modules = ["unknown"] | |
| # Get embedding dimension | |
| if hasattr(model, 'get_sentence_embedding_dimension'): | |
| dimension = model.get_sentence_embedding_dimension() | |
| else: | |
| dimension = "unknown" | |
| # Count parameters | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| self.logger.info(f"Model Info: {len(modules)} modules, dimension={dimension}, parameters={total_params:,}, device={device}") | |
| except Exception as e: | |
| self.logger.debug(f"Could not get detailed model info: {repr(e)}") | |
| def get_loaded_model(self) -> Optional[SentenceTransformer]: | |
| """ | |
| Get currently loaded model | |
| Returns: | |
| -------- | |
| { SentenceTransformer } : Currently loaded model or None | |
| """ | |
| return self._loaded_model | |
| def get_model_info(self) -> dict: | |
| """ | |
| Get information about loaded model | |
| Returns: | |
| -------- | |
| { dict } : Model information dictionary | |
| """ | |
| if self._loaded_model is None: | |
| return {"loaded": False} | |
| info = {"loaded" : True, | |
| "model_name" : self._model_name, | |
| "device" : self._device, | |
| "cache_size" : len(self._model_cache), | |
| } | |
| try: | |
| if hasattr(self._loaded_model, 'get_sentence_embedding_dimension'): | |
| info["embedding_dimension"] = self._loaded_model.get_sentence_embedding_dimension() | |
| info["model_class"] = type(self._loaded_model).__name__ | |
| except Exception as e: | |
| self.logger.warning(f"Could not get detailed model info: {e}") | |
| return info | |
| def clear_cache(self, model_name: Optional[str] = None): | |
| """ | |
| Clear model cache | |
| Arguments: | |
| ---------- | |
| model_name { str } : Specific model to clear (None = all) | |
| """ | |
| if model_name: | |
| # Clear specific model from all devices | |
| keys_to_remove = [k for k in self._model_cache.keys() if k.startswith(model_name)] | |
| for key in keys_to_remove: | |
| del self._model_cache[key] | |
| self.logger.info(f"Cleared cache for model: {model_name}") | |
| else: | |
| # Clear all cache | |
| cache_size = len(self._model_cache) | |
| self._model_cache.clear() | |
| self.logger.info(f"Cleared all model cache ({cache_size} models)") | |
| def unload_model(self): | |
| """ | |
| Unload current model and free memory | |
| """ | |
| if self._loaded_model: | |
| model_name = self._model_name | |
| # Clear from cache | |
| if self._model_name and self._device: | |
| cache_key = f"{self._model_name}_{self._device}" | |
| self._model_cache.pop(cache_key, None) | |
| # Clear references | |
| self._loaded_model = None | |
| self._model_name = None | |
| self._device = None | |
| # Force garbage collection | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| self.logger.info(f"Unloaded model: {model_name}") | |
| # Global model loader instance | |
| _model_loader = None | |
| def get_model_loader() -> EmbeddingModelLoader: | |
| """ | |
| Get global model loader instance (singleton) | |
| Returns: | |
| -------- | |
| { EmbeddingModelLoader } : Model loader instance | |
| """ | |
| global _model_loader | |
| if _model_loader is None: | |
| _model_loader = EmbeddingModelLoader() | |
| return _model_loader | |
| def load_embedding_model(model_name: Optional[str] = None, device: Optional[str] = None) -> SentenceTransformer: | |
| """ | |
| Convenience function to load embedding model | |
| Arguments: | |
| ---------- | |
| model_name { str } : Model name | |
| device { str } : Device | |
| Returns: | |
| -------- | |
| { SentenceTransformer } : Loaded model | |
| """ | |
| loader = get_model_loader() | |
| return loader.load_model(model_name, device) |