File size: 17,502 Bytes
d61265e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1d8c81
c55774d
 
 
 
 
d61265e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
# src/supervisor/graph.py

import operator
from typing import TypedDict, Annotated, List
import logging

from langchain_core.messages import BaseMessage, ToolMessage, AIMessage, HumanMessage
from langchain_anthropic import ChatAnthropic
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode

logger = logging.getLogger(__name__)

from src.supervisor.state import AgentState
from src.tools.toolbelt import toolbelt
from src.prompts import ROUTER_PROMPT, REFLECTION_PROMPT, FINAL_ANSWER_FORMATTER_PROMPT
from src.utils.config import config

# Initialize the LLM for our nodes (on demand)
# Using a powerful model for routing and reflection is key.
def get_model():
    """Get the main model instance with Extended Thinking enabled, initialized on first use."""
    if not hasattr(get_model, '_model'):
        # get_model._model = ChatAnthropic(
        #     model="claude-3-5-haiku-20241022",
        #     temperature=0,
        #     api_key=config.CLAUDE_API_KEY
        # )
        get_model._model = ChatAnthropic(
            model="claude-sonnet-4-20250514",
            temperature=1,
            api_key=config.CLAUDE_API_KEY,
            thinking={"type": "enabled", "budget_tokens": 6000},  # Reduced budget
            max_tokens=16000  # Now greater than budget
        )
    return get_model._model

def get_final_formatter_model():
    """Get the formatter model instance with Extended Thinking enabled, initialized on first use."""
    if not hasattr(get_final_formatter_model, '_model'):
        get_final_formatter_model._model = ChatAnthropic(
            model="claude-sonnet-4-20250514",
            temperature=1,
            api_key=config.CLAUDE_API_KEY,
            thinking={"type": "enabled", "budget_tokens": 8000},  # Reduced budget
            max_tokens=16000  # Now greater than budget
        )
    return get_final_formatter_model._model

def get_model_with_tools():
    """Get the model with tools bound, initialized on first use."""
    if not hasattr(get_model_with_tools, '_model'):
        get_model_with_tools._model = get_model().bind_tools(toolbelt)
    return get_model_with_tools._model

def get_extraction_model():
    """Get the information extraction model instance, initialized on first use."""
    if not hasattr(get_extraction_model, '_model'):
        get_extraction_model._model = ChatAnthropic(
            model="claude-sonnet-4-20250514", #    model="claude-3-5-haiku-20241022",
            temperature=0,
            api_key=config.CLAUDE_API_KEY
        )
    return get_extraction_model._model

### NODE DEFINITIONS ###

def router_node(state: AgentState) -> dict:
    """The central router. Decides what to do next."""
    messages = state["messages"]
    
    # Check if the last message contains an answer that should be extracted
    if messages and hasattr(messages[-1], 'content') and isinstance(messages[-1].content, str):
        last_content = messages[-1].content
        if "ANSWER FOUND:" in last_content:
            # Extract the answer from the content
            answer_start = last_content.find("ANSWER FOUND:") + len("ANSWER FOUND:")
            answer_text = last_content[answer_start:].strip()
            
            # Create a response with the extracted answer
            response = AIMessage(content=f"Based on the analysis, the answer is: {answer_text}")
            logger.debug(f"Router Node - Extracted Answer: {answer_text}")
            logger.debug(f"Router Node - Response: {[response]}")
            return {"messages": [response]}
    
    # Add the system prompt to the message list for the LLM call
    router_prompt = ROUTER_PROMPT.format(tool_names=[t.name for t in toolbelt])
    
    # Add file attachment information if available
    if state.get("file_path"):
        file_path = state["file_path"]
        router_prompt += f"\n\n**IMPORTANT: There is an attached file at: {file_path}**"
        router_prompt += f"\nYou MUST use the appropriate tool to analyze this file."
        router_prompt += f"\n**CRITICAL:** Use this EXACT file path in your code: {file_path}"
        router_prompt += f"\nDO NOT modify or simplify the path. Copy and paste it exactly as shown above."
    
    context = [HumanMessage(content=router_prompt)] + messages
    
    # Log the input messages to the model
    logger.debug(f"Router Node - Input Messages: {[msg.content for msg in context]}")

    # Call the model
    response = get_model_with_tools().invoke(context)
    
    # Log the response from the model
    logger.debug(f"Router Node - Output Message: {response.content}")
    logger.debug(f"Router Node - Response: {[response]}")
    return {"messages": [response]}

def reflection_node(state: AgentState) -> dict:
    """Node for self-reflection and error correction."""
    question = state["question"]
    messages = state["messages"]
    last_message = messages[-1]
    logger.debug(f"Reflection Node - Last Message: {last_message}")
    # Extract the failed tool call and error
    tool_call = last_message.additional_kwargs.get("tool_calls", [])
    error_message = last_message.content
    
    prompt = REFLECTION_PROMPT.format(
        question=question,
        messages=messages[:-1], # Exclude the error message itself
        tool_call=tool_call,
        error_message=error_message
    )
    
    response = get_model_with_tools().invoke(prompt)
    logger.debug(f"Reflection Node - Response: {[response]}")
    return {"messages": [response]}

