File size: 8,644 Bytes
eff7d5f
 
7e90504
d287c8b
7e90504
 
eff7d5f
d287c8b
eff7d5f
 
 
7e90504
eff7d5f
 
 
 
7e90504
eff7d5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e90504
d287c8b
7e90504
 
eff7d5f
d287c8b
 
7e90504
eff7d5f
7e90504
d287c8b
7e90504
 
eff7d5f
 
7e90504
 
d287c8b
7e90504
 
 
 
eff7d5f
 
7e90504
 
 
bdff161
7e90504
 
 
 
d287c8b
7e90504
 
d287c8b
7e90504
d287c8b
7e90504
 
d287c8b
7e90504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d287c8b
7e90504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d287c8b
7e90504
 
 
 
eff7d5f
 
7e90504
eff7d5f
7e90504
 
eff7d5f
7e90504
 
 
 
 
 
 
 
d287c8b
 
7e90504
 
 
 
 
d287c8b
7e90504
 
 
 
 
d287c8b
7e90504
 
 
 
 
 
 
 
 
 
eff7d5f
7e90504
 
d287c8b
7e90504
 
 
 
 
 
 
 
d287c8b
7e90504
d287c8b
7e90504
d287c8b
7e90504
 
d287c8b
7e90504
 
 
 
d287c8b
 
7e90504
 
 
 
 
eff7d5f
7e90504
 
 
 
d287c8b
eff7d5f
7e90504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eff7d5f
7e90504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eff7d5f
 
7e90504
 
 
 
eff7d5f
7e90504
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
# model_manager.py
"""
Lazy-loading Llama-3.2-3B-Instruct with proper ZeroGPU context management.

KEY FIX: Each generate() call is wrapped with @spaces.GPU to ensure
the model is accessible during generation.
"""

import os
import torch
import logging
from typing import Optional, Iterator
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    pipeline as create_pipeline
)

# ZeroGPU support
try:
    import spaces
    HF_SPACES_AVAILABLE = True
except ImportError:
    HF_SPACES_AVAILABLE = False
    class DummySpaces:
        @staticmethod
        def GPU(duration=90):
            def decorator(func):
                return func
            return decorator
    spaces = DummySpaces()

logger = logging.getLogger(__name__)

# Configuration
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")


class LazyLlamaModel:
    """
    Singleton lazy-loading model with proper ZeroGPU context management.
    
    CRITICAL FIX: Model components are loaded fresh within each @spaces.GPU
    decorated call, ensuring GPU context is maintained throughout generation.
    """
    
    _instance = None
    _initialized = False
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance
    
    def __init__(self):
        if not self._initialized:
            self.model_id = MODEL_ID
            self.token = HF_TOKEN
            
            # Don't load model here - load it inside GPU-decorated functions
            self.tokenizer = None
            self.model = None
            self.pipeline = None
            
            LazyLlamaModel._initialized = True
            logger.info(f"LazyLlamaModel initialized (model will load on first generate)")
    
    def _load_model_components(self):
        """
        Load model components. Called INSIDE @spaces.GPU decorated functions.
        This ensures GPU context is maintained.
        """
        if self.model is not None and self.tokenizer is not None:
            return  # Already loaded in this context
        
        logger.info("="*60)
        logger.info("LOADING LLAMA-3.2-3B-INSTRUCT")
        logger.info("="*60)
        
        # Load tokenizer
        logger.info(f"Loading: {self.model_id}")
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_id,
            token=self.token,
            trust_remote_code=True
        )
        logger.info(f"✓ Tokenizer loaded: {type(self.tokenizer).__name__}")
        
        # Configure 4-bit quantization
        logger.info("Config: 4-bit NF4 quantization")
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16
        )
        
        # Load model with quantization
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_id,
            quantization_config=bnb_config,
            device_map="auto",
            token=self.token,
            trust_remote_code=True,
            torch_dtype=torch.float16,
        )
        logger.info(f"✓ Model loaded: {type(self.model).__name__}")
        
        # Create pipeline
        self.pipeline = create_pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            device_map="auto"
        )
        logger.info("✓ Pipeline created and verified: TextGenerationPipeline")
        
        logger.info("="*60)
        logger.info("✅ MODEL LOADED & CACHED")
        logger.info(f"  Model: {self.model_id}")
        logger.info(f"  Tokenizer: {type(self.tokenizer).__name__}")
        logger.info(f"  Pipeline: {type(self.pipeline).__name__}")
        logger.info(f"  Memory: ~1GB VRAM")
        logger.info(f"  Context: 128K tokens")
        logger.info("="*60)
    
    @spaces.GPU(duration=90)
    def generate(
        self,
        system_prompt: str,
        user_message: str,
        max_tokens: int = 500,
        temperature: float = 0.7
    ) -> str:
        """
        Generate text with proper GPU context management.
        
        CRITICAL: @spaces.GPU decorator ensures model stays in GPU context
        throughout the entire generation process.
        """
        # Load model components if not already loaded
        self._load_model_components()
        
        # Verify pipeline is available
        if self.pipeline is None:
            raise RuntimeError(
                "Pipeline is None after loading. This may be a ZeroGPU context issue. "
                "Check that _load_model_components() completed successfully."
            )
        
        # Format prompt with chat template
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_message}
        ]
        
        prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Generate
        outputs = self.pipeline(
            prompt,
            max_new_tokens=max_tokens,
            temperature=temperature,
            do_sample=temperature > 0,
            pad_token_id=self.tokenizer.eos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            return_full_text=False
        )
        
        response = outputs[0]['generated_text']
        return response.strip()
    
    @spaces.GPU(duration=90)
    def generate_streaming(
        self,
        system_prompt: str,
        user_message: str,
        max_tokens: int = 500,
        temperature: float = 0.7
    ) -> Iterator[str]:
        """
        Generate text with streaming output.
        
        CRITICAL: @spaces.GPU decorator ensures model stays in GPU context.
        """
        # Load model components if not already loaded
        self._load_model_components()
        
        # Verify pipeline is available
        if self.pipeline is None:
            raise RuntimeError(
                "Pipeline is None after loading. This may be a ZeroGPU context issue."
            )
        
        # Format prompt
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_message}
        ]
        
        prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Tokenize
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        # Generate with streaming
        last_output_len = 0
        
        with torch.no_grad():
            for _ in range(max_tokens):
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=1,
                    temperature=temperature,
                    do_sample=temperature > 0,
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )
                
                # Decode new tokens
                current_output = self.tokenizer.decode(
                    outputs[0][inputs['input_ids'].shape[1]:],
                    skip_special_tokens=True
                )
                
                # Yield new content
                if len(current_output) > last_output_len:
                    new_text = current_output[last_output_len:]
                    yield new_text
                    last_output_len = len(current_output)
                
                # Check for EOS
                if outputs[0][-1] == self.tokenizer.eos_token_id:
                    break
                
                # Update inputs for next iteration
                inputs = {
                    'input_ids': outputs,
                    'attention_mask': torch.ones_like(outputs)
                }


# Singleton instance
_model_instance = None

def get_model() -> LazyLlamaModel:
    """Get the singleton model instance"""
    global _model_instance
    if _model_instance is None:
        _model_instance = LazyLlamaModel()
    return _model_instance


# Backwards compatibility aliases (within same module - no import)
get_shared_llama = get_model
MistralSharedAgent = LazyLlamaModel
LlamaSharedAgent = LazyLlamaModel

# DO NOT ADD THIS LINE - IT CAUSES CIRCULAR IMPORT:
# from model_manager import get_model as get_shared_llama, LazyLlamaModel as LlamaSharedAgent