Spaces:
Paused
Paused
| # src/supervisor/graph.py | |
| import operator | |
| from typing import TypedDict, Annotated, List | |
| import logging | |
| from langchain_core.messages import BaseMessage, ToolMessage, AIMessage, HumanMessage | |
| from langchain_anthropic import ChatAnthropic | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.prebuilt import ToolNode | |
| logger = logging.getLogger(__name__) | |
| from src.supervisor.state import AgentState | |
| from src.tools.toolbelt import toolbelt | |
| from src.prompts import ROUTER_PROMPT, REFLECTION_PROMPT, FINAL_ANSWER_FORMATTER_PROMPT | |
| from src.utils.config import config | |
| # Initialize the LLM for our nodes (on demand) | |
| # Using a powerful model for routing and reflection is key. | |
| def get_model(): | |
| """Get the main model instance with Extended Thinking enabled, initialized on first use.""" | |
| if not hasattr(get_model, '_model'): | |
| # get_model._model = ChatAnthropic( | |
| # model="claude-3-5-haiku-20241022", | |
| # temperature=0, | |
| # api_key=config.CLAUDE_API_KEY | |
| # ) | |
| get_model._model = ChatAnthropic( | |
| model="claude-sonnet-4-20250514", | |
| temperature=1, | |
| api_key=config.CLAUDE_API_KEY, | |
| thinking={"type": "enabled", "budget_tokens": 6000}, # Reduced budget | |
| max_tokens=16000 # Now greater than budget | |
| ) | |
| return get_model._model | |
| def get_final_formatter_model(): | |
| """Get the formatter model instance with Extended Thinking enabled, initialized on first use.""" | |
| if not hasattr(get_final_formatter_model, '_model'): | |
| get_final_formatter_model._model = ChatAnthropic( | |
| model="claude-sonnet-4-20250514", | |
| temperature=1, | |
| api_key=config.CLAUDE_API_KEY, | |
| thinking={"type": "enabled", "budget_tokens": 8000}, # Reduced budget | |
| max_tokens=16000 # Now greater than budget | |
| ) | |
| return get_final_formatter_model._model | |
| def get_model_with_tools(): | |
| """Get the model with tools bound, initialized on first use.""" | |
| if not hasattr(get_model_with_tools, '_model'): | |
| get_model_with_tools._model = get_model().bind_tools(toolbelt) | |
| return get_model_with_tools._model | |
| def get_extraction_model(): | |
| """Get the information extraction model instance, initialized on first use.""" | |
| if not hasattr(get_extraction_model, '_model'): | |
| get_extraction_model._model = ChatAnthropic( | |
| model="claude-sonnet-4-20250514", # model="claude-3-5-haiku-20241022", | |
| temperature=0, | |
| api_key=config.CLAUDE_API_KEY | |
| ) | |
| return get_extraction_model._model | |
| ### NODE DEFINITIONS ### | |
| def router_node(state: AgentState) -> dict: | |
| """The central router. Decides what to do next.""" | |
| messages = state["messages"] | |
| # Check if the last message contains an answer that should be extracted | |
| if messages and hasattr(messages[-1], 'content') and isinstance(messages[-1].content, str): | |
| last_content = messages[-1].content | |
| if "ANSWER FOUND:" in last_content: | |
| # Extract the answer from the content | |
| answer_start = last_content.find("ANSWER FOUND:") + len("ANSWER FOUND:") | |
| answer_text = last_content[answer_start:].strip() | |
| # Create a response with the extracted answer | |
| response = AIMessage(content=f"Based on the analysis, the answer is: {answer_text}") | |
| logger.debug(f"Router Node - Extracted Answer: {answer_text}") | |
| logger.debug(f"Router Node - Response: {[response]}") | |
| return {"messages": [response]} | |
| # Add the system prompt to the message list for the LLM call | |
| router_prompt = ROUTER_PROMPT.format(tool_names=[t.name for t in toolbelt]) | |
| # Add file attachment information if available | |
| if state.get("file_path"): | |
| file_path = state["file_path"] | |
| router_prompt += f"\n\n**IMPORTANT: There is an attached file at: {file_path}**" | |
| router_prompt += f"\nYou MUST use the appropriate tool to analyze this file." | |
| router_prompt += f"\n**CRITICAL:** Use this EXACT file path in your code: {file_path}" | |
| router_prompt += f"\nDO NOT modify or simplify the path. Copy and paste it exactly as shown above." | |
| context = [HumanMessage(content=router_prompt)] + messages | |
| # Log the input messages to the model | |
| logger.debug(f"Router Node - Input Messages: {[msg.content for msg in context]}") | |
| # Call the model | |
| response = get_model_with_tools().invoke(context) | |
| # Log the response from the model | |
| logger.debug(f"Router Node - Output Message: {response.content}") | |
| logger.debug(f"Router Node - Response: {[response]}") | |
| return {"messages": [response]} | |
| def reflection_node(state: AgentState) -> dict: | |
| """Node for self-reflection and error correction.""" | |
| question = state["question"] | |
| messages = state["messages"] | |
| last_message = messages[-1] | |
| logger.debug(f"Reflection Node - Last Message: {last_message}") | |
| # Extract the failed tool call and error | |
| tool_call = last_message.additional_kwargs.get("tool_calls", []) | |
| error_message = last_message.content | |
| prompt = REFLECTION_PROMPT.format( | |
| question=question, | |
| messages=messages[:-1], # Exclude the error message itself | |
| tool_call=tool_call, | |
| error_message=error_message | |
| ) | |
| response = get_model_with_tools().invoke(prompt) | |
| logger.debug(f"Reflection Node - Response: {[response]}") | |
| return {"messages": [response]} | |
| def information_extraction_node(state: AgentState) -> dict: | |
| """Extracts only relevant information from tool results based on the original question.""" | |
| question = state["question"] | |
| messages = state["messages"] | |
| # Get the last tool result | |
| last_message = messages[-1] | |
| # Only process if the last message is a tool result | |
| if not hasattr(last_message, 'name'): | |
| return {"messages": []} # No changes needed | |
| tool_result = last_message.content | |
| tool_name = last_message.name | |
| # Create extraction prompt | |
| extraction_prompt = f"""You are an information extraction assistant. Your job is to read the result from a tool and determine if the original question can be answered. | |
| Original Question: {question} | |
| Tool Result: {tool_result} | |
| **Instructions:** | |
| 1. **Summarize Key Facts:** List the most important facts you found in the tool result. | |
| 2. **Assess Progress:** Can the original question be fully answered with the information found? | |
| 3. **Decision:** If you have sufficient information to answer the question, start your response with "ANSWER FOUND:" followed by the answer. If not, start with "CONTINUE SEARCHING:" and explain what's still needed. | |
| **Your Output:** | |
| [Your analysis and decision] | |
| """ | |
| try: | |
| response = get_extraction_model().invoke(extraction_prompt) | |
| extracted_info = response.content | |
| # Check if the extraction indicates the answer has been found | |
| if "ANSWER FOUND:" in extracted_info: | |
| # Signal that we have the answer and should stop | |
| extraction_message = AIMessage( | |
| content=f"Key Information Extracted: {extracted_info}\n\nI have found sufficient information to answer the question. No more searching needed." | |
| ) | |
| else: | |
| # Continue normal extraction | |
| extraction_message = AIMessage( | |
| content=f"Key Information Extracted: {extracted_info}" | |
| ) | |
| logger.debug(f"Information Extraction Node - Original: {tool_result[:200]}...") | |
| logger.debug(f"Information Extraction Node - Extracted: {extracted_info}") | |
| return {"messages": [extraction_message]} | |
| except Exception as e: | |
| logger.error(f"Information extraction error: {e}") | |
| return {"messages": []} # Return empty if extraction fails | |
| # This is a pre-built node from LangGraph that executes tools | |
| tool_node = ToolNode(toolbelt) | |
| def final_formatting_node(state: AgentState): | |
| """Extracts and formats the final answer.""" | |
| question = state["question"] | |
| messages = state["messages"] | |
| logger.debug(f"Final Formatting Node - Received {len(messages)} messages") | |
| def extract_text_content(content): | |
| """Helper function to extract text from structured content.""" | |
| if isinstance(content, str): | |
| return content | |
| elif isinstance(content, list): | |
| # Handle structured content with text and tool_use blocks | |
| text_parts = [] | |
| for item in content: | |
| if isinstance(item, dict) and item.get('type') == 'text': | |
| text_parts.append(item.get('text', '')) | |
| return ' '.join(text_parts) if text_parts else str(content) | |
| else: | |
| return str(content) | |
| def extract_key_info(text): | |
| """Extract only relevant information from search results.""" | |
| lines = text.split('\n') | |
| key_info = [] | |
| for line in lines: | |
| line = line.strip() | |
| # Skip empty lines and obvious metadata | |
| if not line or line.startswith('---'): | |
| continue | |
| if line.startswith('URL:'): | |
| continue | |
| # Skip web search result headers but keep the content | |
| if line.startswith('Web Search Result') and ':' in line: | |
| # Extract the content after the colon | |
| parts = line.split(':', 1) | |
| if len(parts) > 1: | |
| content = parts[1].strip() | |
| if content: | |
| line = content | |
| else: | |
| continue | |
| # Skip completely irrelevant content | |
| irrelevant_terms = ['arthritis', 'αλφουζοσίνη', 'ουρικό', 'stohr', 'bischof', | |
| 'vatikan', 'google arama', 'whatsapp', 'calculatrice', | |
| 'hotmail', 'sfr mail', 'orange.pl', '50 cent', 'rapper', | |
| 'flight delay', 'ec 261', 'insurance', 'generali'] | |
| if any(term in line.lower() for term in irrelevant_terms): | |
| continue | |
| key_info.append(line) | |
| return '\n'.join(key_info) if key_info else '' | |
| # Format messages for better readability by the final formatter | |
| # Filter out tool calls and metadata, keep only essential reasoning | |
| formatted_messages = [] | |
| # First, add the original question | |
| formatted_messages.append(f"Question: {question}") | |
| for i, msg in enumerate(messages): | |
| logger.debug(f"Processing message {i}: type={type(msg).__name__}, has_content={hasattr(msg, 'content')}") | |
| if hasattr(msg, 'content'): | |
| msg_type = type(msg).__name__ | |
| if msg_type == "HumanMessage": | |
| # Skip the first human message as we already added the question | |
| if i > 0: | |
| text_content = extract_text_content(msg.content) | |
| if text_content.strip(): | |
| formatted_messages.append(f"Human: {text_content}") | |
| elif msg_type == "AIMessage": | |
| # Handle AI messages | |
| text_content = extract_text_content(msg.content) | |
| # Check if this is an extraction result | |
| if "Key Information Extracted:" in text_content: | |
| formatted_messages.append(f"Extracted: {text_content}") | |
| elif hasattr(msg, 'tool_calls') and msg.tool_calls: | |
| # AI message with tool calls - include the reasoning | |
| if text_content.strip(): | |
| formatted_messages.append(f"AI Reasoning: {text_content}") | |
| # Add tool call info | |
| for tool_call in msg.tool_calls: | |
| tool_name = tool_call.get('name', 'Unknown') | |
| formatted_messages.append(f"Tool Called: {tool_name}") | |
| else: | |
| # Regular AI message | |
| if text_content.strip(): | |
| formatted_messages.append(f"AI: {text_content}") | |
| elif msg_type == "ToolMessage" or hasattr(msg, 'name'): | |
| # Handle tool messages | |
| tool_name = getattr(msg, 'name', 'Unknown Tool') | |
| text_content = extract_text_content(msg.content) | |
| # Extract key information from tool results | |
| key_info = extract_key_info(text_content) | |
| if key_info: | |
| formatted_messages.append(f"Tool Result ({tool_name}): {key_info[:500]}...") | |
| else: | |
| # If no key info extracted, include a summary | |
| formatted_messages.append(f"Tool Result ({tool_name}): [No relevant information found]") | |
| conversation_history = "\n".join(formatted_messages) | |
| # Log the conversation history for debugging | |
| logger.debug(f"Final Formatting Node - Conversation History Length: {len(conversation_history)}") | |
| logger.debug(f"Final Formatting Node - Conversation History Preview: {conversation_history[:500]}...") | |
| # If conversation history is still empty, create a minimal one | |
| if not conversation_history or conversation_history.strip() == f"Question: {question}": | |
| logger.warning("Conversation history is empty or minimal, constructing from raw messages") | |
| conversation_history = f"Question: {question}\n" | |
| for msg in messages[1:]: # Skip first message (the question) | |
| if hasattr(msg, 'content'): | |
| content = str(msg.content)[:200] | |
| conversation_history += f"\n{type(msg).__name__}: {content}..." | |
| prompt = FINAL_ANSWER_FORMATTER_PROMPT.format(question=question, messages=conversation_history) | |
| response = get_final_formatter_model().invoke(prompt) | |
| # Handle Claude Sonnet 4 with thinking enabled - extract text from structured response | |
| if isinstance(response.content, list): | |
| # Find the text content from the structured response | |
| text_content = "" | |
| for item in response.content: | |
| if isinstance(item, dict) and item.get('type') == 'text': | |
| text_content = item.get('text', '') | |
| break | |
| final_answer = text_content | |
| else: | |
| # Fallback for simple string responses | |
| final_answer = response.content | |
| logger.debug(f"Final Formatting Node - Generated Answer: {final_answer[:100]}...") | |
| return {"final_answer": final_answer} | |
| ### CONDITIONAL EDGE LOGIC ### | |
| def should_continue(state: AgentState) -> str: | |
| """Determines the next step after the router or reflection node.""" | |
| last_message = state["messages"][-1] | |
| # Check if the last message indicates the answer has been found | |
| if hasattr(last_message, 'content') and isinstance(last_message.content, str): | |
| if "I have found sufficient information to answer the question" in last_message.content: | |
| return "end" | |
| if "ANSWER FOUND:" in last_message.content: | |
| return "end" | |
| # If the model produced a tool call, we execute it | |
| if last_message.tool_calls: | |
| return "use_tool" | |
| # If there are no tool calls, we are done | |
| return "end" | |
| def after_tool_use(state: AgentState) -> str: | |
| """Determines the next step after a tool has been used.""" | |
| last_message = state["messages"][-1] | |
| logger.debug(f"After Tool Use - Last Message: {last_message}") | |
| # The ToolNode adds a ToolMessage. Check if it contains an error. | |
| if isinstance(last_message, ToolMessage) and "Error:" in last_message.content: | |
| return "reflect" | |
| # If the tool executed successfully, go to information extraction first | |
| return "extract" | |
| def after_extraction(state: AgentState) -> str: | |
| """Determines the next step after information extraction.""" | |
| # After extraction, always go back to the router to decide the next step | |
| return "continue" | |
| ### GRAPH ASSEMBLY ### | |
| def create_agent_graph() -> StateGraph: | |
| """Builds and compiles the agent state machine.""" | |
| graph = StateGraph(AgentState) | |
| # Add nodes to the graph | |
| graph.add_node("router", router_node) | |
| graph.add_node("tool_node", tool_node) | |
| graph.add_node("extraction", information_extraction_node) | |
| graph.add_node("reflector", reflection_node) | |
| # Define the graph's entry point | |
| graph.set_entry_point("router") | |
| # Define the edges | |
| graph.add_conditional_edges( | |
| "router", | |
| should_continue, | |
| { | |
| "use_tool": "tool_node", | |
| "end": END | |
| } | |
| ) | |
| graph.add_conditional_edges( | |
| "tool_node", | |
| after_tool_use, | |
| { | |
| "extract": "extraction", | |
| "reflect": "reflector" | |
| } | |
| ) | |
| # After extraction, always go back to router | |
| graph.add_conditional_edges( | |
| "extraction", | |
| after_extraction, | |
| { | |
| "continue": "router" | |
| } | |
| ) | |
| graph.add_conditional_edges( | |
| "reflector", | |
| should_continue, | |
| { | |
| "use_tool": "tool_node", | |
| "end": END | |
| } | |
| ) | |
| # Compile the graph into a runnable object | |
| agent_graph = graph.compile() | |
| return agent_graph, final_formatting_node |