jdesiree commited on
Commit
7e90504
·
verified ·
1 Parent(s): d357d93

Update model_manager.py

Browse files
Files changed (1) hide show
  1. model_manager.py +201 -872
model_manager.py CHANGED
@@ -1,108 +1,20 @@
1
  # model_manager.py
2
- # agents.py
3
  """
4
- Unified agent architecture for Mimir Educational AI Assistant.
5
 
6
- LAZY-LOADING LLAMA-3.2-3B-INSTRUCT
7
-
8
- Components:
9
- - LazyLlamaModel: Singleton lazy-loading model (loads on first use, cached thereafter)
10
- - ToolDecisionAgent: Uses lazy-loaded Llama for visualization decisions
11
- - PromptRoutingAgents: Uses lazy-loaded Llama for all 4 routing agents
12
- - ThinkingAgents: Uses lazy-loaded Llama for all reasoning (including math)
13
- - ResponseAgent: Uses lazy-loaded Llama for final responses
14
-
15
- Key optimization: Model loads on first generate() call and is cached for all
16
- subsequent requests. Single model architecture with ~1GB memory footprint.
17
- No compile or warmup scripts needed - fully automatic.
18
  """
19
 
20
  import os
21
- import re
22
  import torch
23
  import logging
24
- import time
25
- import subprocess
26
- import threading
27
- from datetime import datetime
28
- from typing import Dict, List, Optional, Tuple, Type
29
- import warnings
30
-
31
- # Setup main logger first
32
- logging.basicConfig(level=logging.INFO)
33
- logger = logging.getLogger(__name__)
34
-
35
- # ============================================================================
36
- # MEMORY PROFILING UTILITIES
37
- # ============================================================================
38
-
39
- def log_memory(tag=""):
40
- """Log current GPU memory usage"""
41
- try:
42
- if torch.cuda.is_available():
43
- allocated = torch.cuda.memory_allocated() / 1024**2
44
- reserved = torch.cuda.memory_reserved() / 1024**2
45
- max_allocated = torch.cuda.max_memory_allocated() / 1024**2
46
- logger.info(f"[{tag}] GPU Memory - Allocated: {allocated:.2f} MB, Reserved: {reserved:.2f} MB, Peak: {max_allocated:.2f} MB")
47
- else:
48
- logger.info(f"[{tag}] No CUDA available")
49
- except Exception as e:
50
- logger.warning(f"[{tag}] Error logging GPU memory: {e}")
51
-
52
-
53
- def log_nvidia_smi(tag=""):
54
- """Log full nvidia-smi output for system-wide GPU view"""
55
- try:
56
- output = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv,noheader,nounits'], encoding='utf-8')
57
- logger.info(f"[{tag}] NVIDIA-SMI: {output.strip()}")
58
- except Exception as e:
59
- logger.warning(f"[{tag}] Error running nvidia-smi: {e}")
60
-
61
-
62
- def log_step(step_name, start_time=None):
63
- """Log a pipeline step with timestamp and duration"""
64
- now = time.time()
65
- timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3]
66
-
67
- if start_time:
68
- duration = now - start_time
69
- logger.info(f"[{timestamp}] ✓ {step_name} completed in {duration:.2f}s")
70
- else:
71
- logger.info(f"[{timestamp}] → {step_name} starting...")
72
-
73
- return now
74
-
75
-
76
- def profile_generation(model, tokenizer, inputs, **gen_kwargs):
77
- """Profile memory and time for model.generate() call"""
78
- torch.cuda.empty_cache()
79
- torch.cuda.reset_peak_memory_stats()
80
-
81
- log_memory("Before generate()")
82
- start_time = time.time()
83
-
84
- with torch.no_grad():
85
- outputs = model.generate(**inputs, **gen_kwargs)
86
-
87
- end_time = time.time()
88
- duration = end_time - start_time
89
- peak_memory = torch.cuda.max_memory_allocated() / 1024**2
90
-
91
- log_memory("After generate()")
92
- logger.info(f"Generation completed in {duration:.2f}s. Peak GPU: {peak_memory:.2f} MB")
93
-
94
- return outputs, duration
95
-
96
-
97
- # ============================================================================
98
- # IMPORTS
99
- # ============================================================================
100
-
101
- # Transformers for standard models
102
  from transformers import (
103
  AutoTokenizer,
104
  AutoModelForCausalLM,
105
  BitsAndBytesConfig,
 
106
  )
107
 
108
  # ZeroGPU support
@@ -119,823 +31,240 @@ except ImportError:
119
  return decorator
120
  spaces = DummySpaces()
121
 
