Spaces:
Sleeping
Sleeping
| """ | |
| Model Manager for handling local model loading and inference. | |
| """ | |
| import os | |
| import sys | |
| import torch | |
| from typing import Dict, Any, Optional, List, Union | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| AutoModel, | |
| pipeline, | |
| BitsAndBytesConfig | |
| ) | |
| from sentence_transformers import SentenceTransformer | |
| # Add parent directories to path | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| src_dir = os.path.dirname(current_dir) | |
| root_dir = os.path.dirname(src_dir) | |
| for path in [root_dir, src_dir]: | |
| if path not in sys.path: | |
| sys.path.insert(0, path) | |
| try: | |
| from config.model_config import get_model_path, get_model_config, ModelConfig | |
| except ImportError: | |
| # Fallback import paths | |
| try: | |
| from src.config.model_config import get_model_path, get_model_config, ModelConfig | |
| except ImportError: | |
| # Define minimal fallback config | |
| from dataclasses import dataclass | |
| from typing import Literal | |
| ModelType = Literal["text-generation", "text-embedding", "vision", "multimodal"] | |
| DeviceType = Literal["auto", "cpu", "cuda"] | |
| class ModelConfig: | |
| model_id: str | |
| model_path: str | |
| model_type: ModelType | |
| device: DeviceType = "auto" | |
| quantize: bool = True | |
| use_safetensors: bool = True | |
| trust_remote_code: bool = True | |
| description: str = "" | |
| size_gb: float = 0.0 | |
| recommended: bool = False | |
| # Fallback model configurations | |
| DEFAULT_MODELS = { | |
| "tiny-llama": ModelConfig( | |
| model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| model_path="./models/tiny-llama-1.1b-chat", | |
| model_type="text-generation", | |
| quantize=False, # Disable quantization for HF spaces | |
| description="Very small and fast model, good for quick testing", | |
| size_gb=1.1, | |
| recommended=True | |
| ), | |
| "mistral-7b": ModelConfig( | |
| model_id="microsoft/DialoGPT-small", # Use smaller model for HF spaces | |
| model_path="./models/dialogpt-small", | |
| model_type="text-generation", | |
| quantize=False, | |
| description="Small conversational model", | |
| size_gb=0.5, | |
| recommended=True | |
| ) | |
| } | |
| def get_model_config(model_name: str) -> Optional[ModelConfig]: | |
| return DEFAULT_MODELS.get(model_name) | |
| def get_model_path(model_name: str) -> str: | |
| config = get_model_config(model_name) | |
| if not config: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| # For HF spaces, use the model_id directly | |
| return config.model_id | |
| class ModelManager: | |
| """Manages loading and using local models.""" | |
| def __init__(self, device: str = None): | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.models: Dict[str, Any] = {} | |
| self.tokenizers: Dict[str, Any] = {} | |
| self.pipelines: Dict[str, Any] = {} | |
| def load_model(self, model_name: str, **kwargs) -> Any: | |
| """Load a model by name.""" | |
| if model_name in self.models: | |
| return self.models[model_name] | |
| config = get_model_config(model_name) | |
| if not config: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| model_path = get_model_path(model_name) | |
| # Load model based on type | |
| if config.model_type == "text-generation": | |
| model = self._load_text_generation_model(model_name, config, **kwargs) | |
| elif config.model_type == "text-embedding": | |
| model = self._load_embedding_model(model_name, config, **kwargs) | |
| else: | |
| raise ValueError(f"Unsupported model type: {config.model_type}") | |
| self.models[model_name] = model | |
| return model | |
| def _load_text_generation_model(self, model_name: str, config: ModelConfig, **kwargs): | |
| """Load a text generation model.""" | |
| model_path = get_model_path(model_name) | |
| try: | |
| # Try to load with quantization if supported | |
| if config.quantize and "gptq" in config.model_id.lower(): | |
| try: | |
| from auto_gptq import AutoGPTQForCausalLM | |
| model = AutoGPTQForCausalLM.from_quantized( | |
| model_path, | |
| device_map="auto", | |
| trust_remote_code=config.trust_remote_code, | |
| use_safetensors=config.use_safetensors, | |
| **kwargs | |
| ) | |
| except ImportError: | |
| # Fallback to regular loading if auto_gptq is not available | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto" if torch.cuda.is_available() else "cpu", | |
| trust_remote_code=config.trust_remote_code, | |
| **kwargs | |
| ) | |
| else: | |
| # Load full precision model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto" if torch.cuda.is_available() else "cpu", | |
| trust_remote_code=config.trust_remote_code, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| **kwargs | |
| ) | |
| except Exception as e: | |
| print(f"Error loading model {model_name}: {e}") | |
| # Fallback to CPU loading | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="cpu", | |
| trust_remote_code=config.trust_remote_code, | |
| torch_dtype=torch.float32, | |
| **kwargs | |
| ) | |
| # Load tokenizer | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| trust_remote_code=config.trust_remote_code | |
| ) | |
| except Exception as e: | |
| print(f"Error loading tokenizer for {model_name}: {e}") | |
| # Use a basic tokenizer as fallback | |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| self.tokenizers[model_name] = tokenizer | |
| return model | |
| def _load_embedding_model(self, model_name: str, config: ModelConfig, **kwargs): | |
| """Load a text embedding model.""" | |
| model_path = get_model_path(model_name) | |
| # Use sentence-transformers for embedding models | |
| model = SentenceTransformer( | |
| model_path, | |
| device=self.device, | |
| **kwargs | |
| ) | |
| return model | |
| def generate_text( | |
| self, | |
| model_name: str, | |
| prompt: str, | |
| max_length: int = 512, | |
| temperature: float = 0.7, | |
| **generation_kwargs | |
| ) -> str: | |
| """Generate text using the specified model.""" | |
| if model_name not in self.models: | |
| self.load_model(model_name) | |
| model = self.models[model_name] | |
| tokenizer = self.tokenizers[model_name] | |
| # Encode the input | |
| inputs = tokenizer(prompt, return_tensors="pt").to(self.device) | |
| # Generate text | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| temperature=temperature, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| **generation_kwargs | |
| ) | |
| # Decode and return the generated text | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return generated_text | |
| def get_embeddings( | |
| self, | |
| model_name: str, | |
| texts: Union[str, List[str]], | |
| batch_size: int = 32, | |
| **kwargs | |
| ) -> torch.Tensor: | |
| """Get embeddings for the input texts.""" | |
| if model_name not in self.models: | |
| self.load_model(model_name) | |
| model = self.models[model_name] | |
| # Get embeddings using sentence-transformers | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| embeddings = model.encode( | |
| texts, | |
| batch_size=batch_size, | |
| show_progress_bar=len(texts) > 1, | |
| convert_to_tensor=True, | |
| **kwargs | |
| ) | |
| return embeddings | |
| # Global model manager instance | |
| model_manager = ModelManager() | |