DrekFretson commited on
Commit
b13c775
·
verified ·
1 Parent(s): 4a1006b

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +350 -210
agent.py CHANGED
@@ -1,213 +1,353 @@
1
- """LangGraph Agent"""
2
  import os
3
- from dotenv import load_dotenv
4
- from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_groq import ChatGroq
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
- from langchain_community.tools.tavily_search import TavilySearchResults
11
- from langchain_community.document_loaders import WikipediaLoader
12
- from langchain_community.document_loaders import ArxivLoader
13
- from langchain_community.vectorstores import SupabaseVectorStore
14
- from langchain_core.messages import SystemMessage, HumanMessage
15
- from langchain_core.tools import tool
16
- from langchain.tools.retriever import create_retriever_tool
17
- from supabase.client import Client, create_client
18
-
19
- load_dotenv()
20
-
21
- @tool
22
- def multiply(a: int, b: int) -> int:
23
- """Multiply two numbers.
24
- Args:
25
- a: first int
26
- b: second int
27
- """
28
- return a * b
29
-
30
- @tool
31
- def add(a: int, b: int) -> int:
32
- """Add two numbers.
33
-
34
- Args:
35
- a: first int
36
- b: second int
37
- """
38
- return a + b
39
-
40
- @tool
41
- def subtract(a: int, b: int) -> int:
42
- """Subtract two numbers.
43
-
44
- Args:
45
- a: first int
46
- b: second int
47
- """
48
- return a - b
49
-
50
- @tool
51
- def divide(a: int, b: int) -> int:
52
- """Divide two numbers.
53
-
54
- Args:
55
- a: first int
56
- b: second int
57
- """
58
- if b == 0:
59
- raise ValueError("Cannot divide by zero.")
60
- return a / b
61
-
62
- @tool
63
- def modulus(a: int, b: int) -> int:
64
- """Get the modulus of two numbers.
65
-
66
- Args:
67
- a: first int
68
- b: second int
69
- """
70
- return a % b
71
-
72
- @tool
73
- def wiki_search(query: str) -> str:
74
- """Search Wikipedia for a query and return maximum 2 results.
75
-
76
- Args:
77
- query: The search query."""
78
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
- formatted_search_docs = "\n\n---\n\n".join(
80
- [
81
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
- for doc in search_docs
83
- ])
84
- return {"wiki_results": formatted_search_docs}
85
-
86
- @tool
87
- def web_search(query: str) -> str:
88
- """Search Tavily for a query and return maximum 3 results.
89
-
90
- Args:
91
- query: The search query."""
92
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
93
- formatted_search_docs = "\n\n---\n\n".join(
94
- [
95
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
- for doc in search_docs
97
- ])
98
- return {"web_results": formatted_search_docs}
99
-
100
- @tool
101
- def arvix_search(query: str) -> str:
102
- """Search Arxiv for a query and return maximum 3 result.
103
-
104
- Args:
105
- query: The search query."""
106
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
- formatted_search_docs = "\n\n---\n\n".join(
108
- [
109
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
- for doc in search_docs
111
- ])
112
- return {"arvix_results": formatted_search_docs}
113
-
114
-
115
-
116
- # load the system prompt from the file
117
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
- system_prompt = f.read()
119
-
120
- # System message
121
- sys_msg = SystemMessage(content=system_prompt)
122
-
123
- # build a retriever
124
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
- supabase: Client = create_client(
126
- os.environ.get("SUPABASE_URL"),
127
- os.environ.get("SUPABASE_SERVICE_KEY"))
128
- vector_store = SupabaseVectorStore(
129
- client=supabase,
130
- embedding= embeddings,
131
- table_name="documents",
132
- query_name="match_documents_langchain",
133
- )
134
- create_retriever_tool = create_retriever_tool(
135
- retriever=vector_store.as_retriever(),
136
- name="Question Search",
137
- description="A tool to retrieve similar questions from a vector store.",
138
- )
139
-
140
-
141
-
142
- tools = [
143
- multiply,
144
- add,
145
- subtract,
146
- divide,
147
- modulus,
148
- wiki_search,
149
- web_search,
150
- arvix_search,
151
- ]
152
-
153
- # Build graph function
154
- def build_graph(provider: str = "groq"):
155
- """Build the graph"""
156
- # Load environment variables from .env file
157
- if provider == "google":
158
- # Google Gemini
159
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
160
- elif provider == "groq":
161
- # Groq https://console.groq.com/docs/models
162
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
163
- elif provider == "huggingface":
164
- # TODO: Add huggingface endpoint
165
- llm = ChatHuggingFace(
166
- llm=HuggingFaceEndpoint(
167
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
168
- temperature=0,
169
- ),
170
  )