122
- # Accelerate
123
- from accelerate import Accelerator
124
- from accelerate.utils import set_seed
125
-
126
- # LangChain Core for proper message handling
127
- from langchain_core.runnables import Runnable
128
- from langchain_core.runnables.utils import Input, Output
129
- from langchain_core.messages import SystemMessage, HumanMessage
130
-
131
- # Import ALL prompts from prompt library
132
- from prompt_library import (
133
- # System prompts
134
- CORE_IDENTITY,
135
- TOOL_DECISION,
136
- agent_1_system,
137
- agent_2_system,
138
- agent_3_system,
139
- agent_4_system,
140
-
141
- # Thinking agent system prompts
142
- MATH_THINKING,
143
- QUESTION_ANSWER_DESIGN,
144
- REASONING_THINKING,
145
-
146
- # Response agent prompts (dynamically applied)
147
- VAUGE_INPUT,
148
- USER_UNDERSTANDING,
149
- GENERAL_FORMATTING,
150
- LATEX_FORMATTING,
151
- GUIDING_TEACHING,
152
- STRUCTURE_PRACTICE_QUESTIONS,
153
- PRACTICE_QUESTION_FOLLOWUP,
154
- TOOL_USE_ENHANCEMENT,
155
- )
156
-
157
- # ============================================================================
158
- # MODEL MANAGER - LAZY LOADING
159
- # ============================================================================
160
- # Import the lazy-loading Llama-3.2-3B model manager
161
- from model_manager import get_model as get_shared_llama, LazyLlamaModel as LlamaSharedAgent
162
-
163
- # Backwards compatibility aliases
164
- get_shared_mistral = get_shared_llama
165
- MistralSharedAgent = LlamaSharedAgent
166
-
167
- # ============================================================================
168
- # CONFIGURATION
169
- # ============================================================================
170
 
171
- CACHE_DIR = "/tmp/compiled_models"
 
172
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
173
 
174
- # Suppress warnings
175
- warnings.filterwarnings("ignore", category=UserWarning)
176
- warnings.filterwarnings("ignore", category=FutureWarning)
177
-
178
- # Model info (for logging/diagnostics)
179
- LLAMA_MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
180
-
181
-
182
- def check_model_cache() -> Dict[str, bool]:
183
- """Check model status (legacy function for compatibility)"""
184
- cache_status = {
185
- "llama": True, # Lazy-loaded on first use
186
- "all_compiled": True,
187
- }
188
-
189
- logger.info("✓ Llama-3.2-3B uses lazy loading (loads on first generate() call)")
190
-
191
- return cache_status
192
-
193
-
194
- # Call at module load
195
- _cache_status = check_model_cache()
196
- log_memory("Module load complete")
197
-
198
-
199
- # ============================================================================
200
- # TOOL DECISION AGENT
201
- # ============================================================================
202
 
203
- class ToolDecisionAgent:
204
  """
205
- Analyzes if visualization/graphing tools should be used.
206
 
207
- Uses lazy-loaded Llama-3.2-3B for decision-making.
208
- Model loads automatically on first use.
209
-
210
- Returns: Boolean (True = use tools, False = skip tools)
211
  """
212
 
213
- def __init__(self):
214
- """Initialize with lazy-loaded Llama model"""
215
- self.model = get_shared_llama()
216
- logger.info("ToolDecisionAgent initialized (using lazy-loaded Llama)")
217
-
218
- def decide(self, user_query: str, conversation_history: List[Dict]) -> bool:
219
- """
220
- Decide if graphing tools should be used.
221
-
222
- Args:
223
- user_query: Current user message
224
- conversation_history: Full conversation context
225
-
226
- Returns:
227
- bool: True if tools should be used
228
- """
229
- logger.info("→ ToolDecisionAgent: Analyzing query for tool usage")
230
-
231
- # Format conversation context
232
- context = "\n".join([
233
- f"{msg['role']}: {msg['content']}"
234
- for msg in conversation_history[-3:] # Last 3 turns
235
- ])
236
-
237
- # Decision prompt
238
- analysis_prompt = f"""Previous conversation:
239
- {context}
240
-
241
- Current query: {user_query}
242
-
243
- Should visualization tools (graphs, charts) be used?"""
244
-
245
- try:
246
- decision_start = time.time()
247
-
248
- # Use shared Llama for decision
249
- response = self.model.generate(
250
- system_prompt=TOOL_DECISION,
251
- user_message=analysis_prompt,
252
- max_tokens=10,
253
- temperature=0.1
254
- )
255
-
256
- decision_time = time.time() - decision_start
257
-
258
- # Parse decision
259
- decision = "YES" in response.upper()
260
-
261
- logger.info(f"✓ ToolDecision: {'USE TOOLS' if decision else 'NO TOOLS'} ({decision_time:.2f}s)")
262
-
263
- return decision
264
-
265
- except Exception as e:
266
- logger.error(f"ToolDecisionAgent error: {e}")
267
- return False # Default: no tools
268
-
269
-
270
- # ============================================================================
271
- # PROMPT ROUTING AGENTS (4 Specialized Agents)
272
- # ============================================================================
273
-
274
- class PromptRoutingAgents:
275
- """
276
- Four specialized agents for prompt segment selection.
277
- All share the same Llama-3.2-3B instance for efficiency.
278
 
279
- Agents:
280
- 1. Practice Question Detector
281
- 2. Discovery Mode Classifier
282
- 3. Follow-up Assessment
283
- 4. Teaching Mode Assessor
284
- """
285
 
286
  def __init__(self):
