Katya Beresneva commited on
Commit
6800d28
·
1 Parent(s): b97b365

fix AgentState

Browse files
Files changed (2) hide show
  1. agent.py +100 -15
  2. app.py +2 -0
agent.py CHANGED
@@ -1,8 +1,33 @@
 
 
 
 
 
 
 
 
 
1
 
2
- class KateState(AgentState):
3
- task_id: str
4
- question: str
5
- file_name: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  KATE_AGENT_PROMPT = """
8
  You are Kate's Advanced AI Assistant designed to solve complex tasks efficiently.
@@ -25,19 +50,79 @@ QUESTION: Capital of France? FINAL ANSWER: Paris
25
  QUESTION: 2+2? FINAL ANSWER: 4
26
  """
27
 
 
 
 
 
 
 
28
  class KateMultiModalAgent:
29
  def __init__(self, model_name: str | None = None):
30
- self.model_name = model_name or AGENT_MODEL_NAME
31
- self.llm = self._get_llm()
32
- self.tools = self._get_tools()
33
- self.agent = self._create_agent()
34
-
35
- def _create_agent(self):
36
- return create_react_agent(
37
- self.llm,
38
- tools=self.tools,
39
  state_schema=KateState,
40
  state_modifier=KATE_AGENT_PROMPT,
41
- checkpointer=MemorySaver(),
42
- max_iterations=10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_core.messages import HumanMessage
3
+ from langchain_core.runnables.config import RunnableConfig
4
+ from langgraph.checkpoint.memory import MemorySaver
5
+ from langchain.globals import set_debug
6
+ from langchain.globals import set_verbose
7
+ from langgraph.prebuilt import create_react_agent
8
+ from langgraph.prebuilt import ToolNode
9
+ from langgraph.prebuilt import AgentState
10
 
11
+ from smolagents import DuckDuckGoSearchTool
12
+ from smolagents import PythonInterpreterTool
13
+ from tools import analyze_audio
14
+ from tools import analyze_excel
15
+ from tools import analyze_image
16
+ from tools import analyze_video
17
+ from tools import download_file_for_task
18
+ from tools import read_file_contents
19
+ from tools import search_arxiv
20
+ from tools import search_tavily
21
+ from tools import search_wikipedia
22
+ from tools import SmolagentToolWrapper
23
+ from tools import tavily_extract_tool
24
+ from utils import get_llm
25
+
26
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
27
+ if not GOOGLE_API_KEY:
28
+ raise ValueError("GOOGLE_API_KEY environment variable is not set.")
29
+
30
+ AGENT_MODEL_NAME = os.getenv("AGENT_MODEL_NAME", "gemini-2.0-flash")
31
 
32
  KATE_AGENT_PROMPT = """
33
  You are Kate's Advanced AI Assistant designed to solve complex tasks efficiently.
 
50
  QUESTION: 2+2? FINAL ANSWER: 4
51
  """
52
 
53
+ class KateState(AgentState):
54
+ task_id: str
55
+ question: str
56
+ file_name: str
57
+
58
+
59
  class KateMultiModalAgent:
60
  def __init__(self, model_name: str | None = None):
61
+ if model_name is None:
62
+ model_name = AGENT_MODEL_NAME
63
+ llm = self._get_llm(model_name)
64
+ tools = self._get_tools()
65
+ self.agent = create_react_agent(
66
+ llm,
67
+ tools=tools,
 
 
68
  state_schema=KateState,
69
  state_modifier=KATE_AGENT_PROMPT,
70
+ checkpointer=MemorySaver()
71
+ )
72
+
73
+ def _get_llm(self, model_name: str):
74
+ return get_llm(
75
+ llm_provider_api_key=GOOGLE_API_KEY,
76
+ model_name=model_name,
77
+ )
78
+
79
+ def _get_tools(self):
80
+ tools = [
81
+ SmolagentToolWrapper(DuckDuckGoSearchTool()),
82
+ SmolagentToolWrapper(PythonInterpreterTool()),
83
+ download_file_for_task,
84
+ read_file_contents,
85
+ analyze_audio,
86
+ analyze_image,
87
+ analyze_excel,
88
+ analyze_video,
89
+ search_arxiv,
90
+ search_tavily,
91
+ search_wikipedia,
92
+ tavily_extract_tool,
93
+ ]
94
+ return ToolNode(tools)
95
+
96
+ async def __call__(
97
+ self, task_id: str, question: str, file_name: str | None = None
98
+ ) -> str:
99
+ config = RunnableConfig(
100
+ recursion_limit=64,
101
+ configurable={"thread_id": task_id}
102
  )
103
+
104
+ if not file_name:
105
+ file_name = "None - no file present"
106
+
107
+ message = HumanMessage(
108
+ content=[
109
+ {
110
+ "type": "text",
111
+ "text": f"Task Id: {task_id}, Question: {question}, Filename: {file_name}. If a filename is present (and is not 'None'), download the file for the task that's referenced in the question. If there isn't a filename present, please use tools where applicable."
112
+ }
113
+ ]
114
+ )
115
+
116
+ answer = await self.agent.ainvoke(
117
+ {
118
+ "messages": [message],
119
+ "question": question,
120
+ "task_id": task_id,
121
+ "file_name": file_name
122
+ }, config)
123
+
124
+ final_answer = answer['messages'][-1].content
125
+ if "FINAL ANSWER: " in final_answer:
126
+ return final_answer.split("FINAL ANSWER: ", 1)[1].strip()
127
+ else:
128
+ return final_answer
app.py CHANGED
@@ -5,6 +5,8 @@ import requests
5
  import pandas as pd
6
  from agent import KateMultiModalAgent
7
 
 
 
8
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
9
  AGENT_NAME = "Kate's Advanced Agent"
10
 
 
5
  import pandas as pd
6
  from agent import KateMultiModalAgent
7
 
8
+ agent = KateMultiModalAgent()
9
+
10
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
11
  AGENT_NAME = "Kate's Advanced Agent"
12