171
- else:
172
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
173
- # Bind tools to LLM
174
- llm_with_tools = llm.bind_tools(tools)
175
-
176
- # Node
177
- def assistant(state: MessagesState):
178
- """Assistant node"""
179
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
-
181
- def retriever(state: MessagesState):
182
- """Retriever node"""
183
- similar_question = vector_store.similarity_search(state["messages"][0].content)
184
- example_msg = HumanMessage(
185
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
  )
187
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
-
189
- builder = StateGraph(MessagesState)
190
- builder.add_node("retriever", retriever)
191
- builder.add_node("assistant", assistant)
192
- builder.add_node("tools", ToolNode(tools))
193
- builder.add_edge(START, "retriever")
194
- builder.add_edge("retriever", "assistant")
195
- builder.add_conditional_edges(
196
- "assistant",
197
- tools_condition,
198
- )
199
- builder.add_edge("tools", "assistant")
200
-
201
- # Compile graph
202
- return builder.compile()
203
-
204
- # test
205
- if __name__ == "__main__":
206
- question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
207
- # Build the graph
208
- graph = build_graph(provider="groq")
209
- # Run the graph
210
- messages = [HumanMessage(content=question)]
211
- messages = graph.invoke({"messages": messages})
212
- for m in messages["messages"]:
213
- m.pretty_print()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
  import os
