Ihor Kozar commited on
Commit
8989d02
·
1 Parent(s): 9096e64
Files changed (4) hide show
  1. agent.py +42 -69
  2. agent_tools.py +19 -1
  3. requirements.txt +5 -1
  4. test.py +1 -1
agent.py CHANGED
@@ -3,11 +3,12 @@ from typing import TypedDict, Annotated, Optional
3
  from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
4
  from langchain_openai import ChatOpenAI
5
  from langgraph.graph import StateGraph, START
6
- from langgraph.graph.message import add_messages
7
  from langgraph.prebuilt import ToolNode, tools_condition
8
  from langchain.vectorstores import Chroma
9
  from langchain.embeddings.openai import OpenAIEmbeddings
10
  from langchain.chains import RetrievalQA
 
11
  from agent_tools import *
12
 
13
  load_dotenv()
@@ -35,32 +36,8 @@ sys_msg = SystemMessage(
35
  """
36
  )
37
 
38
- tools = [
39
- multiply,
40
- add,
41
- subtract,
42
- divide,
43
- modulus,
44
- duckduck_websearch,
45
- arvix_search,
46
- wiki_search,
47
- visit_webpage,
48
- youtube_search,
49
- text_splitter,
50
- read_file,
51
- excel_read,
52
- csv_read,
53
- image_caption,
54
- ]
55
-
56
  print("agent.py loaded")
57
 
58
-
59
- class AgentState(TypedDict):
60
- input_file: Optional[str]
61
- messages: Annotated[list[AnyMessage], add_messages]
62
-
63
-
64
  class CUSTOM_AGENT:
65
  """
66
  A simple deterministic agent that leverages our tools directly and avoids
@@ -70,17 +47,9 @@ class CUSTOM_AGENT:
70
  def __init__(self):
71
  self.llm = ChatOpenAI(name="gpt-4o",
72
  api_key=os.getenv("OPENAI_API_KEY"))
73
-
74
- self.tools = tools
75
  self.llm_with_tools = self.llm.bind_tools(self.tools)
76
- initial_state = {
77
- "input_file": None,
78
- "messages": [],
79
- }
80
- self.app = self._graph_compile()
81
- self.initial_state = initial_state
82
  self.sys_msg = sys_msg
83
- # --- Chroma vectorstore + retriever ---
84
  embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
85
  persist_directory = "chroma_db"
86
  self.vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
@@ -92,42 +61,46 @@ class CUSTOM_AGENT:
92
  )
93
 
94
  def _graph_compile(self):
95
- builder = StateGraph(AgentState)
96
- # Define nodes: these do the work
97
  builder.add_node("assistant", self._assistant)
98
  builder.add_node("tools", ToolNode(self.tools))
99
- # Define edges: these determine how the control flow moves
100
- builder.add_edge(START, "assistant")
 
 
101
  builder.add_conditional_edges(
102
  "assistant",
103
  tools_condition,
104
  )
105
  builder.add_edge("tools", "assistant")
106
- react_graph = builder.compile()
107
- return react_graph
108
-
109
- def _assistant(self, state: AgentState):
110
- last_human = next((m for m in reversed(state["messages"]) if isinstance(m, HumanMessage)), None)
111
- messages = [self.sys_msg] + state["messages"]
112
-
113
- if last_human:
114
- question_text = last_human.content
115
- retrieved_output = self.qa_chain.invoke({"query": question_text})
116
- retrieved_docs = retrieved_output["result"]
117
- docs_text = "\n\n---\n\n".join(
118
- [doc.page_content for doc in retrieved_output.get("source_documents", [])]
119
- )
120
- context_message = SystemMessage(content=f"Context from vectorstore:\n{docs_text}")
121
- messages.append(context_message)
122
-
123
- response = self.llm_with_tools.invoke(messages)
124
-
125
- return {
126
- "messages": state["messages"] + [response],
127
- "input_file": state["input_file"]
128
- }
129
-
130
- def extract_after_final_answer(self, text):
 
 
131
  keyword = "FINAL ANSWER: "
132
  index = text.find(keyword)
133
  if index != -1:
@@ -144,17 +117,17 @@ class CUSTOM_AGENT:
144
  else:
145
  question_text = f'{question} with TASK-ID: {task_id}'
146
 
147
- state = self.initial_state.copy()
148
- state["messages"] = [HumanMessage(content=question_text)]
149
 
150
  max_retries = 3
151
  base_sleep = 1
152
  for attempt in range(max_retries):
153
  try:
154
- response = self.app.invoke(state)
155
- final_ans = self.extract_after_final_answer(response['messages'][-1].content)
156
- time.sleep(10) # avoid rate limit
157
- return final_ans
 
158
  except Exception as e:
159
  sleep_time = base_sleep * (attempt + 1)
160
  if attempt < max_retries - 1:
 
3
  from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
4
  from langchain_openai import ChatOpenAI
5
  from langgraph.graph import StateGraph, START
6
+ from langgraph.graph.message import add_messages, MessagesState
7
  from langgraph.prebuilt import ToolNode, tools_condition