287
- """Initialize with shared Llama model"""
288
- self.model = get_shared_llama()
289
- logger.info("PromptRoutingAgents initialized (4 agents, shared Llama)")
290
-
291
- def agent_1_practice_question(
292
- self,
293
- user_query: str,
294
- conversation_history: List[Dict]
295
- ) -> bool:
296
- """Agent 1: Detect if practice questions should be generated"""
297
- logger.info("→ Agent 1: Analyzing for practice question opportunity")
298
-
299
- context = "\n".join([
300
- f"{msg['role']}: {msg['content']}"
301
- for msg in conversation_history[-4:]
302
- ])
303
-
304
- analysis_prompt = f"""Conversation:
305
- {context}
306
-
307
- New query: {user_query}
308
-
309
- Should I create practice questions?"""
310
-
311
- try:
312
- response = self.model.generate(
313
- system_prompt=agent_1_system,
314
- user_message=analysis_prompt,
315
- max_tokens=10,
316
- temperature=0.1
317
- )
318
-
319
- decision = "YES" in response.upper()
320
- logger.info(f"✓ Agent 1: {'PRACTICE QUESTIONS' if decision else 'NO PRACTICE'}")
321
-
322
- return decision
323
-
324
- except Exception as e:
325
- logger.error(f"Agent 1 error: {e}")
326
- return False
327
-
328
- def agent_2_discovery_mode(
329
- self,
330
- user_query: str,
331
- conversation_history: List[Dict]
332
- ) -> Tuple[bool, bool]:
333
- """Agent 2: Classify vague input and understanding level"""
334
- logger.info("→ Agent 2: Classifying discovery mode")
335
-
336
- context = "\n".join([
337
- f"{msg['role']}: {msg['content']}"
338
- for msg in conversation_history[-3:]
339
- ])
340
-
341
- analysis_prompt = f"""Conversation:
342
- {context}
343
-
344
- Query: {user_query}
345
-
346
- Classification:
347
- 1. Is input vague? (VAGUE/CLEAR)
348
- 2. Understanding level? (LOW/MEDIUM/HIGH)"""
349
-
350
- try:
351
- response = self.model.generate(
352
- system_prompt=agent_2_system,
353
- user_message=analysis_prompt,
354
- max_tokens=20,
355
- temperature=0.1
356
- )
357
-
358
- vague = "VAGUE" in response.upper()
359
- low_understanding = "LOW" in response.upper()
360
-
361
- logger.info(f"✓ Agent 2: Vague={vague}, LowUnderstanding={low_understanding}")
362
-
363
- return vague, low_understanding
364
-
365
- except Exception as e:
366
- logger.error(f"Agent 2 error: {e}")
367
- return False, False
368
-
369
- def agent_3_followup_assessment(
370
- self,
371
- user_query: str,
372
- conversation_history: List[Dict]
373
- ) -> bool:
374
- """Agent 3: Detect if user is responding to practice questions"""
375
- logger.info("→ Agent 3: Checking for practice question follow-up")
376
-
377
- # Check last bot message for practice question indicators
378
- if len(conversation_history) < 2:
379
- return False
380
-
381
- last_bot_msg = None
382
- for msg in reversed(conversation_history):
383
- if msg['role'] == 'assistant':
384
- last_bot_msg = msg['content']
385
- break
386
-
387
- if not last_bot_msg:
388
- return False
389
-
390
- # Look for practice question markers
391
- has_practice = any(marker in last_bot_msg.lower() for marker in [
392
- "practice", "try this", "solve", "calculate", "what is", "question"
393
- ])
394
-
395
- if not has_practice:
396
- return False
397
-
398
- # Analyze if current query is an answer attempt
399
- analysis_prompt = f"""Previous message (from me):
400
- {last_bot_msg[:500]}
401
-
402
- User response:
403
- {user_query}
404
-
405
- Is user answering a practice question?"""
406
-
407
- try:
408
- response = self.model.generate(
409
- system_prompt=agent_3_system,
410
- user_message=analysis_prompt,
411
- max_tokens=10,
412
- temperature=0.1
413
- )
414
-
415
- is_followup = "YES" in response.upper()
416
- logger.info(f"✓ Agent 3: {'GRADING MODE' if is_followup else 'NOT FOLLOWUP'}")
417
-
418
- return is_followup
419
-
420
- except Exception as e:
421
- logger.error(f"Agent 3 error: {e}")
422
- return False
423
-
424
- def agent_4_teaching_mode(
425
- self,
426
- user_query: str,
427
- conversation_history: List[Dict]
428
- ) -> Tuple[bool, bool]:
429
- """Agent 4: Assess teaching vs practice mode"""
430
- logger.info("→ Agent 4: Assessing teaching mode")
431
-
432
- context = "\n".join([
433
- f"{msg['role']}: {msg['content']}"
434
- for msg in conversation_history[-3:]
435
- ])
436
-
437
- analysis_prompt = f"""Conversation:
438
- {context}
439
-
440
- Query: {user_query}
441
-
442
- Assessment:
443
- 1. Need direct teaching? (TEACH/PRACTICE)
444
- 2. Create practice questions? (YES/NO)"""
445
-
446
- try:
447
- response = self.model.generate(
448
- system_prompt=agent_4_system,
449
- user_message=analysis_prompt,
450
- max_tokens=15,
451
- temperature=0.1
452
- )
453
-
454
- teaching = "TEACH" in response.upper()
455
- practice = "YES" in response.upper() or "PRACTICE" in response.upper()
456
-
457
- logger.info(f"✓ Agent 4: Teaching={teaching}, Practice={practice}")
458
-
459
- return teaching, practice
460
-
461
- except Exception as e:
462
- logger.error(f"Agent 4 error: {e}")
463
- return False, False
464
-
465
- def process(
466
- self,
467
- user_input: str,
468
- tool_used: bool = False,
469
- conversation_history: Optional[List[Dict]] = None
470
- ) -> Tuple[str, str]:
471
- """
472
- Unified process method - runs all 4 routing agents sequentially.
473
-
474
- Returns:
475
- Tuple[str, str]: (response_prompts, thinking_prompts)
476
- """
477
- if conversation_history is None:
478
- conversation_history = []
479
-
480
- response_prompts = []
481
- thinking_prompts = []
482
-
483
- # Agent 1: Practice Questions
484
- if self.agent_1_practice_question(user_input, conversation_history):
485
- response_prompts.append("STRUCTURE_PRACTICE_QUESTIONS")
486
 
