Spaces:
Sleeping
Sleeping
| # model_manager.py | |
| """ | |
| Lazy-loading Llama-3.2-3B-Instruct with proper ZeroGPU context management. | |
| KEY FIX: Each generate() call is wrapped with @spaces.GPU to ensure | |
| the model is accessible during generation. | |
| """ | |
| import os | |
| import torch | |
| import logging | |
| from typing import Optional, Iterator | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| BitsAndBytesConfig, | |
| pipeline as create_pipeline | |
| ) | |
| # ZeroGPU support | |
| try: | |
| import spaces | |
| HF_SPACES_AVAILABLE = True | |
| except ImportError: | |
| HF_SPACES_AVAILABLE = False | |
| class DummySpaces: | |
| def GPU(duration=90): | |
| def decorator(func): | |
| return func | |
| return decorator | |
| spaces = DummySpaces() | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct" | |
| HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| class LazyLlamaModel: | |
| """ | |
| Singleton lazy-loading model with proper ZeroGPU context management. | |
| CRITICAL FIX: Model components are loaded fresh within each @spaces.GPU | |
| decorated call, ensuring GPU context is maintained throughout generation. | |
| """ | |
| _instance = None | |
| _initialized = False | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| return cls._instance | |
| def __init__(self): | |
| if not self._initialized: | |
| self.model_id = MODEL_ID | |
| self.token = HF_TOKEN | |
| # Don't load model here - load it inside GPU-decorated functions | |
| self.tokenizer = None | |
| self.model = None | |
| self.pipeline = None | |
| LazyLlamaModel._initialized = True | |
| logger.info(f"LazyLlamaModel initialized (model will load on first generate)") | |
| def _load_model_components(self): | |
| """ | |
| Load model components. Called INSIDE @spaces.GPU decorated functions. | |
| This ensures GPU context is maintained. | |
| """ | |
| if self.model is not None and self.tokenizer is not None: | |
| return # Already loaded in this context | |
| logger.info("="*60) | |
| logger.info("LOADING LLAMA-3.2-3B-INSTRUCT") | |
| logger.info("="*60) | |
| # Load tokenizer | |
| logger.info(f"Loading: {self.model_id}") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_id, | |
| token=self.token, | |
| trust_remote_code=True | |
| ) | |
| logger.info(f"✓ Tokenizer loaded: {type(self.tokenizer).__name__}") | |
| # Configure 4-bit quantization | |
| logger.info("Config: 4-bit NF4 quantization") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16 | |
| ) | |
| # Load model with quantization | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_id, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| token=self.token, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| ) | |
| logger.info(f"✓ Model loaded: {type(self.model).__name__}") | |
| # Create pipeline | |
| self.pipeline = create_pipeline( | |
| "text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| device_map="auto" | |
| ) | |
| logger.info("✓ Pipeline created and verified: TextGenerationPipeline") | |
| logger.info("="*60) | |
| logger.info("✅ MODEL LOADED & CACHED") | |
| logger.info(f" Model: {self.model_id}") | |
| logger.info(f" Tokenizer: {type(self.tokenizer).__name__}") | |
| logger.info(f" Pipeline: {type(self.pipeline).__name__}") | |
| logger.info(f" Memory: ~1GB VRAM") | |
| logger.info(f" Context: 128K tokens") | |
| logger.info("="*60) | |
| def generate( | |
| self, | |
| system_prompt: str, | |
| user_message: str, | |
| max_tokens: int = 500, | |
| temperature: float = 0.7 | |
| ) -> str: | |
| """ | |
| Generate text with proper GPU context management. | |
| CRITICAL: @spaces.GPU decorator ensures model stays in GPU context | |
| throughout the entire generation process. | |
| """ | |
| # Load model components if not already loaded | |
| self._load_model_components() | |
| # Verify pipeline is available | |
| if self.pipeline is None: | |
| raise RuntimeError( | |
| "Pipeline is None after loading. This may be a ZeroGPU context issue. " | |
| "Check that _load_model_components() completed successfully." | |
| ) | |
| # Format prompt with chat template | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_message} | |
| ] | |
| prompt = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Generate | |
| outputs = self.pipeline( | |
| prompt, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| do_sample=temperature > 0, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| return_full_text=False | |
| ) | |
| response = outputs[0]['generated_text'] | |
| return response.strip() | |
| def generate_streaming( | |
| self, | |
| system_prompt: str, | |
| user_message: str, | |
| max_tokens: int = 500, | |
| temperature: float = 0.7 | |
| ) -> Iterator[str]: | |
| """ | |
| Generate text with streaming output. | |
| CRITICAL: @spaces.GPU decorator ensures model stays in GPU context. | |
| """ | |
| # Load model components if not already loaded | |
| self._load_model_components() | |
| # Verify pipeline is available | |
| if self.pipeline is None: | |
| raise RuntimeError( | |
| "Pipeline is None after loading. This may be a ZeroGPU context issue." | |
| ) | |
| # Format prompt | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_message} | |
| ] | |
| prompt = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Tokenize | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| # Generate with streaming | |
| last_output_len = 0 | |
| with torch.no_grad(): | |
| for _ in range(max_tokens): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=1, | |
| temperature=temperature, | |
| do_sample=temperature > 0, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| # Decode new tokens | |
| current_output = self.tokenizer.decode( | |
| outputs[0][inputs['input_ids'].shape[1]:], | |
| skip_special_tokens=True | |
| ) | |
| # Yield new content | |
| if len(current_output) > last_output_len: | |
| new_text = current_output[last_output_len:] | |
| yield new_text | |
| last_output_len = len(current_output) | |
| # Check for EOS | |
| if outputs[0][-1] == self.tokenizer.eos_token_id: | |
| break | |
| # Update inputs for next iteration | |
| inputs = { | |
| 'input_ids': outputs, | |
| 'attention_mask': torch.ones_like(outputs) | |
| } | |
| # Singleton instance | |
| _model_instance = None | |
| def get_model() -> LazyLlamaModel: | |
| """Get the singleton model instance""" | |
| global _model_instance | |
| if _model_instance is None: | |
| _model_instance = LazyLlamaModel() | |
| return _model_instance | |
| # Backwards compatibility aliases (within same module - no import) | |
| get_shared_llama = get_model | |
| MistralSharedAgent = LazyLlamaModel | |
| LlamaSharedAgent = LazyLlamaModel | |
| # DO NOT ADD THIS LINE - IT CAUSES CIRCULAR IMPORT: | |
| # from model_manager import get_model as get_shared_llama, LazyLlamaModel as LlamaSharedAgent |