opencode-zerogpu / models.py
serenichron's picture
Fix device handling: check GPU availability before device_map
6d6c01e
"""Model loading, caching, and memory management for ZeroGPU inference."""
import gc
import logging
from dataclasses import dataclass, field
from typing import Optional, Generator, Any
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TextIteratorStreamer,
)
from threading import Thread
from config import get_config, should_quantize
logger = logging.getLogger(__name__)
@dataclass
class LoadedModel:
"""Container for a loaded model and its tokenizer."""
model_id: str
model: Any
tokenizer: Any
quantization: str = "none"
# Global model cache (single model at a time due to memory constraints)
_current_model: Optional[LoadedModel] = None
def get_quantization_config(quantization: str) -> Optional[BitsAndBytesConfig]:
"""Get BitsAndBytes configuration for the specified quantization level."""
if quantization == "int8":
return BitsAndBytesConfig(load_in_8bit=True)
elif quantization == "int4":
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
return None
def clear_gpu_memory() -> None:
"""Clear GPU memory by running garbage collection and emptying CUDA cache."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
logger.debug("GPU memory cleared")
def unload_model() -> None:
"""Unload the currently loaded model and free memory."""
global _current_model
if _current_model is not None:
logger.info(f"Unloading model: {_current_model.model_id}")
del _current_model.model
del _current_model.tokenizer
_current_model = None
clear_gpu_memory()
def load_model(
model_id: str,
quantization: Optional[str] = None,
force_reload: bool = False,
) -> LoadedModel:
"""
Load a model from HuggingFace Hub.
Args:
model_id: HuggingFace model ID (e.g., "meta-llama/Llama-3.1-8B-Instruct")
quantization: Force specific quantization ("none", "int8", "int4")
If None, auto-determine based on model size
force_reload: If True, reload even if already loaded
Returns:
LoadedModel with model and tokenizer ready for inference
Raises:
ValueError: If model_id is None or empty
"""
global _current_model
if not model_id:
raise ValueError("model_id cannot be None or empty")
# Check if already loaded
if not force_reload and _current_model is not None:
if _current_model.model_id == model_id:
logger.debug(f"Model already loaded: {model_id}")
return _current_model
# Determine quantization
if quantization is None:
quantization = should_quantize(model_id)
logger.info(f"Loading model: {model_id} (quantization: {quantization})")
# Unload current model first
unload_model()
config = get_config()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=config.hf_token,
trust_remote_code=True,
)
# Ensure tokenizer has pad token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model with appropriate configuration
quant_config = get_quantization_config(quantization)
model_kwargs = {
"token": config.hf_token,
"trust_remote_code": True,
}
# On ZeroGPU, use device_map only when GPU is available
# Otherwise load to CPU for local testing
if torch.cuda.is_available():
model_kwargs["device_map"] = "auto"
if quant_config is not None:
model_kwargs["quantization_config"] = quant_config
else:
model_kwargs["torch_dtype"] = torch.bfloat16
else:
# CPU mode - no quantization, float32
model_kwargs["device_map"] = "cpu"
model_kwargs["torch_dtype"] = torch.float32
model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
_current_model = LoadedModel(
model_id=model_id,
model=model,
tokenizer=tokenizer,
quantization=quantization,
)
logger.info(f"Model loaded successfully: {model_id}")
return _current_model
def get_current_model() -> Optional[LoadedModel]:
"""Get the currently loaded model, if any."""
return _current_model
def generate_text(
model_id: str,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.95,
top_k: int = 50,
repetition_penalty: float = 1.1,
stop_sequences: Optional[list[str]] = None,
) -> str:
"""
Generate text using the specified model.
Args:
model_id: HuggingFace model ID
prompt: Input prompt (already formatted with chat template)
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling probability
top_k: Top-k sampling parameter
repetition_penalty: Penalty for repeating tokens
stop_sequences: Additional stop sequences
Returns:
Generated text (without the input prompt)
"""
loaded = load_model(model_id)
inputs = loaded.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=loaded.tokenizer.model_max_length - max_new_tokens,
)
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
# Build generation config
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"do_sample": temperature > 0,
"pad_token_id": loaded.tokenizer.pad_token_id,
"eos_token_id": loaded.tokenizer.eos_token_id,
}
with torch.no_grad():
outputs = loaded.model.generate(**inputs, **gen_kwargs)
# Decode only the new tokens
input_length = inputs["input_ids"].shape[1]
generated_tokens = outputs[0][input_length:]
response = loaded.tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Handle stop sequences
if stop_sequences:
for stop_seq in stop_sequences:
if stop_seq in response:
response = response.split(stop_seq)[0]
return response
def generate_text_stream(
model_id: str,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.95,
top_k: int = 50,
repetition_penalty: float = 1.1,
stop_sequences: Optional[list[str]] = None,
) -> Generator[str, None, None]:
"""
Generate text using streaming output.
Yields tokens as they are generated.
"""
loaded = load_model(model_id)
inputs = loaded.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=loaded.tokenizer.model_max_length - max_new_tokens,
)
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
# Create streamer
streamer = TextIteratorStreamer(
loaded.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
# Build generation config
gen_kwargs = {
**inputs,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"do_sample": temperature > 0,
"pad_token_id": loaded.tokenizer.pad_token_id,
"eos_token_id": loaded.tokenizer.eos_token_id,
"streamer": streamer,
}
# Run generation in separate thread
thread = Thread(target=loaded.model.generate, kwargs=gen_kwargs)
thread.start()
# Stream tokens
accumulated = ""
for token in streamer:
accumulated += token
# Check for stop sequences
should_stop = False
if stop_sequences:
for stop_seq in stop_sequences:
if stop_seq in accumulated:
# Yield everything before the stop sequence
before_stop = accumulated.split(stop_seq)[0]
if before_stop:
yield before_stop[len(accumulated) - len(token):]
should_stop = True
break
if should_stop:
break
yield token
thread.join()
# Tokenizer cache (separate from model cache, for ZeroGPU compatibility)
_tokenizer_cache: dict = {}
def get_tokenizer(model_id: str):
"""
Get or load a tokenizer for the specified model.
This is separate from model loading for ZeroGPU compatibility -
tokenizers can be loaded outside GPU context.
"""
if model_id in _tokenizer_cache:
return _tokenizer_cache[model_id]
config = get_config()
tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=config.hf_token,
trust_remote_code=True,
)
# Ensure tokenizer has pad token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
_tokenizer_cache[model_id] = tokenizer
return tokenizer
def apply_chat_template(
model_id: str,
messages: list[dict[str, str]],
add_generation_prompt: bool = True,
) -> str:
"""
Apply the model's chat template to format messages.
Args:
model_id: HuggingFace model ID
messages: List of message dicts with "role" and "content"
add_generation_prompt: Whether to add the generation prompt
Returns:
Formatted prompt string
"""
tokenizer = get_tokenizer(model_id)
# Check if tokenizer has chat template
if hasattr(tokenizer, "apply_chat_template"):
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
# Fallback: simple formatting
prompt_parts = []
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
prompt_parts.append(f"System: {content}\n")
elif role == "user":
prompt_parts.append(f"User: {content}\n")
elif role == "assistant":
prompt_parts.append(f"Assistant: {content}\n")
if add_generation_prompt:
prompt_parts.append("Assistant:")
return "".join(prompt_parts)