File size: 3,136 Bytes
be5f49d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# agent_plan_solve/graph/nodes.py

from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langgraph.prebuilt import create_react_agent
from langgraph.prebuilt.chat_agent_executor import AgentState
from  agent_plan_solve.graph.state import Plan, PlanState
from agent_plan_solve.utils.llm_config import get_openai_llm
from agent_plan_solve.tools.calculator import calculator_tool
from langchain.agents import load_tools
from typing import Literal

# Load LLM
llm = get_openai_llm()

# Planner prompt
planner_prompt = ChatPromptTemplate.from_messages([
    ("system", 
     "For the given task, come up with a step-by-step plan.\n"
     "Each step should be self-contained and lead to the final answer.\n"
     "Avoid superfluous steps."),
    ("user", "Prepare a plan to solve the following task:\n{task}\n")
])

planner = planner_prompt | llm.with_structured_output(Plan)

# Toolset
tools = load_tools(
    tool_names=["ddg-search", "arxiv", "wikipedia"],
    llm=llm
) + [calculator_tool]

# Execution agent prompt
step_prompt = ChatPromptTemplate.from_messages([
    ("system", 
     "You're a smart assistant that carefully helps solve complex tasks.\n"
     "Use tools to verify facts, compute results, and avoid assumptions."),
    ("user", 
     "TASK:\n{task}\n\nPLAN:\n{plan}\n\nSTEP TO EXECUTE:\n{step}\n")
])

class StepState(AgentState):
    task: str
    plan: str
    step: str

execution_agent = create_react_agent(
    model=llm,
    tools=tools,
    state_schema=StepState,
    prompt=step_prompt
)

# Final response prompt
final_prompt = PromptTemplate.from_template(
    "You're a helpful assistant that has executed a plan.\n"
    "Given the results, prepare the final response.\n"
    "TASK:\n{task}\n\nPLAN WITH RESULTS:\n{plan}\nFINAL RESPONSE:\n"
)

# Utility functions
def get_current_step(state: PlanState) -> int:
    return len(state.get("past_steps", []))

def get_full_plan(state: PlanState) -> str:
    full_plan = []
    for i, step in enumerate(state["plan"].steps):
        full_step = f"# {i+1}. Planned step: {step}\n"
        if i < get_current_step(state):
            full_step += f"Result: {state['past_steps'][i]}\n"
        full_plan.append(full_step)
    return "\n".join(full_plan)

# Node functions
async def _build_initial_plan(state: PlanState) -> PlanState:
    plan = await planner.ainvoke(state["task"])
    return {"plan": plan}

async def _run_step(state: PlanState) -> PlanState:
    plan = state["plan"]
    current_step = get_current_step(state)
    step = await execution_agent.ainvoke({
        "plan": get_full_plan(state),
        "step": plan.steps[current_step],
        "task": state["task"]
    })
    return {"past_steps": [step["messages"][-1].content]}

async def _get_final_response(state: PlanState) -> PlanState:
    final_response = await (final_prompt | llm).ainvoke({
        "task": state["task"],
        "plan": get_full_plan(state)
    })
    return {"final_response": final_response}

def _should_continue(state: PlanState) -> Literal["run", "response"]:
    return "run" if get_current_step(state) < len(state["plan"].steps) else "response"