def information_extraction_node(state: AgentState) -> dict:
    """Extracts only relevant information from tool results based on the original question."""
    question = state["question"]
    messages = state["messages"]
    
    # Get the last tool result
    last_message = messages[-1]
    
    # Only process if the last message is a tool result
    if not hasattr(last_message, 'name'):
        return {"messages": []}  # No changes needed
    
    tool_result = last_message.content
    tool_name = last_message.name
    
    # Create extraction prompt
    extraction_prompt = f"""You are an information extraction assistant. Your job is to read the result from a tool and determine if the original question can be answered.

Original Question: {question}
Tool Result: {tool_result}

**Instructions:**
1.  **Summarize Key Facts:** List the most important facts you found in the tool result.
2.  **Assess Progress:** Can the original question be fully answered with the information found? 
3.  **Decision:** If you have sufficient information to answer the question, start your response with "ANSWER FOUND:" followed by the answer. If not, start with "CONTINUE SEARCHING:" and explain what's still needed.

**Your Output:**
[Your analysis and decision]
"""

    try:
        response = get_extraction_model().invoke(extraction_prompt)
        extracted_info = response.content
        
        # Check if the extraction indicates the answer has been found
        if "ANSWER FOUND:" in extracted_info:
            # Signal that we have the answer and should stop
            extraction_message = AIMessage(
                content=f"Key Information Extracted: {extracted_info}\n\nI have found sufficient information to answer the question. No more searching needed."
            )
        else:
            # Continue normal extraction
            extraction_message = AIMessage(
                content=f"Key Information Extracted: {extracted_info}"
            )
        
        logger.debug(f"Information Extraction Node - Original: {tool_result[:200]}...")
        logger.debug(f"Information Extraction Node - Extracted: {extracted_info}")
        
        return {"messages": [extraction_message]}
        
    except Exception as e:
        logger.error(f"Information extraction error: {e}")
        return {"messages": []}  # Return empty if extraction fails

# This is a pre-built node from LangGraph that executes tools
tool_node = ToolNode(toolbelt)

def final_formatting_node(state: AgentState):
    """Extracts and formats the final answer."""
    question = state["question"]
    messages = state["messages"]
    
    logger.debug(f"Final Formatting Node - Received {len(messages)} messages")
    
    def extract_text_content(content):
        """Helper function to extract text from structured content."""
        if isinstance(content, str):
            return content
        elif isinstance(content, list):
            # Handle structured content with text and tool_use blocks
            text_parts = []
            for item in content:
                if isinstance(item, dict) and item.get('type') == 'text':
                    text_parts.append(item.get('text', ''))
            return ' '.join(text_parts) if text_parts else str(content)
        else:
            return str(content)
    
    def extract_key_info(text):
        """Extract only relevant information from search results."""
        lines = text.split('\n')
        key_info = []
        
        for line in lines:
            line = line.strip()
            # Skip empty lines and obvious metadata
            if not line or line.startswith('---'):
                continue
            if line.startswith('URL:'):
                continue
                
            # Skip web search result headers but keep the content
            if line.startswith('Web Search Result') and ':' in line:
                # Extract the content after the colon
                parts = line.split(':', 1)
                if len(parts) > 1:
                    content = parts[1].strip()
                    if content:
                        line = content
                else:
                    continue
            
            # Skip completely irrelevant content
            irrelevant_terms = ['arthritis', 'αλφουζοσίνη', 'ουρικό', 'stohr', 'bischof', 
                               'vatikan', 'google arama', 'whatsapp', 'calculatrice', 
                               'hotmail', 'sfr mail', 'orange.pl', '50 cent', 'rapper',
                               'flight delay', 'ec 261', 'insurance', 'generali']
            
            if any(term in line.lower() for term in irrelevant_terms):
                continue
                
            key_info.append(line)
            
        
        return '\n'.join(key_info) if key_info else ''
    
    # Format messages for better readability by the final formatter
    # Filter out tool calls and metadata, keep only essential reasoning
    formatted_messages = []
    
    # First, add the original question
    formatted_messages.append(f"Question: {question}")
    
    for i, msg in enumerate(messages):
        logger.debug(f"Processing message {i}: type={type(msg).__name__}, has_content={hasattr(msg, 'content')}")
        
        if hasattr(msg, 'content'):
            msg_type = type(msg).__name__
            
            if msg_type == "HumanMessage":
                # Skip the first human message as we already added the question
                if i > 0:
                    text_content = extract_text_content(msg.content)
                    if text_content.strip():
                        formatted_messages.append(f"Human: {text_content}")
                        
            elif msg_type == "AIMessage":
                # Handle AI messages
                text_content = extract_text_content(msg.content)
                
                # Check if this is an extraction result
                if "Key Information Extracted:" in text_content:
                    formatted_messages.append(f"Extracted: {text_content}")
                elif hasattr(msg, 'tool_calls') and msg.tool_calls:
                    # AI message with tool calls - include the reasoning
                    if text_content.strip():
                        formatted_messages.append(f"AI Reasoning: {text_content}")
                    # Add tool call info
                    for tool_call in msg.tool_calls:
                        tool_name = tool_call.get('name', 'Unknown')
                        formatted_messages.append(f"Tool Called: {tool_name}")
                else:
                    # Regular AI message
                    if text_content.strip():
                        formatted_messages.append(f"AI: {text_content}")
                        
            elif msg_type == "ToolMessage" or hasattr(msg, 'name'):
                # Handle tool messages
                tool_name = getattr(msg, 'name', 'Unknown Tool')
                text_content = extract_text_content(msg.content)
                
                # Extract key information from tool results
                key_info = extract_key_info(text_content)
                if key_info:
                    formatted_messages.append(f"Tool Result ({tool_name}): {key_info[:500]}...")
                else:
                    # If no key info extracted, include a summary
                    formatted_messages.append(f"Tool Result ({tool_name}): [No relevant information found]")
    
    conversation_history = "\n".join(formatted_messages)
    
    # Log the conversation history for debugging
    logger.debug(f"Final Formatting Node - Conversation History Length: {len(conversation_history)}")
    logger.debug(f"Final Formatting Node - Conversation History Preview: {conversation_history[:500]}...")
    
    # If conversation history is still empty, create a minimal one
    if not conversation_history or conversation_history.strip() == f"Question: {question}":
        logger.warning("Conversation history is empty or minimal, constructing from raw messages")
        conversation_history = f"Question: {question}\n"
        for msg in messages[1:]:  # Skip first message (the question)
            if hasattr(msg, 'content'):
                content = str(msg.content)[:200]
                conversation_history += f"\n{type(msg).__name__}: {content}..."
    
    prompt = FINAL_ANSWER_FORMATTER_PROMPT.format(question=question, messages=conversation_history)
    response = get_final_formatter_model().invoke(prompt)
    
    # Handle Claude Sonnet 4 with thinking enabled - extract text from structured response
    if isinstance(response.content, list):
        # Find the text content from the structured response
        text_content = ""
        for item in response.content:
            if isinstance(item, dict) and item.get('type') == 'text':
                text_content = item.get('text', '')
                break
        final_answer = text_content
    else:
        # Fallback for simple string responses
        final_answer = response.content
    
    logger.debug(f"Final Formatting Node - Generated Answer: {final_answer[:100]}...")
    
    return {"final_answer": final_answer}

