File size: 5,811 Bytes
49c347b 6c44e19 715a633 49c347b f247400 715a633 49c347b 6c44e19 49c347b 715a633 6c44e19 715a633 758564e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
from typing import Iterable, Optional
from langchain_openai import ChatOpenAI
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from src.schemas import ComplexityLevel, ExecutionReport, PlannerPlan
from src.prompts.prompts import COMPLEXITY_ASSESSOR_PROMPT
from src.state import AgentState
def log_stage(title: str, subtitle: Optional[str] = None, icon: str = "🚀") -> None:
"""Render a banner for the current execution stage."""
title_line = f" {title.strip()} "
border = icon + " " + "═" * max(len(title_line), 20)
print(f"\n{border}\n{icon} {title_line}\n{icon} " + "═" * max(len(title_line), 20))
if subtitle:
print(f"{icon} {subtitle}")
def log_key_values(pairs: Iterable[tuple[str, str]]) -> None:
"""Pretty-print simple key/value diagnostics."""
for key, value in pairs:
print(f" • {key}: {value}")
def format_plan_overview(plan: PlannerPlan) -> str:
"""Create a human-readable summary of plan steps."""
if not plan or not plan.steps:
return "(no steps – direct response)"
lines = []
for step in plan.steps:
tool_hint = step.tool if step.tool else "no tool"
lines.append(f"{step.id}: {step.goal} [{tool_hint}]")
return "\n".join(lines)
def display_plan(plan: PlannerPlan) -> None:
"""Print plan contents in a compact, readable form."""
log_stage("PLANNER OUTPUT", icon="🧭")
print(f"Task type: {plan.task_type}")
print(f"Summary: {plan.summary}")
if plan.assumptions:
print("Assumptions:")
for item in plan.assumptions:
print(f" - {item}")
print("Steps:")
for step in plan.steps:
print(f" {step.id} → {step.goal}")
if step.tool:
print(f" tool: {step.tool}")
else:
print(" tool: (none)")
if step.inputs:
print(f" inputs: {step.inputs}")
print(f" expected: {step.expected_result}")
if step.on_fail:
print(f" on_fail: {step.on_fail}")
if plan.answer_guidelines:
print(f"Answer guidelines: {plan.answer_guidelines}")
def clean_message_history(messages):
"""
Очищает историю сообщений от неполных циклов tool_calls/responses.
Удаляет AIMessage с tool_calls, если нет соответствующих ToolMessage.
"""
cleaned_messages = []
i = 0
while i < len(messages):
msg = messages[i]
# Если это AIMessage с tool_calls
if hasattr(msg, 'tool_calls') and msg.tool_calls:
# Ищем соответствующие ToolMessage
tool_call_ids = {tc['id'] for tc in msg.tool_calls}
found_responses = set()
# Проверяем следующие сообщения на наличие ответов
j = i + 1
while j < len(messages) and isinstance(messages[j], ToolMessage):
if messages[j].tool_call_id in tool_call_ids:
found_responses.add(messages[j].tool_call_id)
j += 1
# Если все tool_calls имеют ответы, добавляем весь блок
if found_responses == tool_call_ids:
# Добавляем AIMessage и все соответствующие ToolMessage
cleaned_messages.append(msg)
for k in range(i + 1, j):
cleaned_messages.append(messages[k])
i = j
else:
# Пропускаем неполный блок
print(f"Removing incomplete tool call block: {tool_call_ids - found_responses}")
i = j
else:
# Обычное сообщение - добавляем
cleaned_messages.append(msg)
i += 1
return cleaned_messages
def format_final_answer(report: ExecutionReport, complexity: dict) -> str:
"""Format the final answer based on complexity and report content."""
if complexity.level == 'simple':
# For simple queries, just return the answer
return f"FINAL ANSWER: {report.final_answer}"
# For complex queries, provide more detailed response
formatted = f"""FINAL ANSWER: {report.final_answer}
SUMMARY:
{report.query_summary}
KEY FINDINGS:
{chr(10).join(f"• {finding}" for finding in report.key_findings)}"""
if report.data_sources:
formatted += f"""
SOURCES:
{chr(10).join(f"• {source}" for source in report.data_sources[:5])}""" # Limit to 5 sources
if report.limitations:
formatted += f"""
LIMITATIONS:
{chr(10).join(f"• {limitation}" for limitation in report.limitations)}"""
return formatted
def complexity_assessor(state: AgentState) -> AgentState:
"""Assess query complexity and determine if planning is needed."""
print("=== COMPLEXITY ASSESSMENT ===")
complexity_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.25).with_structured_output(ComplexityLevel)
assessment_message = [
SystemMessage(content=COMPLEXITY_ASSESSOR_PROMPT.strip()),
HumanMessage(content=f"Query: {state['query']}")
]
assessment = complexity_llm.invoke(assessment_message)
print(f"Complexity: {assessment.level}")
print(f"Needs planning: {assessment.needs_planning}")
print(f"Reasoning: {assessment.reasoning}")
return {
"complexity_assessment": assessment,
"messages": state["messages"] + assessment_message
}
def trim(s: str, max_len: int = 10_000) -> str:
if s and len(s) > max_len:
return s[:max_len] + "... [truncated]"
return s
|