| | |
| | |
| | |
| | import logging |
| | import os |
| | import torch |
| | from typing import Optional, Dict, Any |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel |
| | from sentence_transformers import SentenceTransformer |
| |
|
| | |
| | try: |
| | from huggingface_hub.exceptions import GatedRepoError |
| | from huggingface_hub import login as hf_login |
| | except ImportError: |
| | |
| | GatedRepoError = Exception |
| | hf_login = None |
| |
|
| | |
| | try: |
| | from .config import settings |
| | except ImportError: |
| | try: |
| | from config import settings |
| | except ImportError: |
| | settings = None |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | class LocalModelLoader: |
| | """ |
| | Loads and manages models locally on GPU for faster inference. |
| | Optimized for NVIDIA T4 Medium with 16GB VRAM using 4-bit quantization. |
| | """ |
| | |
| | def __init__(self, device: Optional[str] = None): |
| | """Initialize the model loader with GPU device detection.""" |
| | |
| | if device is None: |
| | if torch.cuda.is_available(): |
| | self.device = "cuda" |
| | self.device_name = torch.cuda.get_device_name(0) |
| | logger.info(f"GPU detected: {self.device_name}") |
| | logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") |
| | else: |
| | self.device = "cpu" |
| | self.device_name = "CPU" |
| | logger.warning("No GPU detected, using CPU") |
| | else: |
| | self.device = device |
| | self.device_name = device |
| | |
| | |
| | if settings: |
| | self.cache_dir = settings.hf_cache_dir |
| | self.hf_token = settings.hf_token |
| | else: |
| | |
| | self.cache_dir = os.getenv("HF_HOME") or os.getenv("TRANSFORMERS_CACHE") or "/tmp/huggingface" |
| | self.hf_token = os.getenv("HF_TOKEN", "") |
| | |
| | |
| | os.makedirs(self.cache_dir, exist_ok=True) |
| | |
| | |
| | if not os.getenv("HF_HOME"): |
| | os.environ["HF_HOME"] = self.cache_dir |
| | if not os.getenv("TRANSFORMERS_CACHE"): |
| | os.environ["TRANSFORMERS_CACHE"] = self.cache_dir |
| | |
| | logger.info(f"Cache directory: {self.cache_dir}") |
| | |
| | |
| | if self.hf_token and hf_login: |
| | try: |
| | hf_login(token=self.hf_token, add_to_git_credential=False) |
| | logger.info("✓ HF_TOKEN authenticated for gated model access") |
| | except Exception as e: |
| | logger.warning(f"HF_TOKEN login failed (may not be needed): {e}") |
| | |
| | |
| | self.loaded_models: Dict[str, Any] = {} |
| | self.loaded_tokenizers: Dict[str, Any] = {} |
| | self.loaded_embedding_models: Dict[str, Any] = {} |
| | |
| | def load_chat_model(self, model_id: str, load_in_8bit: bool = False, load_in_4bit: bool = False) -> tuple: |
| | """ |
| | Load a chat model and tokenizer on GPU. |
| | |
| | Args: |
| | model_id: HuggingFace model identifier |
| | load_in_8bit: Use 8-bit quantization (saves memory) |
| | load_in_4bit: Use 4-bit quantization (saves more memory) |
| | |
| | Returns: |
| | Tuple of (model, tokenizer) |
| | """ |
| | if model_id in self.loaded_models: |
| | logger.info(f"Model {model_id} already loaded, reusing") |
| | return self.loaded_models[model_id], self.loaded_tokenizers[model_id] |
| | |
| | try: |
| | logger.info(f"Loading model {model_id} on {self.device}...") |
| | |
| | |
| | |
| | base_model_id = model_id.split(':')[0] if ':' in model_id else model_id |
| | if base_model_id != model_id: |
| | logger.info(f"Stripping API suffix from {model_id}, using base model: {base_model_id}") |
| | |
| | |
| | |
| | try: |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | base_model_id, |
| | cache_dir=self.cache_dir, |
| | token=self.hf_token if self.hf_token else None, |
| | trust_remote_code=True |
| | ) |
| | except Exception as e: |
| | |
| | error_str = str(e).lower() |
| | if "gated" in error_str or "authorized" in error_str or "access" in error_str: |
| | |
| | try: |
| | from huggingface_hub.exceptions import GatedRepoError as RealGatedRepoError |
| | if isinstance(e, RealGatedRepoError): |
| | logger.error(f"❌ Gated Repository Error: Cannot access gated repo {base_model_id}") |
| | logger.error(f" Access to model {base_model_id} is restricted and you are not in the authorized list.") |
| | logger.error(f" Visit https://huggingface.co/{base_model_id} to request access.") |
| | logger.error(f" Error details: {e}") |
| | raise RealGatedRepoError( |
| | f"Cannot access gated repository {base_model_id}. " |
| | f"Visit https://huggingface.co/{base_model_id} to request access." |
| | ) from e |
| | except ImportError: |
| | pass |
| | |
| | |
| | raise |
| | |
| | |
| | if load_in_4bit and self.device == "cuda": |
| | try: |
| | from transformers import BitsAndBytesConfig |
| | quantization_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_compute_dtype=torch.float16, |
| | bnb_4bit_use_double_quant=True, |
| | bnb_4bit_quant_type="nf4" |
| | ) |
| | logger.info("Using 4-bit quantization") |
| | except ImportError: |
| | logger.warning("bitsandbytes not available, loading without quantization") |
| | quantization_config = None |
| | elif load_in_8bit and self.device == "cuda": |
| | try: |
| | quantization_config = {"load_in_8bit": True} |
| | logger.info("Using 8-bit quantization") |
| | except: |
| | quantization_config = None |
| | else: |
| | quantization_config = None |
| | |
| | |
| | |
| | load_kwargs = { |
| | "cache_dir": self.cache_dir, |
| | "token": self.hf_token if self.hf_token else None, |
| | "trust_remote_code": True |
| | } |
| | |
| | if self.device == "cuda": |
| | |
| | |
| | load_kwargs.update({ |
| | "torch_dtype": torch.float16, |
| | }) |
| | |
| | |
| | |
| | |
| | model = None |
| | quantization_failed = False |
| | |
| | if quantization_config and self.device == "cuda": |
| | try: |
| | if isinstance(quantization_config, dict): |
| | load_kwargs.update(quantization_config) |
| | else: |
| | load_kwargs["quantization_config"] = quantization_config |
| | |
| | |
| | load_kwargs["device_map"] = "auto" |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | base_model_id, |
| | **load_kwargs |
| | ) |
| | logger.info("✓ Model loaded with quantization") |
| | except (RuntimeError, ModuleNotFoundError, ImportError) as e: |
| | error_str = str(e).lower() |
| | |
| | if "bitsandbytes" in error_str or "int8_mm_dequant" in error_str or "validate_bnb_backend" in error_str: |
| | logger.warning(f"⚠ BitsAndBytes error detected: {e}") |
| | logger.warning("⚠ Falling back to loading without quantization") |
| | quantization_failed = True |
| | |
| | load_kwargs.pop("quantization_config", None) |
| | load_kwargs.pop("load_in_8bit", None) |
| | load_kwargs.pop("load_in_4bit", None) |
| | else: |
| | |
| | raise |
| | |
| | |
| | if model is None: |
| | try: |
| | if self.device == "cuda": |
| | |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | base_model_id, |
| | **load_kwargs |
| | ) |
| | |
| | model = model.to(self.device) |
| | logger.info(f"✓ Model loaded without quantization on {self.device}") |
| | else: |
| | load_kwargs.update({ |
| | "torch_dtype": torch.float32, |
| | }) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | base_model_id, |
| | **load_kwargs |
| | ) |
| | model = model.to(self.device) |
| | except Exception as e: |
| | |
| | error_str = str(e).lower() |
| | if "bitsandbytes" in error_str or "int8_mm_dequant" in error_str: |
| | |
| | logger.error(f"❌ Unexpected BitsAndBytes error: {e}") |
| | raise RuntimeError(f"BitsAndBytes compatibility issue: {e}") from e |
| | |
| | |
| | try: |
| | from huggingface_hub.exceptions import GatedRepoError as RealGatedRepoError |
| | if isinstance(e, RealGatedRepoError) or "gated" in error_str or "authorized" in error_str: |
| | logger.error(f"❌ Gated Repository Error: Cannot access gated repo {base_model_id}") |
| | logger.error(f" Access to model {base_model_id} is restricted and you are not in the authorized list.") |
| | logger.error(f" Visit https://huggingface.co/{base_model_id} to request access.") |
| | logger.error(f" Error details: {e}") |
| | raise RealGatedRepoError( |
| | f"Cannot access gated repository {base_model_id}. " |
| | f"Visit https://huggingface.co/{base_model_id} to request access." |
| | ) from e |
| | except ImportError: |
| | pass |
| | |
| | |
| | raise |
| | |
| | |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | |
| | |
| | self.loaded_models[model_id] = model |
| | self.loaded_tokenizers[model_id] = tokenizer |
| | |
| | |
| | if self.device == "cuda": |
| | allocated = torch.cuda.memory_allocated(0) / 1024**3 |
| | reserved = torch.cuda.memory_reserved(0) / 1024**3 |
| | logger.info(f"GPU Memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB") |
| | |
| | logger.info(f"✓ Model {model_id} (base: {base_model_id}) loaded successfully on {self.device}") |
| | return model, tokenizer |
| | |
| | except GatedRepoError: |
| | |
| | raise |
| | except Exception as e: |
| | logger.error(f"Error loading model {model_id}: {e}", exc_info=True) |
| | raise |
| | |
| | def load_embedding_model(self, model_id: str) -> SentenceTransformer: |
| | """ |
| | Load a sentence transformer model for embeddings. |
| | |
| | Args: |
| | model_id: HuggingFace model identifier |
| | |
| | Returns: |
| | SentenceTransformer model |
| | """ |
| | if model_id in self.loaded_embedding_models: |
| | logger.info(f"Embedding model {model_id} already loaded, reusing") |
| | return self.loaded_embedding_models[model_id] |
| | |
| | try: |
| | logger.info(f"Loading embedding model {model_id}...") |
| | |
| | |
| | base_model_id = model_id.split(':')[0] if ':' in model_id else model_id |
| | if base_model_id != model_id: |
| | logger.info(f"Stripping API suffix from {model_id}, using base model: {base_model_id}") |
| | |
| | |
| | |
| | |
| | try: |
| | model = SentenceTransformer( |
| | base_model_id, |
| | device=self.device |
| | ) |
| | except GatedRepoError as e: |
| | logger.error(f"❌ Gated Repository Error: Cannot access gated repo {base_model_id}") |
| | logger.error(f" Access to model {base_model_id} is restricted and you are not in the authorized list.") |
| | logger.error(f" Visit https://huggingface.co/{base_model_id} to request access.") |
| | logger.error(f" Error details: {e}") |
| | raise GatedRepoError( |
| | f"Cannot access gated repository {base_model_id}. " |
| | f"Visit https://huggingface.co/{base_model_id} to request access." |
| | ) from e |
| | |
| | |
| | self.loaded_embedding_models[model_id] = model |
| | |
| | logger.info(f"✓ Embedding model {model_id} (base: {base_model_id}) loaded successfully on {self.device}") |
| | return model |
| | |
| | except GatedRepoError: |
| | |
| | raise |
| | except Exception as e: |
| | logger.error(f"Error loading embedding model {model_id}: {e}", exc_info=True) |
| | raise |
| | |
| | def generate_text( |
| | self, |
| | model_id: str, |
| | prompt: str, |
| | max_tokens: int = 512, |
| | temperature: float = 0.7, |
| | **kwargs |
| | ) -> str: |
| | """ |
| | Generate text using a loaded chat model. |
| | |
| | Args: |
| | model_id: Model identifier |
| | prompt: Input prompt |
| | max_tokens: Maximum tokens to generate |
| | temperature: Sampling temperature |
| | |
| | Returns: |
| | Generated text |
| | """ |
| | if model_id not in self.loaded_models: |
| | raise ValueError(f"Model {model_id} not loaded. Call load_chat_model() first.") |
| | |
| | model = self.loaded_models[model_id] |
| | tokenizer = self.loaded_tokenizers[model_id] |
| | |
| | try: |
| | |
| | inputs = tokenizer(prompt, return_tensors="pt").to(self.device) |
| | |
| | |
| | generation_kwargs = { |
| | "max_new_tokens": max_tokens, |
| | "temperature": temperature, |
| | "do_sample": True, |
| | "pad_token_id": tokenizer.pad_token_id, |
| | "eos_token_id": tokenizer.eos_token_id, |
| | } |
| | |
| | |
| | |
| | if "phi" in model_id.lower() or "phi3" in model_id.lower() or "phi-3" in model_id.lower(): |
| | |
| | generation_kwargs["use_cache"] = False |
| | logger.debug(f"Using use_cache=False for Phi-3 model to avoid DynamicCache compatibility issues") |
| | |
| | |
| | generation_kwargs.update(kwargs) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | **inputs, |
| | **generation_kwargs |
| | ) |
| | |
| | |
| | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | |
| | |
| | if generated_text.startswith(prompt): |
| | generated_text = generated_text[len(prompt):].strip() |
| | |
| | return generated_text |
| | |
| | except AttributeError as e: |
| | |
| | if "seen_tokens" in str(e) or "DynamicCache" in str(e): |
| | logger.warning(f"DynamicCache compatibility issue detected ({e}), retrying without cache") |
| | try: |
| | |
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=max_tokens, |
| | temperature=temperature, |
| | do_sample=True, |
| | use_cache=False, |
| | pad_token_id=tokenizer.pad_token_id, |
| | eos_token_id=tokenizer.eos_token_id, |
| | **{k: v for k, v in kwargs.items() if k != "use_cache"} |
| | ) |
| | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | if generated_text.startswith(prompt): |
| | generated_text = generated_text[len(prompt):].strip() |
| | logger.info("✓ Generation successful after DynamicCache workaround") |
| | return generated_text |
| | except Exception as retry_error: |
| | logger.error(f"Retry without cache also failed: {retry_error}", exc_info=True) |
| | raise RuntimeError(f"Generation failed even with cache disabled: {retry_error}") from retry_error |
| | |
| | raise |
| | except Exception as e: |
| | logger.error(f"Error generating text: {e}", exc_info=True) |
| | raise |
| | |
| | def generate_chat_completion( |
| | self, |
| | model_id: str, |
| | messages: list, |
| | max_tokens: int = 512, |
| | temperature: float = 0.7, |
| | **kwargs |
| | ) -> str: |
| | """ |
| | Generate chat completion using a loaded model. |
| | |
| | Args: |
| | model_id: Model identifier |
| | messages: List of message dicts with 'role' and 'content' |
| | max_tokens: Maximum tokens to generate |
| | temperature: Sampling temperature |
| | |
| | Returns: |
| | Generated response |
| | """ |
| | if model_id not in self.loaded_models: |
| | raise ValueError(f"Model {model_id} not loaded. Call load_chat_model() first.") |
| | |
| | model = self.loaded_models[model_id] |
| | tokenizer = self.loaded_tokenizers[model_id] |
| | |
| | try: |
| | |
| | if hasattr(tokenizer, 'apply_chat_template'): |
| | |
| | prompt = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True |
| | ) |
| | else: |
| | |
| | prompt = "\n".join([ |
| | f"{msg['role']}: {msg['content']}" |
| | for msg in messages |
| | ]) + "\nassistant: " |
| | |
| | |
| | return self.generate_text( |
| | model_id=model_id, |
| | prompt=prompt, |
| | max_tokens=max_tokens, |
| | temperature=temperature, |
| | **kwargs |
| | ) |
| | |
| | except Exception as e: |
| | logger.error(f"Error generating chat completion: {e}", exc_info=True) |
| | raise |
| | |
| | def get_embedding(self, model_id: str, text: str) -> list: |
| | """ |
| | Get embedding vector for text. |
| | |
| | Args: |
| | model_id: Embedding model identifier |
| | text: Input text |
| | |
| | Returns: |
| | Embedding vector |
| | """ |
| | if model_id not in self.loaded_embedding_models: |
| | raise ValueError(f"Embedding model {model_id} not loaded. Call load_embedding_model() first.") |
| | |
| | model = self.loaded_embedding_models[model_id] |
| | |
| | try: |
| | embedding = model.encode(text, convert_to_numpy=True) |
| | return embedding.tolist() |
| | except Exception as e: |
| | logger.error(f"Error getting embedding: {e}", exc_info=True) |
| | raise |
| | |
| | def clear_cache(self): |
| | """Clear all loaded models from memory.""" |
| | logger.info("Clearing model cache...") |
| | |
| | |
| | for model_id in list(self.loaded_models.keys()): |
| | del self.loaded_models[model_id] |
| | for model_id in list(self.loaded_tokenizers.keys()): |
| | del self.loaded_tokenizers[model_id] |
| | for model_id in list(self.loaded_embedding_models.keys()): |
| | del self.loaded_embedding_models[model_id] |
| | |
| | |
| | if self.device == "cuda": |
| | torch.cuda.empty_cache() |
| | |
| | logger.info("✓ Model cache cleared") |
| | |
| | def get_memory_usage(self) -> Dict[str, float]: |
| | """Get current GPU memory usage in GB.""" |
| | if self.device != "cuda": |
| | return {"device": "cpu", "gpu_available": False} |
| | |
| | return { |
| | "device": self.device_name, |
| | "gpu_available": True, |
| | "allocated_gb": torch.cuda.memory_allocated(0) / 1024**3, |
| | "reserved_gb": torch.cuda.memory_reserved(0) / 1024**3, |
| | "total_gb": torch.cuda.get_device_properties(0).total_memory / 1024**3 |
| | } |
| |
|
| |
|