487
- # Agent 2: Discovery Mode
488
- is_vague, low_understanding = self.agent_2_discovery_mode(user_input, conversation_history)
489
- if is_vague:
490
- response_prompts.append("VAUGE_INPUT")
491
- if low_understanding:
492
- response_prompts.append("USER_UNDERSTANDING")
493
 
494
- # Agent 3: Follow-up Assessment
495
- if self.agent_3_followup_assessment(user_input, conversation_history):
496
- response_prompts.append("PRACTICE_QUESTION_FOLLOWUP")
497
-
498
- # Agent 4: Teaching Mode
499
- needs_teaching, needs_practice = self.agent_4_teaching_mode(user_input, conversation_history)
500
- if needs_teaching:
501
- response_prompts.append("GUIDING_TEACHING")
502
-
503
- # Always add base formatting
504
- response_prompts.extend(["GENERAL_FORMATTING", "LATEX_FORMATTING"])
505
-
506
- # Tool enhancement if used
507
- if tool_used:
508
- response_prompts.append("TOOL_USE_ENHANCEMENT")
509
-
510
- # Return as newline-separated strings
511
- response_prompts_str = "\n".join(response_prompts)
512
- thinking_prompts_str = "" # Thinking prompts decided elsewhere
513
-
514
- return response_prompts_str, thinking_prompts_str
515
-
516
- # ============================================================================
517
- # THINKING AGENTS (Preprocessing Layer)
518
- # ============================================================================
519
-
520
- class ThinkingAgents:
521
- """
522
- Generates reasoning context before final response.
523
- Uses shared Llama-3.2-3B for all thinking (including math).
524
 
525
- Agents:
526
- 1. Math Thinking (Tree-of-Thought)
527
- 2. Q&A Design (Chain-of-Thought)
528
- 3. General Reasoning (Chain-of-Thought)
529
- """
530
-
531
- def __init__(self):
532
- """Initialize with shared Llama model"""
533
- self.model = get_shared_llama()
534
- logger.info("ThinkingAgents initialized (using shared Llama for all thinking)")
535
-
536
- def math_thinking(
537
- self,
538
- user_query: str,
539
- conversation_history: List[Dict],
540
- tool_context: str = ""
541
- ) -> str:
542
  """
543
- Generate mathematical reasoning using Tree-of-Thought.
544
- Now uses Llama-3.2-3B instead of GGUF.
545
  """
546
- logger.info("→ Math Thinking Agent: Generating reasoning")
547
-
548
- context = "\n".join([
549
- f"{msg['role']}: {msg['content']}"
550
- for msg in conversation_history[-3:]
551
- ])
552
-
553
- thinking_prompt = f"""Conversation context:
554
- {context}
555
-
556
- Current query: {user_query}
557
-
558
- {f"Tool output: {tool_context}" if tool_context else ""}
559
-
560
- Generate mathematical reasoning:"""
561
-
562
- try:
563
- thinking_start = time.time()
564
-
565
- reasoning = self.model.generate(
566
- system_prompt=MATH_THINKING,
567
- user_message=thinking_prompt,
568
- max_tokens=300,
569
- temperature=0.7
570
- )
571
-
572
- thinking_time = time.time() - thinking_start
573
- logger.info(f"✓ Math Thinking: Generated {len(reasoning)} chars ({thinking_time:.2f}s)")
574
-
575
- return reasoning
576
-
577
- except Exception as e:
578
- logger.error(f"Math Thinking error: {e}")
579
- return ""
580
-
581
- def qa_design_thinking(
582
- self,
583
- user_query: str,
584
- conversation_history: List[Dict],
585
- tool_context: str = ""
586
- ) -> str:
587
- """Generate practice question design reasoning"""
588
- logger.info("→ Q&A Design Agent: Generating question strategy")
589
-
590
- context = "\n".join([
591
- f"{msg['role']}: {msg['content']}"
592
- for msg in conversation_history[-3:]
593
- ])
594
-
595
- thinking_prompt = f"""Context:
596
- {context}
597
-
598
- Query: {user_query}
599
-
600
- {f"Tool data: {tool_context}" if tool_context else ""}
601
-
602
- Design practice questions:"""
603
 
