QuerySphere / embeddings /model_loader.py
satyakimitra's picture
first commit
0a4529c
# DEPENDENCIES
import os
import gc
import torch
from typing import Optional
from config.settings import get_settings
from config.logging_config import get_logger
from utils.error_handler import handle_errors
from utils.error_handler import EmbeddingError
from sentence_transformers import SentenceTransformer
# Setup Settings and Logging
settings = get_settings()
logger = get_logger(__name__)
class EmbeddingModelLoader:
"""
Manages loading and caching of embedding models: Supports multiple models with efficient resource management
"""
def __init__(self):
self.logger = logger
self._loaded_model = None
self._model_name = None
self._device = None
# Model cache for multiple models
self._model_cache = dict()
@handle_errors(error_type = EmbeddingError, log_error = True, reraise = True)
def load_model(self, model_name: Optional[str] = None, device: Optional[str] = None, force_reload: bool = False) -> SentenceTransformer:
"""
Load embedding model with caching and device optimization
Arguments:
----------
model_name { str } : Name of model to load (default from settings)
device { str } : Device to load on ('cpu', 'cuda', 'mps', 'auto')
force_reload { bool } : Force reload even if model is cached
Returns:
--------
{ SentenceTransformer } : Loaded model instance
Raises:
-------
EmbeddingError : If model loading fails
"""
model_name = model_name or settings.EMBEDDING_MODEL
device = self._resolve_device(device)
# Check cache first
cache_key = f"{model_name}_{device}"
if ((not force_reload) and (cache_key in self._model_cache)):
self.logger.debug(f"Using cached model: {cache_key}")
self._loaded_model = self._model_cache[cache_key]
self._model_name = model_name
self._device = device
return self._loaded_model
try:
self.logger.info(f"Loading embedding model: {model_name} on device: {device}")
# Load model with optimized settings
model = SentenceTransformer(model_name,
device = device,
cache_folder = os.path.expanduser("~/.cache/sentence_transformers"),
)
# Model-specific optimizations
model = self._optimize_model(model = model,
device = device,
)
# Cache the model
self._model_cache[cache_key] = model
self._loaded_model = model
self._model_name = model_name
self._device = device
# Log model info
self._log_model_info(model = model,
device = device,
)
self.logger.info(f"Successfully loaded model: {model_name}")
return model
except Exception as e:
self.logger.error(f"Failed to load model {model_name}: {repr(e)}")
raise EmbeddingError(f"Model loading failed: {repr(e)}")
def _resolve_device(self, device: Optional[str] = None) -> str:
"""
Resolve the best available device
Arguments:
----------
device { str } : Requested device
Returns:
--------
{ str } : Actual device to use
"""
if (device and (device != "auto")):
return device
# Auto device selection
if (settings.EMBEDDING_DEVICE != "auto"):
return settings.EMBEDDING_DEVICE
# Automatic detection
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
def _optimize_model(self, model: SentenceTransformer, device: str) -> SentenceTransformer:
"""
Apply optimizations to the model
Arguments:
----------
model { SentenceTransformer } : Model to optimize
device { str } : Device model is on
Returns:
--------
{ SentenceTransformer } : Optimized model
"""
# Enable eval mode for inference
model.eval()
# GPU optimizations
if (device == "cuda"):
# Use half precision for GPU if supported
try:
model = model.half()
self.logger.debug("Enabled half precision for GPU")
except Exception as e:
self.logger.warning(f"Could not enable half precision: {repr(e)}")
# Disable gradient computation
for param in model.parameters():
param.requires_grad = False
return model
def _log_model_info(self, model: SentenceTransformer, device: str):
"""
Log detailed model information
Arguments:
----------
model { SentenceTransformer } : Model to log info for
device { str } : Device model is on
"""
try:
# Get model architecture info
if hasattr(model, '_modules'):
modules = list(model._modules.keys())
else:
modules = ["unknown"]
# Get embedding dimension
if hasattr(model, 'get_sentence_embedding_dimension'):
dimension = model.get_sentence_embedding_dimension()
else:
dimension = "unknown"
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
self.logger.info(f"Model Info: {len(modules)} modules, dimension={dimension}, parameters={total_params:,}, device={device}")
except Exception as e:
self.logger.debug(f"Could not get detailed model info: {repr(e)}")
def get_loaded_model(self) -> Optional[SentenceTransformer]:
"""
Get currently loaded model
Returns:
--------
{ SentenceTransformer } : Currently loaded model or None
"""
return self._loaded_model
def get_model_info(self) -> dict:
"""
Get information about loaded model
Returns:
--------
{ dict } : Model information dictionary
"""
if self._loaded_model is None:
return {"loaded": False}
info = {"loaded" : True,
"model_name" : self._model_name,
"device" : self._device,
"cache_size" : len(self._model_cache),
}
try:
if hasattr(self._loaded_model, 'get_sentence_embedding_dimension'):
info["embedding_dimension"] = self._loaded_model.get_sentence_embedding_dimension()
info["model_class"] = type(self._loaded_model).__name__
except Exception as e:
self.logger.warning(f"Could not get detailed model info: {e}")
return info
def clear_cache(self, model_name: Optional[str] = None):
"""
Clear model cache
Arguments:
----------
model_name { str } : Specific model to clear (None = all)
"""
if model_name:
# Clear specific model from all devices
keys_to_remove = [k for k in self._model_cache.keys() if k.startswith(model_name)]
for key in keys_to_remove:
del self._model_cache[key]
self.logger.info(f"Cleared cache for model: {model_name}")
else:
# Clear all cache
cache_size = len(self._model_cache)
self._model_cache.clear()
self.logger.info(f"Cleared all model cache ({cache_size} models)")
def unload_model(self):
"""
Unload current model and free memory
"""
if self._loaded_model:
model_name = self._model_name
# Clear from cache
if self._model_name and self._device:
cache_key = f"{self._model_name}_{self._device}"
self._model_cache.pop(cache_key, None)
# Clear references
self._loaded_model = None
self._model_name = None
self._device = None
# Force garbage collection
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.logger.info(f"Unloaded model: {model_name}")
# Global model loader instance
_model_loader = None
def get_model_loader() -> EmbeddingModelLoader:
"""
Get global model loader instance (singleton)
Returns:
--------
{ EmbeddingModelLoader } : Model loader instance
"""
global _model_loader
if _model_loader is None:
_model_loader = EmbeddingModelLoader()
return _model_loader
def load_embedding_model(model_name: Optional[str] = None, device: Optional[str] = None) -> SentenceTransformer:
"""
Convenience function to load embedding model
Arguments:
----------
model_name { str } : Model name
device { str } : Device
Returns:
--------
{ SentenceTransformer } : Loaded model
"""
loader = get_model_loader()
return loader.load_model(model_name, device)