jdesiree commited on
Commit
3ed10cd
·
verified ·
1 Parent(s): 7ea174c

Delete shared_models.py

Browse files
Files changed (1) hide show
  1. shared_models.py +0 -285
shared_models.py DELETED
@@ -1,285 +0,0 @@
1
- # shared_models.py
2
- """
3
- Shared model manager for Mimir agents.
4
- Uses Llama-3.2-3B-Instruct with transformers for all agents.
5
- """
6
- import torch
7
- import threading
8
- import logging
9
- import os
10
- from typing import Optional, List
11
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
- # ZeroGPU support
16
- try:
17
- import spaces
18
- HF_SPACES_AVAILABLE = True
19
- logger.info("✅ ZeroGPU (spaces) available")
20
- except ImportError:
21
- HF_SPACES_AVAILABLE = False
22
- class DummySpaces:
23
- @staticmethod
24
- def GPU(duration=90):
25
- def decorator(func):
26
- return func
27
- return decorator
28
- spaces = DummySpaces()
29
- logger.warning("⚠️ ZeroGPU not available - running without GPU allocation")
30
-
31
- HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
32
-
33
- # Model configuration
34
- LLAMA_MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
35
-
36
-
37
- class LlamaSharedAgent:
38
- """
39
- Singleton agent using Llama-3.2-3B-Instruct for all Mimir operations.
40
- Thread-safe with ZeroGPU allocation management.
41
-
42
- Used by:
43
- - ToolDecisionAgent
44
- - PromptRoutingAgents (all 4 agents)
45
- - ThinkingAgents (all reasoning agents)
46
- - ResponseAgent
47
- """
48
-
49
- _instance = None
50
- _lock = threading.Lock()
51
-
52
- def __new__(cls):
53
- """Ensure only one instance exists (singleton pattern)"""
54
- if cls._instance is None:
55
- with cls._lock:
56
- if cls._instance is None:
57
- cls._instance = super().__new__(cls)
58
- return cls._instance
59
-
60
- def __init__(self):
61
- """Initialize only once"""
62
- if hasattr(self, '_initialized'):
63
- return
64
-
65
- self.pipe = None
66
- self.tokenizer = None
67
- self.model = None
68
- self.model_loaded = False
69
- self._initialized = True
70
- logger.info("LlamaSharedAgent instance created (singleton)")
71
-
72
- @spaces.GPU(duration=120)
73
- def _ensure_loaded(self):
74
- """
75
- Load model with GPU allocation (ZeroGPU).
76
- Only ONE @spaces.GPU decorator for Llama across entire app!
77
- """
78
- if self.model_loaded:
79
- logger.info("✅ Llama-3.2-3B already loaded, reusing existing instance")
80
- return
81
-
82
- logger.info("="*60)
83
- logger.info("LOADING SHARED LLAMA-3.2-3B-INSTRUCT")
84
- logger.info("="*60)
85
-
86
- try:
87
- # 4-bit quantization config for memory efficiency
88
- quantization_config = BitsAndBytesConfig(
89
- load_in_4bit=True,
90
- bnb_4bit_quant_type="nf4",
91
- bnb_4bit_compute_dtype=torch.bfloat16,
92
- bnb_4bit_use_double_quant=True,
93
- )
94
-
95
- logger.info(f"Loading model: {LLAMA_MODEL_ID}")
96
- logger.info("Configuration: 4-bit NF4 quantization")
97
-
98
- # Load tokenizer
99
- self.tokenizer = AutoTokenizer.from_pretrained(
100
- LLAMA_MODEL_ID,
101
- token=HF_TOKEN,
102
- trust_remote_code=True,
103
- )
104
-
105
- # Load model with quantization
106
- self.model = AutoModelForCausalLM.from_pretrained(
107
- LLAMA_MODEL_ID,
108
- quantization_config=quantization_config,
109
- device_map="auto",
110
- token=HF_TOKEN,
111
- trust_remote_code=True,
112
- torch_dtype=torch.bfloat16,
113
- )
114
-
115
- # Create pipeline
116
- self.pipe = pipeline(
117
- "text-generation",
118
- model=self.model,
119
- tokenizer=self.tokenizer,
120
- torch_dtype=torch.bfloat16,
121
- device_map="auto",
122
- )
123
-
124
- self.model_loaded = True
125
-
126
- logger.info("="*60)
127
- logger.info("✅ SHARED LLAMA-3.2-3B LOADED SUCCESSFULLY")
128
- logger.info(f" Model: {LLAMA_MODEL_ID}")
129
- logger.info(f" Quantization: 4-bit NF4")
130
- logger.info(f" Memory: ~1GB VRAM (vs 3.3GB GGUF)")
131
- logger.info(" Context: 128K tokens")
132
- logger.info(" This model will be reused by:")
133
- logger.info(" - ToolDecisionAgent")
134
- logger.info(" - PromptRoutingAgents (all 4 agents)")
135
- logger.info(" - ThinkingAgents (all reasoning)")
136
- logger.info(" - ResponseAgent (final responses)")
137
- logger.info("="*60)
138
-
139
- except Exception as e:
140
- logger.error(f"Failed to load Llama-3.2-3B: {e}")
141
- raise
142
-
143
- def generate(
144
- self,
145
- system_prompt: str,
146
- user_message: str,
147
- max_tokens: int = 100,
148
- temperature: float = 0.7,
149
- stop_sequences: Optional[List[str]] = None
150
- ) -> str:
151
- """
152
- Generate response using shared Llama-3.2-3B model.
153
-
154
- Args:
155
- system_prompt: System instruction
156
- user_message: User query
157
- max_tokens: Max tokens to generate
158
- temperature: Sampling temperature
159
- stop_sequences: Optional list of stop sequences (not used with pipeline)
160
-
161
- Returns:
162
- Generated text
163
- """
164
- # Ensure model is loaded (triggers @spaces.GPU only once)
165
- self._ensure_loaded()
166
-
167
- # Format messages using Llama 3.2 chat template (handled automatically)
168
- messages = [
169
- {"role": "system", "content": system_prompt},
170
- {"role": "user", "content": user_message},
171
- ]
172
-
173
- try:
174
- # Generate using pipeline
175
- outputs = self.pipe(
176
- messages,
177
- max_new_tokens=max_tokens,
178
- temperature=temperature,
179
- do_sample=True,
180
- top_p=0.9,
181
- top_k=40,
182
- repetition_penalty=1.15,
183
- )
184
-
185
- # Extract generated text (pipeline returns full conversation)
186
- result = outputs[0]["generated_text"][-1]["content"]
187
-
188
- return result.strip()
189
-
190
- except Exception as e:
191
- logger.error(f"Generation error: {e}")
192
- return ""
193
-
194
- def generate_streaming(
195
- self,
196
- system_prompt: str,
197
- user_message: str,
198
- max_tokens: int = 512,
199
- temperature: float = 0.7,
200
- ):
201
- """
202
- Generate response with streaming (for ResponseAgent).
203
-
204
- Yields:
205
- str: Generated text chunks
206
- """
207
- self._ensure_loaded()
208
-
209
- messages = [
210
- {"role": "system", "content": system_prompt},
211
- {"role": "user", "content": user_message},
212
- ]
213
-
214
- try:
215
- # Use TextIteratorStreamer for streaming
216
- from transformers import TextIteratorStreamer
217
- from threading import Thread
218
-
219
- # Apply chat template
220
- input_ids = self.tokenizer.apply_chat_template(
221
- messages,
222
- add_generation_prompt=True,
223
- return_tensors="pt"
224
- ).to(self.model.device)
225
-
226
- streamer = TextIteratorStreamer(
227
- self.tokenizer,
228
- skip_prompt=True,
229
- skip_special_tokens=True
230
- )
231
-
232
- generation_kwargs = dict(
233
- input_ids=input_ids,
234
- streamer=streamer,
235
- max_new_tokens=max_tokens,
236
- temperature=temperature,
237
- do_sample=True,
238
- top_p=0.9,
239
- top_k=40,
240
- repetition_penalty=1.15,
241
- )
242
-
243
- # Generate in separate thread
244
- thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
245
- thread.start()
246
-
247
- # Yield generated text
248
- for text in streamer:
249
- yield text
250
-
251
- except Exception as e:
252
- logger.error(f"Streaming generation error: {e}")
253
- yield ""
254
-
255
- def get_model_info(self) -> dict:
256
- """Get model information for diagnostics"""
257
- return {
258
- "status": "loaded" if self.model_loaded else "not_loaded",
259
- "model_id": LLAMA_MODEL_ID,
260
- "model_type": "llama-3.2-3b-instruct",
261
- "quantization": "4-bit NF4",
262
- "size_gb": 1.0,
263
- "context_length": 128000,
264
- "zerogpu_ready": True,
265
- "transformers_pipeline": True,
266
- "shared_instance": True,
267
- }
268
-
269
-
270
- # Global singleton instance
271
- _shared_llama = None
272
-
273
- def get_shared_llama() -> LlamaSharedAgent:
274
- """Get or create the shared Llama-3.2-3B agent instance"""
275
- global _shared_llama
276
- if _shared_llama is None:
277
- _shared_llama = LlamaSharedAgent()
278
- return _shared_llama
279
-
280
-
281
- # Backwards compatibility aliases
282
- Qwen3SharedAgent = LlamaSharedAgent
283
- MistralSharedAgent = LlamaSharedAgent
284
- get_shared_qwen3 = get_shared_llama
285
- get_shared_mistral = get_shared_llama