### CONDITIONAL EDGE LOGIC ###

def should_continue(state: AgentState) -> str:
    """Determines the next step after the router or reflection node."""
    last_message = state["messages"][-1]

    # Check if the last message indicates the answer has been found
    if hasattr(last_message, 'content') and isinstance(last_message.content, str):
        if "I have found sufficient information to answer the question" in last_message.content:
            return "end"
        if "ANSWER FOUND:" in last_message.content:
            return "end"

    # If the model produced a tool call, we execute it
    if last_message.tool_calls:
        return "use_tool"
    
    # If there are no tool calls, we are done
    return "end"

def after_tool_use(state: AgentState) -> str:
    """Determines the next step after a tool has been used."""
    last_message = state["messages"][-1]
    logger.debug(f"After Tool Use - Last Message: {last_message}")
    # The ToolNode adds a ToolMessage. Check if it contains an error.
    if isinstance(last_message, ToolMessage) and "Error:" in last_message.content:
        return "reflect"
        
    # If the tool executed successfully, go to information extraction first
    return "extract"

def after_extraction(state: AgentState) -> str:
    """Determines the next step after information extraction."""
    # After extraction, always go back to the router to decide the next step
    return "continue"

### GRAPH ASSEMBLY ###

def create_agent_graph() -> StateGraph:
    """Builds and compiles the agent state machine."""
    graph = StateGraph(AgentState)
    
    # Add nodes to the graph
    graph.add_node("router", router_node)
    graph.add_node("tool_node", tool_node)
    graph.add_node("extraction", information_extraction_node)
    graph.add_node("reflector", reflection_node)
    
    # Define the graph's entry point
    graph.set_entry_point("router")
    
    # Define the edges
    graph.add_conditional_edges(
        "router",
        should_continue,
        {
            "use_tool": "tool_node",
            "end": END
        }
    )
    
    graph.add_conditional_edges(
        "tool_node",
        after_tool_use,
        {
            "extract": "extraction",
            "reflect": "reflector"
        }
    )
    
    # After extraction, always go back to router
    graph.add_conditional_edges(
        "extraction",
        after_extraction,
        {
            "continue": "router"
        }
    )
    
    graph.add_conditional_edges(
        "reflector",
        should_continue,
        {
            "use_tool": "tool_node",
            "end": END
        }
    )

    # Compile the graph into a runnable object
    agent_graph = graph.compile()
    
    return agent_graph, final_formatting_node