604
- try:
605
- reasoning = self.model.generate(
606
- system_prompt=QUESTION_ANSWER_DESIGN,
607
- user_message=thinking_prompt,
608
- max_tokens=250,
609
- temperature=0.7
610
- )
611
-
612
- logger.info(f"✓ Q&A Design: Generated {len(reasoning)} chars")
613
-
614
- return reasoning
615
-
616
- except Exception as e:
617
- logger.error(f"Q&A Design error: {e}")
618
- return ""
619
-
620
- def process(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  self,
622
- user_input: str,
623
- conversation_history: str = "",
624
- thinking_prompts: str = "",
625
- tool_img_output: str = "",
626
- tool_context: str = ""
627
  ) -> str:
628
  """
629
- Unified process method - runs thinking agents based on active prompts.
630
 
631
- Args:
632
- user_input: User's query
633
- conversation_history: Formatted conversation history string
634
- thinking_prompts: Newline-separated list of thinking prompts to activate
635
- tool_img_output: HTML output from visualization tool
636
- tool_context: Context from tool usage
637
-
638
- Returns:
639
- str: Combined thinking context from all activated agents
640
  """
641
- thinking_outputs = []
642
-
643
- # Convert history string to list format for agent methods
644
- history_list = []
645
- if conversation_history and conversation_history != "No previous conversation":
646
- for line in conversation_history.split('\n'):
647
- if ':' in line:
648
- role, content = line.split(':', 1)
649
- history_list.append({'role': role.strip(), 'content': content.strip()})
650
-
651
- # Determine which thinking agents to run based on prompts
652
- prompt_list = [p.strip() for p in thinking_prompts.split('\n') if p.strip()]
653
-
654
- # Math Thinking
655
- if any('MATH' in p.upper() for p in prompt_list):
656
- math_output = self.math_thinking(
657
- user_query=user_input,
658
- conversation_history=history_list,
659
- tool_context=tool_context
660
- )
661
- if math_output:
662
- thinking_outputs.append(f"[Mathematical Reasoning]\n{math_output}")
663
-
664
- # Q&A Design Thinking
665
- if any('PRACTICE' in p.upper() or 'QUESTION' in p.upper() for p in prompt_list):
666
- qa_output = self.qa_design_thinking(
667
- user_query=user_input,
668
- conversation_history=history_list,
669
- tool_context=tool_context
670
- )
671
- if qa_output:
672
- thinking_outputs.append(f"[Practice Question Design]\n{qa_output}")
673
-
674
- # General Reasoning (fallback or when no specific thinking needed)
675
- if not thinking_outputs or any('REASONING' in p.upper() for p in prompt_list):
676
- general_output = self.general_reasoning(
677
- user_query=user_input,
678
- conversation_history=history_list,
679
- tool_context=tool_context
680
  )
681
- if general_output:
682
- thinking_outputs.append(f"[General Reasoning]\n{general_output}")
683
 
684
- # Combine all thinking outputs
685
- combined_thinking = "\n\n".join(thinking_outputs) if thinking_outputs else ""
 
 
 
686
 
687
- if combined_thinking:
688
- logger.info(f"✓ Thinking complete: {len(combined_thinking)} chars from {len(thinking_outputs)} agents")
689
-
690
- return combined_thinking
691
-
692
- def general_reasoning(
693
- self,
694
- user_query: str,
695
- conversation_history: List[Dict],
696
- tool_context: str = ""
697
- ) -> str:
698
- """Generate general reasoning context"""
699
- logger.info("→ General Reasoning Agent: Generating context")
700
-
701
- context = "\n".join([
702
- f"{msg['role']}: {msg['content']}"
703
- for msg in conversation_history[-4:]
704
- ])
705
 
706
- thinking_prompt = f"""Conversation:
707
- {context}
708
-
709
- Query: {user_query}
710
-
711
- {f"Context: {tool_context}" if tool_context else ""}
712
-
713
- Analyze and provide reasoning:"""
 
 
714
 
715
- try:
716
- reasoning = self.model.generate(
717
- system_prompt=REASONING_THINKING,
718
- user_message=thinking_prompt,
719
- max_tokens=200,
720
- temperature=0.7
721
- )
722
-
723
- logger.info(f"✓ General Reasoning: Generated {len(reasoning)} chars")
724
-
725
- return reasoning
726
-
727
- except Exception as e:
728
- logger.error(f"General Reasoning error: {e}")
729
- return ""
730
-
731
-
732
- # ============================================================================
733
- # RESPONSE AGENT (Final Response Generation)
734
- # ============================================================================
735
-
736
- class ResponseAgent(Runnable):
737
- """
738
- Generates final educational responses using lazy-loaded Llama-3.2-3B.
739
- Model loads automatically on first use.
740
-
741
- Features:
742
- - Dynamic prompt assembly based on agent decisions
743
- - Streaming word-by-word output
744
- - Educational tone enforcement
745
- - LaTeX support for math
746
- - Context integration (thinking outputs, tool outputs)
747
- """
748
-
749
- def __init__(self):
750
- """Initialize with lazy-loaded Llama model"""
751
- super().__init__()
752
- self.model = get_shared_llama()
753
- logger.info("ResponseAgent initialized (using lazy-loaded Llama)")
754
 
755
- def invoke(self, input_data: Dict) -> Dict:
 
 
 
 
 
 
 
