| from typing import Annotated, Sequence, TypedDict |
| from langchain_community.llms import HuggingFaceHub |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage |
| from langgraph.graph import StateGraph, END |
| from langchain_core.agents import AgentAction, AgentFinish |
| from langchain.agents import create_react_agent |
| from langchain import hub |
| from ai_tools import get_tools |
|
|
| class AgentState(TypedDict): |
| messages: Annotated[Sequence[BaseMessage], operator.add] |
| intermediate_steps: Annotated[list, operator.add] |
|
|
| def build_graph(): |
| |
| llm = HuggingFaceHub( |
| repo_id="mistralai/Mistral-7B-Instruct-v0.2", |
| model_kwargs={"temperature": 0.1, "max_new_tokens": 500} |
| ) |
| |
| |
| prompt = hub.pull("hwchase17/react") |
| tools = get_tools() |
| agent = create_react_agent(llm, tools, prompt) |
| |
| |
| def agent_node(state: AgentState): |
| input = state["messages"][-1].content |
| result = agent.invoke({ |
| "input": input, |
| "intermediate_steps": state["intermediate_steps"] |
| }) |
| return {"intermediate_steps": [result]} |
| |
| def tool_node(state: AgentState): |
| last_step = state["intermediate_steps"][-1] |
| action = last_step[0] if isinstance(last_step, list) else last_step |
| |
| if not isinstance(action, AgentAction): |
| return {"messages": [AIMessage(content="Invalid action format")]} |
| |
| |
| tool = next((t for t in tools if t.name == action.tool), None) |
| if not tool: |
| return {"messages": [AIMessage(content=f"Tool {action.tool} not found")]} |
| |
| observation = tool.run(action.tool_input) |
| return {"messages": [AIMessage(content=observation)]} |
| |
| |
| workflow = StateGraph(AgentState) |
| workflow.add_node("agent", agent_node) |
| workflow.add_node("tool", tool_node) |
| |
| |
| def route_action(state: AgentState): |
| last_step = state["intermediate_steps"][-1] |
| action = last_step[0] if isinstance(last_step, list) else last_step |
| |
| if isinstance(action, AgentFinish): |
| return END |
| return "tool" |
| |
| workflow.set_entry_point("agent") |
| workflow.add_conditional_edges( |
| "agent", |
| route_action, |
| {"tool": "tool", END: END} |
| ) |
| workflow.add_edge("tool", "agent") |
| |
| return workflow.compile() |
|
|
| class BasicAgent: |
| """LangGraph智能体封装""" |
| def __init__(self): |
| print("BasicAgent initialized.") |
| self.graph = build_graph() |
| |
| def __call__(self, question: str) -> str: |
| print(f"Agent received question: {question[:50]}...") |
| messages = [HumanMessage(content=question)] |
| result = self.graph.invoke({ |
| "messages": messages, |
| "intermediate_steps": [] |
| }) |
| |
| |
| final_message = result["messages"][-1].content |
| return final_message.strip() |