DrekFretson commited on
Commit
083cf21
·
verified ·
1 Parent(s): 2a3d964

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +70 -202
agent.py CHANGED
@@ -1,213 +1,81 @@
1
- """LangGraph Agent"""
2
- import os
3
- from dotenv import load_dotenv
4
- from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_groq import ChatGroq
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
- from langchain_community.tools.tavily_search import TavilySearchResults
11
- from langchain_community.document_loaders import WikipediaLoader
12
- from langchain_community.document_loaders import ArxivLoader
13
- from langchain_community.vectorstores import SupabaseVectorStore
14
- from langchain_core.messages import SystemMessage, HumanMessage
15
- from langchain_core.tools import tool
16
- from langchain.tools.retriever import create_retriever_tool
17
- from supabase.client import Client, create_client
18
 
19
- load_dotenv()
20
 
21
- @tool
22
- def multiply(a: int, b: int) -> int:
23
- """Multiply two numbers.
24
- Args:
25
- a: first int
26
- b: second int
27
- """
28
- return a * b
29
 
30
- @tool
31
- def add(a: int, b: int) -> int:
32
- """Add two numbers.
33
-
34
- Args:
35
- a: first int
36
- b: second int
37
- """
38
- return a + b
39
 
40
- @tool
41
- def subtract(a: int, b: int) -> int:
42
- """Subtract two numbers.
43
-
44
- Args:
45
- a: first int
46
- b: second int
47
- """
48
- return a - b
49
 
50
- @tool
51
- def divide(a: int, b: int) -> int:
52
- """Divide two numbers.
53
-
54
- Args:
55
- a: first int
56
- b: second int
57
  """
58
- if b == 0:
59
- raise ValueError("Cannot divide by zero.")
60
- return a / b
61
-
62
- @tool
63
- def modulus(a: int, b: int) -> int:
64
- """Get the modulus of two numbers.
65
-
66
  Args:
67
- a: first int
68
- b: second int
 
69
  """
70
- return a % b
71
-
72
- @tool
73
- def wiki_search(query: str) -> str:
74
- """Search Wikipedia for a query and return maximum 2 results.
75
-
76
- Args:
77
- query: The search query."""
78
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
- formatted_search_docs = "\n\n---\n\n".join(
80
- [
81
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
- for doc in search_docs
83
- ])
84
- return {"wiki_results": formatted_search_docs}
85
-
86
- @tool
87
- def web_search(query: str) -> str:
88
- """Search Tavily for a query and return maximum 3 results.
89
-
90
- Args:
91
- query: The search query."""
92
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
93
- formatted_search_docs = "\n\n---\n\n".join(
94
- [
95
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
- for doc in search_docs
97
- ])
98
- return {"web_results": formatted_search_docs}
99
-
100
- @tool
101
- def arvix_search(query: str) -> str:
102
- """Search Arxiv for a query and return maximum 3 result.
103
-
104
- Args:
105
- query: The search query."""
106
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
- formatted_search_docs = "\n\n---\n\n".join(
108
- [
109
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
- for doc in search_docs
111
- ])
112
- return {"arvix_results": formatted_search_docs}
113
 
114
-
115
-
116
- # load the system prompt from the file
117
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
- system_prompt = f.read()
119
-
120
- # System message
121
- sys_msg = SystemMessage(content=system_prompt)
122
-
123
- # build a retriever
124
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
- supabase: Client = create_client(
126
- os.environ.get("SUPABASE_URL"),
127
- os.environ.get("SUPABASE_SERVICE_KEY"))
128
- vector_store = SupabaseVectorStore(
129
- client=supabase,
130
- embedding= embeddings,
131
- table_name="documents",
132
- query_name="match_documents_langchain",
133
- )
134
- create_retriever_tool = create_retriever_tool(
135
- retriever=vector_store.as_retriever(),
136
- name="Question Search",
137
- description="A tool to retrieve similar questions from a vector store.",
138
- )
139
-
140
-
141
-
142
- tools = [
143
- multiply,
144
- add,
145
- subtract,
146
- divide,
147
- modulus,
148
- wiki_search,
149
- web_search,
150
- arvix_search,
151
- ]
152
-
153
- # Build graph function
154
- def build_graph(provider: str = "groq"):
155
- """Build the graph"""
156
- # Load environment variables from .env file
157
- if provider == "google":
158
- # Google Gemini
159
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
160
- elif provider == "groq":
161
- # Groq https://console.groq.com/docs/models
162
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
163
- elif provider == "huggingface":
164
- # TODO: Add huggingface endpoint
165
- llm = ChatHuggingFace(
166
- llm=HuggingFaceEndpoint(
167
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
168
- temperature=0,
169
- ),
170
  )