756
  """
757
- Generate final response with streaming.
758
 
759
- Args:
760
- input_data: {
761
- 'user_query': str,
762
- 'conversation_history': List[Dict],
763
- 'active_prompts': List[str],
764
- 'thinking_context': str,
765
- 'tool_context': str,
766
- }
767
-
768
- Returns:
769
- {'response': str, 'metadata': Dict}
770
  """
771
- logger.info("→ ResponseAgent: Generating final response")
772
-
773
- # Extract inputs
774
- user_query = input_data.get('user_query', '')
775
- conversation_history = input_data.get('conversation_history', [])
776
- active_prompts = input_data.get('active_prompts', [])
777
- thinking_context = input_data.get('thinking_context', '')
778
- tool_context = input_data.get('tool_context', '')
779
-
780
- # Build system prompt from active segments
781
- system_prompt = self._build_system_prompt(active_prompts)
782
-
783
- # Build user message with context
784
- user_message = self._build_user_message(
785
- user_query,
786
- conversation_history,
787
- thinking_context,
788
- tool_context
789
- )
790
 
791
- try:
792
- response_start = time.time()
793
-
794
- # Generate response (streaming handled at app.py level)
795
- response = self.model.generate(
796
- system_prompt=system_prompt,
797
- user_message=user_message,
798
- max_tokens=600,
799
- temperature=0.7
800
  )
801
-
802
- response_time = time.time() - response_start
803
-
804
- # Clean up response
805
- response = self._clean_response(response)
806
-
807
- logger.info(f"✓ ResponseAgent: Generated {len(response)} chars ({response_time:.2f}s)")
808
-
809
- return {
810
- 'response': response,
811
- 'metadata': {
812
- 'generation_time': response_time,
813
- 'model': LLAMA_MODEL_ID,
814
- 'active_prompts': active_prompts
815
- }
816
- }
817
-
818
- except Exception as e:
819
- logger.error(f"ResponseAgent error: {e}")
820
- return {
821
- 'response': "I apologize, but I encountered an error generating a response. Please try again.",
822
- 'metadata': {'error': str(e)}
823
- }
824
-
825
- def _build_system_prompt(self, active_prompts: List[str]) -> str:
826
- """Assemble system prompt from active segments"""
827
- prompt_map = {
828
- 'CORE_IDENTITY': CORE_IDENTITY,
829
- 'GENERAL_FORMATTING': GENERAL_FORMATTING,
830
- 'LATEX_FORMATTING': LATEX_FORMATTING,
831
- 'VAUGE_INPUT': VAUGE_INPUT,
832
- 'USER_UNDERSTANDING': USER_UNDERSTANDING,
833
- 'GUIDING_TEACHING': GUIDING_TEACHING,
834
- 'STRUCTURE_PRACTICE_QUESTIONS': STRUCTURE_PRACTICE_QUESTIONS,
835
- 'PRACTICE_QUESTION_FOLLOWUP': PRACTICE_QUESTION_FOLLOWUP,
836
- 'TOOL_USE_ENHANCEMENT': TOOL_USE_ENHANCEMENT,
837
- }
838
-
839
- # Always include core identity
840
- segments = [CORE_IDENTITY, GENERAL_FORMATTING]
841
-
842
- # Add active prompts
843
- for prompt_name in active_prompts:
844
- if prompt_name in prompt_map and prompt_map[prompt_name] not in segments:
845
- segments.append(prompt_map[prompt_name])
846
-
847
- return "\n\n".join(segments)
848
-
849
- def _build_user_message(
850
- self,
851
- user_query: str,
852
- conversation_history: List[Dict],
853
- thinking_context: str,
854
- tool_context: str
855
- ) -> str:
856
- """Build user message with all context"""
857
- parts = []
858
-
859
- # Conversation history (last 3 turns)
860
- if conversation_history:
861
- history_text = "\n".join([
862
- f"{msg['role']}: {msg['content'][:200]}"
863
- for msg in conversation_history[-3:]
864
- ])
865
- parts.append(f"Recent conversation:\n{history_text}")
866
-
867
- # Thinking context (invisible to user, guides response)
868
- if thinking_context:
869
- parts.append(f"[Internal reasoning context]: {thinking_context}")
870
-
871
- # Tool context
872
- if tool_context:
873
- parts.append(f"[Tool output]: {tool_context}")
874
-
875
- # Current query
876
- parts.append(f"Student query: {user_query}")
877
-
878
- return "\n\n".join(parts)
879
-
880
- def _clean_response(self, response: str) -> str:
881
- """Clean up response artifacts"""
882
- # Remove common artifacts
883
- artifacts = ['<|im_end|>', '<|endoftext|>', '###', '<|end|>']
884
- for artifact in artifacts:
885
- response = response.replace(artifact, '')
886
-
887
- # Remove trailing incomplete sentences
888
- if response and response[-1] not in '.!?':
889
- # Find last complete sentence
890
- for delimiter in ['. ', '! ', '? ']:
891
- if delimiter in response:
892
- response = response.rsplit(delimiter, 1)[0] + delimiter[0]
893
- break
894
 
