jdesiree commited on
Commit
ad38d0d
·
verified ·
1 Parent(s): 79845af

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +815 -0
agent.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # agents.py
2
+ """
3
+ Unified agent architecture for Mimir Educational AI Assistant.
4
+
5
+ MIGRATED TO LLAMA-3.2-3B-INSTRUCT
6
+
7
+ Components:
8
+ - LlamaSharedAgent: SINGLETON shared Llama-3.2-3B for ALL agents (loads ONCE)
9
+ - ToolDecisionAgent: Uses shared Llama for visualization decisions
10
+ - PromptRoutingAgents: Uses shared Llama for all 4 routing agents
11
+ - ThinkingAgents: Uses shared Llama for all reasoning (including math)
12
+ - ResponseAgent: Uses shared Llama for final responses
13
+
14
+ Key optimization: All agents share ONE Llama-3.2-3B instance to eliminate
15
+ redundant loading. Single model architecture with 1GB memory footprint.
16
+ """
17
+
18
+ import os
19
+ import re
20
+ import torch
21
+ import logging
22
+ import time
23
+ import subprocess
24
+ import threading
25
+ from datetime import datetime
26
+ from typing import Dict, List, Optional, Tuple, Type
27
+ import warnings
28
+
29
+ # Setup main logger first
30
+ logging.basicConfig(level=logging.INFO)
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # ============================================================================
34
+ # MEMORY PROFILING UTILITIES
35
+ # ============================================================================
36
+
37
+ def log_memory(tag=""):
38
+ """Log current GPU memory usage"""
39
+ try:
40
+ if torch.cuda.is_available():
41
+ allocated = torch.cuda.memory_allocated() / 1024**2
42
+ reserved = torch.cuda.memory_reserved() / 1024**2
43
+ max_allocated = torch.cuda.max_memory_allocated() / 1024**2
44
+ logger.info(f"[{tag}] GPU Memory - Allocated: {allocated:.2f} MB, Reserved: {reserved:.2f} MB, Peak: {max_allocated:.2f} MB")
45
+ else:
46
+ logger.info(f"[{tag}] No CUDA available")
47
+ except Exception as e:
48
+ logger.warning(f"[{tag}] Error logging GPU memory: {e}")
49
+
50
+
51
+ def log_nvidia_smi(tag=""):
52
+ """Log full nvidia-smi output for system-wide GPU view"""
53
+ try:
54
+ output = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv,noheader,nounits'], encoding='utf-8')
55
+ logger.info(f"[{tag}] NVIDIA-SMI: {output.strip()}")
56
+ except Exception as e:
57
+ logger.warning(f"[{tag}] Error running nvidia-smi: {e}")
58
+
59
+
60
+ def log_step(step_name, start_time=None):
61
+ """Log a pipeline step with timestamp and duration"""
62
+ now = time.time()
63
+ timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3]
64
+
65
+ if start_time:
66
+ duration = now - start_time
67
+ logger.info(f"[{timestamp}] ✓ {step_name} completed in {duration:.2f}s")
68
+ else:
69
+ logger.info(f"[{timestamp}] → {step_name} starting...")
70
+
71
+ return now
72
+
73
+
74
+ def profile_generation(model, tokenizer, inputs, **gen_kwargs):
75
+ """Profile memory and time for model.generate() call"""
76
+ torch.cuda.empty_cache()
77
+ torch.cuda.reset_peak_memory_stats()
78
+
79
+ log_memory("Before generate()")
80
+ start_time = time.time()
81
+
82
+ with torch.no_grad():
83
+ outputs = model.generate(**inputs, **gen_kwargs)
84
+
85
+ end_time = time.time()
86
+ duration = end_time - start_time
87
+ peak_memory = torch.cuda.max_memory_allocated() / 1024**2
88
+
89
+ log_memory("After generate()")
90
+ logger.info(f"Generation completed in {duration:.2f}s. Peak GPU: {peak_memory:.2f} MB")
91
+
92
+ return outputs, duration
93
+
94
+
95
+ # ============================================================================
96
+ # IMPORTS
97
+ # ============================================================================
98
+
99
+ # Transformers for standard models
100
+ from transformers import (
101
+ AutoTokenizer,
102
+ AutoModelForCausalLM,
103
+ BitsAndBytesConfig,
104
+ )
105
+
106
+ # ZeroGPU support
107
+ try:
108
+ import spaces
109
+ HF_SPACES_AVAILABLE = True
110
+ except ImportError:
111
+ HF_SPACES_AVAILABLE = False
112
+ class DummySpaces:
113
+ @staticmethod
114
+ def GPU(duration=90):
115
+ def decorator(func):
116
+ return func
117
+ return decorator
118
+ spaces = DummySpaces()
119
+
120
+ # Accelerate
121
+ from accelerate import Accelerator
122
+ from accelerate.utils import set_seed
123
+
124
+ # LangChain Core for proper message handling
125
+ from langchain_core.runnables import Runnable
126
+ from langchain_core.runnables.utils import Input, Output
127
+ from langchain_core.messages import SystemMessage, HumanMessage
128
+
129
+ # Import ALL prompts from prompt library
130
+ from prompt_library import (
131
+ # System prompts
132
+ CORE_IDENTITY,
133
+ TOOL_DECISION,
134
+ agent_1_system,
135
+ agent_2_system,
136
+ agent_3_system,
137
+ agent_4_system,
138
+
139
+ # Thinking agent system prompts
140
+ MATH_THINKING,
141
+ QUESTION_ANSWER_DESIGN,
142
+ REASONING_THINKING,
143
+
144
+ # Response agent prompts (dynamically applied)
145
+ VAUGE_INPUT,
146
+ USER_UNDERSTANDING,
147
+ GENERAL_FORMATTING,
148
+ LATEX_FORMATTING,
149
+ GUIDING_TEACHING,
150
+ STRUCTURE_PRACTICE_QUESTIONS,
151
+ PRACTICE_QUESTION_FOLLOWUP,
152
+ TOOL_USE_ENHANCEMENT,
153
+ )
154
+
155
+ # ============================================================================
156
+ # SHARED MODEL IMPORT - CRITICAL CHANGE
157
+ # ============================================================================
158
+ # Import the shared Llama-3.2-3B agent from shared_models.py
159
+ from shared_models import get_shared_llama, LlamaSharedAgent
160
+
161
+ # Backwards compatibility alias
162
+ get_shared_mistral = get_shared_llama
163
+ MistralSharedAgent = LlamaSharedAgent
164
+
165
+ # ============================================================================
166
+ # CONFIGURATION
167
+ # ============================================================================
168
+
169
+ CACHE_DIR = "/tmp/compiled_models"
170
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
171
+
172
+ # Suppress warnings
173
+ warnings.filterwarnings("ignore", category=UserWarning)
174
+ warnings.filterwarnings("ignore", category=FutureWarning)
175
+
176
+ # Model info (for logging/diagnostics)
177
+ LLAMA_MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
178
+
179
+
180
+ def check_model_cache() -> Dict[str, bool]:
181
+ """Check which models are pre-compiled (legacy function, kept for compatibility)"""
182
+ cache_status = {
183
+ "llama": True, # Handled by transformers cache
184
+ "all_compiled": True,
185
+ }
186
+
187
+ logger.info("✓ Llama-3.2-3B uses transformers cache (automatic)")
188
+
189
+ return cache_status
190
+
191
+
192
+ # Call at module load
193
+ _cache_status = check_model_cache()
194
+ log_memory("Module load complete")
195
+
196
+
197
+ # ============================================================================
198
+ # TOOL DECISION AGENT
199
+ # ============================================================================
200
+
201
+ class ToolDecisionAgent:
202
+ """
203
+ Analyzes if visualization/graphing tools should be used.
204
+
205
+ Uses shared Llama-3.2-3B for decision-making.
206
+
207
+ Returns: Boolean (True = use tools, False = skip tools)
208
+ """
209
+
210
+ def __init__(self):
211
+ """Initialize with shared Llama model"""
212
+ self.model = get_shared_llama()
213
+ logger.info("ToolDecisionAgent initialized (using shared Llama)")
214
+
215
+ def decide(self, user_query: str, conversation_history: List[Dict]) -> bool:
216
+ """
217
+ Decide if graphing tools should be used.
218
+
219
+ Args:
220
+ user_query: Current user message
221
+ conversation_history: Full conversation context
222
+
223
+ Returns:
224
+ bool: True if tools should be used
225
+ """
226
+ logger.info("→ ToolDecisionAgent: Analyzing query for tool usage")
227
+
228
+ # Format conversation context
229
+ context = "\n".join([
230
+ f"{msg['role']}: {msg['content']}"
231
+ for msg in conversation_history[-3:] # Last 3 turns
232
+ ])
233
+
234
+ # Decision prompt
235
+ analysis_prompt = f"""Previous conversation:
236
+ {context}
237
+
238
+ Current query: {user_query}
239
+
240
+ Should visualization tools (graphs, charts) be used?"""
241
+
242
+ try:
243
+ decision_start = time.time()
244
+
245
+ # Use shared Llama for decision
246
+ response = self.model.generate(
247
+ system_prompt=TOOL_DECISION,
248
+ user_message=analysis_prompt,
249
+ max_tokens=10,
250
+ temperature=0.1
251
+ )
252
+
253
+ decision_time = time.time() - decision_start
254
+
255
+ # Parse decision
256
+ decision = "YES" in response.upper()
257
+
258
+ logger.info(f"✓ ToolDecision: {'USE TOOLS' if decision else 'NO TOOLS'} ({decision_time:.2f}s)")
259
+
260
+ return decision
261
+
262
+ except Exception as e:
263
+ logger.error(f"ToolDecisionAgent error: {e}")
264
+ return False # Default: no tools
265
+
266
+
267
+ # ============================================================================
268
+ # PROMPT ROUTING AGENTS (4 Specialized Agents)
269
+ # ============================================================================
270
+
271
+ class PromptRoutingAgents:
272
+ """
273
+ Four specialized agents for prompt segment selection.
274
+ All share the same Llama-3.2-3B instance for efficiency.
275
+
276
+ Agents:
277
+ 1. Practice Question Detector
278
+ 2. Discovery Mode Classifier
279
+ 3. Follow-up Assessment
280
+ 4. Teaching Mode Assessor
281
+ """
282
+
283
+ def __init__(self):
284
+ """Initialize with shared Llama model"""
285
+ self.model = get_shared_llama()
286
+ logger.info("PromptRoutingAgents initialized (4 agents, shared Llama)")
287
+
288
+ def agent_1_practice_question(
289
+ self,
290
+ user_query: str,
291
+ conversation_history: List[Dict]
292
+ ) -> bool:
293
+ """Agent 1: Detect if practice questions should be generated"""
294
+ logger.info("→ Agent 1: Analyzing for practice question opportunity")
295
+
296
+ context = "\n".join([
297
+ f"{msg['role']}: {msg['content']}"
298
+ for msg in conversation_history[-4:]
299
+ ])
300
+
301
+ analysis_prompt = f"""Conversation:
302
+ {context}
303
+
304
+ New query: {user_query}
305
+
306
+ Should I create practice questions?"""
307
+
308
+ try:
309
+ response = self.model.generate(
310
+ system_prompt=agent_1_system,
311
+ user_message=analysis_prompt,
312
+ max_tokens=10,
313
+ temperature=0.1
314
+ )
315
+
316
+ decision = "YES" in response.upper()
317
+ logger.info(f"�� Agent 1: {'PRACTICE QUESTIONS' if decision else 'NO PRACTICE'}")
318
+
319
+ return decision
320
+
321
+ except Exception as e:
322
+ logger.error(f"Agent 1 error: {e}")
323
+ return False
324
+
325
+ def agent_2_discovery_mode(
326
+ self,
327
+ user_query: str,
328
+ conversation_history: List[Dict]
329
+ ) -> Tuple[bool, bool]:
330
+ """Agent 2: Classify vague input and understanding level"""
331
+ logger.info("→ Agent 2: Classifying discovery mode")
332
+
333
+ context = "\n".join([
334
+ f"{msg['role']}: {msg['content']}"
335
+ for msg in conversation_history[-3:]
336
+ ])
337
+
338
+ analysis_prompt = f"""Conversation:
339
+ {context}
340
+
341
+ Query: {user_query}
342
+
343
+ Classification:
344
+ 1. Is input vague? (VAGUE/CLEAR)
345
+ 2. Understanding level? (LOW/MEDIUM/HIGH)"""
346
+
347
+ try:
348
+ response = self.model.generate(
349
+ system_prompt=agent_2_system,
350
+ user_message=analysis_prompt,
351
+ max_tokens=20,
352
+ temperature=0.1
353
+ )
354
+
355
+ vague = "VAGUE" in response.upper()
356
+ low_understanding = "LOW" in response.upper()
357
+
358
+ logger.info(f"✓ Agent 2: Vague={vague}, LowUnderstanding={low_understanding}")
359
+
360
+ return vague, low_understanding
361
+
362
+ except Exception as e:
363
+ logger.error(f"Agent 2 error: {e}")
364
+ return False, False
365
+
366
+ def agent_3_followup_assessment(
367
+ self,
368
+ user_query: str,
369
+ conversation_history: List[Dict]
370
+ ) -> bool:
371
+ """Agent 3: Detect if user is responding to practice questions"""
372
+ logger.info("→ Agent 3: Checking for practice question follow-up")
373
+
374
+ # Check last bot message for practice question indicators
375
+ if len(conversation_history) < 2:
376
+ return False
377
+
378
+ last_bot_msg = None
379
+ for msg in reversed(conversation_history):
380
+ if msg['role'] == 'assistant':
381
+ last_bot_msg = msg['content']
382
+ break
383
+
384
+ if not last_bot_msg:
385
+ return False
386
+
387
+ # Look for practice question markers
388
+ has_practice = any(marker in last_bot_msg.lower() for marker in [
389
+ "practice", "try this", "solve", "calculate", "what is", "question"
390
+ ])
391
+
392
+ if not has_practice:
393
+ return False
394
+
395
+ # Analyze if current query is an answer attempt
396
+ analysis_prompt = f"""Previous message (from me):
397
+ {last_bot_msg[:500]}
398
+
399
+ User response:
400
+ {user_query}
401
+
402
+ Is user answering a practice question?"""
403
+
404
+ try:
405
+ response = self.model.generate(
406
+ system_prompt=agent_3_system,
407
+ user_message=analysis_prompt,
408
+ max_tokens=10,
409
+ temperature=0.1
410
+ )
411
+
412
+ is_followup = "YES" in response.upper()
413
+ logger.info(f"✓ Agent 3: {'GRADING MODE' if is_followup else 'NOT FOLLOWUP'}")
414
+
415
+ return is_followup
416
+
417
+ except Exception as e:
418
+ logger.error(f"Agent 3 error: {e}")
419
+ return False
420
+
421
+ def agent_4_teaching_mode(
422
+ self,
423
+ user_query: str,
424
+ conversation_history: List[Dict]
425
+ ) -> Tuple[bool, bool]:
426
+ """Agent 4: Assess teaching vs practice mode"""
427
+ logger.info("→ Agent 4: Assessing teaching mode")
428
+
429
+ context = "\n".join([
430
+ f"{msg['role']}: {msg['content']}"
431
+ for msg in conversation_history[-3:]
432
+ ])
433
+
434
+ analysis_prompt = f"""Conversation:
435
+ {context}
436
+
437
+ Query: {user_query}
438
+
439
+ Assessment:
440
+ 1. Need direct teaching? (TEACH/PRACTICE)
441
+ 2. Create practice questions? (YES/NO)"""
442
+
443
+ try:
444
+ response = self.model.generate(
445
+ system_prompt=agent_4_system,
446
+ user_message=analysis_prompt,
447
+ max_tokens=15,
448
+ temperature=0.1
449
+ )
450
+
451
+ teaching = "TEACH" in response.upper()
452
+ practice = "YES" in response.upper() or "PRACTICE" in response.upper()
453
+
454
+ logger.info(f"✓ Agent 4: Teaching={teaching}, Practice={practice}")
455
+
456
+ return teaching, practice
457
+
458
+ except Exception as e:
459
+ logger.error(f"Agent 4 error: {e}")
460
+ return False, False
461
+
462
+
463
+ # ============================================================================
464
+ # THINKING AGENTS (Preprocessing Layer)
465
+ # ============================================================================
466
+
467
+ class ThinkingAgents:
468
+ """
469
+ Generates reasoning context before final response.
470
+ Uses shared Llama-3.2-3B for all thinking (including math).
471
+
472
+ Agents:
473
+ 1. Math Thinking (Tree-of-Thought)
474
+ 2. Q&A Design (Chain-of-Thought)
475
+ 3. General Reasoning (Chain-of-Thought)
476
+ """
477
+
478
+ def __init__(self):
479
+ """Initialize with shared Llama model"""
480
+ self.model = get_shared_llama()
481
+ logger.info("ThinkingAgents initialized (using shared Llama for all thinking)")
482
+
483
+ def math_thinking(
484
+ self,
485
+ user_query: str,
486
+ conversation_history: List[Dict],
487
+ tool_context: str = ""
488
+ ) -> str:
489
+ """
490
+ Generate mathematical reasoning using Tree-of-Thought.
491
+ Now uses Llama-3.2-3B instead of GGUF.
492
+ """
493
+ logger.info("→ Math Thinking Agent: Generating reasoning")
494
+
495
+ context = "\n".join([
496
+ f"{msg['role']}: {msg['content']}"
497
+ for msg in conversation_history[-3:]
498
+ ])
499
+
500
+ thinking_prompt = f"""Conversation context:
501
+ {context}
502
+
503
+ Current query: {user_query}
504
+
505
+ {f"Tool output: {tool_context}" if tool_context else ""}
506
+
507
+ Generate mathematical reasoning:"""
508
+
509
+ try:
510
+ thinking_start = time.time()
511
+
512
+ reasoning = self.model.generate(
513
+ system_prompt=MATH_THINKING,
514
+ user_message=thinking_prompt,
515
+ max_tokens=300,
516
+ temperature=0.7
517
+ )
518
+
519
+ thinking_time = time.time() - thinking_start
520
+ logger.info(f"✓ Math Thinking: Generated {len(reasoning)} chars ({thinking_time:.2f}s)")
521
+
522
+ return reasoning
523
+
524
+ except Exception as e:
525
+ logger.error(f"Math Thinking error: {e}")
526
+ return ""
527
+
528
+ def qa_design_thinking(
529
+ self,
530
+ user_query: str,
531
+ conversation_history: List[Dict],
532
+ tool_context: str = ""
533
+ ) -> str:
534
+ """Generate practice question design reasoning"""
535
+ logger.info("→ Q&A Design Agent: Generating question strategy")
536
+
537
+ context = "\n".join([
538
+ f"{msg['role']}: {msg['content']}"
539
+ for msg in conversation_history[-3:]
540
+ ])
541
+
542
+ thinking_prompt = f"""Context:
543
+ {context}
544
+
545
+ Query: {user_query}
546
+
547
+ {f"Tool data: {tool_context}" if tool_context else ""}
548
+
549
+ Design practice questions:"""
550
+
551
+ try:
552
+ reasoning = self.model.generate(
553
+ system_prompt=QUESTION_ANSWER_DESIGN,
554
+ user_message=thinking_prompt,
555
+ max_tokens=250,
556
+ temperature=0.7
557
+ )
558
+
559
+ logger.info(f"✓ Q&A Design: Generated {len(reasoning)} chars")
560
+
561
+ return reasoning
562
+
563
+ except Exception as e:
564
+ logger.error(f"Q&A Design error: {e}")
565
+ return ""
566
+
567
+ def general_reasoning(
568
+ self,
569
+ user_query: str,
570
+ conversation_history: List[Dict],
571
+ tool_context: str = ""
572
+ ) -> str:
573
+ """Generate general reasoning context"""
574
+ logger.info("→ General Reasoning Agent: Generating context")
575
+
576
+ context = "\n".join([
577
+ f"{msg['role']}: {msg['content']}"
578
+ for msg in conversation_history[-4:]
579
+ ])
580
+
581
+ thinking_prompt = f"""Conversation:
582
+ {context}
583
+
584
+ Query: {user_query}
585
+
586
+ {f"Context: {tool_context}" if tool_context else ""}
587
+
588
+ Analyze and provide reasoning:"""
589
+
590
+ try:
591
+ reasoning = self.model.generate(
592
+ system_prompt=REASONING_THINKING,
593
+ user_message=thinking_prompt,
594
+ max_tokens=200,
595
+ temperature=0.7
596
+ )
597
+
598
+ logger.info(f"✓ General Reasoning: Generated {len(reasoning)} chars")
599
+
600
+ return reasoning
601
+
602
+ except Exception as e:
603
+ logger.error(f"General Reasoning error: {e}")
604
+ return ""
605
+
606
+
607
+ # ============================================================================
608
+ # RESPONSE AGENT (Final Response Generation)
609
+ # ============================================================================
610
+
611
+ class ResponseAgent(Runnable):
612
+ """
613
+ Generates final educational responses using shared Llama-3.2-3B.
614
+
615
+ Features:
616
+ - Dynamic prompt assembly based on agent decisions
617
+ - Streaming word-by-word output
618
+ - Educational tone enforcement
619
+ - LaTeX support for math
620
+ - Context integration (thinking outputs, tool outputs)
621
+ """
622
+
623
+ def __init__(self):
624
+ """Initialize with shared Llama model"""
625
+ super().__init__()
626
+ self.model = get_shared_llama()
627
+ logger.info("ResponseAgent initialized (using shared Llama)")
628
+
629
+ def invoke(self, input_data: Dict) -> Dict:
630
+ """
631
+ Generate final response with streaming.
632
+
633
+ Args:
634
+ input_data: {
635
+ 'user_query': str,
636
+ 'conversation_history': List[Dict],
637
+ 'active_prompts': List[str],
638
+ 'thinking_context': str,
639
+ 'tool_context': str,
640
+ }
641
+
642
+ Returns:
643
+ {'response': str, 'metadata': Dict}
644
+ """
645
+ logger.info("→ ResponseAgent: Generating final response")
646
+
647
+ # Extract inputs
648
+ user_query = input_data.get('user_query', '')
649
+ conversation_history = input_data.get('conversation_history', [])
650
+ active_prompts = input_data.get('active_prompts', [])
651
+ thinking_context = input_data.get('thinking_context', '')
652
+ tool_context = input_data.get('tool_context', '')
653
+
654
+ # Build system prompt from active segments
655
+ system_prompt = self._build_system_prompt(active_prompts)
656
+
657
+ # Build user message with context
658
+ user_message = self._build_user_message(
659
+ user_query,
660
+ conversation_history,
661
+ thinking_context,
662
+ tool_context
663
+ )
664
+
665
+ try:
666
+ response_start = time.time()
667
+
668
+ # Generate response (streaming handled at app.py level)
669
+ response = self.model.generate(
670
+ system_prompt=system_prompt,
671
+ user_message=user_message,
672
+ max_tokens=600,
673
+ temperature=0.7
674
+ )
675
+
676
+ response_time = time.time() - response_start
677
+
678
+ # Clean up response
679
+ response = self._clean_response(response)
680
+
681
+ logger.info(f"✓ ResponseAgent: Generated {len(response)} chars ({response_time:.2f}s)")
682
+
683
+ return {
684
+ 'response': response,
685
+ 'metadata': {
686
+ 'generation_time': response_time,
687
+ 'model': LLAMA_MODEL_ID,
688
+ 'active_prompts': active_prompts
689
+ }
690
+ }
691
+
692
+ except Exception as e:
693
+ logger.error(f"ResponseAgent error: {e}")
694
+ return {
695
+ 'response': "I apologize, but I encountered an error generating a response. Please try again.",
696
+ 'metadata': {'error': str(e)}
697
+ }
698
+
699
+ def _build_system_prompt(self, active_prompts: List[str]) -> str:
700
+ """Assemble system prompt from active segments"""
701
+ prompt_map = {
702
+ 'CORE_IDENTITY': CORE_IDENTITY,
703
+ 'GENERAL_FORMATTING': GENERAL_FORMATTING,
704
+ 'LATEX_FORMATTING': LATEX_FORMATTING,
705
+ 'VAUGE_INPUT': VAUGE_INPUT,
706
+ 'USER_UNDERSTANDING': USER_UNDERSTANDING,
707
+ 'GUIDING_TEACHING': GUIDING_TEACHING,
708
+ 'STRUCTURE_PRACTICE_QUESTIONS': STRUCTURE_PRACTICE_QUESTIONS,
709
+ 'PRACTICE_QUESTION_FOLLOWUP': PRACTICE_QUESTION_FOLLOWUP,
710
+ 'TOOL_USE_ENHANCEMENT': TOOL_USE_ENHANCEMENT,
711
+ }
712
+
713
+ # Always include core identity
714
+ segments = [CORE_IDENTITY, GENERAL_FORMATTING]
715
+
716
+ # Add active prompts
717
+ for prompt_name in active_prompts:
718
+ if prompt_name in prompt_map and prompt_map[prompt_name] not in segments:
719
+ segments.append(prompt_map[prompt_name])
720
+
721
+ return "\n\n".join(segments)
722
+
723
+ def _build_user_message(
724
+ self,
725
+ user_query: str,
726
+ conversation_history: List[Dict],
727
+ thinking_context: str,
728
+ tool_context: str
729
+ ) -> str:
730
+ """Build user message with all context"""
731
+ parts = []
732
+
733
+ # Conversation history (last 3 turns)
734
+ if conversation_history:
735
+ history_text = "\n".join([
736
+ f"{msg['role']}: {msg['content'][:200]}"
737
+ for msg in conversation_history[-3:]
738
+ ])
739
+ parts.append(f"Recent conversation:\n{history_text}")
740
+
741
+ # Thinking context (invisible to user, guides response)
742
+ if thinking_context:
743
+ parts.append(f"[Internal reasoning context]: {thinking_context}")
744
+
745
+ # Tool context
746
+ if tool_context:
747
+ parts.append(f"[Tool output]: {tool_context}")
748
+
749
+ # Current query
750
+ parts.append(f"Student query: {user_query}")
751
+
752
+ return "\n\n".join(parts)
753
+
754
+ def _clean_response(self, response: str) -> str:
755
+ """Clean up response artifacts"""
756
+ # Remove common artifacts
757
+ artifacts = ['<|im_end|>', '<|endoftext|>', '###', '<|end|>']
758
+ for artifact in artifacts:
759
+ response = response.replace(artifact, '')
760
+
761
+ # Remove trailing incomplete sentences
762
+ if response and response[-1] not in '.!?':
763
+ # Find last complete sentence
764
+ for delimiter in ['. ', '! ', '? ']:
765
+ if delimiter in response:
766
+ response = response.rsplit(delimiter, 1)[0] + delimiter[0]
767
+ break
768
+
769
+ return response.strip()
770
+
771
+ def stream(self, input_data: Dict):
772
+ """
773
+ Stream response word-by-word.
774
+
775
+ Yields:
776
+ str: Response chunks
777
+ """
778
+ logger.info("→ ResponseAgent: Streaming response")
779
+
780
+ # Build prompts
781
+ system_prompt = self._build_system_prompt(input_data.get('active_prompts', []))
782
+ user_message = self._build_user_message(
783
+ input_data.get('user_query', ''),
784
+ input_data.get('conversation_history', []),
785
+ input_data.get('thinking_context', ''),
786
+ input_data.get('tool_context', '')
787
+ )
788
+
789
+ try:
790
+ # Use streaming generation from shared model
791
+ for chunk in self.model.generate_streaming(
792
+ system_prompt=system_prompt,
793
+ user_message=user_message,
794
+ max_tokens=600,
795
+ temperature=0.7
796
+ ):
797
+ yield chunk
798
+
799
+ except Exception as e:
800
+ logger.error(f"Streaming error: {e}")
801
+ yield "I apologize, but I encountered an error. Please try again."
802
+
803
+
804
+ # ============================================================================
805
+ # MODULE INITIALIZATION
806
+ # ============================================================================
807
+
808
+ logger.info("="*60)
809
+ logger.info("MIMIR AGENTS MODULE INITIALIZED")
810
+ logger.info("="*60)
811
+ logger.info(f" Model: Llama-3.2-3B-Instruct (shared)")
812
+ logger.info(f" Agents: Tool, Routing (4x), Thinking (3x), Response")
813
+ logger.info(f" Memory: ~1GB total (shared instance)")
814
+ logger.info(f" Architecture: Single unified model")
815
+ logger.info("="*60)