3
+ import time
4
+ from langchain.tools import Tool, tool
5
+ from typing import Tuple, List
6
+ from typing_extensions import TypedDict, Annotated, Optional
7
+ from langgraph.graph.message import add_messages
8
+ from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
9
+ from langgraph.graph import START, StateGraph, END
10
+ from langgraph.prebuilt import ToolNode, tools_condition
11
+ from langchain_litellm import ChatLiteLLM
12
+ from IPython.display import Image, display
13
+ import asyncio
14
+ from tools import (search_tool,
15
+ download_tool,
16
+ get_web_page,
17
+ add,
18
+ subtract,
19
+ multiply,
20
+ divide,
21
+ power,
22
+ square_root,
23
+ get_information_from_wikipedia,
24
+ get_information_from_arxiv,
25
+ get_information_from_youtube,
26
+ python_tool,
27
+ get_information_from_json,
28
+ get_information_from_audio,
29
+ get_information_from_xml,
30
+ get_information_from_docx,
31
+ get_information_from_txt,
32
+ get_information_from_pdf,
33
+ get_information_from_csv,
34
+ get_information_from_excel,
35
+ get_information_from_pdb,
36
+ get_information_from_image,
37
+ get_information_from_pptx,
38
+ get_all_files_from_zip,
39
+ get_information_from_python)
40
+
41
+ DELAY = 5
42
+ TIME_SLEEP = 60/15 + DELAY
43
+
44
+ GEMINI_API_KEY_1 = os.getenv("GOOGLE_API_KEY_1")
45
+ GEMINI_API_KEY_2 = os.getenv("GOOGLE_API_KEY_2")
46
+ GEMINI_API_KEY_3 = os.getenv("GOOGLE_API_KEY_3")
47
+
48
+ chat_model_1 = ChatLiteLLM(model="gemini/gemini-2.0-flash",
49
+ temperature=0,
50
+ api_key=GEMINI_API_KEY_1,
51
+ max_retries=10,
52
+ verbose=True)
53
+
54
+ chat_model_2 = ChatLiteLLM(model="gemini/gemini-2.0-flash",
55
+ temperature=0,
56
+ api_key=GEMINI_API_KEY_2,
57
+ max_retries=10,
58
+ verbose=True)
59
+
60
+ chat_model_3 = ChatLiteLLM(model="gemini/gemini-2.0-flash",
61
+ temperature=0,
62
+ api_key=GEMINI_API_KEY_3,
63
+ max_retries=10,
64
+ verbose=True)
65
+
66
+ class AgentState(TypedDict):
67
+ messages: Annotated[list[AnyMessage], add_messages]
68
+ question: Optional[str]
69
+ file_path: Optional[str]
70
+ task_id: Optional[str]
71
+ new_messages: Optional[int]
72
+ final_answer: Optional[str]
73
+ attempt: Optional[int]
74
+ chat_model: Optional[int]
75
+
76
+ class MyAgent:
77
+ def __init__(self, web_tools=None):
78
+ print("MyAgent initialized.")
79
+
80
+ self.chat_1 = chat_model_1
81
+ self.chat_2 = chat_model_2
82
+ self.chat_3 = chat_model_3
83
+
84
+ self.tools = [search_tool,
85
+ download_tool,
86
+ get_web_page,
87
+ add,
88
+ subtract,
89
+ multiply,
90
+ divide,
91
+ power,
92
+ square_root,
93
+ get_information_from_wikipedia,
94
+ get_information_from_arxiv,
95
+ get_information_from_youtube,
96
+ python_tool,
97
+ get_information_from_json,
98
+ get_information_from_audio,
99
+ get_information_from_xml,
100
+ get_information_from_docx,
101
+ get_information_from_txt,
102
+ get_information_from_pdf,
103
+ get_information_from_csv,
104
+ get_information_from_excel,
105
+ get_information_from_pdb,
106
+ get_information_from_image,
107
+ get_information_from_pptx,
108
+ get_all_files_from_zip] + web_tools
109
+
110
+ self.chat_with_tools_1 = self.chat_1.bind_tools(self.tools, verbose=True)
111
+ self.chat_with_tools_2 = self.chat_2.bind_tools(self.tools, verbose=True)
112
+ self.chat_with_tools_3 = self.chat_3.bind_tools(self.tools, verbose=True)
113
+ self.chats = [self.chat_with_tools_1, self.chat_with_tools_2, self.chat_with_tools_3]
114
+
115
+ self.builder = StateGraph(AgentState)
116
+ self.builder.add_node("assistant", self.assistant)
117
+ self.builder.add_node("tools", ToolNode(self.tools))
118
+ self.builder.add_node("extract_data_from_file", self.extract_data_from_file)
119
+ self.builder.add_node("postprocess", self.postprocess)
120
+
121
+ self.builder.add_edge(START, "extract_data_from_file")
122
+ self.builder.add_edge("extract_data_from_file", "assistant")
123
+ self.builder.add_conditional_edges(
124
+ "assistant",
125
+ self.assistant_router,
126
+ {
127
+ "tools": "tools",
128
+ "postprocess": "postprocess"
129
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  )
131
+ self.builder.add_edge("tools", "assistant")
132
+ self.builder.add_conditional_edges(
133
+ "postprocess",
134
+ self.answer_evaluation,
135
+ {
136
+ "RETRY": "assistant",
137
+ "END": END
138
+ }
 
 
 
 
 
 
 
139
  )
140
+
141
+ self.agent = self.builder.compile()
142
+
143
+ async def __call__(self, question: str, file_path: str, task_id: str) -> str:
144
+ print("\033[1m\033[93m"+"="*150+"\033[0m")
145
+ print(f"QUESTION: {question}")
146
+ print(f"File: {file_path}")
147
+ prompt = f"""You are a general AI assistant. You will receive a user question and extracted data from associated files.
148
+ Follow this process:
149
+ 1. Identify the required output type (e.g., number, string, list) and key concepts in the question.
150
+ 2. Before using any tools, check if the answer can be deduced or recalled directly. If yes, answer immediately. Never guess.
151
+ 3. If tools are needed:
152
+ - Create a plan with:
153
+ - The reasoning approach and tool sequence.
154
+ - A rephrased version of the question optimized for search engines (DuckDuckGo or Google).
155
+ - Search queries must:
156
+ - Be keyword-focused (avoid full sentences).
157
+ - Use advanced operators if helpful: `site:` for domains, `inurl:` for internal paths, `filetype:` for formats.
158
+ - Avoid punctuation, commas, quotes, or special characters.
159
+ - Cover multiple query angles if needed.
160
+ 4. Do not run any tool until the plan is complete.
161
+ 5. If a tool fails or returns no useful result:
162
+ - Reformulate the query with synonyms or tighter context.
163
+ - Retry or use a fallback tool.
164
+ 6. Analyze tool results carefully. If multiple source links appear, use `navigate_browser` to explore and extract relevant information from each.
165
+ Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
166
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
167
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
168
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
169
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""
170
+
171
+ user_message = f"""Question: {question}
172
+ Filepath: {file_path}"""
173
+
174
+ messages = [SystemMessage(content=prompt, name="SYSTEM"),
175
+ HumanMessage(content=user_message, name="USER")]
176
+
177
+ response = await self.agent.ainvoke({"messages": messages,
178
+ "question": question,
179
+ "file_path": file_path,
180
+ "task_id": task_id,
181
+ "new_messages": 1,
182
+ "chat_model": 0,
183
+ "final_answer": "",
184
+ "attempt": 0},
185
+ {"recursion_limit": 100})
186
+
187
+ print("\033[1m\033[93m"+"="*150+"\033[0m")
188
+ return response['final_answer']
189
+
190
+ async def call_chat(self, chat, state: AgentState, max_retries=5):
191
+ from google.api_core.exceptions import GoogleAPICallError
192
+ for i in range(max_retries):
193
+ try:
194
+ return await chat.ainvoke(state["messages"])
195
+ except GoogleAPICallError as e:
196
+ if "503" in str(e) or "UNAVAILABLE" in str(e):
197
+ wait = 2 ** i
198
+ print(f"[Gemini] Overloaded (attempt {i+1}), retrying in {wait:.1f}s...")
199
+ await asyncio.sleep(wait)
200
+ else:
201
+ raise e
202
+ raise RuntimeError("Gemini failed after multiple retries")
203
+
204
+ async def assistant(self, state: AgentState):
205
+ new_messages = state["new_messages"]
206
+
207
+ for i in reversed(range(1, new_messages+1)):
208
+ print("\033[1m\033[92m"+"+"*150+"\033[0m")
209
+ name = state["messages"][-i].name
210
+ content = state["messages"][-i].content
211
+ print(f'\033[1m\033[96m{name}\033[0m: {content if len(content) < 5000 else content[:5000]}')
212
+
213
+ chat = self.chats[state["chat_model"]]
214
+ result = await self.call_chat(chat=chat, state=state)
215
+
216
+ state["chat_model"] += 1
217
+ state["chat_model"] %= len(self.chats)
218
+
219
+ result.name="ASSISTANT"
220
+ await asyncio.sleep(TIME_SLEEP)
221
+
222
+ print("\033[1m\033[92m"+"+"*150+"\033[0m")
223
+ content = result.content[:-2] if result.content[-2:] == '\n\n' else result.content
224
+ print(f'\033[1m\033[96m{result.name}\033[0m: {content}')
225
+ state["new_messages"] = 1
226
+ state["messages"].append(result)
227
+
228
+ return state
229
+
230
+ def extract_data_from_file(self, state: AgentState) -> str:
231
+ path = state["file_path"]
232
+ new_messages = state["new_messages"]
233
+ prompt = ""
234
+ messages = []
235
+
236
+ if path and "." in path:
237
+ ext = path.strip().split(".")[-1].lower()
238
+ print(f"Extension detected: {ext}")
239
+
240
+ if ext == "zip":
241
+ files, prompt = get_all_files_from_zip(path)
242
+ name = "get_all_file_from_zip"
243
+ messages.append(AIMessage(content=prompt, name=name))
244
+ else:
245
+ files = [path]
246
+
247
+ for file_path in files:
248
+ ext = file_path.strip().split(".")[-1].lower()
249
+ print(f"Extension detected: {ext}")
250
+
251
+ prompt = f"Information extracted from {file_path}.\n\n"
252
+ match ext:
253
+ case "csv":
254
+ content = get_information_from_csv.invoke(file_path)
255
+ name = "get_information_from_csv"
256
+ case "txt":
257
+ content = get_information_from_txt.invoke(file_path)
258
+ name = "get_information_from_txt"
259
+ case "pdf":
260
+ content = get_information_from_pdf.invoke(file_path)
261
+ name = "get_information_from_pdf"
262
+ case "json":
263
+ content = get_information_from_json.invoke(file_path)
264
+ name = "get_information_from_json"
265
+ case "jsonld":
266
+ content = get_information_from_json.invoke(file_path)
267
+ name = "get_information_from_json"
268
+ case "xml":
269
+ content = get_information_from_xml.invoke(file_path)
270
+ name = "get_information_from_xml"
271
+ case "pdb":
272
+ content = get_information_from_pdb.invoke(file_path)
273
+ name = "get_information_from_pdb"
274
+ case "mp3":
275
+ content = get_information_from_audio.invoke(file_path)
276
+ name = "get_information_from_audio"
277
+ case "m4a":
278
+ content = get_information_from_audio.invoke(file_path)
279
+ name = "get_information_from_audio"
280
+ case "docx":
281
+ content = get_information_from_docx.invoke(file_path)
282
+ name = "get_information_from_docx"
283
+ case "xlsx":
284
+ content = get_information_from_excel.invoke(file_path)
285
+ name = "get_information_from_excel"
286
+ case "xls":
287
+ content = get_information_from_excel.invoke(file_path)
288
+ name = "get_information_from_excel"
289
+ case "png":
290
+ content = get_information_from_image.invoke({"file_path": file_path, "question": state["question"]})
291
+ name = "get_information_from_image"
292
+ case "jpg":
293
+ content = get_information_from_image.invoke({"file_path": file_path, "question": state["question"]})
294
+ name = "get_information_from_image"
295
+ case "py":
296
+ content = get_information_from_python.invoke(file_path)
297
+ name = "get_information_from_python"
298
+ case "pptx":
299
+ content = get_information_from_pptx.invoke(file_path)
300
+ name = "get_information_from_pptx"
301
+ case _:
302
+ content = "Try to use some available tool to answer the user question."
303
+ name = "handle_no_file"
304
+ prompt += f"{content}"
305
+ messages.append(AIMessage(content=prompt, name=name))
306
+ new_messages += 1
307
+ else:
308
+ prompt = "The question doesn't have an attached file."
309
+ name = "handle_no_file"
310
+
311
+ return {"messages": messages, "new_messages": new_messages}
312
+
313
+ def assistant_router(self, state: AgentState) -> str:
314
+ tool_decision = tools_condition(state)
315
+ if tool_decision == "tools":
316
+ return "tools"
317
+ else:
318
+ return "postprocess"
319
+
320
+ def postprocess(self, state: AgentState) -> AgentState:
321
+ last_msg = state["messages"][-1]
322
+ content = last_msg.content
323
+ index = content.find("FINAL ANSWER: ")
324
+ if index != -1:
325
+ content = content[index+len("FINAL ANSWER: "):].replace("\n", "")
326
+ state["final_answer"] = content
327
+ return state
328
+ else:
329
+ state["attempt"] += 1
330
+ prompt = f"""You were unable to find a satisfactory answer to the user's question.
331
+ Now, try again, but use a different approach. You may:
332
+ - Focus on a different angle of the question,
333
+ - Reformulate it using alternative terminology,
334
+ - Search for related concepts,
335
+ - Or use a different reasoning path.
336
+ Be creative and precise. Your goal is to uncover useful information that may have been missed previously.
337
+ Original question:
338
+ {state["question"]}"""
339
+
340
+ state["messages"].append(AIMessage(content=prompt, name="ASSISTANT"))
341
+ return state
342
+ def answer_evaluation(self, state: AgentState):
343
+ if state["final_answer"] != "":
344
+ return "END"
345
+ elif state["attempt"] >= 3:
346
+ state["final_answer"] = "Unable to find the answer."
347
+ return "END"
348
+ else:
349
+ return "RETRY"
350
+
351
+ def draw_graph(self):
352
+ display(Image(self.agent.get_graph().draw_mermaid_png()))
353
+ return