171
- else:
172
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
173
- # Bind tools to LLM
174
- llm_with_tools = llm.bind_tools(tools)
175
-
176
- # Node
177
- def assistant(state: MessagesState):
178
- """Assistant node"""
179
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
-
181
- def retriever(state: MessagesState):
182
- """Retriever node"""
183
- similar_question = vector_store.similarity_search(state["messages"][0].content)
184
- example_msg = HumanMessage(
185
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
 
 
186
  )
187
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
-
189
- builder = StateGraph(MessagesState)
190
- builder.add_node("retriever", retriever)
191
- builder.add_node("assistant", assistant)
192
- builder.add_node("tools", ToolNode(tools))
193
- builder.add_edge(START, "retriever")
194
- builder.add_edge("retriever", "assistant")
195
- builder.add_conditional_edges(
196
- "assistant",
197
- tools_condition,
198
- )
199
- builder.add_edge("tools", "assistant")
200
-
201
- # Compile graph
202
- return builder.compile()
203
-
204
- # test
205
- if __name__ == "__main__":
206
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
207
- # Build the graph
208
- graph = build_graph(provider="groq")
209
- # Run the graph
210
- messages = [HumanMessage(content=question)]
211
- messages = graph.invoke({"messages": messages})
212
- for m in messages["messages"]:
213
- m.pretty_print()
 
1
+ from typing import Any, List, Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ from smolagents import CodeAgent
4
 
5
+ from utils.logger import get_logger
 
 
 
 
 
 
 
6
 
7
+ logger = get_logger(__name__)
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
9
 
10
+ class Agent:
 
 
 
 
 
 
11
  """
12
+ Agent class that wraps a CodeAgent and provides a callable interface for answering questions.
 
 
 
 
 
 
 
13
  Args:
14
+ model (Any): The language model to use.
15
+ tools (Optional[List[Any]]): List of tools to provide to the agent.
16
+ prompt (Optional[str]): Custom prompt template for the agent.
17
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ def __init__(
20
+ self,
21
+ model: Any,
22
+ tools: Optional[List[Any]] = None,
23
+ prompt: Optional[str] = None,
24
+ ):
25
+ logger.info("Initializing Agent")
26
+ self.model = model
27
+ self.tools = tools
28
+ self.imports = [
29
+ "pandas",
30
+ "numpy",
31
+ "os",
32
+ "requests",
33
+ "tempfile",
34
+ "datetime",
35
+ "json",
36
+ "time",
37
+ "re",
38
+ "openpyxl",
39
+ "pathlib",
40
+ "sys",
41
+ ]
42
+ self.agent = CodeAgent(
43
+ model=self.model,
44
+ tools=self.tools,
45
+ add_base_tools=True,
46
+ additional_authorized_imports=self.imports,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  )
48
+ self.prompt = prompt or (
49
+ """
50
+ You are an advanced AI assistant specialized in solving complex, real-world tasks that require multi-step reasoning, factual accuracy, and use of external tools.
51
+
52
+ Follow these principles:
53
+ - Be precise and concise. The final answer must strictly match the required format with no extra commentary.
54
+ - Use tools intelligently. If a question involves external information, structured data, images, or audio, call the appropriate tool to retrieve or process it.
55
+ - Reason step-by-step. Think through the solution logically and plan your actions carefully before answering.
56
+ - Validate information. Always verify facts when possible instead of guessing.
57
+ - Use code if needed. For calculations, parsing, or transformations, generate Python code and execute it. But be careful, some questions contains time-consuming tasks, so you should be careful with the code you run. Better analyze the question and think about the best way to solve it.
58
+ - Don't forget to use `final_answer` to give the final answer.
59
+ - Use name of file ONLY FROM "FILE:" section. THIS IF ALWAYS A FILE.
60
+ IMPORTANT: When giving the final answer, output only the direct required result without any extra text like "Final Answer:" or explanations. YOU MUST RESPOND IN THE EXACT FORMAT AS THE QUESTION.
61
+ QUESTION: {question}
62
+ FILE: {context}
63
+ ANSWER:
64
+ """
65
  )
66
+ logger.info("Agent initialized")
67
+
68
+ def __call__(self, question: str, file_path: Optional[str] = None) -> str:
69
+ """
70
+ Run the agent to answer a question, optionally using a file as context.
71
+ Args:
72
+ question (str): The question to answer.
73
+ file_path (Optional[str]): Path to a file to use as context (if any).
74
+ Returns:
75
+ str: The agent's answer as a string.
76
+ """
77
+ answer = self.agent.run(
78
+ self.prompt.format(question=question, context=file_path)
79
+ )
80
+ answer = str(answer).strip("'").strip('"').strip()
81
+ return answer