CiccioQuinto commited on
Commit
cee45ae
·
verified ·
1 Parent(s): 0b77a82

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +673 -86
agent.py CHANGED
@@ -1,15 +1,36 @@
1
- """LangGraph Agent"""
2
  import os
3
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_groq import ChatGroq
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
12
  from langchain_community.document_loaders import ArxivLoader
 
 
 
 
 
 
 
 
13
  from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
@@ -18,117 +39,664 @@ from supabase.client import Client, create_client
18
 
19
  load_dotenv()
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  @tool
22
- def multiply(a: int, b: int) -> int:
23
- """Multiply two numbers.
24
 
25
  Args:
26
- a: first int
27
- b: second int
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  """
29
  return a * b
30
 
 
31
  @tool
32
- def add(a: int, b: int) -> int:
33
- """Add two numbers.
34
-
 
35
  Args:
36
- a: first int
37
- b: second int
38
  """
39
  return a + b
40
 
 
41
  @tool
42
- def subtract(a: int, b: int) -> int:
43
- """Subtract two numbers.
44
-
 
45
  Args:
46
- a: first int
47
- b: second int
48
  """
49
  return a - b
50
 
 
51
  @tool
52
- def divide(a: int, b: int) -> int:
53
- """Divide two numbers.
54
-
 
55
  Args:
56
- a: first int
57
- b: second int
58
  """
59
  if b == 0:
60
- raise ValueError("Cannot divide by zero.")
61
  return a / b
62
 
 
63
  @tool
64
  def modulus(a: int, b: int) -> int:
65
- """Get the modulus of two numbers.
66
-
 
67
  Args:
68
- a: first int
69
- b: second int
70
  """
71
  return a % b
72
 
 
73
  @tool
74
- def wiki_search(query: str) -> str:
75
- """Search Wikipedia for a query and return maximum 2 results.
76
-
 
77
  Args:
78
- query: The search query."""
79
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
80
- formatted_search_docs = "\n\n---\n\n".join(
81
- [
82
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
83
- for doc in search_docs
84
- ])
85
- return {"wiki_results": formatted_search_docs}
86
 
87
  @tool
88
- def web_search(query: str) -> str:
89
- """Search Tavily for a query and return maximum 3 results.
90
-
 
91
  Args:
92
- query: The search query."""
93
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
94
- formatted_search_docs = "\n\n---\n\n".join(
95
- [
96
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
97
- for doc in search_docs
98
- ])
99
- return {"web_results": formatted_search_docs}
 
100
 
101
  @tool
