jdesiree commited on
Commit
eff7d5f
·
verified ·
1 Parent(s): 038e223

Create model_manager.py

Browse files
Files changed (1) hide show
  1. model_manager.py +270 -0
model_manager.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_manager.py
2
+ """
3
+ Lazy-loading model manager for Llama-3.2-3B-Instruct.
4
+ Model is loaded on first use and cached for subsequent calls.
5
+ """
6
+ import os
7
+ import torch
8
+ import logging
9
+ import threading
10
+ from typing import Optional, List
11
+ from transformers import (
12
+ AutoTokenizer,
13
+ AutoModelForCausalLM,
14
+ BitsAndBytesConfig,
15
+ TextIteratorStreamer,
16
+ pipeline
17
+ )
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # ZeroGPU support
22
+ try:
23
+ import spaces
24
+ HF_SPACES_AVAILABLE = True
25
+ logger.info("✅ ZeroGPU (spaces) available")
26
+ except ImportError:
27
+ HF_SPACES_AVAILABLE = False
28
+ class DummySpaces:
29
+ @staticmethod
30
+ def GPU(duration=90):
31
+ def decorator(func):
32
+ return func
33
+ return decorator
34
+ spaces = DummySpaces()
35
+ logger.warning("⚠️ ZeroGPU not available - running without GPU allocation")
36
+
37
+ # Configuration
38
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
39
+ LLAMA_MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
40
+
41
+
42
+ class LazyLlamaModel:
43
+ """
44
+ Lazy-loading Llama-3.2-3B model with caching.
45
+ Thread-safe singleton pattern - model loads on first use.
46
+ """
47
+
48
+ _instance = None
49
+ _lock = threading.Lock()
50
+
51
+ def __new__(cls):
52
+ if cls._instance is None:
53
+ with cls._lock:
54
+ if cls._instance is None:
55
+ cls._instance = super().__new__(cls)
56
+ cls._instance._initialized = False
57
+ return cls._instance
58
+
59
+ def __init__(self):
60
+ """Initialize only once"""
61
+ if self._initialized:
62
+ return
63
+
64
+ self.model = None
65
+ self.tokenizer = None
66
+ self.pipe = None
67
+ self._initialized = True
68
+ logger.info("LazyLlamaModel created (model not loaded yet)")
69
+
70
+ @spaces.GPU(duration=120)
71
+ def _load_model(self):
72
+ """Load model with 4-bit quantization. Called automatically on first use."""
73
+ if self.model is not None:
74
+ return # Already loaded
75
+
76
+ logger.info("="*60)
77
+ logger.info("LOADING LLAMA-3.2-3B-INSTRUCT (First Use)")
78
+ logger.info("="*60)
79
+
80
+ try:
81
+ # 4-bit quantization for memory efficiency
82
+ quantization_config = BitsAndBytesConfig(
83
+ load_in_4bit=True,
84
+ bnb_4bit_quant_type="nf4",
85
+ bnb_4bit_compute_dtype=torch.bfloat16,
86
+ bnb_4bit_use_double_quant=True,
87
+ )
88
+
89
+ logger.info(f"Loading: {LLAMA_MODEL_ID}")
90
+ logger.info("Config: 4-bit NF4 quantization")
91
+
92
+ # Load tokenizer
93
+ self.tokenizer = AutoTokenizer.from_pretrained(
94
+ LLAMA_MODEL_ID,
95
+ token=HF_TOKEN,
96
+ trust_remote_code=True,
97
+ )
98
+
99
+ # Load model
100
+ self.model = AutoModelForCausalLM.from_pretrained(
101
+ LLAMA_MODEL_ID,
102
+ quantization_config=quantization_config,
103
+ device_map="auto",
104
+ token=HF_TOKEN,
105
+ trust_remote_code=True,
106
+ torch_dtype=torch.bfloat16,
107
+ )
108
+
109
+ # Create pipeline
110
+ self.pipe = pipeline(
111
+ "text-generation",
112
+ model=self.model,
113
+ tokenizer=self.tokenizer,
114
+ torch_dtype=torch.bfloat16,
115
+ device_map="auto",
116
+ )
117
+
118
+ logger.info("="*60)
119
+ logger.info("✅ MODEL LOADED & CACHED")
120
+ logger.info(f" Model: {LLAMA_MODEL_ID}")
121
+ logger.info(f" Memory: ~1GB VRAM")
122
+ logger.info(f" Context: 128K tokens")
123
+ logger.info("="*60)
124
+
125
+ except Exception as e:
126
+ logger.error(f"Failed to load model: {e}")
127
+ raise
128
+
129
+ def generate(
130
+ self,
131
+ system_prompt: str,
132
+ user_message: str,
133
+ max_tokens: int = 100,
134
+ temperature: float = 0.7,
135
+ stop_sequences: Optional[List[str]] = None
136
+ ) -> str:
137
+ """
138
+ Generate response. Automatically loads model on first call.
139
+
140
+ Args:
141
+ system_prompt: System instruction
142
+ user_message: User query
143
+ max_tokens: Max tokens to generate
144
+ temperature: Sampling temperature
145
+ stop_sequences: Optional stop sequences (not used with pipeline)
146
+
147
+ Returns:
148
+ Generated text
149
+ """
150
+ # Lazy load on first use
151
+ if self.model is None:
152
+ self._load_model()
153
+
154
+ messages = [
155
+ {"role": "system", "content": system_prompt},
156
+ {"role": "user", "content": user_message},
157
+ ]
158
+
159
+ try:
160
+ outputs = self.pipe(
161
+ messages,
162
+ max_new_tokens=max_tokens,
163
+ temperature=temperature,
164
+ do_sample=True,
165
+ top_p=0.9,
166
+ top_k=40,
167
+ repetition_penalty=1.15,
168
+ )
169
+
170
+ result = outputs[0]["generated_text"][-1]["content"]
171
+ return result.strip()
172
+
173
+ except Exception as e:
174
+ logger.error(f"Generation error: {e}")
175
+ return ""
176
+
177
+ def generate_streaming(
178
+ self,
179
+ system_prompt: str,
180
+ user_message: str,
181
+ max_tokens: int = 512,
182
+ temperature: float = 0.7,
183
+ ):
184
+ """
185
+ Generate response with streaming. Automatically loads model on first call.
186
+
187
+ Yields:
188
+ str: Generated text chunks
189
+ """
190
+ # Lazy load on first use
191
+ if self.model is None:
192
+ self._load_model()
193
+
194
+ messages = [
195
+ {"role": "system", "content": system_prompt},
196
+ {"role": "user", "content": user_message},
197
+ ]
198
+
199
+ try:
200
+ # Apply chat template
201
+ input_ids = self.tokenizer.apply_chat_template(
202
+ messages,
203
+ add_generation_prompt=True,
204
+ return_tensors="pt"
205
+ ).to(self.model.device)
206
+
207
+ streamer = TextIteratorStreamer(
208
+ self.tokenizer,
209
+ skip_prompt=True,
210
+ skip_special_tokens=True
211
+ )
212
+
213
+ generation_kwargs = dict(
214
+ input_ids=input_ids,
215
+ streamer=streamer,
216
+ max_new_tokens=max_tokens,
217
+ temperature=temperature,
218
+ do_sample=True,
219
+ top_p=0.9,
220
+ top_k=40,
221
+ repetition_penalty=1.15,
222
+ )
223
+
224
+ # Generate in separate thread
225
+ thread = threading.Thread(target=self.model.generate, kwargs=generation_kwargs)
226
+ thread.start()
227
+
228
+ # Yield generated text
229
+ for text in streamer:
230
+ yield text
231
+
232
+ except Exception as e:
233
+ logger.error(f"Streaming error: {e}")
234
+ yield ""
235
+
236
+ def is_loaded(self) -> bool:
237
+ """Check if model is loaded"""
238
+ return self.model is not None
239
+
240
+ def get_model_info(self) -> dict:
241
+ """Get model information"""
242
+ return {
243
+ "model_id": LLAMA_MODEL_ID,
244
+ "loaded": self.is_loaded(),
245
+ "quantization": "4-bit NF4",
246
+ "size_gb": 1.0,
247
+ "context_length": 128000,
248
+ "lazy_loading": True,
249
+ }
250
+
251
+
252
+ # Global instance - model loads on first use
253
+ _model_instance = None
254
+
255
+ def get_model() -> LazyLlamaModel:
256
+ """
257
+ Get the lazy-loading model instance.
258
+ Model will automatically load on first generate() call.
259
+ """
260
+ global _model_instance
261
+ if _model_instance is None:
262
+ _model_instance = LazyLlamaModel()
263
+ return _model_instance
264
+
265
+
266
+ # Backwards compatibility aliases
267
+ get_shared_llama = get_model
268
+ get_shared_qwen3 = get_model
269
+ get_shared_mistral = get_model
270
+ LlamaSharedAgent = LazyLlamaModel