Spaces:
Build error
Build error
| from typing import Dict, List, cast | |
| from langchain_core.messages import AIMessage | |
| from langgraph.graph import StateGraph | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from src.config import Configuration | |
| from src.model import GoogleModel | |
| from src.state import InputState, State | |
| from src.tools import TOOLS | |
| class GaiaAgent: | |
| def __init__(self): | |
| self.graph = self._build_graph() | |
| def _build_graph(self) -> StateGraph: | |
| builder = StateGraph(State, input=InputState, config_schema=Configuration) | |
| # Define the two nodes we will cycle between | |
| builder.add_node("call_model", self._call_model) | |
| builder.add_node("tools", ToolNode(TOOLS)) | |
| # Set the entrypoint as `call_model` | |
| # This means that this node is the first one called | |
| builder.add_edge("__start__", "call_model") | |
| builder.add_conditional_edges( | |
| "call_model", | |
| # If the latest message requires a tool, route to tools | |
| # Otherwise, provide a direct response | |
| tools_condition, | |
| ) | |
| builder.add_edge("tools", "call_model") | |
| graph = builder.compile(name="GAIA Agent", debug=False) | |
| return graph | |
| def _call_model(self, state: State) -> Dict[str, List[AIMessage]]: | |
| """Call the LLM powering our "agent". | |
| This function prepares the prompt, initializes the model, and processes the response. | |
| Args: | |
| state (State): The current state of the conversation. | |
| config (RunnableConfig): Configuration for the model run. | |
| Returns: | |
| dict: A dictionary containing the model's response message. | |
| """ | |
| configuration = Configuration.from_context() | |
| # Initialize the model with tool binding. Change the model or add more tools here. | |
| model = GoogleModel( | |
| model=configuration.google_model, | |
| temperature=configuration.temperature, | |
| tools=TOOLS | |
| ) | |
| # Format the system prompt. Customize this to change the agent's behavior. | |
| system_message = configuration.system_prompt | |
| if state.file_name: | |
| file_prompt = ( | |
| f"\n\nThe task id is {state.task_id}.\n" | |
| f"Please use this to download the file." | |
| ) | |
| system_message += file_prompt | |
| # Get the model's response | |
| response = cast( | |
| AIMessage, | |
| model.llm.invoke( | |
| [ | |
| {"role": "system", "content": system_message}, | |
| *state.messages, | |
| ] | |
| ), | |
| ) | |
| # Handle the case when it's the last step and the model still wants to use a tool | |
| if state.is_last_step and response.tool_calls: | |
| return { | |
| "messages": [ | |
| AIMessage( | |
| id=response.id, | |
| content="Sorry, I could not find an answer to your question in the specified number of steps.", | |
| ) | |
| ] | |
| } | |
| # Return the model's response as a list to be added to existing messages | |
| return {"messages": [response]} | |