dleandro commited on
Commit
3927a42
·
1 Parent(s): 0feea5d

Tools, Retriever, Systemp prompt and Agent creation

Browse files
Files changed (4) hide show
  1. agent.py +120 -0
  2. retriever.py +44 -0
  3. system_prompt.txt +12 -0
  4. tools.py +300 -0
agent.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from typing import TypedDict, List, Dict, Any, Optional
4
+
5
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
6
+ from langchain_huggingface.chat_models import ChatHuggingFace
7
+ from langchain_groq.chat_models import ChatGroq
8
+
9
+ from langgraph.graph.message import add_messages
10
+ from langgraph.graph import StateGraph, START, END, MessagesState
11
+ from langgraph.prebuilt import ToolNode, tools_condition
12
+
13
+ from tools import (
14
+ add,
15
+ subtract, multiply, div, modulus, power,
16
+ wikipedia_search, search_web, arxiv_search,
17
+ save_and_read_file, download_file_from_url, extract_text_from_image,
18
+ pdf_loader
19
+ )
20
+ from retriever import get_retriever_tool
21
+
22
+ load_dotenv(dotenv_path = ".env")
23
+
24
+ # Configurations
25
+ SYSTEM_PROMPT_PATH = "system_prompt.txt"
26
+ DEFAULT_PROVIDER = "groq"
27
+ MODEL_NAME = "llama3-70b-8192"
28
+
29
+ def load_system_prompt(path: str = SYSTEM_PROMPT_PATH) -> str:
30
+ if not os.path.exists(path):
31
+ raise ValueError(f"System prompt file not foud at: {path}")
32
+ with open(path, "r", encoding = "utf-8") as f:
33
+ return f.read()
34
+
35
+
36
+ system_prompt = load_system_prompt()
37
+ sys_msg = SystemMessage(content = system_prompt)
38
+
39
+ # Load tools
40
+ vector_store, vector_retriever, retriever_tool = get_retriever_tool()
41
+
42
+ TOOLS = [
43
+ # Math
44
+ add, subtract, multiply, div, modulus, power,
45
+ # Documents Search
46
+ wikipedia_search, search_web, arxiv_search,
47
+ # Process Files
48
+ save_and_read_file, download_file_from_url, extract_text_from_image,
49
+ pdf_loader,
50
+ # Retriever
51
+ retriever_tool
52
+ ]
53
+
54
+ def get_llm(provider: str = DEFAULT_PROVIDER):
55
+ if provider == "groq":
56
+ return ChatGroq(model = MODEL_NAME, temperature = 0)
57
+ elif provider == "huggingface":
58
+ raise NotImplementedError("HuggingFace support not yet implemented.")
59
+ else:
60
+ raise ValueError("Invalid LLM provider. Choose 'groq' or 'huggingface'")
61
+
62
+
63
+ def build_graph(provider: str = DEFAULT_PROVIDER):
64
+ """
65
+ Builds LangGraph graph
66
+ """
67
+ llm = get_llm(provider)
68
+
69
+ # Add tools to the LLM
70
+ llm_with_tools = llm.bind_tools(TOOLS)
71
+
72
+ def assistant(state: MessagesState):
73
+ return {"messages": llm_with_tools.invoke(state["messages"])}
74
+
75
+ def retriever(state: MessagesState):
76
+ query = state["messages"][0].content
77
+ similar_qas = vector_store.similarity_search(query)
78
+
79
+ if similar_qas:
80
+ reference = similar_qas[0].page_content
81
+ example_qa = HumanMessage(
82
+ content = f"I provide a similar question and answer for reference:\n\n{reference}"
83
+ )
84
+ return {"messages": [sys_msg] + state["messages"] + [example_qa]}
85
+ else:
86
+ return {"messages": [sys_msg] + state["messages"]}
87
+
88
+ # Graph
89
+ builder = StateGraph(MessagesState)
90
+
91
+ # Nodes
92
+ builder.add_node("retriever", retriever)
93
+ builder.add_node("assistant", assistant)
94
+ builder.add_node("tools", ToolNode(TOOLS))
95
+
96
+ # Edges
97
+ builder.add_edge(START, "retriever")
98
+ builder.add_edge("retriever", "assistant")
99
+ builder.add_conditional_edges(
100
+ "assistant",
101
+ tools_condition
102
+ )
103
+ builder.add_edge("tools", "assistant")
104
+
105
+ return builder.compile()
106
+
107
+ if __name__ == "__main__":
108
+ import random
109
+ import json
110
+
111
+ with open("metadata.jsonl") as dataset_file:
112
+ json_list = list(dataset_file)
113
+
114
+ QAs = [json.loads(qa) for qa in json_list]
115
+ question = random.choice(QAs)["Question"]
116
+ graph = build_graph()
117
+ messages = [HumanMessage(content = question)]
118
+ messages = graph.invoke({"messages": messages})
119
+ for m in messages["messages"]:
120
+ m.pretty_print()
retriever.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ from huggingface_hub import login
5
+ from supabase import Client, create_client
6
+ from supabase.client import ClientOptions
7
+
8
+ from langchain_community.vectorstores import SupabaseVectorStore
9
+ from langchain_huggingface.embeddings import HuggingFaceEmbeddings
10
+ from langchain.tools.retriever import create_retriever_tool
11
+
12
+ load_dotenv()
13
+
14
+ MODEL_NAME = "BAAI/bge-base-en-v1.5"
15
+ TBL_NAME = "documents_tbl"
16
+ QUERY_NAME = "match_documents"
17
+
18
+ def get_retriever_tool():
19
+ embedding_model = HuggingFaceEmbeddings(model_name = MODEL_NAME)
20
+ DIMS_EMBEDDING = embedding_model._client.get_sentence_embedding_dimension()
21
+
22
+ # Supabase client
23
+ supabase: Client = create_client(
24
+ os.environ.get("SUPABASE_URL"),
25
+ os.environ.get("SUPABASE_ANON_KEY"),
26
+ options = ClientOptions(schema = "public")
27
+ )
28
+
29
+ # Vector Store
30
+ vector_store = SupabaseVectorStore(
31
+ client = supabase,
32
+ embedding = embedding_model,
33
+ table_name = TBL_NAME,
34
+ query_name = QUERY_NAME
35
+ )
36
+
37
+ vector_retriever = vector_store.as_retriever()
38
+
39
+ retriever_tool = create_retriever_tool(
40
+ retriever = vector_retriever,
41
+ name = "question_search_tool",
42
+ description = "A tool to retrieve similar questions based on embedding from Supabase vector store."
43
+ )
44
+ return vector_store, vector_retriever, retriever_tool
system_prompt.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a general AI assistant tasked with answering questions using a set of tools.
2
+ I will ask you a question. Report your thoughts, and finish your answer with the following template:
3
+ FINAL ANSWER: [YOUR FINAL ANSWER].
4
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
5
+ If you are asked for a number, don’t use comma to write your number neither use units such as $ or
6
+ percent sign unless specified otherwise.
7
+ If you are asked for a string, don’t use articles, neither abbreviations (e.g. for cities), and write the
8
+ digits in plain text unless specified otherwise.
9
+ If you are asked for a comma separated list, apply the above rules depending of whether the element
10
+ to be put in the list is a number or a string, ensure there is exactly one space after each comma.
11
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
12
+
tools.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - `Search Engine` (arXiv, Wikipedia, DuckDuckGo)
2
+ # - `Calculator` (add, substract, divide, multiply, modulus, etc.)
3
+ # - `Access` and `Download Files` from Web
4
+ # - `Excel`/`Google Sheets`: Process Downloaded files
5
+
6
+ import os
7
+ import requests
8
+ import tempfile
9
+ import uuid
10
+
11
+ import pytesseract
12
+
13
+ from datetime import datetime, timezone
14
+ from PIL import Image
15
+ from urllib.parse import urlparse
16
+ from typing import Optional
17
+
18
+ from langchain_core.tools import tool
19
+ from langchain_community.document_loaders import (
20
+ WikipediaLoader, ArxivLoader, PyPDFLoader
21
+ )
22
+ from langchain_community.tools import DuckDuckGoSearchResults
23
+
24
+ #* === MATH TOOLS ===
25
+ @tool
26
+ def add(a: float, b: float) -> int:
27
+ """
28
+ Adds multple integers.
29
+
30
+ Args:
31
+ a (float): Number to add.
32
+ b (float): Number to add.
33
+
34
+ Returns:
35
+ int: Sum of the two provided integers.
36
+ """
37
+ return a + b
38
+
39
+ @tool
40
+ def subtract(a: int, b: int) -> int:
41
+ """
42
+ Subtracts one integer from another.
43
+
44
+ Args:
45
+ a (int): The number from which to subtract.
46
+ b (int): The number to subtract.
47
+
48
+ Returns:
49
+ int: The result of a - b.
50
+ """
51
+ return a - b
52
+
53
+ @tool
54
+ def multiply(a: float, b: float) -> int:
55
+ """
56
+ Multiplies multple integers.
57
+
58
+ Args:
59
+ a (float): First number to multiply.
60
+ b (float): Second number to multiply.
61
+
62
+ Returns:
63
+ int: Multiplication of the two provided floats or integers.
64
+ """
65
+ return a * b
66
+
67
+ @tool
68
+ def div(a: float, b: float):
69
+ """
70
+ Divides two numbers.
71
+
72
+ Args:
73
+ a (int or float): The dividend.
74
+ b (int or float): The divisor.
75
+
76
+ Returns:
77
+ float: The result of dividing a by b.
78
+
79
+ Raises:
80
+ ZeroDivisionError: If b is zero.
81
+ """
82
+ return a / b
83
+
84
+ @tool
85
+ def modulus(a: int, b: int):
86
+ """
87
+ Computes the modulus (remainder) of dividing two integers.
88
+
89
+ Args:
90
+ a (int): The dividend.
91
+ b (int): The divisor.
92
+
93
+ Returns:
94
+ int: The remainder when a is divided by b.
95
+
96
+ Raises:
97
+ ZeroDivisionError: If b is zero.
98
+ """
99
+ return a % b
100
+
101
+ @tool
102
+ def power(a: float, b: float) -> float:
103
+ """
104
+ Raises a number `a` to the power of `b`.
105
+
106
+ Args:
107
+ a (float): Base number.
108
+ b (float): Exponent.
109
+
110
+ Returns:
111
+ float: Result of a ** b.
112
+ """
113
+ return a**b
114
+
115
+ #* === SEARCH TOOLS ===
116
+
117
+ @tool
118
+ def wikipedia_search(query: str) -> dict:
119
+ """
120
+ Search Wikipedia for a query and return up to 3 formatted results.
121
+
122
+ Args:
123
+ query (str): The topic to search for.
124
+
125
+ Returns:
126
+ dict: A dictionary with the key 'wikipedia_results' containing the formatted documents.
127
+ """
128
+ search_docs = WikipediaLoader(query = query,load_max_docs = 3).load()
129
+ formatted_docs = "\n\n---\n\n".join(
130
+ [
131
+ f"Document source='{doc.metadata['source']}' page={doc.metadata.get('page', '')}/>\n"
132
+ f"{doc.page_content}\n</Document>"
133
+ for doc in search_docs
134
+ ]
135
+ )
136
+ return {"wikipedia_results": formatted_docs}
137
+
138
+ @tool
139
+ def search_web(query: str) -> dict:
140
+ """
141
+ Performs a web search using DuckDuckGo and returns up to 4 formatted results.
142
+
143
+ Args:
144
+ query (str): The search query to submit to DuckDuckGo.
145
+
146
+ Returns:
147
+ dict: A dictionary with a single key "web_results" containing the formatted search results
148
+ as a string. Each result includes the document source and content, separated by "---".
149
+ """
150
+ search_docs = DuckDuckGoSearchResults(max_results = 4).invoke(query)
151
+ formatted_docs = "\n\n---\n\n".join(
152
+ [
153
+ f"Document source='{doc.metadata['source']}' page={doc.metadata.get('page', '')}/>\n"
154
+ f"{doc.page_content}\n</Document>"
155
+ for doc in search_docs
156
+ ]
157
+ )
158
+ return {"web_results": formatted_docs}
159
+
160
+ @tool
161
+ def arxiv_search(query: str) -> dict:
162
+ """
163
+ Perform a search on the arXiv academic paper repository and return the top results.
164
+
165
+ Args:
166
+ query (str): The search query to use on arXiv.
167
+
168
+ Returns:
169
+ dict: A dictionary containing a string under the key "arxiv_results", which includes
170
+ a formatted summary of the top retrieved documents. Each entry contains the
171
+ document's source, optional page number, and the first 1000 characters of the content.
172
+ """
173
+ search_docs = ArxivLoader(query = query, load_max_docs = 3).load()
174
+ formatted_docs = "\n\n---\n\n".join(
175
+ [
176
+ f"Document source='{doc.metadata['source']}' page={doc.metadata.get('page', '')}/>\n"
177
+ f"{doc.page_content[:1000]}\n</Document>"
178
+ for doc in search_docs
179
+ ]
180
+ )
181
+ return {"arxiv_results": formatted_docs}
182
+
183
+ #* === FILE PROCESSING TOOLS ===
184
+
185
+ @tool
186
+ def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
187
+ """
188
+ Saves the provided text content to a temporary file and returns its path.
189
+
190
+ If no filename is provided, a random temporary filename will be generated.
191
+ The file is saved in the system's temporary directory.
192
+
193
+ Args:
194
+ content (str): The text content to be written to the file.
195
+ filename (Optional[str]): Optional name for the file. If not provided, a temporary name is used.
196
+
197
+ Returns:
198
+ str: A message with the path to the saved file, indicating it is ready for processing.
199
+ """
200
+ try:
201
+ temp_dir = tempfile.gettempdir()
202
+ if filename is None:
203
+ temp_file = tempfile.NamedTemporaryFile(delete = False, dir = temp_dir)
204
+ filepath = temp_file.name
205
+ else:
206
+ filepath = os.path.join(temp_dir, filename)
207
+
208
+ with open(filepath, "w", encoding = "utf-8") as f:
209
+ f.write(content)
210
+ return f"File saved to {filepath}. It is available to read for processing its contents."
211
+ except Exception as e:
212
+ return f"Error saving file: {str(e)}"
213
+
214
+ @tool
215
+ def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
216
+ """
217
+ Downloads a file from a given URL and saves it to a temporary directory.
218
+
219
+ If no filename is provided, it attempts to extract it from the URL. If the URL
220
+ does not contain a valid filename, a temporary unique filename will be generated.
221
+
222
+ Args:
223
+ url (str): The URL of the file to download.
224
+ filename (Optional[str]): Optional name for the downloaded file.
225
+
226
+ Returns:
227
+ str: A string indicating the path to the downloaded file, or an error message.
228
+ """
229
+ try:
230
+ # Parse URL to get filename if not provided
231
+ if not filename:
232
+ path = urlparse(url).path
233
+ filename = os.path.basename(path)
234
+ if not filename:
235
+ ts = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')
236
+ filename = f"downloaded_{ts}_{uuid.uuid4().hex[:8]}.tmp"
237
+
238
+ # Create temporary file
239
+ temp_dir = tempfile.gettempdir()
240
+ filepath = os.path.join(temp_dir, filename)
241
+
242
+ # Download the file
243
+ response = requests.get(url, stream = True)
244
+ response.raise_for_status()
245
+
246
+ # Save the file
247
+ with open(filepath, "wb") as f:
248
+ for chunk in response.iter_content(chunk_size = 8192):
249
+ if chunk:
250
+ f.write(chunk)
251
+ return f"File downloaded to {filepath}. It is available to read for processing its contents."
252
+ except Exception as e:
253
+ return f"Error downloading file: {str(e)}"
254
+
255
+ @ tool
256
+ def extract_text_from_image(image_path: str) -> str:
257
+ """
258
+ Extracts text content from an image file using Optical Character Recognition (OCR).
259
+
260
+ Args:
261
+ image_path (str): The path to the image file from which text will be extracted.
262
+
263
+ Returns:
264
+ str: Extracted text content. If extraction fails, returns an error message.
265
+ """
266
+ try:
267
+ # Open image
268
+ image = Image.open(image_path)
269
+ # Extract text from image
270
+ text = pytesseract.image_to_string(image)
271
+ return f"Text extracted from image:\n\n{text.strip()}"
272
+ except Exception as e:
273
+ return f"Error extracting text from image '{image_path}': {str(e)}"
274
+
275
+ @tool
276
+ def pdf_loader(filepath: str) -> dict:
277
+ """
278
+ Loads a PDF file from the given file path, parses its contents,
279
+ and returns a preview of each page's content (up to 1000 characters per page).
280
+
281
+ Args:
282
+ filepath (str): The full path to the PDF file.
283
+
284
+ Returns:
285
+ dict: A dictionary containing formatted PDF page previews under the key 'pdf_results'.
286
+ Each page is separated by "\n\n---\n\n".
287
+ """
288
+ try:
289
+ pdf_content = PyPDFLoader(file_path=filepath).load()
290
+ formatted_content = "\n\n---\n\n".join(
291
+ [
292
+ f"Document source='{doc.metadata['source']}' page={doc.metadata.get('page', '')}/>\n"
293
+ f"{doc.page_content[:1000]}\n</Document>"
294
+ for doc in pdf_content
295
+ ]
296
+ )
297
+ return {"pdf_results": formatted_content}
298
+ except Exception as e:
299
+ return {"pdf_results": f"Error reading PDF file: {str(e)}"}
300
+