"""Model loading, caching, and memory management for ZeroGPU inference.""" import gc import logging from dataclasses import dataclass, field from typing import Optional, Generator, Any import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, ) from threading import Thread from config import get_config, should_quantize logger = logging.getLogger(__name__) @dataclass class LoadedModel: """Container for a loaded model and its tokenizer.""" model_id: str model: Any tokenizer: Any quantization: str = "none" # Global model cache (single model at a time due to memory constraints) _current_model: Optional[LoadedModel] = None def get_quantization_config(quantization: str) -> Optional[BitsAndBytesConfig]: """Get BitsAndBytes configuration for the specified quantization level.""" if quantization == "int8": return BitsAndBytesConfig(load_in_8bit=True) elif quantization == "int4": return BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) return None def clear_gpu_memory() -> None: """Clear GPU memory by running garbage collection and emptying CUDA cache.""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() logger.debug("GPU memory cleared") def unload_model() -> None: """Unload the currently loaded model and free memory.""" global _current_model if _current_model is not None: logger.info(f"Unloading model: {_current_model.model_id}") del _current_model.model del _current_model.tokenizer _current_model = None clear_gpu_memory() def load_model( model_id: str, quantization: Optional[str] = None, force_reload: bool = False, ) -> LoadedModel: """ Load a model from HuggingFace Hub. Args: model_id: HuggingFace model ID (e.g., "meta-llama/Llama-3.1-8B-Instruct") quantization: Force specific quantization ("none", "int8", "int4") If None, auto-determine based on model size force_reload: If True, reload even if already loaded Returns: LoadedModel with model and tokenizer ready for inference Raises: ValueError: If model_id is None or empty """ global _current_model if not model_id: raise ValueError("model_id cannot be None or empty") # Check if already loaded if not force_reload and _current_model is not None: if _current_model.model_id == model_id: logger.debug(f"Model already loaded: {model_id}") return _current_model # Determine quantization if quantization is None: quantization = should_quantize(model_id) logger.info(f"Loading model: {model_id} (quantization: {quantization})") # Unload current model first unload_model() config = get_config() # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( model_id, token=config.hf_token, trust_remote_code=True, ) # Ensure tokenizer has pad token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load model with appropriate configuration quant_config = get_quantization_config(quantization) model_kwargs = { "token": config.hf_token, "trust_remote_code": True, } # On ZeroGPU, use device_map only when GPU is available # Otherwise load to CPU for local testing if torch.cuda.is_available(): model_kwargs["device_map"] = "auto" if quant_config is not None: model_kwargs["quantization_config"] = quant_config else: model_kwargs["torch_dtype"] = torch.bfloat16 else: # CPU mode - no quantization, float32 model_kwargs["device_map"] = "cpu" model_kwargs["torch_dtype"] = torch.float32 model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) _current_model = LoadedModel( model_id=model_id, model=model, tokenizer=tokenizer, quantization=quantization, ) logger.info(f"Model loaded successfully: {model_id}") return _current_model def get_current_model() -> Optional[LoadedModel]: """Get the currently loaded model, if any.""" return _current_model def generate_text( model_id: str, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95, top_k: int = 50, repetition_penalty: float = 1.1, stop_sequences: Optional[list[str]] = None, ) -> str: """ Generate text using the specified model. Args: model_id: HuggingFace model ID prompt: Input prompt (already formatted with chat template) max_new_tokens: Maximum tokens to generate temperature: Sampling temperature top_p: Nucleus sampling probability top_k: Top-k sampling parameter repetition_penalty: Penalty for repeating tokens stop_sequences: Additional stop sequences Returns: Generated text (without the input prompt) """ loaded = load_model(model_id) inputs = loaded.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=loaded.tokenizer.model_max_length - max_new_tokens, ) if torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} # Build generation config gen_kwargs = { "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "do_sample": temperature > 0, "pad_token_id": loaded.tokenizer.pad_token_id, "eos_token_id": loaded.tokenizer.eos_token_id, } with torch.no_grad(): outputs = loaded.model.generate(**inputs, **gen_kwargs) # Decode only the new tokens input_length = inputs["input_ids"].shape[1] generated_tokens = outputs[0][input_length:] response = loaded.tokenizer.decode(generated_tokens, skip_special_tokens=True) # Handle stop sequences if stop_sequences: for stop_seq in stop_sequences: if stop_seq in response: response = response.split(stop_seq)[0] return response def generate_text_stream( model_id: str, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95, top_k: int = 50, repetition_penalty: float = 1.1, stop_sequences: Optional[list[str]] = None, ) -> Generator[str, None, None]: """ Generate text using streaming output. Yields tokens as they are generated. """ loaded = load_model(model_id) inputs = loaded.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=loaded.tokenizer.model_max_length - max_new_tokens, ) if torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} # Create streamer streamer = TextIteratorStreamer( loaded.tokenizer, skip_prompt=True, skip_special_tokens=True, ) # Build generation config gen_kwargs = { **inputs, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "do_sample": temperature > 0, "pad_token_id": loaded.tokenizer.pad_token_id, "eos_token_id": loaded.tokenizer.eos_token_id, "streamer": streamer, } # Run generation in separate thread thread = Thread(target=loaded.model.generate, kwargs=gen_kwargs) thread.start() # Stream tokens accumulated = "" for token in streamer: accumulated += token # Check for stop sequences should_stop = False if stop_sequences: for stop_seq in stop_sequences: if stop_seq in accumulated: # Yield everything before the stop sequence before_stop = accumulated.split(stop_seq)[0] if before_stop: yield before_stop[len(accumulated) - len(token):] should_stop = True break if should_stop: break yield token thread.join() # Tokenizer cache (separate from model cache, for ZeroGPU compatibility) _tokenizer_cache: dict = {} def get_tokenizer(model_id: str): """ Get or load a tokenizer for the specified model. This is separate from model loading for ZeroGPU compatibility - tokenizers can be loaded outside GPU context. """ if model_id in _tokenizer_cache: return _tokenizer_cache[model_id] config = get_config() tokenizer = AutoTokenizer.from_pretrained( model_id, token=config.hf_token, trust_remote_code=True, ) # Ensure tokenizer has pad token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token _tokenizer_cache[model_id] = tokenizer return tokenizer def apply_chat_template( model_id: str, messages: list[dict[str, str]], add_generation_prompt: bool = True, ) -> str: """ Apply the model's chat template to format messages. Args: model_id: HuggingFace model ID messages: List of message dicts with "role" and "content" add_generation_prompt: Whether to add the generation prompt Returns: Formatted prompt string """ tokenizer = get_tokenizer(model_id) # Check if tokenizer has chat template if hasattr(tokenizer, "apply_chat_template"): return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=add_generation_prompt, ) # Fallback: simple formatting prompt_parts = [] for msg in messages: role = msg["role"] content = msg["content"] if role == "system": prompt_parts.append(f"System: {content}\n") elif role == "user": prompt_parts.append(f"User: {content}\n") elif role == "assistant": prompt_parts.append(f"Assistant: {content}\n") if add_generation_prompt: prompt_parts.append("Assistant:") return "".join(prompt_parts)