""" Model Manager for handling local model loading and inference. """ import os import sys import torch from typing import Dict, Any, Optional, List, Union from transformers import ( AutoModelForCausalLM, AutoTokenizer, AutoModel, pipeline, BitsAndBytesConfig ) from sentence_transformers import SentenceTransformer # Add parent directories to path current_dir = os.path.dirname(os.path.abspath(__file__)) src_dir = os.path.dirname(current_dir) root_dir = os.path.dirname(src_dir) for path in [root_dir, src_dir]: if path not in sys.path: sys.path.insert(0, path) try: from config.model_config import get_model_path, get_model_config, ModelConfig except ImportError: # Fallback import paths try: from src.config.model_config import get_model_path, get_model_config, ModelConfig except ImportError: # Define minimal fallback config from dataclasses import dataclass from typing import Literal ModelType = Literal["text-generation", "text-embedding", "vision", "multimodal"] DeviceType = Literal["auto", "cpu", "cuda"] @dataclass class ModelConfig: model_id: str model_path: str model_type: ModelType device: DeviceType = "auto" quantize: bool = True use_safetensors: bool = True trust_remote_code: bool = True description: str = "" size_gb: float = 0.0 recommended: bool = False # Fallback model configurations DEFAULT_MODELS = { "tiny-llama": ModelConfig( model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0", model_path="./models/tiny-llama-1.1b-chat", model_type="text-generation", quantize=False, # Disable quantization for HF spaces description="Very small and fast model, good for quick testing", size_gb=1.1, recommended=True ), "mistral-7b": ModelConfig( model_id="microsoft/DialoGPT-small", # Use smaller model for HF spaces model_path="./models/dialogpt-small", model_type="text-generation", quantize=False, description="Small conversational model", size_gb=0.5, recommended=True ) } def get_model_config(model_name: str) -> Optional[ModelConfig]: return DEFAULT_MODELS.get(model_name) def get_model_path(model_name: str) -> str: config = get_model_config(model_name) if not config: raise ValueError(f"Unknown model: {model_name}") # For HF spaces, use the model_id directly return config.model_id class ModelManager: """Manages loading and using local models.""" def __init__(self, device: str = None): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.models: Dict[str, Any] = {} self.tokenizers: Dict[str, Any] = {} self.pipelines: Dict[str, Any] = {} def load_model(self, model_name: str, **kwargs) -> Any: """Load a model by name.""" if model_name in self.models: return self.models[model_name] config = get_model_config(model_name) if not config: raise ValueError(f"Unknown model: {model_name}") model_path = get_model_path(model_name) # Load model based on type if config.model_type == "text-generation": model = self._load_text_generation_model(model_name, config, **kwargs) elif config.model_type == "text-embedding": model = self._load_embedding_model(model_name, config, **kwargs) else: raise ValueError(f"Unsupported model type: {config.model_type}") self.models[model_name] = model return model def _load_text_generation_model(self, model_name: str, config: ModelConfig, **kwargs): """Load a text generation model.""" model_path = get_model_path(model_name) try: # Try to load with quantization if supported if config.quantize and "gptq" in config.model_id.lower(): try: from auto_gptq import AutoGPTQForCausalLM model = AutoGPTQForCausalLM.from_quantized( model_path, device_map="auto", trust_remote_code=config.trust_remote_code, use_safetensors=config.use_safetensors, **kwargs ) except ImportError: # Fallback to regular loading if auto_gptq is not available model = AutoModelForCausalLM.from_pretrained( model_path, device_map="auto" if torch.cuda.is_available() else "cpu", trust_remote_code=config.trust_remote_code, **kwargs ) else: # Load full precision model model = AutoModelForCausalLM.from_pretrained( model_path, device_map="auto" if torch.cuda.is_available() else "cpu", trust_remote_code=config.trust_remote_code, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, **kwargs ) except Exception as e: print(f"Error loading model {model_name}: {e}") # Fallback to CPU loading model = AutoModelForCausalLM.from_pretrained( model_path, device_map="cpu", trust_remote_code=config.trust_remote_code, torch_dtype=torch.float32, **kwargs ) # Load tokenizer try: tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=config.trust_remote_code ) except Exception as e: print(f"Error loading tokenizer for {model_name}: {e}") # Use a basic tokenizer as fallback tokenizer = AutoTokenizer.from_pretrained("gpt2") self.tokenizers[model_name] = tokenizer return model def _load_embedding_model(self, model_name: str, config: ModelConfig, **kwargs): """Load a text embedding model.""" model_path = get_model_path(model_name) # Use sentence-transformers for embedding models model = SentenceTransformer( model_path, device=self.device, **kwargs ) return model def generate_text( self, model_name: str, prompt: str, max_length: int = 512, temperature: float = 0.7, **generation_kwargs ) -> str: """Generate text using the specified model.""" if model_name not in self.models: self.load_model(model_name) model = self.models[model_name] tokenizer = self.tokenizers[model_name] # Encode the input inputs = tokenizer(prompt, return_tensors="pt").to(self.device) # Generate text with torch.no_grad(): outputs = model.generate( **inputs, max_length=max_length, temperature=temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id, **generation_kwargs ) # Decode and return the generated text generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_text def get_embeddings( self, model_name: str, texts: Union[str, List[str]], batch_size: int = 32, **kwargs ) -> torch.Tensor: """Get embeddings for the input texts.""" if model_name not in self.models: self.load_model(model_name) model = self.models[model_name] # Get embeddings using sentence-transformers if isinstance(texts, str): texts = [texts] embeddings = model.encode( texts, batch_size=batch_size, show_progress_bar=len(texts) > 1, convert_to_tensor=True, **kwargs ) return embeddings # Global model manager instance model_manager = ModelManager()