File size: 5,811 Bytes
49c347b
6c44e19
715a633
49c347b
f247400
 
 
715a633
49c347b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c44e19
 
49c347b
 
 
 
 
 
 
 
 
715a633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c44e19
715a633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758564e
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from typing import Iterable, Optional
from langchain_openai import ChatOpenAI
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage

from src.schemas import ComplexityLevel, ExecutionReport, PlannerPlan
from src.prompts.prompts import COMPLEXITY_ASSESSOR_PROMPT
from src.state import AgentState

def log_stage(title: str, subtitle: Optional[str] = None, icon: str = "🚀") -> None:
    """Render a banner for the current execution stage."""

    title_line = f" {title.strip()} "
    border = icon + " " + "═" * max(len(title_line), 20)
    print(f"\n{border}\n{icon} {title_line}\n{icon} " + "═" * max(len(title_line), 20))
    if subtitle:
        print(f"{icon} {subtitle}")


def log_key_values(pairs: Iterable[tuple[str, str]]) -> None:
    """Pretty-print simple key/value diagnostics."""

    for key, value in pairs:
        print(f"   • {key}: {value}")


def format_plan_overview(plan: PlannerPlan) -> str:
    """Create a human-readable summary of plan steps."""

    if not plan or not plan.steps:
        return "(no steps – direct response)"

    lines = []
    for step in plan.steps:
        tool_hint = step.tool if step.tool else "no tool"
        lines.append(f"{step.id}: {step.goal} [{tool_hint}]")
    return "\n".join(lines)


def display_plan(plan: PlannerPlan) -> None:
    """Print plan contents in a compact, readable form."""

    log_stage("PLANNER OUTPUT", icon="🧭")
    print(f"Task type: {plan.task_type}")
    print(f"Summary: {plan.summary}")
    if plan.assumptions:
        print("Assumptions:")
        for item in plan.assumptions:
            print(f"   - {item}")
    print("Steps:")
    for step in plan.steps:
        print(f"   {step.id}{step.goal}")
        if step.tool:
            print(f"      tool: {step.tool}")
        else:
            print("      tool: (none)")
        if step.inputs:
            print(f"      inputs: {step.inputs}")
        print(f"      expected: {step.expected_result}")
        if step.on_fail:
            print(f"      on_fail: {step.on_fail}")
    if plan.answer_guidelines:
        print(f"Answer guidelines: {plan.answer_guidelines}")


def clean_message_history(messages):
    """
    Очищает историю сообщений от неполных циклов tool_calls/responses.
    Удаляет AIMessage с tool_calls, если нет соответствующих ToolMessage.
    """
    cleaned_messages = []
    i = 0
    
    while i < len(messages):
        msg = messages[i]
        
        # Если это AIMessage с tool_calls
        if hasattr(msg, 'tool_calls') and msg.tool_calls:
            # Ищем соответствующие ToolMessage
            tool_call_ids = {tc['id'] for tc in msg.tool_calls}
            found_responses = set()
            
            # Проверяем следующие сообщения на наличие ответов
            j = i + 1
            while j < len(messages) and isinstance(messages[j], ToolMessage):
                if messages[j].tool_call_id in tool_call_ids:
                    found_responses.add(messages[j].tool_call_id)
                j += 1
            
            # Если все tool_calls имеют ответы, добавляем весь блок
            if found_responses == tool_call_ids:
                # Добавляем AIMessage и все соответствующие ToolMessage
                cleaned_messages.append(msg)
                for k in range(i + 1, j):
                    cleaned_messages.append(messages[k])
                i = j
            else:
                # Пропускаем неполный блок
                print(f"Removing incomplete tool call block: {tool_call_ids - found_responses}")
                i = j
        else:
            # Обычное сообщение - добавляем
            cleaned_messages.append(msg)
            i += 1
    
    return cleaned_messages

def format_final_answer(report: ExecutionReport, complexity: dict) -> str:
    """Format the final answer based on complexity and report content."""
    
    if complexity.level == 'simple':
        # For simple queries, just return the answer
        return f"FINAL ANSWER: {report.final_answer}"
    
    # For complex queries, provide more detailed response
    formatted = f"""FINAL ANSWER: {report.final_answer}

SUMMARY:
{report.query_summary}

KEY FINDINGS:
{chr(10).join(f"• {finding}" for finding in report.key_findings)}"""
    
    if report.data_sources:
        formatted += f"""

SOURCES:
{chr(10).join(f"• {source}" for source in report.data_sources[:5])}"""  # Limit to 5 sources
    
    if report.limitations:
        formatted += f"""

LIMITATIONS:
{chr(10).join(f"• {limitation}" for limitation in report.limitations)}"""
    
    return formatted


def complexity_assessor(state: AgentState) -> AgentState:
    """Assess query complexity and determine if planning is needed."""
    print("=== COMPLEXITY ASSESSMENT ===")
    
    complexity_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.25).with_structured_output(ComplexityLevel)
    
    assessment_message = [
        SystemMessage(content=COMPLEXITY_ASSESSOR_PROMPT.strip()),
        HumanMessage(content=f"Query: {state['query']}")
    ]
    
    assessment = complexity_llm.invoke(assessment_message)
    
    print(f"Complexity: {assessment.level}")
    print(f"Needs planning: {assessment.needs_planning}")
    print(f"Reasoning: {assessment.reasoning}")
    
    return {
        "complexity_assessment": assessment,
        "messages": state["messages"] + assessment_message
    }


def trim(s: str, max_len: int = 10_000) -> str:
    if s and len(s) > max_len:
        return s[:max_len] + "... [truncated]"
    return s