102
- def arvix_search(query: str) -> str:
103
- """Search Arxiv for a query and return maximum 3 result.
104
-
 
105
  Args:
106
- query: The search query."""
107
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
108
- formatted_search_docs = "\n\n---\n\n".join(
109
- [
110
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
111
- for doc in search_docs
112
- ])
113
- return {"arvix_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
 
117
  # load the system prompt from the file
118
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
119
  system_prompt = f.read()
 
120
 
121
  # System message
122
  sys_msg = SystemMessage(content=system_prompt)
123
 
124
  # build a retriever
125
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
 
 
126
  supabase: Client = create_client(
127
- os.environ.get("SUPABASE_URL"),
128
- os.environ.get("SUPABASE_SERVICE_KEY"))
129
  vector_store = SupabaseVectorStore(
130
  client=supabase,
131
- embedding= embeddings,
132
  table_name="documents",
133
  query_name="match_documents_langchain",
134
  )
@@ -139,38 +707,53 @@ create_retriever_tool = create_retriever_tool(
139
  )
140
 
141
 
142
-
143
  tools = [
 
 
 
144
  multiply,
145
  add,
146
  subtract,
147
  divide,
148
  modulus,
149
- wiki_search,
150
- web_search,
151
- arvix_search,
 
 
 
 
 
 
 
 
 
 
152
  ]
153
 
 
154
  # Build graph function
155
  def build_graph(provider: str = "groq"):
156
  """Build the graph"""
157
  # Load environment variables from .env file
158
- if provider == "google":
159
- # Google Gemini
160
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
161
- elif provider == "groq":
162
  # Groq https://console.groq.com/docs/models
163
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
164
  elif provider == "huggingface":
165
  # TODO: Add huggingface endpoint
166
  llm = ChatHuggingFace(
167
  llm=HuggingFaceEndpoint(
168
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
 
 
 
 
169
  temperature=0,
170
  ),
 
171
  )
172
  else:
173
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
174
  # Bind tools to LLM
175
  llm_with_tools = llm.bind_tools(tools)
176
 
@@ -178,14 +761,19 @@ def build_graph(provider: str = "groq"):
178
  def assistant(state: MessagesState):
179
  """Assistant node"""
180
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
181
-
182
  def retriever(state: MessagesState):
183
  """Retriever node"""
184
  similar_question = vector_store.similarity_search(state["messages"][0].content)
185
- example_msg = HumanMessage(
186
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
187
- )
188
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
 
 
 
 
 
189
 
190
  builder = StateGraph(MessagesState)
191
  builder.add_node("retriever", retriever)
@@ -202,13 +790,12 @@ def build_graph(provider: str = "groq"):
202
  # Compile graph
203
  return builder.compile()
204
 
 
205
  # test
206
  if __name__ == "__main__":
207
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
208
- # Build the graph
209
  graph = build_graph(provider="groq")
210
- # Run the graph
211
  messages = [HumanMessage(content=question)]
212
  messages = graph.invoke({"messages": messages})
213
  for m in messages["messages"]:
214
- m.pretty_print()
 
 
1
  import os
2
  from dotenv import load_dotenv
3
+ from typing import List, Dict, Any, Optional
4
+ import tempfile
5
+ import re
6
+ import json
7
+ import requests
8
+ from urllib.parse import urlparse
9
+ import pytesseract
10
+ from PIL import Image, ImageDraw, ImageFont, ImageEnhance, ImageFilter
11
+ import cmath
12
+ import pandas as pd
13
+ import uuid
14
+ import numpy as np
15
+ from code_interpreter import CodeInterpreter
16
+
17
+ interpreter_instance = CodeInterpreter()
18
+
19
+ from image_processing import *
20
+
21
+ """Langraph"""
22
  from langgraph.graph import START, StateGraph, MessagesState
 
 
 
 
 
23
  from langchain_community.tools.tavily_search import TavilySearchResults
24
  from langchain_community.document_loaders import WikipediaLoader
25
  from langchain_community.document_loaders import ArxivLoader
26
+ from langgraph.prebuilt import ToolNode, tools_condition
27
+ from langchain_google_genai import ChatGoogleGenerativeAI
28
+ from langchain_groq import ChatGroq
29
+ from langchain_huggingface import (
30
+ ChatHuggingFace,
31
+ HuggingFaceEndpoint,
32
+ HuggingFaceEmbeddings,
33
+ )
34
  from langchain_community.vectorstores import SupabaseVectorStore
35
  from langchain_core.messages import SystemMessage, HumanMessage
36
  from langchain_core.tools import tool
 
39
 
40
  load_dotenv()
41
 
42
+ ### =============== BROWSER TOOLS =============== ###
43
+
44
+
45
+ @tool
46
+ def wiki_search(query: str) -> str:
47
+ """Search Wikipedia for a query and return maximum 2 results.
48
+
49
+ Args:
50
+ query: The search query."""
51
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
52
+ formatted_search_docs = "\n\n---\n\n".join(
53
+ [
54
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
55
+ for doc in search_docs
56
+ ]
57
+ )
58
+ return {"wiki_results": formatted_search_docs}
59
+
60
+
61
+ @tool
62
+ def web_search(query: str) -> str:
63
+ """Search Tavily for a query and return maximum 3 results.
64
+
65
+ Args:
66
+ query: The search query."""
67
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
68
+ formatted_search_docs = "\n\n---\n\n".join(
69
+ [
70
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
71
+ for doc in search_docs
72
+ ]
73
+ )
74
+ return {"web_results": formatted_search_docs}
75
+
76
+
77
  @tool
78
+ def arxiv_search(query: str) -> str:
79
+ """Search Arxiv for a query and return maximum 3 result.
80
 
81
  Args:
82
+ query: The search query."""
83
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
84
+ formatted_search_docs = "\n\n---\n\n".join(
85
+ [
86
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
87
+ for doc in search_docs
88
+ ]
89
+ )
90
+ return {"arxiv_results": formatted_search_docs}
91
+
92
+
93
+ ### =============== CODE INTERPRETER TOOLS =============== ###
94
+
95
+
96
+ @tool
97
+ def execute_code_multilang(code: str, language: str = "python") -> str:
98
+ """Execute code in multiple languages (Python, Bash, SQL, C, Java) and return results.
99
+
100
+ Args:
101
+ code (str): The source code to execute.
102
+ language (str): The language of the code. Supported: "python", "bash", "sql", "c", "java".
103
+
104
+ Returns:
105
+ A string summarizing the execution results (stdout, stderr, errors, plots, dataframes if any).
106
+ """
107
+ supported_languages = ["python", "bash", "sql", "c", "java"]
108
+ language = language.lower()
109
+
110
+ if language not in supported_languages:
111
+ return f"❌ Unsupported language: {language}. Supported languages are: {', '.join(supported_languages)}"
112
+
113
+ result = interpreter_instance.execute_code(code, language=language)
114
+
115
+ response = []
116
+
117
+ if result["status"] == "success":
118
+ response.append(f"✅ Code executed successfully in **{language.upper()}**")
119
+
120
+ if result.get("stdout"):
121
+ response.append(
122
+ "\n**Standard Output:**\n```\n" + result["stdout"].strip() + "\n```"
123
+ )
124
+
125
+ if result.get("stderr"):
126
+ response.append(
127
+ "\n**Standard Error (if any):**\n```\n"
128
+ + result["stderr"].strip()
129
+ + "\n```"
130
+ )
131
+
132
+ if result.get("result") is not None:
133
+ response.append(
134
+ "\n**Execution Result:**\n```\n"
135
+ + str(result["result"]).strip()
136
+ + "\n```"
137
+ )
138
+
139
+ if result.get("dataframes"):
140
+ for df_info in result["dataframes"]:
141
+ response.append(
142
+ f"\n**DataFrame `{df_info['name']}` (Shape: {df_info['shape']})**"
143
+ )
144
+ df_preview = pd.DataFrame(df_info["head"])
145
+ response.append("First 5 rows:\n```\n" + str(df_preview) + "\n```")
146
+
147
+ if result.get("plots"):
148
+ response.append(
149
+ f"\n**Generated {len(result['plots'])} plot(s)** (Image data returned separately)"
150
+ )
151
+
152
+ else:
153
+ response.append(f"❌ Code execution failed in **{language.upper()}**")
154
+ if result.get("stderr"):
155
+ response.append(
156
+ "\n**Error Log:**\n```\n" + result["stderr"].strip() + "\n```"
157
+ )
158
+
159
+ return "\n".join(response)
160
+
161
+
162
+ ### =============== MATHEMATICAL TOOLS =============== ###
163
+
164
+
165
+ @tool
166
+ def multiply(a: float, b: float) -> float:
167
+ """
168
+ Multiplies two numbers.
169
+
170
+ Args:
171
+ a (float): the first number
172
+ b (float): the second number
173
  """
174
  return a * b
175
 
176
+
177
  @tool
178
+ def add(a: float, b: float) -> float:
179
+ """
180
+ Adds two numbers.
181
+
182
  Args:
183
+ a (float): the first number
184
+ b (float): the second number
185
  """
186
  return a + b
187
 
188
+
189
  @tool
190
+ def subtract(a: float, b: float) -> int:
191
+ """
192
+ Subtracts two numbers.
193
+
194
  Args:
195
+ a (float): the first number
196
+ b (float): the second number
197
  """
198
  return a - b
199
 
200
+
201
  @tool
202
+ def divide(a: float, b: float) -> float:
203
+ """
204
+ Divides two numbers.
205
+
206
  Args:
207
+ a (float): the first float number
208
+ b (float): the second float number
209
  """
210
  if b == 0:
211
+ raise ValueError("Cannot divided by zero.")
212
  return a / b
213
 
214
+
215
  @tool
216
  def modulus(a: int, b: int) -> int:
217
+ """
218
+ Get the modulus of two numbers.
219
+
220
  Args:
221
+ a (int): the first number
222
+ b (int): the second number
223
  """
224
  return a % b
225
 
226
+
227
  @tool
228
+ def power(a: float, b: float) -> float:
229
+ """
230
+ Get the power of two numbers.
231
+
232
  Args:
233
+ a (float): the first number
234
+ b (float): the second number
235
+ """
236
+ return a**b
237
+
 
 
 
238
 
239
  @tool
240
+ def square_root(a: float) -> float | complex:
241
+ """
242
+ Get the square root of a number.
243
+
244
  Args:
245
+ a (float): the number to get the square root of
246
+ """
247
+ if a >= 0:
248
+ return a**0.5
249
+ return cmath.sqrt(a)
250
+
251
+
252
+ ### =============== DOCUMENT PROCESSING TOOLS =============== ###
253
+
254
 
255
  @tool
256
+ def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
257
+ """
258
+ Save content to a file and return the path.
259
+
260
  Args:
261
+ content (str): the content to save to the file
262
+ filename (str, optional): the name of the file. If not provided, a random name file will be created.
263
+ """
264
+ temp_dir = tempfile.gettempdir()
265
+ if filename is None:
266
+ temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
267
+ filepath = temp_file.name
268
+ else:
269
+ filepath = os.path.join(temp_dir, filename)
270
+
271
+ with open(filepath, "w") as f:
272
+ f.write(content)
273
+
274
+ return f"File saved to {filepath}. You can read this file to process its contents."
275
+
276
+
277
+ @tool
278
+ def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
279
+ """
280
+ Download a file from a URL and save it to a temporary location.
281
+
282
+ Args:
283
+ url (str): the URL of the file to download.
284
+ filename (str, optional): the name of the file. If not provided, a random name file will be created.
285
+ """
286
+ try:
287
+ # Parse URL to get filename if not provided
288
+ if not filename:
289
+ path = urlparse(url).path
290
+ filename = os.path.basename(path)
291
+ if not filename:
292
+ filename = f"downloaded_{uuid.uuid4().hex[:8]}"
293
+
294
+ # Create temporary file
295
+ temp_dir = tempfile.gettempdir()
296
+ filepath = os.path.join(temp_dir, filename)
297
+
298
+ # Download the file
299
+ response = requests.get(url, stream=True)
300
+ response.raise_for_status()
301
+
302
+ # Save the file
303
+ with open(filepath, "wb") as f:
304
+ for chunk in response.iter_content(chunk_size=8192):
305
+ f.write(chunk)
306
+
307
+ return f"File downloaded to {filepath}. You can read this file to process its contents."
308
+ except Exception as e:
309
+ return f"Error downloading file: {str(e)}"
310
+
311
+
312
+ @tool
313
+ def extract_text_from_image(image_path: str) -> str:
314
+ """
315
+ Extract text from an image using OCR library pytesseract (if available).
316
+
317
+ Args:
318
+ image_path (str): the path to the image file.
319
+ """
320
+ try:
321
+ # Open the image
322
+ image = Image.open(image_path)
323
+
324
+ # Extract text from the image
325
+ text = pytesseract.image_to_string(image)
326
+
327
+ return f"Extracted text from image:\n\n{text}"
328
+ except Exception as e:
329
+ return f"Error extracting text from image: {str(e)}"
330
+
331
+
332
+ @tool
333
+ def analyze_csv_file(file_path: str, query: str) -> str:
334
+ """
335
+ Analyze a CSV file using pandas and answer a question about it.
336
+
337
+ Args:
338
+ file_path (str): the path to the CSV file.
339
+ query (str): Question about the data
340
+ """
341
+ try:
342
+ # Read the CSV file
343
+ df = pd.read_csv(file_path)
344
+
345
+ # Run various analyses based on the query
346
+ result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
347
+ result += f"Columns: {', '.join(df.columns)}\n\n"
348
+
349
+ # Add summary statistics
350
+ result += "Summary statistics:\n"
351
+ result += str(df.describe())
352
+
353
+ return result
354
+
355
+ except Exception as e:
356
+ return f"Error analyzing CSV file: {str(e)}"
357
+
358
+
359
+ @tool
360
+ def analyze_excel_file(file_path: str, query: str) -> str:
361
+ """
362
+ Analyze an Excel file using pandas and answer a question about it.
363
+
364
+ Args:
365
+ file_path (str): the path to the Excel file.
366
+ query (str): Question about the data
367
+ """
368
+ try:
369
+ # Read the Excel file
370
+ df = pd.read_excel(file_path)
371
+
372
+ # Run various analyses based on the query
373
+ result = (
374
+ f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
375
+ )
376
+ result += f"Columns: {', '.join(df.columns)}\n\n"
377
+
378
+ # Add summary statistics
379
+ result += "Summary statistics:\n"
380
+ result += str(df.describe())
381
+
382
+ return result
383
+
384
+ except Exception as e:
385
+ return f"Error analyzing Excel file: {str(e)}"
386
+
387
+
388
+ ### ============== IMAGE PROCESSING AND GENERATION TOOLS =============== ###
389
+
390
+
391
+ @tool
392
+ def analyze_image(image_base64: str) -> Dict[str, Any]:
393
+ """
394
+ Analyze basic properties of an image (size, mode, color analysis, thumbnail preview).
395
+
396
+ Args:
397
+ image_base64 (str): Base64 encoded image string
398
+
399
+ Returns:
400
+ Dictionary with analysis result
401
+ """
402
+ try:
403
+ img = decode_image(image_base64)
404
+ width, height = img.size
405
+ mode = img.mode
406
 
407
+ if mode in ("RGB", "RGBA"):
408
+ arr = np.array(img)
409
+ avg_colors = arr.mean(axis=(0, 1))
410
+ dominant = ["Red", "Green", "Blue"][np.argmax(avg_colors[:3])]
411
+ brightness = avg_colors.mean()
412
+ color_analysis = {
413
+ "average_rgb": avg_colors.tolist(),
414
+ "brightness": brightness,
415
+ "dominant_color": dominant,
416
+ }
417
+ else:
418
+ color_analysis = {"note": f"No color analysis for mode {mode}"}
419
+
420
+ thumbnail = img.copy()
421
+ thumbnail.thumbnail((100, 100))
422
+ thumb_path = save_image(thumbnail, "thumbnails")
423
+ thumbnail_base64 = encode_image(thumb_path)
424
+
425
+ return {
426
+ "dimensions": (width, height),
427
+ "mode": mode,
428
+ "color_analysis": color_analysis,
429
+ "thumbnail": thumbnail_base64,
430
+ }
431
+ except Exception as e:
432
+ return {"error": str(e)}
433
+
434
+
435
+ @tool
436
+ def transform_image(
437
+ image_base64: str, operation: str, params: Optional[Dict[str, Any]] = None
438
+ ) -> Dict[str, Any]:
439
+ """
440
+ Apply transformations: resize, rotate, crop, flip, brightness, contrast, blur, sharpen, grayscale.
441
+
442
+ Args:
443
+ image_base64 (str): Base64 encoded input image
444
+ operation (str): Transformation operation
445
+ params (Dict[str, Any], optional): Parameters for the operation
446
+
447
+ Returns:
448
+ Dictionary with transformed image (base64)
449
+ """
450
+ try:
451
+ img = decode_image(image_base64)
452
+ params = params or {}
453
+
454
+ if operation == "resize":
455
+ img = img.resize(
456
+ (
457
+ params.get("width", img.width // 2),
458
+ params.get("height", img.height // 2),
459
+ )
460
+ )
461
+ elif operation == "rotate":
462
+ img = img.rotate(params.get("angle", 90), expand=True)
463
+ elif operation == "crop":
464
+ img = img.crop(
465
+ (
466
+ params.get("left", 0),
467
+ params.get("top", 0),
468
+ params.get("right", img.width),
469
+ params.get("bottom", img.height),
470
+ )
471
+ )
472
+ elif operation == "flip":
473
+ if params.get("direction", "horizontal") == "horizontal":
474
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
475
+ else:
476
+ img = img.transpose(Image.FLIP_TOP_BOTTOM)
477
+ elif operation == "adjust_brightness":
478
+ img = ImageEnhance.Brightness(img).enhance(params.get("factor", 1.5))
479
+ elif operation == "adjust_contrast":
480
+ img = ImageEnhance.Contrast(img).enhance(params.get("factor", 1.5))
481
+ elif operation == "blur":
482
+ img = img.filter(ImageFilter.GaussianBlur(params.get("radius", 2)))
483
+ elif operation == "sharpen":
484
+ img = img.filter(ImageFilter.SHARPEN)
485
+ elif operation == "grayscale":
486
+ img = img.convert("L")
487
+ else:
488
+ return {"error": f"Unknown operation: {operation}"}
489
+
490
+ result_path = save_image(img)
491
+ result_base64 = encode_image(result_path)
492
+ return {"transformed_image": result_base64}
493
+
494
+ except Exception as e:
495
+ return {"error": str(e)}
496
+
497
+
498
+ @tool
499
+ def draw_on_image(
500
+ image_base64: str, drawing_type: str, params: Dict[str, Any]
501
+ ) -> Dict[str, Any]:
502
+ """
503
+ Draw shapes (rectangle, circle, line) or text onto an image.
504
+
505
+ Args:
506
+ image_base64 (str): Base64 encoded input image
507
+ drawing_type (str): Drawing type
508
+ params (Dict[str, Any]): Drawing parameters
509
+
510
+ Returns:
511
+ Dictionary with result image (base64)
512
+ """
513
+ try:
514
+ img = decode_image(image_base64)
515
+ draw = ImageDraw.Draw(img)
516
+ color = params.get("color", "red")
517
+
518
+ if drawing_type == "rectangle":
519
+ draw.rectangle(
520
+ [params["left"], params["top"], params["right"], params["bottom"]],
521
+ outline=color,
522
+ width=params.get("width", 2),
523
+ )
524
+ elif drawing_type == "circle":
525
+ x, y, r = params["x"], params["y"], params["radius"]
526
+ draw.ellipse(
527
+ (x - r, y - r, x + r, y + r),
528
+ outline=color,
529
+ width=params.get("width", 2),
530
+ )
531
+ elif drawing_type == "line":
532
+ draw.line(
533
+ (
534
+ params["start_x"],
535
+ params["start_y"],
536
+ params["end_x"],
537
+ params["end_y"],
538
+ ),
539
+ fill=color,
540
+ width=params.get("width", 2),
541
+ )
542
+ elif drawing_type == "text":
543
+ font_size = params.get("font_size", 20)
544
+ try:
545
+ font = ImageFont.truetype("arial.ttf", font_size)
546
+ except IOError:
547
+ font = ImageFont.load_default()
548
+ draw.text(
549
+ (params["x"], params["y"]),
550
+ params.get("text", "Text"),
551
+ fill=color,
552
+ font=font,
553
+ )
554
+ else:
555
+ return {"error": f"Unknown drawing type: {drawing_type}"}
556
+
557
+ result_path = save_image(img)
558
+ result_base64 = encode_image(result_path)
559
+ return {"result_image": result_base64}
560
+
561
+ except Exception as e:
562
+ return {"error": str(e)}
563
+
564
+
565
+ @tool
566
+ def generate_simple_image(
567
+ image_type: str,
568
+ width: int = 500,
569
+ height: int = 500,
570
+ params: Optional[Dict[str, Any]] = None,
571
+ ) -> Dict[str, Any]:
572
+ """
573
+ Generate a simple image (gradient, noise, pattern, chart).
574
+
575
+ Args:
576
+ image_type (str): Type of image
577
+ width (int), height (int)
578
+ params (Dict[str, Any], optional): Specific parameters
579
+
580
+ Returns:
581
+ Dictionary with generated image (base64)
582
+ """
583
+ try:
584
+ params = params or {}
585
+
586
+ if image_type == "gradient":
587
+ direction = params.get("direction", "horizontal")
588
+ start_color = params.get("start_color", (255, 0, 0))
589
+ end_color = params.get("end_color", (0, 0, 255))
590
+
591
+ img = Image.new("RGB", (width, height))
592
+ draw = ImageDraw.Draw(img)
593
+
594
+ if direction == "horizontal":
595
+ for x in range(width):
596
+ r = int(
597
+ start_color[0] + (end_color[0] - start_color[0]) * x / width
598
+ )
599
+ g = int(
600
+ start_color[1] + (end_color[1] - start_color[1]) * x / width
601
+ )
602
+ b = int(
603
+ start_color[2] + (end_color[2] - start_color[2]) * x / width
604
+ )
605
+ draw.line([(x, 0), (x, height)], fill=(r, g, b))
606
+ else:
607
+ for y in range(height):
608
+ r = int(
609
+ start_color[0] + (end_color[0] - start_color[0]) * y / height
610
+ )
611
+ g = int(
612
+ start_color[1] + (end_color[1] - start_color[1]) * y / height
613
+ )
614
+ b = int(
615
+ start_color[2] + (end_color[2] - start_color[2]) * y / height
616
+ )
617
+ draw.line([(0, y), (width, y)], fill=(r, g, b))
618
+
619
+ elif image_type == "noise":
620
+ noise_array = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
621
+ img = Image.fromarray(noise_array, "RGB")
622
+
623
+ else:
624
+ return {"error": f"Unsupported image_type {image_type}"}
625
+
626
+ result_path = save_image(img)
627
+ result_base64 = encode_image(result_path)
628
+ return {"generated_image": result_base64}
629
+
630
+ except Exception as e:
631
+ return {"error": str(e)}
632
+
633
+
634
+ @tool
635
+ def combine_images(
636
+ images_base64: List[str], operation: str, params: Optional[Dict[str, Any]] = None
637
+ ) -> Dict[str, Any]:
638
+ """
639
+ Combine multiple images (collage, stack, blend).
640
+
641
+ Args:
642
+ images_base64 (List[str]): List of base64 images
643
+ operation (str): Combination type
644
+ params (Dict[str, Any], optional)
645
+
646
+ Returns:
647
+ Dictionary with combined image (base64)
648
+ """
649
+ try:
650
+ images = [decode_image(b64) for b64 in images_base64]
651
+ params = params or {}
652
+
653
+ if operation == "stack":
654
+ direction = params.get("direction", "horizontal")
655
+ if direction == "horizontal":
656
+ total_width = sum(img.width for img in images)
657
+ max_height = max(img.height for img in images)
658
+ new_img = Image.new("RGB", (total_width, max_height))
659
+ x = 0
660
+ for img in images:
661
+ new_img.paste(img, (x, 0))
662
+ x += img.width
663
+ else:
664
+ max_width = max(img.width for img in images)
665
+ total_height = sum(img.height for img in images)
666
+ new_img = Image.new("RGB", (max_width, total_height))
667
+ y = 0
668
+ for img in images:
669
+ new_img.paste(img, (0, y))
670
+ y += img.height
671
+ else:
672
+ return {"error": f"Unsupported combination operation {operation}"}
673
+
674
+ result_path = save_image(new_img)
675
+ result_base64 = encode_image(result_path)
676
+ return {"combined_image": result_base64}
677
+
678
+ except Exception as e:
679
+ return {"error": str(e)}
680
 
681
 
682
  # load the system prompt from the file
683
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
684
  system_prompt = f.read()
685
+ print(system_prompt)
686
 
687
  # System message
688
  sys_msg = SystemMessage(content=system_prompt)
689
 
690
  # build a retriever
691
+ embeddings = HuggingFaceEmbeddings(
692
+ model_name="sentence-transformers/all-mpnet-base-v2"
693
+ ) # dim=768
694
  supabase: Client = create_client(
695
+ os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_ROLE_KEY")
696
+ )
697
  vector_store = SupabaseVectorStore(
698
  client=supabase,
699
+ embedding=embeddings,
700
  table_name="documents",
701
  query_name="match_documents_langchain",
702
  )
 
707
  )
708
 
709
 
 
710
  tools = [
711
+ web_search,
712
+ wiki_search,
713
+ arxiv_search,
714
  multiply,
715
  add,
716
  subtract,
717
  divide,
718
  modulus,
719
+ power,
720
+ square_root,
721
+ save_and_read_file,
722
+ download_file_from_url,
723
+ extract_text_from_image,
724
+ analyze_csv_file,
725
+ analyze_excel_file,
726
+ execute_code_multilang,
727
+ analyze_image,
728
+ transform_image,
729
+ draw_on_image,
730
+ generate_simple_image,
731
+ combine_images,
732
  ]
733
 
734
+
735
  # Build graph function
736
  def build_graph(provider: str = "groq"):
737
  """Build the graph"""
738
  # Load environment variables from .env file
739
+ if provider == "groq":
 
 
 
740
  # Groq https://console.groq.com/docs/models
741
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
742
  elif provider == "huggingface":
743
  # TODO: Add huggingface endpoint
744
  llm = ChatHuggingFace(
745
  llm=HuggingFaceEndpoint(
746
+ repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
747
+ task="text-generation", # for chat‐style use “text-generation”
748
+ max_new_tokens=1024,
749
+ do_sample=False,
750
+ repetition_penalty=1.03,
751
  temperature=0,
752
  ),
753
+ verbose=True,
754
  )
755
  else:
756
+ raise ValueError("Invalid provider. Choose 'groq' or 'huggingface'.")
757
  # Bind tools to LLM
758
  llm_with_tools = llm.bind_tools(tools)
759
 
 
761
  def assistant(state: MessagesState):
762
  """Assistant node"""
763
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
764
+
765
  def retriever(state: MessagesState):
766
  """Retriever node"""
767
  similar_question = vector_store.similarity_search(state["messages"][0].content)
768
+
769
+ if similar_question: # Check if the list is not empty
770
+ example_msg = HumanMessage(
771
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
772
+ )
773
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
774
+ else:
775
+ # Handle the case when no similar questions are found
776
+ return {"messages": [sys_msg] + state["messages"]}
777
 
778
  builder = StateGraph(MessagesState)
779
  builder.add_node("retriever", retriever)
 
790
  # Compile graph
791
  return builder.compile()
792
 
793
+
794
  # test
795
  if __name__ == "__main__":
796
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
 
797
  graph = build_graph(provider="groq")
 
798
  messages = [HumanMessage(content=question)]
799
  messages = graph.invoke({"messages": messages})
800
  for m in messages["messages"]:
801
+ m.pretty_print()