Mimir / model_manager.py
jdesiree's picture
Update model_manager.py
7e90504 verified
# model_manager.py
"""
Lazy-loading Llama-3.2-3B-Instruct with proper ZeroGPU context management.
KEY FIX: Each generate() call is wrapped with @spaces.GPU to ensure
the model is accessible during generation.
"""
import os
import torch
import logging
from typing import Optional, Iterator
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
pipeline as create_pipeline
)
# ZeroGPU support
try:
import spaces
HF_SPACES_AVAILABLE = True
except ImportError:
HF_SPACES_AVAILABLE = False
class DummySpaces:
@staticmethod
def GPU(duration=90):
def decorator(func):
return func
return decorator
spaces = DummySpaces()
logger = logging.getLogger(__name__)
# Configuration
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
class LazyLlamaModel:
"""
Singleton lazy-loading model with proper ZeroGPU context management.
CRITICAL FIX: Model components are loaded fresh within each @spaces.GPU
decorated call, ensuring GPU context is maintained throughout generation.
"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self.model_id = MODEL_ID
self.token = HF_TOKEN
# Don't load model here - load it inside GPU-decorated functions
self.tokenizer = None
self.model = None
self.pipeline = None
LazyLlamaModel._initialized = True
logger.info(f"LazyLlamaModel initialized (model will load on first generate)")
def _load_model_components(self):
"""
Load model components. Called INSIDE @spaces.GPU decorated functions.
This ensures GPU context is maintained.
"""
if self.model is not None and self.tokenizer is not None:
return # Already loaded in this context
logger.info("="*60)
logger.info("LOADING LLAMA-3.2-3B-INSTRUCT")
logger.info("="*60)
# Load tokenizer
logger.info(f"Loading: {self.model_id}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_id,
token=self.token,
trust_remote_code=True
)
logger.info(f"✓ Tokenizer loaded: {type(self.tokenizer).__name__}")
# Configure 4-bit quantization
logger.info("Config: 4-bit NF4 quantization")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
# Load model with quantization
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id,
quantization_config=bnb_config,
device_map="auto",
token=self.token,
trust_remote_code=True,
torch_dtype=torch.float16,
)
logger.info(f"✓ Model loaded: {type(self.model).__name__}")
# Create pipeline
self.pipeline = create_pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device_map="auto"
)
logger.info("✓ Pipeline created and verified: TextGenerationPipeline")
logger.info("="*60)
logger.info("✅ MODEL LOADED & CACHED")
logger.info(f" Model: {self.model_id}")
logger.info(f" Tokenizer: {type(self.tokenizer).__name__}")
logger.info(f" Pipeline: {type(self.pipeline).__name__}")
logger.info(f" Memory: ~1GB VRAM")
logger.info(f" Context: 128K tokens")
logger.info("="*60)
@spaces.GPU(duration=90)
def generate(
self,
system_prompt: str,
user_message: str,
max_tokens: int = 500,
temperature: float = 0.7
) -> str:
"""
Generate text with proper GPU context management.
CRITICAL: @spaces.GPU decorator ensures model stays in GPU context
throughout the entire generation process.
"""
# Load model components if not already loaded
self._load_model_components()
# Verify pipeline is available
if self.pipeline is None:
raise RuntimeError(
"Pipeline is None after loading. This may be a ZeroGPU context issue. "
"Check that _load_model_components() completed successfully."
)
# Format prompt with chat template
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
]
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Generate
outputs = self.pipeline(
prompt,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
return_full_text=False
)
response = outputs[0]['generated_text']
return response.strip()
@spaces.GPU(duration=90)
def generate_streaming(
self,
system_prompt: str,
user_message: str,
max_tokens: int = 500,
temperature: float = 0.7
) -> Iterator[str]:
"""
Generate text with streaming output.
CRITICAL: @spaces.GPU decorator ensures model stays in GPU context.
"""
# Load model components if not already loaded
self._load_model_components()
# Verify pipeline is available
if self.pipeline is None:
raise RuntimeError(
"Pipeline is None after loading. This may be a ZeroGPU context issue."
)
# Format prompt
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
]
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
# Generate with streaming
last_output_len = 0
with torch.no_grad():
for _ in range(max_tokens):
outputs = self.model.generate(
**inputs,
max_new_tokens=1,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# Decode new tokens
current_output = self.tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
# Yield new content
if len(current_output) > last_output_len:
new_text = current_output[last_output_len:]
yield new_text
last_output_len = len(current_output)
# Check for EOS
if outputs[0][-1] == self.tokenizer.eos_token_id:
break
# Update inputs for next iteration
inputs = {
'input_ids': outputs,
'attention_mask': torch.ones_like(outputs)
}
# Singleton instance
_model_instance = None
def get_model() -> LazyLlamaModel:
"""Get the singleton model instance"""
global _model_instance
if _model_instance is None:
_model_instance = LazyLlamaModel()
return _model_instance
# Backwards compatibility aliases (within same module - no import)
get_shared_llama = get_model
MistralSharedAgent = LazyLlamaModel
LlamaSharedAgent = LazyLlamaModel
# DO NOT ADD THIS LINE - IT CAUSES CIRCULAR IMPORT:
# from model_manager import get_model as get_shared_llama, LazyLlamaModel as LlamaSharedAgent