8
  from langchain.vectorstores import Chroma
9
  from langchain.embeddings.openai import OpenAIEmbeddings
10
  from langchain.chains import RetrievalQA
11
+
12
  from agent_tools import *
13
 
14
  load_dotenv()
 
36
  """
37
  )
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  print("agent.py loaded")
40
 
 
 
 
 
 
 
41
  class CUSTOM_AGENT:
42
  """
43
  A simple deterministic agent that leverages our tools directly and avoids
 
47
  def __init__(self):
48
  self.llm = ChatOpenAI(name="gpt-4o",
49
  api_key=os.getenv("OPENAI_API_KEY"))
50
+ self.tools = TOOLS
 
51
  self.llm_with_tools = self.llm.bind_tools(self.tools)
 
 
 
 
 
 
52
  self.sys_msg = sys_msg
 
53
  embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
54
  persist_directory = "chroma_db"
55
  self.vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
 
61
  )
62
 
63
  def _graph_compile(self):
64
+ builder = StateGraph(MessagesState)
65
+ builder.add_node("retriever", self._retriever_node)
66
  builder.add_node("assistant", self._assistant)
67
  builder.add_node("tools", ToolNode(self.tools))
68
+
69
+ builder.add_edge(START, "retriever")
70
+ builder.add_edge("retriever", "assistant")
71
+
72
  builder.add_conditional_edges(
73
  "assistant",
74
  tools_condition,
75
  )
76
  builder.add_edge("tools", "assistant")
77
+ return builder.compile()
78
+
79
+ def _retriever_node(self, state: MessagesState):
80
+ """Retriever node"""
81
+ question = state["messages"][-1].content
82
+ docs = self.retriever.get_relevant_documents(question)
83
+
84
+ if docs:
85
+ content = "\n".join([d.page_content for d in docs])
86
+ else:
87
+ content = "No relevant documents found"
88
+
89
+ return {"messages": [HumanMessage(content=content)]}
90
+
91
+ def _assistant(self, state: MessagesState):
92
+ """Assistant node"""
93
+ if not any(isinstance(m, SystemMessage) for m in state["messages"]):
94
+ messages = [self.sys_msg] + state["messages"]
95
+ else:
96
+ messages = state["messages"]
97
+
98
+ llm_response = self.llm_with_tools.invoke(messages)
99
+
100
+ return {"messages": [llm_response]}
101
+
102
+ @staticmethod
103
+ def extract_after_final_answer(text):
104
  keyword = "FINAL ANSWER: "
105
  index = text.find(keyword)
106
  if index != -1:
 
117
  else:
118
  question_text = f'{question} with TASK-ID: {task_id}'
119
 
120
+ graph = self._graph_compile()
 
121
 
122
  max_retries = 3
123
  base_sleep = 1
124
  for attempt in range(max_retries):
125
  try:
126
+ messages: list[HumanMessage] = [HumanMessage(content=question_text)]
127
+ result = graph.invoke({"messages": messages})
128
+
129
+ final_text = result["messages"][-1].content
130
+ return self.extract_after_final_answer(final_text)
131
  except Exception as e:
132
  sleep_time = base_sleep * (attempt + 1)
133
  if attempt < max_retries - 1:
agent_tools.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import re
3
- from typing import List
4
 
5
  import pandas as pd
6
  import requests
@@ -284,3 +284,21 @@ def arvix_search(query: str) -> str:
284
  ])
285
  return formatted_search_docs
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
+ from typing import List, Callable, Any
4
 
5
  import pandas as pd
6
  import requests
 
284
  ])
285
  return formatted_search_docs
286
 
287
+
288
+ TOOLS: List[Callable[..., Any]] = [
289
+ multiply,
290
+ add,
291
+ subtract,
292
+ divide,
293
+ modulus,
294
+ duckduck_websearch,
295
+ arvix_search,
296
+ wiki_search,
297
+ visit_webpage,
298
+ youtube_search,
299
+ text_splitter,
300
+ read_file,
301
+ excel_read,
302
+ csv_read,
303
+ image_caption,
304
+ ]
requirements.txt CHANGED
@@ -26,4 +26,8 @@ langchain_google_genai
26
  langchain_openai
27
  google-genai
28
  openpyxl
29
- chromadb
 
 
 
 
 
26
  langchain_openai
27
  google-genai
28
  openpyxl
29
+ chromadb
30
+ langchain-google-genai
31
+ pytesseract
32
+ matplotlib
33
+ sentence_transformers
test.py CHANGED
@@ -107,7 +107,7 @@ questions = [
107
  # Test
108
  if __name__ == "__main__":
109
  agent = CUSTOM_AGENT()
110
- q = questions[1]
111
  print("Question:", q["question"])
112
  answer = agent.run(q)
113
  print("Answer:", answer)
 
107
  # Test
108
  if __name__ == "__main__":
109
  agent = CUSTOM_AGENT()
110
+ q = questions[0]
111
  print("Question:", q["question"])
112
  answer = agent.run(q)
113
  print("Answer:", answer)