895
- return response.strip()
896
-
897
- def stream(self, input_data: Dict):
898
- """
899
- Stream response word-by-word.
900
 
901
- Yields:
902
- str: Response chunks
903
- """
904
- logger.info("→ ResponseAgent: Streaming response")
905
-
906
- # Build prompts
907
- system_prompt = self._build_system_prompt(input_data.get('active_prompts', []))
908
- user_message = self._build_user_message(
909
- input_data.get('user_query', ''),
910
- input_data.get('conversation_history', []),
911
- input_data.get('thinking_context', ''),
912
- input_data.get('tool_context', '')
913
  )
914
 
915
- try:
916
- # Use streaming generation from shared model
917
- for chunk in self.model.generate_streaming(
918
- system_prompt=system_prompt,
919
- user_message=user_message,
920
- max_tokens=600,
921
- temperature=0.7
922
- ):
923
- yield chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
924
 
925
- except Exception as e:
926
- logger.error(f"Streaming error: {e}")
927
- yield "I apologize, but I encountered an error. Please try again."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
928
 
929
 
930
- # ============================================================================
931
- # MODULE INITIALIZATION
932
- # ============================================================================
 
933
 
934
- logger.info("="*60)
935
- logger.info("MIMIR AGENTS MODULE INITIALIZED")
936
- logger.info("="*60)
937
- logger.info(f" Model: Llama-3.2-3B-Instruct (lazy-loaded)")
938
- logger.info(f" Agents: Tool, Routing (4x), Thinking (3x), Response")
939
- logger.info(f" Memory: ~1GB (loads on first use)")
940
- logger.info(f" Architecture: Single unified model with caching")
941
- logger.info("="*60)
 
1
  # model_manager.py
 
2
  """
3
+ Lazy-loading Llama-3.2-3B-Instruct with proper ZeroGPU context management.
4
 
5
+ KEY FIX: Each generate() call is wrapped with @spaces.GPU to ensure
6
+ the model is accessible during generation.
 
 
 
 
 
 
 
 
 
 
7
  """
8
 
9
  import os
 
10
  import torch
11
  import logging
12
+ from typing import Optional, Iterator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from transformers import (
14
  AutoTokenizer,
15
  AutoModelForCausalLM,
16
  BitsAndBytesConfig,
17
+ pipeline as create_pipeline
18
  )
19
 
20
  # ZeroGPU support
 
31
  return decorator
32
  spaces = DummySpaces()
33
 
34
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # Configuration
37
+ MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
38
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ class LazyLlamaModel:
42
  """
43
+ Singleton lazy-loading model with proper ZeroGPU context management.
44
 
45
+ CRITICAL FIX: Model components are loaded fresh within each @spaces.GPU
46
+ decorated call, ensuring GPU context is maintained throughout generation.
 
 
47
  """
48
 
49
+ _instance = None
50
+ _initialized = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ def __new__(cls):
53
+ if cls._instance is None:
54
+ cls._instance = super().__new__(cls)
55
+ return cls._instance
 
 
56
 
57
  def __init__(self):
58
+ if not self._initialized:
59
+ self.model_id = MODEL_ID
60
+ self.token = HF_TOKEN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # Don't load model here - load it inside GPU-decorated functions
63
+ self.tokenizer = None
64
+ self.model = None
65
+ self.pipeline = None
 
 
66
 
