Grux3 / src /supervisor /graph.py
BladeSzaSza's picture
Update src/supervisor/graph.py
a1d8c81 verified
# src/supervisor/graph.py
import operator
from typing import TypedDict, Annotated, List
import logging
from langchain_core.messages import BaseMessage, ToolMessage, AIMessage, HumanMessage
from langchain_anthropic import ChatAnthropic
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode
logger = logging.getLogger(__name__)
from src.supervisor.state import AgentState
from src.tools.toolbelt import toolbelt
from src.prompts import ROUTER_PROMPT, REFLECTION_PROMPT, FINAL_ANSWER_FORMATTER_PROMPT
from src.utils.config import config
# Initialize the LLM for our nodes (on demand)
# Using a powerful model for routing and reflection is key.
def get_model():
"""Get the main model instance with Extended Thinking enabled, initialized on first use."""
if not hasattr(get_model, '_model'):
# get_model._model = ChatAnthropic(
# model="claude-3-5-haiku-20241022",
# temperature=0,
# api_key=config.CLAUDE_API_KEY
# )
get_model._model = ChatAnthropic(
model="claude-sonnet-4-20250514",
temperature=1,
api_key=config.CLAUDE_API_KEY,
thinking={"type": "enabled", "budget_tokens": 6000}, # Reduced budget
max_tokens=16000 # Now greater than budget
)
return get_model._model
def get_final_formatter_model():
"""Get the formatter model instance with Extended Thinking enabled, initialized on first use."""
if not hasattr(get_final_formatter_model, '_model'):
get_final_formatter_model._model = ChatAnthropic(
model="claude-sonnet-4-20250514",
temperature=1,
api_key=config.CLAUDE_API_KEY,
thinking={"type": "enabled", "budget_tokens": 8000}, # Reduced budget
max_tokens=16000 # Now greater than budget
)
return get_final_formatter_model._model
def get_model_with_tools():
"""Get the model with tools bound, initialized on first use."""
if not hasattr(get_model_with_tools, '_model'):
get_model_with_tools._model = get_model().bind_tools(toolbelt)
return get_model_with_tools._model
def get_extraction_model():
"""Get the information extraction model instance, initialized on first use."""
if not hasattr(get_extraction_model, '_model'):
get_extraction_model._model = ChatAnthropic(
model="claude-sonnet-4-20250514", # model="claude-3-5-haiku-20241022",
temperature=0,
api_key=config.CLAUDE_API_KEY
)
return get_extraction_model._model
### NODE DEFINITIONS ###
def router_node(state: AgentState) -> dict:
"""The central router. Decides what to do next."""
messages = state["messages"]
# Check if the last message contains an answer that should be extracted
if messages and hasattr(messages[-1], 'content') and isinstance(messages[-1].content, str):
last_content = messages[-1].content
if "ANSWER FOUND:" in last_content:
# Extract the answer from the content
answer_start = last_content.find("ANSWER FOUND:") + len("ANSWER FOUND:")
answer_text = last_content[answer_start:].strip()
# Create a response with the extracted answer
response = AIMessage(content=f"Based on the analysis, the answer is: {answer_text}")
logger.debug(f"Router Node - Extracted Answer: {answer_text}")
logger.debug(f"Router Node - Response: {[response]}")
return {"messages": [response]}
# Add the system prompt to the message list for the LLM call
router_prompt = ROUTER_PROMPT.format(tool_names=[t.name for t in toolbelt])
# Add file attachment information if available
if state.get("file_path"):
file_path = state["file_path"]
router_prompt += f"\n\n**IMPORTANT: There is an attached file at: {file_path}**"
router_prompt += f"\nYou MUST use the appropriate tool to analyze this file."
router_prompt += f"\n**CRITICAL:** Use this EXACT file path in your code: {file_path}"
router_prompt += f"\nDO NOT modify or simplify the path. Copy and paste it exactly as shown above."
context = [HumanMessage(content=router_prompt)] + messages
# Log the input messages to the model
logger.debug(f"Router Node - Input Messages: {[msg.content for msg in context]}")
# Call the model
response = get_model_with_tools().invoke(context)
# Log the response from the model
logger.debug(f"Router Node - Output Message: {response.content}")
logger.debug(f"Router Node - Response: {[response]}")
return {"messages": [response]}
def reflection_node(state: AgentState) -> dict:
"""Node for self-reflection and error correction."""
question = state["question"]
messages = state["messages"]
last_message = messages[-1]
logger.debug(f"Reflection Node - Last Message: {last_message}")
# Extract the failed tool call and error
tool_call = last_message.additional_kwargs.get("tool_calls", [])
error_message = last_message.content
prompt = REFLECTION_PROMPT.format(
question=question,
messages=messages[:-1], # Exclude the error message itself
tool_call=tool_call,
error_message=error_message
)
response = get_model_with_tools().invoke(prompt)
logger.debug(f"Reflection Node - Response: {[response]}")
return {"messages": [response]}
def information_extraction_node(state: AgentState) -> dict:
"""Extracts only relevant information from tool results based on the original question."""
question = state["question"]
messages = state["messages"]
# Get the last tool result
last_message = messages[-1]
# Only process if the last message is a tool result
if not hasattr(last_message, 'name'):
return {"messages": []} # No changes needed
tool_result = last_message.content
tool_name = last_message.name
# Create extraction prompt
extraction_prompt = f"""You are an information extraction assistant. Your job is to read the result from a tool and determine if the original question can be answered.
Original Question: {question}
Tool Result: {tool_result}
**Instructions:**
1. **Summarize Key Facts:** List the most important facts you found in the tool result.
2. **Assess Progress:** Can the original question be fully answered with the information found?
3. **Decision:** If you have sufficient information to answer the question, start your response with "ANSWER FOUND:" followed by the answer. If not, start with "CONTINUE SEARCHING:" and explain what's still needed.
**Your Output:**
[Your analysis and decision]
"""
try:
response = get_extraction_model().invoke(extraction_prompt)
extracted_info = response.content
# Check if the extraction indicates the answer has been found
if "ANSWER FOUND:" in extracted_info:
# Signal that we have the answer and should stop
extraction_message = AIMessage(
content=f"Key Information Extracted: {extracted_info}\n\nI have found sufficient information to answer the question. No more searching needed."
)
else:
# Continue normal extraction
extraction_message = AIMessage(
content=f"Key Information Extracted: {extracted_info}"
)
logger.debug(f"Information Extraction Node - Original: {tool_result[:200]}...")
logger.debug(f"Information Extraction Node - Extracted: {extracted_info}")
return {"messages": [extraction_message]}
except Exception as e:
logger.error(f"Information extraction error: {e}")
return {"messages": []} # Return empty if extraction fails
# This is a pre-built node from LangGraph that executes tools
tool_node = ToolNode(toolbelt)
def final_formatting_node(state: AgentState):
"""Extracts and formats the final answer."""
question = state["question"]
messages = state["messages"]
logger.debug(f"Final Formatting Node - Received {len(messages)} messages")
def extract_text_content(content):
"""Helper function to extract text from structured content."""
if isinstance(content, str):
return content
elif isinstance(content, list):
# Handle structured content with text and tool_use blocks
text_parts = []
for item in content:
if isinstance(item, dict) and item.get('type') == 'text':
text_parts.append(item.get('text', ''))
return ' '.join(text_parts) if text_parts else str(content)
else:
return str(content)
def extract_key_info(text):
"""Extract only relevant information from search results."""
lines = text.split('\n')
key_info = []
for line in lines:
line = line.strip()
# Skip empty lines and obvious metadata
if not line or line.startswith('---'):
continue
if line.startswith('URL:'):
continue
# Skip web search result headers but keep the content
if line.startswith('Web Search Result') and ':' in line:
# Extract the content after the colon
parts = line.split(':', 1)
if len(parts) > 1:
content = parts[1].strip()
if content:
line = content
else:
continue
# Skip completely irrelevant content
irrelevant_terms = ['arthritis', 'αλφουζοσίνη', 'ουρικό', 'stohr', 'bischof',
'vatikan', 'google arama', 'whatsapp', 'calculatrice',
'hotmail', 'sfr mail', 'orange.pl', '50 cent', 'rapper',
'flight delay', 'ec 261', 'insurance', 'generali']
if any(term in line.lower() for term in irrelevant_terms):
continue
key_info.append(line)
return '\n'.join(key_info) if key_info else ''
# Format messages for better readability by the final formatter
# Filter out tool calls and metadata, keep only essential reasoning
formatted_messages = []
# First, add the original question
formatted_messages.append(f"Question: {question}")
for i, msg in enumerate(messages):
logger.debug(f"Processing message {i}: type={type(msg).__name__}, has_content={hasattr(msg, 'content')}")
if hasattr(msg, 'content'):
msg_type = type(msg).__name__
if msg_type == "HumanMessage":
# Skip the first human message as we already added the question
if i > 0:
text_content = extract_text_content(msg.content)
if text_content.strip():
formatted_messages.append(f"Human: {text_content}")
elif msg_type == "AIMessage":
# Handle AI messages
text_content = extract_text_content(msg.content)
# Check if this is an extraction result
if "Key Information Extracted:" in text_content:
formatted_messages.append(f"Extracted: {text_content}")
elif hasattr(msg, 'tool_calls') and msg.tool_calls:
# AI message with tool calls - include the reasoning
if text_content.strip():
formatted_messages.append(f"AI Reasoning: {text_content}")
# Add tool call info
for tool_call in msg.tool_calls:
tool_name = tool_call.get('name', 'Unknown')
formatted_messages.append(f"Tool Called: {tool_name}")
else:
# Regular AI message
if text_content.strip():
formatted_messages.append(f"AI: {text_content}")
elif msg_type == "ToolMessage" or hasattr(msg, 'name'):
# Handle tool messages
tool_name = getattr(msg, 'name', 'Unknown Tool')
text_content = extract_text_content(msg.content)
# Extract key information from tool results
key_info = extract_key_info(text_content)
if key_info:
formatted_messages.append(f"Tool Result ({tool_name}): {key_info[:500]}...")
else:
# If no key info extracted, include a summary
formatted_messages.append(f"Tool Result ({tool_name}): [No relevant information found]")
conversation_history = "\n".join(formatted_messages)
# Log the conversation history for debugging
logger.debug(f"Final Formatting Node - Conversation History Length: {len(conversation_history)}")
logger.debug(f"Final Formatting Node - Conversation History Preview: {conversation_history[:500]}...")
# If conversation history is still empty, create a minimal one
if not conversation_history or conversation_history.strip() == f"Question: {question}":
logger.warning("Conversation history is empty or minimal, constructing from raw messages")
conversation_history = f"Question: {question}\n"
for msg in messages[1:]: # Skip first message (the question)
if hasattr(msg, 'content'):
content = str(msg.content)[:200]
conversation_history += f"\n{type(msg).__name__}: {content}..."
prompt = FINAL_ANSWER_FORMATTER_PROMPT.format(question=question, messages=conversation_history)
response = get_final_formatter_model().invoke(prompt)
# Handle Claude Sonnet 4 with thinking enabled - extract text from structured response
if isinstance(response.content, list):
# Find the text content from the structured response
text_content = ""
for item in response.content:
if isinstance(item, dict) and item.get('type') == 'text':
text_content = item.get('text', '')
break
final_answer = text_content
else:
# Fallback for simple string responses
final_answer = response.content
logger.debug(f"Final Formatting Node - Generated Answer: {final_answer[:100]}...")
return {"final_answer": final_answer}
### CONDITIONAL EDGE LOGIC ###
def should_continue(state: AgentState) -> str:
"""Determines the next step after the router or reflection node."""
last_message = state["messages"][-1]
# Check if the last message indicates the answer has been found
if hasattr(last_message, 'content') and isinstance(last_message.content, str):
if "I have found sufficient information to answer the question" in last_message.content:
return "end"
if "ANSWER FOUND:" in last_message.content:
return "end"
# If the model produced a tool call, we execute it
if last_message.tool_calls:
return "use_tool"
# If there are no tool calls, we are done
return "end"
def after_tool_use(state: AgentState) -> str:
"""Determines the next step after a tool has been used."""
last_message = state["messages"][-1]
logger.debug(f"After Tool Use - Last Message: {last_message}")
# The ToolNode adds a ToolMessage. Check if it contains an error.
if isinstance(last_message, ToolMessage) and "Error:" in last_message.content:
return "reflect"
# If the tool executed successfully, go to information extraction first
return "extract"
def after_extraction(state: AgentState) -> str:
"""Determines the next step after information extraction."""
# After extraction, always go back to the router to decide the next step
return "continue"
### GRAPH ASSEMBLY ###
def create_agent_graph() -> StateGraph:
"""Builds and compiles the agent state machine."""
graph = StateGraph(AgentState)
# Add nodes to the graph
graph.add_node("router", router_node)
graph.add_node("tool_node", tool_node)
graph.add_node("extraction", information_extraction_node)
graph.add_node("reflector", reflection_node)
# Define the graph's entry point
graph.set_entry_point("router")
# Define the edges
graph.add_conditional_edges(
"router",
should_continue,
{
"use_tool": "tool_node",
"end": END
}
)
graph.add_conditional_edges(
"tool_node",
after_tool_use,
{
"extract": "extraction",
"reflect": "reflector"
}
)
# After extraction, always go back to router
graph.add_conditional_edges(
"extraction",
after_extraction,
{
"continue": "router"
}
)
graph.add_conditional_edges(
"reflector",
should_continue,
{
"use_tool": "tool_node",
"end": END
}
)
# Compile the graph into a runnable object
agent_graph = graph.compile()
return agent_graph, final_formatting_node