Spaces:
Sleeping
Sleeping
Patryk Studzinski
refactor: enhance model unloading and memory management for improved GPU efficiency
371aac9
| """ | |
| GPU-optimized Transformers implementation using bitsandbytes quantization. | |
| Automatically offloads to GPU if available, falls back to CPU gracefully. | |
| """ | |
| import os | |
| import asyncio | |
| import traceback | |
| from typing import List, Dict, Any, Optional | |
| from app.models.base_llm import BaseLLM | |
| try: | |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| HAS_TRANSFORMERS = True | |
| except ImportError: | |
| HAS_TRANSFORMERS = False | |
| try: | |
| import bitsandbytes as bnb | |
| HAS_BITSANDBYTES = True | |
| except ImportError: | |
| HAS_BITSANDBYTES = False | |
| import torch | |
| class TransformersModel(BaseLLM): | |
| """ | |
| Wrapper for HuggingFace Transformers models with GPU acceleration. | |
| Supports 8-bit quantization via bitsandbytes for memory efficiency. | |
| Automatically detects and uses GPU if available. | |
| """ | |
| def __init__(self, name: str, model_id: str, use_8bit: bool = True, device_map: str = "auto", enable_cpu_offload: bool = False): | |
| super().__init__(name, model_id) | |
| self.use_8bit = use_8bit | |
| self.device_map = device_map | |
| env_cpu_offload = os.getenv("TRANSFORMERS_ENABLE_CPU_OFFLOAD", "").strip().lower() in ("1", "true", "yes", "on") | |
| self.enable_cpu_offload = enable_cpu_offload or env_cpu_offload | |
| self.offload_dir = os.getenv("HF_OFFLOAD_DIR", "/tmp/hf-offload") | |
| self.pipeline = None | |
| self.tokenizer = None | |
| self.model = None | |
| self._response_cache = {} | |
| self._max_cache_size = 100 | |
| if not HAS_TRANSFORMERS: | |
| raise ImportError("transformers is not installed. Cannot use Transformers models.") | |
| async def initialize(self) -> None: | |
| """Load model with GPU optimization.""" | |
| if self._initialized: | |
| return | |
| try: | |
| print(f"[{self.name}] Initializing Transformers model: {self.model_id}") | |
| print(f"[{self.name}] Device map: {self.device_map}, 8-bit quantization: {self.use_8bit}") | |
| # Load in thread to avoid blocking event loop | |
| await asyncio.to_thread(self._load_model) | |
| self._initialized = True | |
| print(f"[{self.name}] Transformers Model loaded successfully") | |
| except Exception as e: | |
| error_msg = str(e) if str(e) else repr(e) | |
| print(f"[{self.name}] Failed to load Transformers model: {error_msg}") | |
| traceback.print_exc() | |
| raise RuntimeError(f"Failed to load Transformers model: {error_msg}") from e | |
| def _load_model(self) -> None: | |
| """Load model with optimal device configuration and quantization support.""" | |
| import gc | |
| # Set PyTorch environment variables for optimal memory management | |
| if not os.getenv("PYTORCH_CUDA_ALLOC_CONF"): | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| print(f"[{self.name}] Set PYTORCH_CUDA_ALLOC_CONF to prevent GPU memory fragmentation") | |
| # Force garbage collection before loading new model | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Check GPU availability with detailed diagnostics | |
| cuda_available = torch.cuda.is_available() | |
| cuda_device_count = torch.cuda.device_count() if cuda_available else 0 | |
| device = "cuda" if cuda_available else "cpu" | |
| print(f"[{self.name}] === MODEL LOADING DIAGNOSTICS ===") | |
| print(f"[{self.name}] torch.cuda.is_available(): {cuda_available}") | |
| print(f"[{self.name}] torch.cuda.device_count(): {cuda_device_count}") | |
| if cuda_available: | |
| try: | |
| print(f"[{self.name}] Current CUDA device: {torch.cuda.current_device()}") | |
| print(f"[{self.name}] CUDA device name: {torch.cuda.get_device_name(0)}") | |
| except: | |
| pass | |
| print(f"[{self.name}] ===================================") | |
| print(f"[{self.name}] Loading model: {self.model_id}") | |
| print(f"[{self.name}] Device to use: {device}") | |
| print(f"[{self.name}] Device map: {self.device_map}") | |
| print(f"[{self.name}] 8-bit quantization requested: {self.use_8bit}") | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
| # Use float16 for GPU, float32 for CPU | |
| dtype = torch.float16 if cuda_available else torch.float32 | |
| is_large_model = "11b" in self.model_id.lower() or "11b" in self.name.lower() | |
| cpu_offload_enabled = self.enable_cpu_offload or is_large_model | |
| # Build model kwargs conditionally based on quantization setting | |
| model_kwargs = { | |
| "trust_remote_code": True, | |
| "torch_dtype": dtype, | |
| } | |
| # Apply 8-bit quantization if requested, available, and GPU is present | |
| if self.use_8bit and HAS_BITSANDBYTES and cuda_available: | |
| try: | |
| print(f"[{self.name}] Using 8-bit quantization for memory efficiency") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| bnb_8bit_compute_dtype=torch.float16, | |
| llm_int8_enable_fp32_cpu_offload=cpu_offload_enabled, | |
| ) | |
| model_kwargs["quantization_config"] = bnb_config | |
| model_kwargs["device_map"] = "auto" | |
| if cpu_offload_enabled: | |
| os.makedirs(self.offload_dir, exist_ok=True) | |
| model_kwargs["offload_folder"] = self.offload_dir | |
| except Exception as e: | |
| print(f"[{self.name}] Failed to setup 8-bit quantization: {e}") | |
| print(f"[{self.name}] Falling back to full precision") | |
| self.use_8bit = False | |
| model_kwargs["device_map"] = self.device_map | |
| elif self.use_8bit and not cuda_available: | |
| # 8-bit quantization requested but no GPU available - fall back to full precision | |
| print(f"[{self.name}] WARNING: 8-bit quantization requested but no GPU available") | |
| print(f"[{self.name}] Falling back to full precision on CPU (model may be very slow)") | |
| self.use_8bit = False | |
| model_kwargs["device_map"] = "cpu" | |
| else: | |
| # No quantization - use explicit device mapping | |
| if not self.use_8bit and self.use_8bit is not None: | |
| print(f"[{self.name}] bitsandbytes not available or quantization disabled - using full precision") | |
| # For large models without quantization, be more careful with device mapping | |
| if "11b" in self.model_id.lower() and not self.use_8bit and cuda_available: | |
| print(f"[{self.name}] WARNING: Loading large 11B model without quantization on GPU") | |
| print(f"[{self.name}] WARNING: This may cause out-of-memory errors on 16GB GPUs") | |
| print(f"[{self.name}] WARNING: Consider enabling use_8bit=True in registry.py") | |
| # Use CPU offloading for safety | |
| model_kwargs["device_map"] = "cpu" | |
| else: | |
| model_kwargs["device_map"] = self.device_map | |
| try: | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_id, | |
| **model_kwargs | |
| ) | |
| except ValueError as e: | |
| error_text = str(e) | |
| should_retry_with_offload = ( | |
| self.use_8bit | |
| and HAS_BITSANDBYTES | |
| and cuda_available | |
| and "dispatched on the cpu or the disk" in error_text.lower() | |
| ) | |
| if not should_retry_with_offload: | |
| raise | |
| print(f"[{self.name}] Retrying load with explicit fp32 CPU offload") | |
| os.makedirs(self.offload_dir, exist_ok=True) | |
| retry_kwargs = dict(model_kwargs) | |
| retry_kwargs["quantization_config"] = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| bnb_8bit_compute_dtype=torch.float16, | |
| llm_int8_enable_fp32_cpu_offload=True, | |
| ) | |
| retry_kwargs["device_map"] = "auto" | |
| retry_kwargs["offload_folder"] = self.offload_dir | |
| try: | |
| total_mem = torch.cuda.get_device_properties(0).total_memory | |
| gpu_gib = max(1, int((total_mem / (1024 ** 3)) * 0.9)) | |
| retry_kwargs["max_memory"] = {0: f"{gpu_gib}GiB", "cpu": "64GiB"} | |
| except Exception: | |
| pass | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_id, | |
| **retry_kwargs | |
| ) | |
| # Log final state | |
| model_device = next(self.model.parameters()).device | |
| quantization_status = "8-bit quantized" if self.use_8bit else "full precision" | |
| print(f"[{self.name}] Model loaded successfully") | |
| print(f"[{self.name}] Dtype: {self.model.dtype} | Quantization: {quantization_status}") | |
| print(f"[{self.name}] Device: {model_device}") | |
| async def generate( | |
| self, | |
| prompt: str = None, | |
| chat_messages: List[Dict[str, str]] = None, | |
| max_new_tokens: int = 150, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| grammar: str = None, | |
| **kwargs | |
| ) -> str: | |
| """Generate text using Transformers pipeline. | |
| Note: grammar parameter is ignored (Transformers doesn't support GBNF). | |
| Use stricter prompt engineering instead. | |
| """ | |
| if not self._initialized or self.model is None: | |
| raise RuntimeError(f"[{self.name}] Model not initialized") | |
| # Build prompt from messages | |
| prompt_text = self._build_prompt_from_messages(chat_messages) if chat_messages else prompt | |
| if not prompt_text: | |
| raise ValueError("Either prompt or chat_messages required") | |
| # Cache Check | |
| import json | |
| cache_key = f"{json.dumps(chat_messages or prompt_text)}_{max_new_tokens}_{temperature}_{top_p}" | |
| if cache_key in self._response_cache: | |
| return self._response_cache[cache_key] | |
| print(f"DEBUG: Generating with Transformers model", flush=True) | |
| if grammar: | |
| print(f"DEBUG: Note - GBNF grammar not supported in Transformers, using prompt engineering instead", flush=True) | |
| # Generate in thread to avoid blocking | |
| response_text = await asyncio.to_thread( | |
| self._generate_text, | |
| prompt_text, | |
| max_new_tokens, | |
| temperature, | |
| top_p | |
| ) | |
| # Cache Store | |
| if len(self._response_cache) >= self._max_cache_size: | |
| first_key = next(iter(self._response_cache)) | |
| del self._response_cache[first_key] | |
| self._response_cache[cache_key] = response_text | |
| print(f"DEBUG: Extracted text: {response_text[:200]}", flush=True) | |
| return response_text | |
| def _build_prompt_from_messages(self, messages: List[Dict[str, str]]) -> str: | |
| """Convert chat messages to prompt using Bielik's chat template.""" | |
| # Bielik uses: <|im_start|>role\ncontent<|im_end|>\n | |
| prompt_parts = [] | |
| for msg in messages: | |
| role = msg.get("role", "user") | |
| content = msg.get("content", "") | |
| prompt_parts.append(f"<|im_start|>{role}\n{content}<|im_end|>\n") | |
| # Add assistant start token for generation | |
| prompt_parts.append("<|im_start|>assistant\n") | |
| return "".join(prompt_parts) | |
| def _generate_text( | |
| self, | |
| prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float | |
| ) -> str: | |
| """Internal method to generate text (called in thread).""" | |
| # Tokenize input | |
| inputs = self.tokenizer(prompt, return_tensors="pt") | |
| # Move to same device as model if using CPU | |
| if next(self.model.parameters()).device.type == "cpu": | |
| inputs = {k: v.to("cpu") for k, v in inputs.items()} | |
| else: | |
| inputs = {k: v.to(next(self.model.parameters()).device) for k, v in inputs.items()} | |
| # Generate with optimized settings for better quality and speed | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| use_cache=False, # Disabled: KV cache causes degradation after ~50 requests | |
| num_beams=1, # Greedy decoding is fastest (can adjust for quality) | |
| ) | |
| # Decode - skip prompt tokens | |
| generated_text = self.tokenizer.decode( | |
| outputs[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=True | |
| ) | |
| # Clear GPU cache to prevent memory accumulation and degradation | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return generated_text.strip() | |
| def get_info(self) -> Dict[str, Any]: | |
| """Return model information for /models endpoint.""" | |
| device = "unknown" | |
| dtype = "unknown" | |
| if self.model: | |
| device = str(next(self.model.parameters()).device) | |
| dtype = str(self.model.dtype) | |
| return { | |
| "name": self.name, | |
| "model_id": self.model_id, | |
| "type": "transformers", | |
| "backend": "huggingface-transformers", | |
| "loaded": self._initialized, | |
| "device": device, | |
| "dtype": dtype, | |
| "optimization": "float16, KV cache disabled (prevents degradation), 8-bit quantization", | |
| "note": "KV cache disabled to prevent quality degradation after 50+ requests" | |
| } | |
| async def cleanup(self) -> None: | |
| """Free memory.""" | |
| import gc | |
| if self.model: | |
| del self.model | |
| self.model = None | |
| if self.tokenizer: | |
| del self.tokenizer | |
| self.tokenizer = None | |
| self._initialized = False | |
| # Aggressive cleanup | |
| gc.collect() # Force garbage collection | |
| # Clear CUDA cache if available | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| try: | |
| # Empty reserved memory too (PyTorch 2.0+) | |
| device_id = torch.cuda.current_device() | |
| torch.cuda.reset_peak_memory_stats(device_id) | |
| except: | |
| pass | |
| print(f"[{self.name}] Transformers Model unloaded and memory freed") | |