67
+ LazyLlamaModel._initialized = True
68
+ logger.info(f"LazyLlamaModel initialized (model will load on first generate)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ def _load_model_components(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  """
72
+ Load model components. Called INSIDE @spaces.GPU decorated functions.
73
+ This ensures GPU context is maintained.
74
  """
75
+ if self.model is not None and self.tokenizer is not None:
76
+ return # Already loaded in this context
77
+
78
+ logger.info("="*60)
79
+ logger.info("LOADING LLAMA-3.2-3B-INSTRUCT")
80
+ logger.info("="*60)
81
+
82
+ # Load tokenizer
83
+ logger.info(f"Loading: {self.model_id}")
84
+ self.tokenizer = AutoTokenizer.from_pretrained(
85
+ self.model_id,
86
+ token=self.token,
87
+ trust_remote_code=True
88
+ )
89
+ logger.info(f"✓ Tokenizer loaded: {type(self.tokenizer).__name__}")
90
+
91
+ # Configure 4-bit quantization
92
+ logger.info("Config: 4-bit NF4 quantization")
93
+ bnb_config = BitsAndBytesConfig(
94
+ load_in_4bit=True,
95
+ bnb_4bit_use_double_quant=True,
96
+ bnb_4bit_quant_type="nf4",
97
+ bnb_4bit_compute_dtype=torch.float16
98
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ # Load model with quantization
101
+ self.model = AutoModelForCausalLM.from_pretrained(
102
+ self.model_id,
103
+ quantization_config=bnb_config,
104
+ device_map="auto",
105
+ token=self.token,
106
+ trust_remote_code=True,
107
+ torch_dtype=torch.float16,
108
+ )
109
+ logger.info(f"✓ Model loaded: {type(self.model).__name__}")
110
+
111
+ # Create pipeline
112
+ self.pipeline = create_pipeline(
113
+ "text-generation",
114
+ model=self.model,
115
+ tokenizer=self.tokenizer,
116
+ device_map="auto"
117
+ )
118
+ logger.info("✓ Pipeline created and verified: TextGenerationPipeline")
119
+
120
+ logger.info("="*60)
121
+ logger.info("✅ MODEL LOADED & CACHED")
122
+ logger.info(f" Model: {self.model_id}")
123
+ logger.info(f" Tokenizer: {type(self.tokenizer).__name__}")
124
+ logger.info(f" Pipeline: {type(self.pipeline).__name__}")
125
+ logger.info(f" Memory: ~1GB VRAM")
126
+ logger.info(f" Context: 128K tokens")
127
+ logger.info("="*60)
128
+
129
+ @spaces.GPU(duration=90)
130
+ def generate(
131
  self,
132
+ system_prompt: str,
133
+ user_message: str,
134
+ max_tokens: int = 500,
135
+ temperature: float = 0.7
 
136
  ) -> str:
137
  """
138
+ Generate text with proper GPU context management.
139
 
140
+ CRITICAL: @spaces.GPU decorator ensures model stays in GPU context
141
+ throughout the entire generation process.
 
 
 
 
 
 
 
142
  """
143
+ # Load model components if not already loaded
144
+ self._load_model_components()
145
+
146
+ # Verify pipeline is available
147
+ if self.pipeline is None:
148
+ raise RuntimeError(
149
+ "Pipeline is None after loading. This may be a ZeroGPU context issue. "
150
+ "Check that _load_model_components() completed successfully."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
 
 
152
 
153
+ # Format prompt with chat template
154
+ messages = [
155
+ {"role": "system", "content": system_prompt},
156
+ {"role": "user", "content": user_message}
157
+ ]
158
 
159
+ prompt = self.tokenizer.apply_chat_template(
160
+ messages,
161
+ tokenize=False,
162
+ add_generation_prompt=True
163
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ # Generate
166
+ outputs = self.pipeline(
167
+ prompt,
168
+ max_new_tokens=max_tokens,
169
+ temperature=temperature,
170
+ do_sample=temperature > 0,
171
+ pad_token_id=self.tokenizer.eos_token_id,
172
+ eos_token_id=self.tokenizer.eos_token_id,
173
+ return_full_text=False
174
+ )
175
 
176
+ response = outputs[0]['generated_text']
177
+ return response.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
+ @spaces.GPU(duration=90)
180
+ def generate_streaming(
181
+ self,
182
+ system_prompt: str,
183
+ user_message: str,
184
+ max_tokens: int = 500,
185
+ temperature: float = 0.7
186
+ ) -> Iterator[str]:
187
  """
188
+ Generate text with streaming output.
189
 
190
+ CRITICAL: @spaces.GPU decorator ensures model stays in GPU context.
 
 
 
 
 
 
 
 
 
 
191
  """
192
+ # Load model components if not already loaded
193
+ self._load_model_components()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ # Verify pipeline is available
196
+ if self.pipeline is None:
197
+ raise RuntimeError(
198
+ "Pipeline is None after loading. This may be a ZeroGPU context issue."
 
 
 
 
 
199
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ # Format prompt
202
+ messages = [
203
+ {"role": "system", "content": system_prompt},
204
+ {"role": "user", "content": user_message}
205
+ ]
206
 
207
+ prompt = self.tokenizer.apply_chat_template(
208
+ messages,
209
+ tokenize=False,
210
+ add_generation_prompt=True
 
 
 
 
 
 
 
 
211
  )
212
 
213
+ # Tokenize
214
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
215
+
216
+ # Generate with streaming
217
+ last_output_len = 0
218
+
219
+ with torch.no_grad():
220
+ for _ in range(max_tokens):
221
+ outputs = self.model.generate(
222
+ **inputs,
223
+ max_new_tokens=1,
224
+ temperature=temperature,
225
+ do_sample=temperature > 0,
226
+ pad_token_id=self.tokenizer.eos_token_id,
227
+ eos_token_id=self.tokenizer.eos_token_id,
228
+ )
229
+
230
+ # Decode new tokens
231
+ current_output = self.tokenizer.decode(
232
+ outputs[0][inputs['input_ids'].shape[1]:],
233
+ skip_special_tokens=True
234
+ )
235
 
236
+ # Yield new content
237
+ if len(current_output) > last_output_len:
238
+ new_text = current_output[last_output_len:]
239
+ yield new_text
240
+ last_output_len = len(current_output)
241
+
242
+ # Check for EOS
243
+ if outputs[0][-1] == self.tokenizer.eos_token_id:
244
+ break
245
+
246
+ # Update inputs for next iteration
247
+ inputs = {
248
+ 'input_ids': outputs,
249
+ 'attention_mask': torch.ones_like(outputs)
250
+ }
251
+
252
+
253
+ # Singleton instance
254
+ _model_instance = None
255
+
256
+ def get_model() -> LazyLlamaModel:
257
+ """Get the singleton model instance"""
258
+ global _model_instance
259
+ if _model_instance is None:
260
+ _model_instance = LazyLlamaModel()
261
+ return _model_instance
262
 
263
 
264
+ # Backwards compatibility aliases (within same module - no import)
265
+ get_shared_llama = get_model
266
+ MistralSharedAgent = LazyLlamaModel
267
+ LlamaSharedAgent = LazyLlamaModel
268
 
269
+ # DO NOT ADD THIS LINE - IT CAUSES CIRCULAR IMPORT:
270
+ # from model_manager import get_model as get_shared_llama, LazyLlamaModel as LlamaSharedAgent