Spaces:
Sleeping
Sleeping
| import os | |
| import asyncio | |
| from datetime import datetime | |
| from typing import TypedDict, Sequence, Optional, Annotated, List | |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessageChunk | |
| from langchain_core.runnables import RunnableConfig | |
| from langgraph.graph import StateGraph, MessagesState, END, START | |
| from langgraph.graph.message import add_messages | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from utils.model_loader import ModelLoader | |
| from prompt_library.prompt import SYSTEM_PROMPT | |
| from tools.weather_info_tool import WeatherInfoTool | |
| from tools.place_search_tool import PlaceSearchTool | |
| from tools.expense_calculator_tool import CalculatorTool | |
| from tools.currency_conversion_tool import CurrencyConverterTool | |
| class AgentState(TypedDict): | |
| messages: Annotated[Sequence[BaseMessage], add_messages] | |
| class GraphBuilder: | |
| def __init__(self, model_provider: str = "openai"): | |
| self.model_loader = ModelLoader(model_provider=model_provider) | |
| self.llm = self.model_loader.load_llm() | |
| self.tools = [] | |
| # Load tools | |
| self.weather_tools = WeatherInfoTool() | |
| self.place_search_tools = PlaceSearchTool() | |
| self.calculator_tools = CalculatorTool() | |
| self.currency_converter_tools = CurrencyConverterTool() | |
| # Register tools | |
| self.tools.extend([ | |
| *self.weather_tools.weather_tool_list, | |
| *self.place_search_tools.place_search_tool_list, | |
| *self.calculator_tools.calculator_tool_list, | |
| *self.currency_converter_tools.currency_converter_tool_list | |
| ]) | |
| # Bind LLM with tools | |
| self.llm_with_tools = self.llm.bind_tools(tools=self.tools) | |
| self.graph = None | |
| self.system_prompt = SYSTEM_PROMPT | |
| async def async_agent_function(self, state: MessagesState): | |
| """ | |
| Async agent function that streams the LLM response. | |
| """ | |
| user_question = state["messages"] | |
| input_question = [self.system_prompt] + user_question | |
| full_response = "" | |
| async for chunk in self.llm_with_tools.astream(input_question): | |
| if isinstance(chunk, AIMessageChunk): | |
| content = chunk.content or "" | |
| full_response += content | |
| print(f"[LLM] {content}", end="", flush=True) | |
| print() # newline after full message | |
| return {"messages": [chunk]} | |
| def should_continue(self, state: AgentState) -> str: | |
| """ | |
| Decide whether to continue to tools or end the graph. | |
| """ | |
| messages = state['messages'] | |
| last_message = messages[-1] | |
| if not last_message.tool_calls: | |
| return "end" | |
| else: | |
| return "continue" | |
| def build_graph(self): | |
| """ | |
| Build the LangGraph state machine with streaming. | |
| """ | |
| graph_builder = StateGraph(MessagesState) | |
| # Add nodes | |
| graph_builder.add_node("agent", self.async_agent_function) | |
| graph_builder.add_node("tools", ToolNode(tools=self.tools)) | |
| # Add transitions | |
| graph_builder.add_edge(START, "agent") | |
| graph_builder.add_conditional_edges( | |
| source="agent", | |
| path=self.should_continue, | |
| path_map={ | |
| "end": END, | |
| "continue": "tools" | |
| } | |
| ) | |
| graph_builder.add_edge("tools", "agent") | |
| self.graph = graph_builder.compile() | |
| return self.graph | |
| async def stream(self, user_input: str): | |
| """ | |
| Stream the response end-to-end from graph. | |
| """ | |
| if self.graph is None: | |
| self.build_graph() | |
| print(f"[User] {user_input}\n") | |
| async for event in self.graph.astream( | |
| {"messages": [HumanMessage(content=user_input)]}, | |
| config=RunnableConfig() | |
| ): | |
| if 'messages' in event: | |
| for msg in event['messages']: | |
| print(f"\n[Tool/Final] {msg.content}\n") | |
| async def __call__(self, user_input: str) -> str: | |
| if self.graph is None: | |
| self.build_graph() | |
| final_output = "" | |
| async for event in self.graph.astream( | |
| {"messages": [HumanMessage(content=user_input)]}, | |
| config=RunnableConfig() | |
| ): | |
| if "messages" in event: | |
| for msg in event["messages"]: | |
| final_output = msg.content # only keep last | |
| print(f"[Tool/LLM] {msg.content}") | |
| return final_output | |