jdesiree commited on
Commit
82d9923
·
verified ·
1 Parent(s): 51b6648

Create shared_models.py

Browse files
Files changed (1) hide show
  1. shared_models.py +285 -0
shared_models.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # shared_models.py
2
+ """
3
+ Shared model manager for Mimir agents.
4
+ Uses Llama-3.2-3B-Instruct with transformers for all agents.
5
+ """
6
+ import torch
7
+ import threading
8
+ import logging
9
+ import os
10
+ from typing import Optional, List
11
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # ZeroGPU support
16
+ try:
17
+ import spaces
18
+ HF_SPACES_AVAILABLE = True
19
+ logger.info("✅ ZeroGPU (spaces) available")
20
+ except ImportError:
21
+ HF_SPACES_AVAILABLE = False
22
+ class DummySpaces:
23
+ @staticmethod
24
+ def GPU(duration=90):
25
+ def decorator(func):
26
+ return func
27
+ return decorator
28
+ spaces = DummySpaces()
29
+ logger.warning("⚠️ ZeroGPU not available - running without GPU allocation")
30
+
31
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
32
+
33
+ # Model configuration
34
+ LLAMA_MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
35
+
36
+
37
+ class LlamaSharedAgent:
38
+ """
39
+ Singleton agent using Llama-3.2-3B-Instruct for all Mimir operations.
40
+ Thread-safe with ZeroGPU allocation management.
41
+
42
+ Used by:
43
+ - ToolDecisionAgent
44
+ - PromptRoutingAgents (all 4 agents)
45
+ - ThinkingAgents (all reasoning agents)
46
+ - ResponseAgent
47
+ """
48
+
49
+ _instance = None
50
+ _lock = threading.Lock()
51
+
52
+ def __new__(cls):
53
+ """Ensure only one instance exists (singleton pattern)"""
54
+ if cls._instance is None:
55
+ with cls._lock:
56
+ if cls._instance is None:
57
+ cls._instance = super().__new__(cls)
58
+ return cls._instance
59
+
60
+ def __init__(self):
61
+ """Initialize only once"""
62
+ if hasattr(self, '_initialized'):
63
+ return
64
+
65
+ self.pipe = None
66
+ self.tokenizer = None
67
+ self.model = None
68
+ self.model_loaded = False
69
+ self._initialized = True
70
+ logger.info("LlamaSharedAgent instance created (singleton)")
71
+
72
+ @spaces.GPU(duration=120)
73
+ def _ensure_loaded(self):
74
+ """
75
+ Load model with GPU allocation (ZeroGPU).
76
+ Only ONE @spaces.GPU decorator for Llama across entire app!
77
+ """
78
+ if self.model_loaded:
79
+ logger.info("✅ Llama-3.2-3B already loaded, reusing existing instance")
80
+ return
81
+
82
+ logger.info("="*60)
83
+ logger.info("LOADING SHARED LLAMA-3.2-3B-INSTRUCT")
84
+ logger.info("="*60)
85
+
86
+ try:
87
+ # 4-bit quantization config for memory efficiency
88
+ quantization_config = BitsAndBytesConfig(
89
+ load_in_4bit=True,
90
+ bnb_4bit_quant_type="nf4",
91
+ bnb_4bit_compute_dtype=torch.bfloat16,
92
+ bnb_4bit_use_double_quant=True,
93
+ )
94
+
95
+ logger.info(f"Loading model: {LLAMA_MODEL_ID}")
96
+ logger.info("Configuration: 4-bit NF4 quantization")
97
+
98
+ # Load tokenizer
99
+ self.tokenizer = AutoTokenizer.from_pretrained(
100
+ LLAMA_MODEL_ID,
101
+ token=HF_TOKEN,
102
+ trust_remote_code=True,
103
+ )
104
+
105
+ # Load model with quantization
106
+ self.model = AutoModelForCausalLM.from_pretrained(
107
+ LLAMA_MODEL_ID,
108
+ quantization_config=quantization_config,
109
+ device_map="auto",
110
+ token=HF_TOKEN,
111
+ trust_remote_code=True,
112
+ torch_dtype=torch.bfloat16,
113
+ )
114
+
115
+ # Create pipeline
116
+ self.pipe = pipeline(
117
+ "text-generation",
118
+ model=self.model,
119
+ tokenizer=self.tokenizer,
120
+ torch_dtype=torch.bfloat16,
121
+ device_map="auto",
122
+ )
123
+
124
+ self.model_loaded = True
125
+
126
+ logger.info("="*60)
127
+ logger.info("✅ SHARED LLAMA-3.2-3B LOADED SUCCESSFULLY")
128
+ logger.info(f" Model: {LLAMA_MODEL_ID}")
129
+ logger.info(f" Quantization: 4-bit NF4")
130
+ logger.info(f" Memory: ~1GB VRAM (vs 3.3GB GGUF)")
131
+ logger.info(" Context: 128K tokens")
132
+ logger.info(" This model will be reused by:")
133
+ logger.info(" - ToolDecisionAgent")
134
+ logger.info(" - PromptRoutingAgents (all 4 agents)")
135
+ logger.info(" - ThinkingAgents (all reasoning)")
136
+ logger.info(" - ResponseAgent (final responses)")
137
+ logger.info("="*60)
138
+
139
+ except Exception as e:
140
+ logger.error(f"Failed to load Llama-3.2-3B: {e}")
141
+ raise
142
+
143
+ def generate(
144
+ self,
145
+ system_prompt: str,
146
+ user_message: str,
147
+ max_tokens: int = 100,
148
+ temperature: float = 0.7,
149
+ stop_sequences: Optional[List[str]] = None
150
+ ) -> str:
151
+ """
152
+ Generate response using shared Llama-3.2-3B model.
153
+
154
+ Args:
155
+ system_prompt: System instruction
156
+ user_message: User query
157
+ max_tokens: Max tokens to generate
158
+ temperature: Sampling temperature
159
+ stop_sequences: Optional list of stop sequences (not used with pipeline)
160
+
161
+ Returns:
162
+ Generated text
163
+ """
164
+ # Ensure model is loaded (triggers @spaces.GPU only once)
165
+ self._ensure_loaded()
166
+
167
+ # Format messages using Llama 3.2 chat template (handled automatically)
168
+ messages = [
169
+ {"role": "system", "content": system_prompt},
170
+ {"role": "user", "content": user_message},
171
+ ]
172
+
173
+ try:
174
+ # Generate using pipeline
175
+ outputs = self.pipe(
176
+ messages,
177
+ max_new_tokens=max_tokens,
178
+ temperature=temperature,
179
+ do_sample=True,
180
+ top_p=0.9,
181
+ top_k=40,
182
+ repetition_penalty=1.15,
183
+ )
184
+
185
+ # Extract generated text (pipeline returns full conversation)
186
+ result = outputs[0]["generated_text"][-1]["content"]
187
+
188
+ return result.strip()
189
+
190
+ except Exception as e:
191
+ logger.error(f"Generation error: {e}")
192
+ return ""
193
+
194
+ def generate_streaming(
195
+ self,
196
+ system_prompt: str,
197
+ user_message: str,
198
+ max_tokens: int = 512,
199
+ temperature: float = 0.7,
200
+ ):
201
+ """
202
+ Generate response with streaming (for ResponseAgent).
203
+
204
+ Yields:
205
+ str: Generated text chunks
206
+ """
207
+ self._ensure_loaded()
208
+
209
+ messages = [
210
+ {"role": "system", "content": system_prompt},
211
+ {"role": "user", "content": user_message},
212
+ ]
213
+
214
+ try:
215
+ # Use TextIteratorStreamer for streaming
216
+ from transformers import TextIteratorStreamer
217
+ from threading import Thread
218
+
219
+ # Apply chat template
220
+ input_ids = self.tokenizer.apply_chat_template(
221
+ messages,
222
+ add_generation_prompt=True,
223
+ return_tensors="pt"
224
+ ).to(self.model.device)
225
+
226
+ streamer = TextIteratorStreamer(
227
+ self.tokenizer,
228
+ skip_prompt=True,
229
+ skip_special_tokens=True
230
+ )
231
+
232
+ generation_kwargs = dict(
233
+ input_ids=input_ids,
234
+ streamer=streamer,
235
+ max_new_tokens=max_tokens,
236
+ temperature=temperature,
237
+ do_sample=True,
238
+ top_p=0.9,
239
+ top_k=40,
240
+ repetition_penalty=1.15,
241
+ )
242
+
243
+ # Generate in separate thread
244
+ thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
245
+ thread.start()
246
+
247
+ # Yield generated text
248
+ for text in streamer:
249
+ yield text
250
+
251
+ except Exception as e:
252
+ logger.error(f"Streaming generation error: {e}")
253
+ yield ""
254
+
255
+ def get_model_info(self) -> dict:
256
+ """Get model information for diagnostics"""
257
+ return {
258
+ "status": "loaded" if self.model_loaded else "not_loaded",
259
+ "model_id": LLAMA_MODEL_ID,
260
+ "model_type": "llama-3.2-3b-instruct",
261
+ "quantization": "4-bit NF4",
262
+ "size_gb": 1.0,
263
+ "context_length": 128000,
264
+ "zerogpu_ready": True,
265
+ "transformers_pipeline": True,
266
+ "shared_instance": True,
267
+ }
268
+
269
+
270
+ # Global singleton instance
271
+ _shared_llama = None
272
+
273
+ def get_shared_llama() -> LlamaSharedAgent:
274
+ """Get or create the shared Llama-3.2-3B agent instance"""
275
+ global _shared_llama
276
+ if _shared_llama is None:
277
+ _shared_llama = LlamaSharedAgent()
278
+ return _shared_llama
279
+
280
+
281
+ # Backwards compatibility aliases
282
+ Qwen3SharedAgent = LlamaSharedAgent
283
+ MistralSharedAgent = LlamaSharedAgent
284
+ get_shared_qwen3 = get_shared_llama
285
+ get_shared_mistral = get_shared_llama