Spaces:
Sleeping
Sleeping
| """Model loading, caching, and memory management for ZeroGPU inference.""" | |
| import gc | |
| import logging | |
| from dataclasses import dataclass, field | |
| from typing import Optional, Generator, Any | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| TextIteratorStreamer, | |
| ) | |
| from threading import Thread | |
| from config import get_config, should_quantize | |
| logger = logging.getLogger(__name__) | |
| class LoadedModel: | |
| """Container for a loaded model and its tokenizer.""" | |
| model_id: str | |
| model: Any | |
| tokenizer: Any | |
| quantization: str = "none" | |
| # Global model cache (single model at a time due to memory constraints) | |
| _current_model: Optional[LoadedModel] = None | |
| def get_quantization_config(quantization: str) -> Optional[BitsAndBytesConfig]: | |
| """Get BitsAndBytes configuration for the specified quantization level.""" | |
| if quantization == "int8": | |
| return BitsAndBytesConfig(load_in_8bit=True) | |
| elif quantization == "int4": | |
| return BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| ) | |
| return None | |
| def clear_gpu_memory() -> None: | |
| """Clear GPU memory by running garbage collection and emptying CUDA cache.""" | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| logger.debug("GPU memory cleared") | |
| def unload_model() -> None: | |
| """Unload the currently loaded model and free memory.""" | |
| global _current_model | |
| if _current_model is not None: | |
| logger.info(f"Unloading model: {_current_model.model_id}") | |
| del _current_model.model | |
| del _current_model.tokenizer | |
| _current_model = None | |
| clear_gpu_memory() | |
| def load_model( | |
| model_id: str, | |
| quantization: Optional[str] = None, | |
| force_reload: bool = False, | |
| ) -> LoadedModel: | |
| """ | |
| Load a model from HuggingFace Hub. | |
| Args: | |
| model_id: HuggingFace model ID (e.g., "meta-llama/Llama-3.1-8B-Instruct") | |
| quantization: Force specific quantization ("none", "int8", "int4") | |
| If None, auto-determine based on model size | |
| force_reload: If True, reload even if already loaded | |
| Returns: | |
| LoadedModel with model and tokenizer ready for inference | |
| Raises: | |
| ValueError: If model_id is None or empty | |
| """ | |
| global _current_model | |
| if not model_id: | |
| raise ValueError("model_id cannot be None or empty") | |
| # Check if already loaded | |
| if not force_reload and _current_model is not None: | |
| if _current_model.model_id == model_id: | |
| logger.debug(f"Model already loaded: {model_id}") | |
| return _current_model | |
| # Determine quantization | |
| if quantization is None: | |
| quantization = should_quantize(model_id) | |
| logger.info(f"Loading model: {model_id} (quantization: {quantization})") | |
| # Unload current model first | |
| unload_model() | |
| config = get_config() | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| token=config.hf_token, | |
| trust_remote_code=True, | |
| ) | |
| # Ensure tokenizer has pad token | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load model with appropriate configuration | |
| quant_config = get_quantization_config(quantization) | |
| model_kwargs = { | |
| "token": config.hf_token, | |
| "trust_remote_code": True, | |
| } | |
| # On ZeroGPU, use device_map only when GPU is available | |
| # Otherwise load to CPU for local testing | |
| if torch.cuda.is_available(): | |
| model_kwargs["device_map"] = "auto" | |
| if quant_config is not None: | |
| model_kwargs["quantization_config"] = quant_config | |
| else: | |
| model_kwargs["torch_dtype"] = torch.bfloat16 | |
| else: | |
| # CPU mode - no quantization, float32 | |
| model_kwargs["device_map"] = "cpu" | |
| model_kwargs["torch_dtype"] = torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) | |
| _current_model = LoadedModel( | |
| model_id=model_id, | |
| model=model, | |
| tokenizer=tokenizer, | |
| quantization=quantization, | |
| ) | |
| logger.info(f"Model loaded successfully: {model_id}") | |
| return _current_model | |
| def get_current_model() -> Optional[LoadedModel]: | |
| """Get the currently loaded model, if any.""" | |
| return _current_model | |
| def generate_text( | |
| model_id: str, | |
| prompt: str, | |
| max_new_tokens: int = 512, | |
| temperature: float = 0.7, | |
| top_p: float = 0.95, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.1, | |
| stop_sequences: Optional[list[str]] = None, | |
| ) -> str: | |
| """ | |
| Generate text using the specified model. | |
| Args: | |
| model_id: HuggingFace model ID | |
| prompt: Input prompt (already formatted with chat template) | |
| max_new_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature | |
| top_p: Nucleus sampling probability | |
| top_k: Top-k sampling parameter | |
| repetition_penalty: Penalty for repeating tokens | |
| stop_sequences: Additional stop sequences | |
| Returns: | |
| Generated text (without the input prompt) | |
| """ | |
| loaded = load_model(model_id) | |
| inputs = loaded.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=loaded.tokenizer.model_max_length - max_new_tokens, | |
| ) | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.cuda() for k, v in inputs.items()} | |
| # Build generation config | |
| gen_kwargs = { | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "repetition_penalty": repetition_penalty, | |
| "do_sample": temperature > 0, | |
| "pad_token_id": loaded.tokenizer.pad_token_id, | |
| "eos_token_id": loaded.tokenizer.eos_token_id, | |
| } | |
| with torch.no_grad(): | |
| outputs = loaded.model.generate(**inputs, **gen_kwargs) | |
| # Decode only the new tokens | |
| input_length = inputs["input_ids"].shape[1] | |
| generated_tokens = outputs[0][input_length:] | |
| response = loaded.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| # Handle stop sequences | |
| if stop_sequences: | |
| for stop_seq in stop_sequences: | |
| if stop_seq in response: | |
| response = response.split(stop_seq)[0] | |
| return response | |
| def generate_text_stream( | |
| model_id: str, | |
| prompt: str, | |
| max_new_tokens: int = 512, | |
| temperature: float = 0.7, | |
| top_p: float = 0.95, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.1, | |
| stop_sequences: Optional[list[str]] = None, | |
| ) -> Generator[str, None, None]: | |
| """ | |
| Generate text using streaming output. | |
| Yields tokens as they are generated. | |
| """ | |
| loaded = load_model(model_id) | |
| inputs = loaded.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=loaded.tokenizer.model_max_length - max_new_tokens, | |
| ) | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.cuda() for k, v in inputs.items()} | |
| # Create streamer | |
| streamer = TextIteratorStreamer( | |
| loaded.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| # Build generation config | |
| gen_kwargs = { | |
| **inputs, | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "repetition_penalty": repetition_penalty, | |
| "do_sample": temperature > 0, | |
| "pad_token_id": loaded.tokenizer.pad_token_id, | |
| "eos_token_id": loaded.tokenizer.eos_token_id, | |
| "streamer": streamer, | |
| } | |
| # Run generation in separate thread | |
| thread = Thread(target=loaded.model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| # Stream tokens | |
| accumulated = "" | |
| for token in streamer: | |
| accumulated += token | |
| # Check for stop sequences | |
| should_stop = False | |
| if stop_sequences: | |
| for stop_seq in stop_sequences: | |
| if stop_seq in accumulated: | |
| # Yield everything before the stop sequence | |
| before_stop = accumulated.split(stop_seq)[0] | |
| if before_stop: | |
| yield before_stop[len(accumulated) - len(token):] | |
| should_stop = True | |
| break | |
| if should_stop: | |
| break | |
| yield token | |
| thread.join() | |
| # Tokenizer cache (separate from model cache, for ZeroGPU compatibility) | |
| _tokenizer_cache: dict = {} | |
| def get_tokenizer(model_id: str): | |
| """ | |
| Get or load a tokenizer for the specified model. | |
| This is separate from model loading for ZeroGPU compatibility - | |
| tokenizers can be loaded outside GPU context. | |
| """ | |
| if model_id in _tokenizer_cache: | |
| return _tokenizer_cache[model_id] | |
| config = get_config() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| token=config.hf_token, | |
| trust_remote_code=True, | |
| ) | |
| # Ensure tokenizer has pad token | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| _tokenizer_cache[model_id] = tokenizer | |
| return tokenizer | |
| def apply_chat_template( | |
| model_id: str, | |
| messages: list[dict[str, str]], | |
| add_generation_prompt: bool = True, | |
| ) -> str: | |
| """ | |
| Apply the model's chat template to format messages. | |
| Args: | |
| model_id: HuggingFace model ID | |
| messages: List of message dicts with "role" and "content" | |
| add_generation_prompt: Whether to add the generation prompt | |
| Returns: | |
| Formatted prompt string | |
| """ | |
| tokenizer = get_tokenizer(model_id) | |
| # Check if tokenizer has chat template | |
| if hasattr(tokenizer, "apply_chat_template"): | |
| return tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=add_generation_prompt, | |
| ) | |
| # Fallback: simple formatting | |
| prompt_parts = [] | |
| for msg in messages: | |
| role = msg["role"] | |
| content = msg["content"] | |
| if role == "system": | |
| prompt_parts.append(f"System: {content}\n") | |
| elif role == "user": | |
| prompt_parts.append(f"User: {content}\n") | |
| elif role == "assistant": | |
| prompt_parts.append(f"Assistant: {content}\n") | |
| if add_generation_prompt: | |
| prompt_parts.append("Assistant:") | |
| return "".join(prompt_parts) | |