Mimir / agents.py
jdesiree's picture
Update agents.py
df75c85 verified
# agents.py
"""
Unified agent architecture for Mimir Educational AI Assistant.
LAZY-LOADING LLAMA-3.2-3B-INSTRUCT
Components:
- LazyLlamaModel: Singleton lazy-loading model (loads on first use, cached thereafter)
- ToolDecisionAgent: Uses lazy-loaded Llama for visualization decisions
- PromptRoutingAgents: Uses lazy-loaded Llama for all 4 routing agents
- ThinkingAgents: Uses lazy-loaded Llama for all reasoning (including math)
- ResponseAgent: Uses lazy-loaded Llama for final responses
Key optimization: Model loads on first generate() call and is cached for all
subsequent requests. Single model architecture with ~1GB memory footprint.
No compile or warmup scripts needed - fully automatic.
"""
import os
import re
import torch
import logging
import time
import subprocess
import threading
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Type
import warnings
# Setup main logger first
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ============================================================================
# MEMORY PROFILING UTILITIES
# ============================================================================
def log_memory(tag=""):
"""Log current GPU memory usage"""
try:
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**2
reserved = torch.cuda.memory_reserved() / 1024**2
max_allocated = torch.cuda.max_memory_allocated() / 1024**2
logger.info(f"[{tag}] GPU Memory - Allocated: {allocated:.2f} MB, Reserved: {reserved:.2f} MB, Peak: {max_allocated:.2f} MB")
else:
logger.info(f"[{tag}] No CUDA available")
except Exception as e:
logger.warning(f"[{tag}] Error logging GPU memory: {e}")
def log_nvidia_smi(tag=""):
"""Log full nvidia-smi output for system-wide GPU view"""
try:
output = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv,noheader,nounits'], encoding='utf-8')
logger.info(f"[{tag}] NVIDIA-SMI: {output.strip()}")
except Exception as e:
logger.warning(f"[{tag}] Error running nvidia-smi: {e}")
def log_step(step_name, start_time=None):
"""Log a pipeline step with timestamp and duration"""
now = time.time()
timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3]
if start_time:
duration = now - start_time
logger.info(f"[{timestamp}] ✓ {step_name} completed in {duration:.2f}s")
else:
logger.info(f"[{timestamp}] → {step_name} starting...")
return now
def profile_generation(model, tokenizer, inputs, **gen_kwargs):
"""Profile memory and time for model.generate() call"""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
log_memory("Before generate()")
start_time = time.time()
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
end_time = time.time()
duration = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated() / 1024**2
log_memory("After generate()")
logger.info(f"Generation completed in {duration:.2f}s. Peak GPU: {peak_memory:.2f} MB")
return outputs, duration
# ============================================================================
# IMPORTS
# ============================================================================
# Transformers for standard models
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
# ZeroGPU support
try:
import spaces
HF_SPACES_AVAILABLE = True
except ImportError:
HF_SPACES_AVAILABLE = False
class DummySpaces:
@staticmethod
def GPU(duration=90):
def decorator(func):
return func
return decorator
spaces = DummySpaces()
# Accelerate
from accelerate import Accelerator
from accelerate.utils import set_seed
# LangChain Core for proper message handling
from langchain_core.runnables import Runnable
from langchain_core.runnables.utils import Input, Output
from langchain_core.messages import SystemMessage, HumanMessage
# Import ALL prompts from prompt library
from prompt_library import (
# System prompts
CORE_IDENTITY,
TOOL_DECISION,
agent_1_system,
agent_2_system,
agent_3_system,
agent_4_system,
# Thinking agent system prompts
MATH_THINKING,
QUESTION_ANSWER_DESIGN,
REASONING_THINKING,
# Response agent prompts (dynamically applied)
VAUGE_INPUT,
USER_UNDERSTANDING,
GENERAL_FORMATTING,
LATEX_FORMATTING,
GUIDING_TEACHING,
STRUCTURE_PRACTICE_QUESTIONS,
PRACTICE_QUESTION_FOLLOWUP,
TOOL_USE_ENHANCEMENT,
)
# ============================================================================
# MODEL MANAGER - LAZY LOADING
# ============================================================================
# Import the lazy-loading Llama-3.2-3B model manager
from model_manager import get_model as get_shared_llama, LazyLlamaModel as LlamaSharedAgent
# Backwards compatibility aliases
get_shared_mistral = get_shared_llama
MistralSharedAgent = LlamaSharedAgent
# ============================================================================
# CONFIGURATION
# ============================================================================
CACHE_DIR = "/tmp/compiled_models"
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# Model info (for logging/diagnostics)
LLAMA_MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
def check_model_cache() -> Dict[str, bool]:
"""Check model status (legacy function for compatibility)"""
cache_status = {
"llama": True, # Lazy-loaded on first use
"all_compiled": True,
}
logger.info("✓ Llama-3.2-3B uses lazy loading (loads on first generate() call)")
return cache_status
# Call at module load
_cache_status = check_model_cache()
log_memory("Module load complete")
# ============================================================================
# TOOL DECISION AGENT
# ============================================================================
class ToolDecisionAgent:
"""
Analyzes if visualization/graphing tools should be used.
Uses lazy-loaded Llama-3.2-3B for decision-making.
Model loads automatically on first use.
Returns: Boolean (True = use tools, False = skip tools)
"""
def __init__(self):
"""Initialize with lazy-loaded Llama model"""
self.model = get_shared_llama()
logger.info("ToolDecisionAgent initialized (using lazy-loaded Llama)")
def decide(self, user_query: str, conversation_history: List[Dict]) -> bool:
"""
Decide if graphing tools should be used.
Args:
user_query: Current user message
conversation_history: Full conversation context
Returns:
bool: True if tools should be used
"""
logger.info("→ ToolDecisionAgent: Analyzing query for tool usage")
# Format conversation context
context = "\n".join([
f"{msg['role']}: {msg['content']}"
for msg in conversation_history[-3:] # Last 3 turns
])
# Decision prompt
analysis_prompt = f"""Previous conversation:
{context}
Current query: {user_query}
Should visualization tools (graphs, charts) be used?"""
try:
decision_start = time.time()
# Use shared Llama for decision
response = self.model.generate(
system_prompt=TOOL_DECISION,
user_message=analysis_prompt,
max_tokens=10,
temperature=0.1
)
decision_time = time.time() - decision_start
# Parse decision
decision = "YES" in response.upper()
logger.info(f"✓ ToolDecision: {'USE TOOLS' if decision else 'NO TOOLS'} ({decision_time:.2f}s)")
return decision
except Exception as e:
logger.error(f"ToolDecisionAgent error: {e}")
return False # Default: no tools
# ============================================================================
# PROMPT ROUTING AGENTS (4 Specialized Agents)
# ============================================================================
class PromptRoutingAgents:
"""
Four specialized agents for prompt segment selection.
All share the same Llama-3.2-3B instance for efficiency.
Agents:
1. Practice Question Detector
2. Discovery Mode Classifier
3. Follow-up Assessment
4. Teaching Mode Assessor
"""
def __init__(self):
"""Initialize with shared Llama model"""
self.model = get_shared_llama()
logger.info("PromptRoutingAgents initialized (4 agents, shared Llama)")
def agent_1_practice_question(
self,
user_query: str,
conversation_history: List[Dict]
) -> bool:
"""Agent 1: Detect if practice questions should be generated"""
logger.info("→ Agent 1: Analyzing for practice question opportunity")
context = "\n".join([
f"{msg['role']}: {msg['content']}"
for msg in conversation_history[-4:]
])
analysis_prompt = f"""Conversation:
{context}
New query: {user_query}
Should I create practice questions?"""
try:
response = self.model.generate(
system_prompt=agent_1_system,
user_message=analysis_prompt,
max_tokens=10,
temperature=0.1
)
decision = "YES" in response.upper()
logger.info(f"✓ Agent 1: {'PRACTICE QUESTIONS' if decision else 'NO PRACTICE'}")
return decision
except Exception as e:
logger.error(f"Agent 1 error: {e}")
return False
def agent_2_discovery_mode(
self,
user_query: str,
conversation_history: List[Dict]
) -> Tuple[bool, bool]:
"""Agent 2: Classify vague input and understanding level"""
logger.info("→ Agent 2: Classifying discovery mode")
context = "\n".join([
f"{msg['role']}: {msg['content']}"
for msg in conversation_history[-3:]
])
analysis_prompt = f"""Conversation:
{context}
Query: {user_query}
Classification:
1. Is input vague? (VAGUE/CLEAR)
2. Understanding level? (LOW/MEDIUM/HIGH)"""
try:
response = self.model.generate(
system_prompt=agent_2_system,
user_message=analysis_prompt,
max_tokens=20,
temperature=0.1
)
vague = "VAGUE" in response.upper()
low_understanding = "LOW" in response.upper()
logger.info(f"✓ Agent 2: Vague={vague}, LowUnderstanding={low_understanding}")
return vague, low_understanding
except Exception as e:
logger.error(f"Agent 2 error: {e}")
return False, False
def agent_3_followup_assessment(
self,
user_query: str,
conversation_history: List[Dict]
) -> bool:
"""Agent 3: Detect if user is responding to practice questions"""
logger.info("→ Agent 3: Checking for practice question follow-up")
# Check last bot message for practice question indicators
if len(conversation_history) < 2:
return False
last_bot_msg = None
for msg in reversed(conversation_history):
if msg['role'] == 'assistant':
last_bot_msg = msg['content']
break
if not last_bot_msg:
return False
# Look for practice question markers
has_practice = any(marker in last_bot_msg.lower() for marker in [
"practice", "try this", "solve", "calculate", "what is", "question"
])
if not has_practice:
return False
# Analyze if current query is an answer attempt
analysis_prompt = f"""Previous message (from me):
{last_bot_msg[:500]}
User response:
{user_query}
Is user answering a practice question?"""
try:
response = self.model.generate(
system_prompt=agent_3_system,
user_message=analysis_prompt,
max_tokens=10,
temperature=0.1
)
is_followup = "YES" in response.upper()
logger.info(f"✓ Agent 3: {'GRADING MODE' if is_followup else 'NOT FOLLOWUP'}")
return is_followup
except Exception as e:
logger.error(f"Agent 3 error: {e}")
return False
def agent_4_teaching_mode(
self,
user_query: str,
conversation_history: List[Dict]
) -> Tuple[bool, bool]:
"""Agent 4: Assess teaching vs practice mode"""
logger.info("→ Agent 4: Assessing teaching mode")
context = "\n".join([
f"{msg['role']}: {msg['content']}"
for msg in conversation_history[-3:]
])
analysis_prompt = f"""Conversation:
{context}
Query: {user_query}
Assessment:
1. Need direct teaching? (TEACH/PRACTICE)
2. Create practice questions? (YES/NO)"""
try:
response = self.model.generate(
system_prompt=agent_4_system,
user_message=analysis_prompt,
max_tokens=15,
temperature=0.1
)
teaching = "TEACH" in response.upper()
practice = "YES" in response.upper() or "PRACTICE" in response.upper()
logger.info(f"✓ Agent 4: Teaching={teaching}, Practice={practice}")
return teaching, practice
except Exception as e:
logger.error(f"Agent 4 error: {e}")
return False, False
def process(
self,
user_input: str,
tool_used: bool = False,
conversation_history: Optional[List[Dict]] = None
) -> Tuple[str, str]:
"""
Unified process method - runs all 4 routing agents sequentially.
Returns:
Tuple[str, str]: (response_prompts, thinking_prompts)
"""
if conversation_history is None:
conversation_history = []
response_prompts = []
thinking_prompts = []
# Agent 1: Practice Questions
if self.agent_1_practice_question(user_input, conversation_history):
response_prompts.append("STRUCTURE_PRACTICE_QUESTIONS")
# Agent 2: Discovery Mode
is_vague, low_understanding = self.agent_2_discovery_mode(user_input, conversation_history)
if is_vague:
response_prompts.append("VAUGE_INPUT")
if low_understanding:
response_prompts.append("USER_UNDERSTANDING")
# Agent 3: Follow-up Assessment
if self.agent_3_followup_assessment(user_input, conversation_history):
response_prompts.append("PRACTICE_QUESTION_FOLLOWUP")
# Agent 4: Teaching Mode
needs_teaching, needs_practice = self.agent_4_teaching_mode(user_input, conversation_history)
if needs_teaching:
response_prompts.append("GUIDING_TEACHING")
# Always add base formatting
response_prompts.extend(["GENERAL_FORMATTING", "LATEX_FORMATTING"])
# Tool enhancement if used
if tool_used:
response_prompts.append("TOOL_USE_ENHANCEMENT")
# Return as newline-separated strings
response_prompts_str = "\n".join(response_prompts)
thinking_prompts_str = "" # Thinking prompts decided elsewhere
return response_prompts_str, thinking_prompts_str
# ============================================================================
# THINKING AGENTS (Preprocessing Layer)
# ============================================================================
class ThinkingAgents:
"""
Generates reasoning context before final response.
Uses shared Llama-3.2-3B for all thinking (including math).
Agents:
1. Math Thinking (Tree-of-Thought)
2. Q&A Design (Chain-of-Thought)
3. General Reasoning (Chain-of-Thought)
"""
def __init__(self):
"""Initialize with shared Llama model"""
self.model = get_shared_llama()
logger.info("ThinkingAgents initialized (using shared Llama for all thinking)")
def math_thinking(
self,
user_query: str,
conversation_history: List[Dict],
tool_context: str = ""
) -> str:
"""
Generate mathematical reasoning using Tree-of-Thought.
Now uses Llama-3.2-3B instead of GGUF.
"""
logger.info("→ Math Thinking Agent: Generating reasoning")
context = "\n".join([
f"{msg['role']}: {msg['content']}"
for msg in conversation_history[-3:]
])
thinking_prompt = f"""Conversation context:
{context}
Current query: {user_query}
{f"Tool output: {tool_context}" if tool_context else ""}
Generate mathematical reasoning:"""
try:
thinking_start = time.time()
reasoning = self.model.generate(
system_prompt=MATH_THINKING,
user_message=thinking_prompt,
max_tokens=300,
temperature=0.7
)
thinking_time = time.time() - thinking_start
logger.info(f"✓ Math Thinking: Generated {len(reasoning)} chars ({thinking_time:.2f}s)")
return reasoning
except Exception as e:
logger.error(f"Math Thinking error: {e}")
return ""
def qa_design_thinking(
self,
user_query: str,
conversation_history: List[Dict],
tool_context: str = ""
) -> str:
"""Generate practice question design reasoning"""
logger.info("→ Q&A Design Agent: Generating question strategy")
context = "\n".join([
f"{msg['role']}: {msg['content']}"
for msg in conversation_history[-3:]
])
thinking_prompt = f"""Context:
{context}
Query: {user_query}
{f"Tool data: {tool_context}" if tool_context else ""}
Design practice questions:"""
try:
reasoning = self.model.generate(
system_prompt=QUESTION_ANSWER_DESIGN,
user_message=thinking_prompt,
max_tokens=250,
temperature=0.7
)
logger.info(f"✓ Q&A Design: Generated {len(reasoning)} chars")
return reasoning
except Exception as e:
logger.error(f"Q&A Design error: {e}")
return ""
def process(
self,
user_input: str,
conversation_history: str = "",
thinking_prompts: str = "",
tool_img_output: str = "",
tool_context: str = ""
) -> str:
"""
Unified process method - runs thinking agents based on active prompts.
Args:
user_input: User's query
conversation_history: Formatted conversation history string
thinking_prompts: Newline-separated list of thinking prompts to activate
tool_img_output: HTML output from visualization tool
tool_context: Context from tool usage
Returns:
str: Combined thinking context from all activated agents
"""
thinking_outputs = []
# Convert history string to list format for agent methods
history_list = []
if conversation_history and conversation_history != "No previous conversation":
for line in conversation_history.split('\n'):
if ':' in line:
role, content = line.split(':', 1)
history_list.append({'role': role.strip(), 'content': content.strip()})
# Determine which thinking agents to run based on prompts
prompt_list = [p.strip() for p in thinking_prompts.split('\n') if p.strip()]
# Math Thinking
if any('MATH' in p.upper() for p in prompt_list):
math_output = self.math_thinking(
user_query=user_input,
conversation_history=history_list,
tool_context=tool_context
)
if math_output:
thinking_outputs.append(f"[Mathematical Reasoning]\n{math_output}")
# Q&A Design Thinking
if any('PRACTICE' in p.upper() or 'QUESTION' in p.upper() for p in prompt_list):
qa_output = self.qa_design_thinking(
user_query=user_input,
conversation_history=history_list,
tool_context=tool_context
)
if qa_output:
thinking_outputs.append(f"[Practice Question Design]\n{qa_output}")
# General Reasoning (fallback or when no specific thinking needed)
if not thinking_outputs or any('REASONING' in p.upper() for p in prompt_list):
general_output = self.general_reasoning(
user_query=user_input,
conversation_history=history_list,
tool_context=tool_context
)
if general_output:
thinking_outputs.append(f"[General Reasoning]\n{general_output}")
# Combine all thinking outputs
combined_thinking = "\n\n".join(thinking_outputs) if thinking_outputs else ""
if combined_thinking:
logger.info(f"✓ Thinking complete: {len(combined_thinking)} chars from {len(thinking_outputs)} agents")
return combined_thinking
def general_reasoning(
self,
user_query: str,
conversation_history: List[Dict],
tool_context: str = ""
) -> str:
"""Generate general reasoning context"""
logger.info("→ General Reasoning Agent: Generating context")
context = "\n".join([
f"{msg['role']}: {msg['content']}"
for msg in conversation_history[-4:]
])
thinking_prompt = f"""Conversation:
{context}
Query: {user_query}
{f"Context: {tool_context}" if tool_context else ""}
Analyze and provide reasoning:"""
try:
reasoning = self.model.generate(
system_prompt=REASONING_THINKING,
user_message=thinking_prompt,
max_tokens=200,
temperature=0.7
)
logger.info(f"✓ General Reasoning: Generated {len(reasoning)} chars")
return reasoning
except Exception as e:
logger.error(f"General Reasoning error: {e}")
return ""
# ============================================================================
# RESPONSE AGENT (Final Response Generation)
# ============================================================================
class ResponseAgent(Runnable):
"""
Generates final educational responses using lazy-loaded Llama-3.2-3B.
Model loads automatically on first use.
Features:
- Dynamic prompt assembly based on agent decisions
- Streaming word-by-word output
- Educational tone enforcement
- LaTeX support for math
- Context integration (thinking outputs, tool outputs)
"""
def __init__(self):
"""Initialize with lazy-loaded Llama model"""
super().__init__()
self.model = get_shared_llama()
logger.info("ResponseAgent initialized (using lazy-loaded Llama)")
def invoke(self, input_data: Dict) -> Dict:
"""
Generate final response with streaming.
Args:
input_data: {
'user_query': str,
'conversation_history': List[Dict],
'active_prompts': List[str],
'thinking_context': str,
'tool_context': str,
}
Returns:
{'response': str, 'metadata': Dict}
"""
logger.info("→ ResponseAgent: Generating final response")
# Extract inputs
user_query = input_data.get('user_query', '')
conversation_history = input_data.get('conversation_history', [])
active_prompts = input_data.get('active_prompts', [])
thinking_context = input_data.get('thinking_context', '')
tool_context = input_data.get('tool_context', '')
# Build system prompt from active segments
system_prompt = self._build_system_prompt(active_prompts)
# Build user message with context
user_message = self._build_user_message(
user_query,
conversation_history,
thinking_context,
tool_context
)
try:
response_start = time.time()
# Generate response (streaming handled at app.py level)
response = self.model.generate(
system_prompt=system_prompt,
user_message=user_message,
max_tokens=600,
temperature=0.7
)
response_time = time.time() - response_start
# Clean up response
response = self._clean_response(response)
logger.info(f"✓ ResponseAgent: Generated {len(response)} chars ({response_time:.2f}s)")
return {
'response': response,
'metadata': {
'generation_time': response_time,
'model': LLAMA_MODEL_ID,
'active_prompts': active_prompts
}
}
except Exception as e:
logger.error(f"ResponseAgent error: {e}")
return {
'response': "I apologize, but I encountered an error generating a response. Please try again.",
'metadata': {'error': str(e)}
}
def _build_system_prompt(self, active_prompts: List[str]) -> str:
"""Assemble system prompt from active segments"""
prompt_map = {
'CORE_IDENTITY': CORE_IDENTITY,
'GENERAL_FORMATTING': GENERAL_FORMATTING,
'LATEX_FORMATTING': LATEX_FORMATTING,
'VAUGE_INPUT': VAUGE_INPUT,
'USER_UNDERSTANDING': USER_UNDERSTANDING,
'GUIDING_TEACHING': GUIDING_TEACHING,
'STRUCTURE_PRACTICE_QUESTIONS': STRUCTURE_PRACTICE_QUESTIONS,
'PRACTICE_QUESTION_FOLLOWUP': PRACTICE_QUESTION_FOLLOWUP,
'TOOL_USE_ENHANCEMENT': TOOL_USE_ENHANCEMENT,
}
# Always include core identity
segments = [CORE_IDENTITY, GENERAL_FORMATTING]
# Add active prompts
for prompt_name in active_prompts:
if prompt_name in prompt_map and prompt_map[prompt_name] not in segments:
segments.append(prompt_map[prompt_name])
return "\n\n".join(segments)
def _build_user_message(
self,
user_query: str,
conversation_history: List[Dict],
thinking_context: str,
tool_context: str
) -> str:
"""Build user message with all context"""
parts = []
# Conversation history (last 3 turns)
if conversation_history:
history_text = "\n".join([
f"{msg['role']}: {msg['content'][:200]}"
for msg in conversation_history[-3:]
])
parts.append(f"Recent conversation:\n{history_text}")
# Thinking context (invisible to user, guides response)
if thinking_context:
parts.append(f"[Internal reasoning context]: {thinking_context}")
# Tool context
if tool_context:
parts.append(f"[Tool output]: {tool_context}")
# Current query
parts.append(f"Student query: {user_query}")
return "\n\n".join(parts)
def _clean_response(self, response: str) -> str:
"""Clean up response artifacts"""
# Remove common artifacts
artifacts = ['<|im_end|>', '<|endoftext|>', '###', '<|end|>']
for artifact in artifacts:
response = response.replace(artifact, '')
# Remove trailing incomplete sentences
if response and response[-1] not in '.!?':
# Find last complete sentence
for delimiter in ['. ', '! ', '? ']:
if delimiter in response:
response = response.rsplit(delimiter, 1)[0] + delimiter[0]
break
return response.strip()
def stream(self, input_data: Dict):
"""
Stream response word-by-word.
Yields:
str: Response chunks
"""
logger.info("→ ResponseAgent: Streaming response")
# Build prompts
system_prompt = self._build_system_prompt(input_data.get('active_prompts', []))
user_message = self._build_user_message(
input_data.get('user_query', ''),
input_data.get('conversation_history', []),
input_data.get('thinking_context', ''),
input_data.get('tool_context', '')
)
try:
# Use streaming generation from shared model
for chunk in self.model.generate_streaming(
system_prompt=system_prompt,
user_message=user_message,
max_tokens=600,
temperature=0.7
):
yield chunk
except Exception as e:
logger.error(f"Streaming error: {e}")
yield "I apologize, but I encountered an error. Please try again."
# ============================================================================
# MODULE INITIALIZATION
# ============================================================================
logger.info("="*60)
logger.info("MIMIR AGENTS MODULE INITIALIZED")
logger.info("="*60)
logger.info(f" Model: Llama-3.2-3B-Instruct (lazy-loaded)")
logger.info(f" Agents: Tool, Routing (4x), Thinking (3x), Response")
logger.info(f" Memory: ~1GB (loads on first use)")
logger.info(f" Architecture: Single unified model with caching")
logger.info("="*60)