tsrrus commited on
Commit
f002bf5
·
verified ·
1 Parent(s): 497b4bc

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +685 -83
agent.py CHANGED
@@ -1,17 +1,36 @@
1
- """LangGraph Agent"""
2
  import os
3
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_openai import ChatOpenAI
9
- from langchain.agents import initialize_agent, Tool
10
- from langchain_groq import ChatGroq
11
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
12
  from langchain_community.tools.tavily_search import TavilySearchResults
13
  from langchain_community.document_loaders import WikipediaLoader
14
  from langchain_community.document_loaders import ArxivLoader
 
 
 
 
 
 
 
 
15
  from langchain_community.vectorstores import SupabaseVectorStore
16
  from langchain_core.messages import SystemMessage, HumanMessage
17
  from langchain_core.tools import tool
@@ -20,56 +39,13 @@ from supabase.client import Client, create_client
20
 
21
  load_dotenv()
22
 
23
- @tool
24
- def multiply(a: int, b: int) -> int:
25
- """Multiply two numbers.
26
- Args:
27
- a: first int
28
- b: second int
29
- """
30
- return a * b
31
 
32
- @tool
33
- def add(a: int, b: int) -> int:
34
- """Add two numbers.
35
- Args:
36
- a: first int
37
- b: second int
38
- """
39
- return a + b
40
-
41
- @tool
42
- def subtract(a: int, b: int) -> int:
43
- """Subtract two numbers.
44
- Args:
45
- a: first int
46
- b: second int
47
- """
48
- return a - b
49
-
50
- @tool
51
- def divide(a: int, b: int) -> int:
52
- """Divide two numbers.
53
- Args:
54
- a: first int
55
- b: second int
56
- """
57
- if b == 0:
58
- raise ValueError("Cannot divide by zero.")
59
- return a / b
60
-
61
- @tool
62
- def modulus(a: int, b: int) -> int:
63
- """Get the modulus of two numbers.
64
- Args:
65
- a: first int
66
- b: second int
67
- """
68
- return a % b
69
 
70
  @tool
71
  def wiki_search(query: str) -> str:
72
  """Search Wikipedia for a query and return maximum 2 results.
 
73
  Args:
74
  query: The search query."""
75
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
@@ -77,12 +53,15 @@ def wiki_search(query: str) -> str:
77
  [
78
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
79
  for doc in search_docs
80
- ])
 
81
  return {"wiki_results": formatted_search_docs}
82
 
 
83
  @tool
84
  def web_search(query: str) -> str:
85
  """Search Tavily for a query and return maximum 3 results.
 
86
  Args:
87
  query: The search query."""
88
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
@@ -90,12 +69,15 @@ def web_search(query: str) -> str:
90
  [
91
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
92
  for doc in search_docs
93
- ])
 
94
  return {"web_results": formatted_search_docs}
95
 
 
96
  @tool
97
- def arvix_search(query: str) -> str:
98
  """Search Arxiv for a query and return maximum 3 result.
 
99
  Args:
100
  query: The search query."""
101
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
@@ -103,28 +85,620 @@ def arvix_search(query: str) -> str:
103
  [
104
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
105
  for doc in search_docs
106
- ])
107
- return {"arvix_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
 
111
  # load the system prompt from the file
112
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
113
  system_prompt = f.read()
 
114
 
115
  # System message
116
  sys_msg = SystemMessage(content=system_prompt)
117
 
118
  # build a retriever
119
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
 
 
120
  supabase: Client = create_client(
121
- os.environ.get("SUPABASE_URL"),
122
- os.environ.get("SUPABASE_SERVICE_KEY"))
123
  vector_store = SupabaseVectorStore(
124
  client=supabase,
125
- embedding= embeddings,
126
- table_name="documents",
127
- query_name="match_documents_langchain",
128
  )
129
  create_retriever_tool = create_retriever_tool(
130
  retriever=vector_store.as_retriever(),
@@ -133,40 +707,53 @@ create_retriever_tool = create_retriever_tool(
133
  )
134
 
135
 
136
-
137
  tools = [
 
 
 
138
  multiply,
139
  add,
140
  subtract,
141
  divide,
142
  modulus,
143
- wiki_search,
144
- web_search,
145
- arvix_search,
 
 
 
 
 
 
 
 
 
 
146
  ]
147
 
 
148
  # Build graph function
149
  def build_graph(provider: str = "groq"):
150
  """Build the graph"""
151
  # Load environment variables from .env file
152
- if provider == "google":
153
- # Google Gemini
154
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
155
- elif provider == "groq":
156
  # Groq https://console.groq.com/docs/models
157
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
158
- elif provider == "openai":
159
- # OpenAI
160
- llm = ChatOpenAI(model="gpt-4", temperature=0)
161
  elif provider == "huggingface":
 
162
  llm = ChatHuggingFace(
163
  llm=HuggingFaceEndpoint(
164
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
 
 
 
 
165
  temperature=0,
166
  ),
 
167
  )
168
  else:
169
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
170
  # Bind tools to LLM
171
  llm_with_tools = llm.bind_tools(tools)
172
 
@@ -178,10 +765,15 @@ def build_graph(provider: str = "groq"):
178
  def retriever(state: MessagesState):
179
  """Retriever node"""
180
  similar_question = vector_store.similarity_search(state["messages"][0].content)
181
- example_msg = HumanMessage(
182
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
183
- )
184
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
 
 
 
 
 
185
 
186
  builder = StateGraph(MessagesState)
187
  builder.add_node("retriever", retriever)
@@ -196,4 +788,14 @@ def build_graph(provider: str = "groq"):
196
  builder.add_edge("tools", "assistant")
197
 
198
  # Compile graph
199
- return builder.compile()
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from dotenv import load_dotenv
3
+ from typing import List, Dict, Any, Optional
4
+ import tempfile
5
+ import re
6
+ import json
7
+ import requests
8
+ from urllib.parse import urlparse
9
+ import pytesseract
10
+ from PIL import Image, ImageDraw, ImageFont, ImageEnhance, ImageFilter
11
+ import cmath
12
+ import pandas as pd
13
+ import uuid
14
+ import numpy as np
15
+ from code_interpreter import CodeInterpreter
16
+
17
+ interpreter_instance = CodeInterpreter()
18
+
19
+ from image_processing import *
20
+
21
+ """Langraph"""
22
  from langgraph.graph import START, StateGraph, MessagesState
 
 
 
 
 
 
 
23
  from langchain_community.tools.tavily_search import TavilySearchResults
24
  from langchain_community.document_loaders import WikipediaLoader
25
  from langchain_community.document_loaders import ArxivLoader
26
+ from langgraph.prebuilt import ToolNode, tools_condition
27
+ from langchain_google_genai import ChatGoogleGenerativeAI
28
+ from langchain_groq import ChatGroq
29
+ from langchain_huggingface import (
30
+ ChatHuggingFace,
31
+ HuggingFaceEndpoint,
32
+ HuggingFaceEmbeddings,
33
+ )
34
  from langchain_community.vectorstores import SupabaseVectorStore
35
  from langchain_core.messages import SystemMessage, HumanMessage
36
  from langchain_core.tools import tool
 
39
 
40
  load_dotenv()
41
 
42
+ ### =============== BROWSER TOOLS =============== ###
 
 
 
 
 
 
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  @tool
46
  def wiki_search(query: str) -> str:
47
  """Search Wikipedia for a query and return maximum 2 results.
48
+
49
  Args:
50
  query: The search query."""
51
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
 
53
  [
54
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
55
  for doc in search_docs
56
+ ]
57
+ )
58
  return {"wiki_results": formatted_search_docs}
59
 
60
+
61
  @tool
62
  def web_search(query: str) -> str:
63
  """Search Tavily for a query and return maximum 3 results.
64
+
65
  Args:
66
  query: The search query."""
67
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
 
69
  [
70
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
71
  for doc in search_docs
72
+ ]
73
+ )
74
  return {"web_results": formatted_search_docs}
75
 
76
+
77
  @tool
78
+ def arxiv_search(query: str) -> str:
79
  """Search Arxiv for a query and return maximum 3 result.
80
+
81
  Args:
82
  query: The search query."""
83
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
 
85
  [
86
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
87
  for doc in search_docs
88
+ ]
89
+ )
90
+ return {"arxiv_results": formatted_search_docs}
91
+
92
+
93
+ ### =============== CODE INTERPRETER TOOLS =============== ###
94
+
95
+
96
+ @tool
97
+ def execute_code_multilang(code: str, language: str = "python") -> str:
98
+ """Execute code in multiple languages (Python, Bash, SQL, C, Java) and return results.
99
+
100
+ Args:
101
+ code (str): The source code to execute.
102
+ language (str): The language of the code. Supported: "python", "bash", "sql", "c", "java".
103
+
104
+ Returns:
105
+ A string summarizing the execution results (stdout, stderr, errors, plots, dataframes if any).
106
+ """
107
+ supported_languages = ["python", "bash", "sql", "c", "java"]
108
+ language = language.lower()
109
+
110
+ if language not in supported_languages:
111
+ return f"❌ Unsupported language: {language}. Supported languages are: {', '.join(supported_languages)}"
112
+
113
+ result = interpreter_instance.execute_code(code, language=language)
114
+
115
+ response = []
116
+
117
+ if result["status"] == "success":
118
+ response.append(f"✅ Code executed successfully in **{language.upper()}**")
119
+
120
+ if result.get("stdout"):
121
+ response.append(
122
+ "\n**Standard Output:**\n```\n" + result["stdout"].strip() + "\n```"
123
+ )
124
+
125
+ if result.get("stderr"):
126
+ response.append(
127
+ "\n**Standard Error (if any):**\n```\n"
128
+ + result["stderr"].strip()
129
+ + "\n```"
130
+ )
131
+
132
+ if result.get("result") is not None:
133
+ response.append(
134
+ "\n**Execution Result:**\n```\n"
135
+ + str(result["result"]).strip()
136
+ + "\n```"
137
+ )
138
+
139
+ if result.get("dataframes"):
140
+ for df_info in result["dataframes"]:
141
+ response.append(
142
+ f"\n**DataFrame `{df_info['name']}` (Shape: {df_info['shape']})**"
143
+ )
144
+ df_preview = pd.DataFrame(df_info["head"])
145
+ response.append("First 5 rows:\n```\n" + str(df_preview) + "\n```")
146
+
147
+ if result.get("plots"):
148
+ response.append(
149
+ f"\n**Generated {len(result['plots'])} plot(s)** (Image data returned separately)"
150
+ )
151
+
152
+ else:
153
+ response.append(f"❌ Code execution failed in **{language.upper()}**")
154
+ if result.get("stderr"):
155
+ response.append(
156
+ "\n**Error Log:**\n```\n" + result["stderr"].strip() + "\n```"
157
+ )
158
+
159
+ return "\n".join(response)
160
+
161
+
162
+ ### =============== MATHEMATICAL TOOLS =============== ###
163
+
164
+
165
+ @tool
166
+ def multiply(a: float, b: float) -> float:
167
+ """
168
+ Multiplies two numbers.
169
+
170
+ Args:
171
+ a (float): the first number
172
+ b (float): the second number
173
+ """
174
+ return a * b
175
+
176
+
177
+ @tool
178
+ def add(a: float, b: float) -> float:
179
+ """
180
+ Adds two numbers.
181
+
182
+ Args:
183
+ a (float): the first number
184
+ b (float): the second number
185
+ """
186
+ return a + b
187
+
188
+
189
+ @tool
190
+ def subtract(a: float, b: float) -> int:
191
+ """
192
+ Subtracts two numbers.
193
+
194
+ Args:
195
+ a (float): the first number
196
+ b (float): the second number
197
+ """
198
+ return a - b
199
+
200
+
201
+ @tool
202
+ def divide(a: float, b: float) -> float:
203
+ """
204
+ Divides two numbers.
205
+
206
+ Args:
207
+ a (float): the first float number
208
+ b (float): the second float number
209
+ """
210
+ if b == 0:
211
+ raise ValueError("Cannot divided by zero.")
212
+ return a / b
213
+
214
+
215
+ @tool
216
+ def modulus(a: int, b: int) -> int:
217
+ """
218
+ Get the modulus of two numbers.
219
+
220
+ Args:
221
+ a (int): the first number
222
+ b (int): the second number
223
+ """
224
+ return a % b
225
+
226
+
227
+ @tool
228
+ def power(a: float, b: float) -> float:
229
+ """
230
+ Get the power of two numbers.
231
+
232
+ Args:
233
+ a (float): the first number
234
+ b (float): the second number
235
+ """
236
+ return a**b
237
+
238
+
239
+ @tool
240
+ def square_root(a: float) -> float | complex:
241
+ """
242
+ Get the square root of a number.
243
+
244
+ Args:
245
+ a (float): the number to get the square root of
246
+ """
247
+ if a >= 0:
248
+ return a**0.5
249
+ return cmath.sqrt(a)
250
+
251
+
252
+ ### =============== DOCUMENT PROCESSING TOOLS =============== ###
253
+
254
 
255
+ @tool
256
+ def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
257
+ """
258
+ Save content to a file and return the path.
259
+
260
+ Args:
261
+ content (str): the content to save to the file
262
+ filename (str, optional): the name of the file. If not provided, a random name file will be created.
263
+ """
264
+ temp_dir = tempfile.gettempdir()
265
+ if filename is None:
266
+ temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
267
+ filepath = temp_file.name
268
+ else:
269
+ filepath = os.path.join(temp_dir, filename)
270
+
271
+ with open(filepath, "w") as f:
272
+ f.write(content)
273
+
274
+ return f"File saved to {filepath}. You can read this file to process its contents."
275
+
276
+
277
+ @tool
278
+ def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
279
+ """
280
+ Download a file from a URL and save it to a temporary location.
281
+
282
+ Args:
283
+ url (str): the URL of the file to download.
284
+ filename (str, optional): the name of the file. If not provided, a random name file will be created.
285
+ """
286
+ try:
287
+ # Parse URL to get filename if not provided
288
+ if not filename:
289
+ path = urlparse(url).path
290
+ filename = os.path.basename(path)
291
+ if not filename:
292
+ filename = f"downloaded_{uuid.uuid4().hex[:8]}"
293
+
294
+ # Create temporary file
295
+ temp_dir = tempfile.gettempdir()
296
+ filepath = os.path.join(temp_dir, filename)
297
+
298
+ # Download the file
299
+ response = requests.get(url, stream=True)
300
+ response.raise_for_status()
301
+
302
+ # Save the file
303
+ with open(filepath, "wb") as f:
304
+ for chunk in response.iter_content(chunk_size=8192):
305
+ f.write(chunk)
306
+
307
+ return f"File downloaded to {filepath}. You can read this file to process its contents."
308
+ except Exception as e:
309
+ return f"Error downloading file: {str(e)}"
310
+
311
+
312
+ @tool
313
+ def extract_text_from_image(image_path: str) -> str:
314
+ """
315
+ Extract text from an image using OCR library pytesseract (if available).
316
+
317
+ Args:
318
+ image_path (str): the path to the image file.
319
+ """
320
+ try:
321
+ # Open the image
322
+ image = Image.open(image_path)
323
+
324
+ # Extract text from the image
325
+ text = pytesseract.image_to_string(image)
326
+
327
+ return f"Extracted text from image:\n\n{text}"
328
+ except Exception as e:
329
+ return f"Error extracting text from image: {str(e)}"
330
+
331
+
332
+ @tool
333
+ def analyze_csv_file(file_path: str, query: str) -> str:
334
+ """
335
+ Analyze a CSV file using pandas and answer a question about it.
336
+
337
+ Args:
338
+ file_path (str): the path to the CSV file.
339
+ query (str): Question about the data
340
+ """
341
+ try:
342
+ # Read the CSV file
343
+ df = pd.read_csv(file_path)
344
+
345
+ # Run various analyses based on the query
346
+ result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
347
+ result += f"Columns: {', '.join(df.columns)}\n\n"
348
+
349
+ # Add summary statistics
350
+ result += "Summary statistics:\n"
351
+ result += str(df.describe())
352
+
353
+ return result
354
+
355
+ except Exception as e:
356
+ return f"Error analyzing CSV file: {str(e)}"
357
+
358
+
359
+ @tool
360
+ def analyze_excel_file(file_path: str, query: str) -> str:
361
+ """
362
+ Analyze an Excel file using pandas and answer a question about it.
363
+
364
+ Args:
365
+ file_path (str): the path to the Excel file.
366
+ query (str): Question about the data
367
+ """
368
+ try:
369
+ # Read the Excel file
370
+ df = pd.read_excel(file_path)
371
+
372
+ # Run various analyses based on the query
373
+ result = (
374
+ f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
375
+ )
376
+ result += f"Columns: {', '.join(df.columns)}\n\n"
377
+
378
+ # Add summary statistics
379
+ result += "Summary statistics:\n"
380
+ result += str(df.describe())
381
+
382
+ return result
383
+
384
+ except Exception as e:
385
+ return f"Error analyzing Excel file: {str(e)}"
386
+
387
+
388
+ ### ============== IMAGE PROCESSING AND GENERATION TOOLS =============== ###
389
+
390
+
391
+ @tool
392
+ def analyze_image(image_base64: str) -> Dict[str, Any]:
393
+ """
394
+ Analyze basic properties of an image (size, mode, color analysis, thumbnail preview).
395
+
396
+ Args:
397
+ image_base64 (str): Base64 encoded image string
398
+
399
+ Returns:
400
+ Dictionary with analysis result
401
+ """
402
+ try:
403
+ img = decode_image(image_base64)
404
+ width, height = img.size
405
+ mode = img.mode
406
+
407
+ if mode in ("RGB", "RGBA"):
408
+ arr = np.array(img)
409
+ avg_colors = arr.mean(axis=(0, 1))
410
+ dominant = ["Red", "Green", "Blue"][np.argmax(avg_colors[:3])]
411
+ brightness = avg_colors.mean()
412
+ color_analysis = {
413
+ "average_rgb": avg_colors.tolist(),
414
+ "brightness": brightness,
415
+ "dominant_color": dominant,
416
+ }
417
+ else:
418
+ color_analysis = {"note": f"No color analysis for mode {mode}"}
419
+
420
+ thumbnail = img.copy()
421
+ thumbnail.thumbnail((100, 100))
422
+ thumb_path = save_image(thumbnail, "thumbnails")
423
+ thumbnail_base64 = encode_image(thumb_path)
424
+
425
+ return {
426
+ "dimensions": (width, height),
427
+ "mode": mode,
428
+ "color_analysis": color_analysis,
429
+ "thumbnail": thumbnail_base64,
430
+ }
431
+ except Exception as e:
432
+ return {"error": str(e)}
433
+
434
+
435
+ @tool
436
+ def transform_image(
437
+ image_base64: str, operation: str, params: Optional[Dict[str, Any]] = None
438
+ ) -> Dict[str, Any]:
439
+ """
440
+ Apply transformations: resize, rotate, crop, flip, brightness, contrast, blur, sharpen, grayscale.
441
+
442
+ Args:
443
+ image_base64 (str): Base64 encoded input image
444
+ operation (str): Transformation operation
445
+ params (Dict[str, Any], optional): Parameters for the operation
446
+
447
+ Returns:
448
+ Dictionary with transformed image (base64)
449
+ """
450
+ try:
451
+ img = decode_image(image_base64)
452
+ params = params or {}
453
+
454
+ if operation == "resize":
455
+ img = img.resize(
456
+ (
457
+ params.get("width", img.width // 2),
458
+ params.get("height", img.height // 2),
459
+ )
460
+ )
461
+ elif operation == "rotate":
462
+ img = img.rotate(params.get("angle", 90), expand=True)
463
+ elif operation == "crop":
464
+ img = img.crop(
465
+ (
466
+ params.get("left", 0),
467
+ params.get("top", 0),
468
+ params.get("right", img.width),
469
+ params.get("bottom", img.height),
470
+ )
471
+ )
472
+ elif operation == "flip":
473
+ if params.get("direction", "horizontal") == "horizontal":
474
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
475
+ else:
476
+ img = img.transpose(Image.FLIP_TOP_BOTTOM)
477
+ elif operation == "adjust_brightness":
478
+ img = ImageEnhance.Brightness(img).enhance(params.get("factor", 1.5))
479
+ elif operation == "adjust_contrast":
480
+ img = ImageEnhance.Contrast(img).enhance(params.get("factor", 1.5))
481
+ elif operation == "blur":
482
+ img = img.filter(ImageFilter.GaussianBlur(params.get("radius", 2)))
483
+ elif operation == "sharpen":
484
+ img = img.filter(ImageFilter.SHARPEN)
485
+ elif operation == "grayscale":
486
+ img = img.convert("L")
487
+ else:
488
+ return {"error": f"Unknown operation: {operation}"}
489
+
490
+ result_path = save_image(img)
491
+ result_base64 = encode_image(result_path)
492
+ return {"transformed_image": result_base64}
493
+
494
+ except Exception as e:
495
+ return {"error": str(e)}
496
+
497
+
498
+ @tool
499
+ def draw_on_image(
500
+ image_base64: str, drawing_type: str, params: Dict[str, Any]
501
+ ) -> Dict[str, Any]:
502
+ """
503
+ Draw shapes (rectangle, circle, line) or text onto an image.
504
+
505
+ Args:
506
+ image_base64 (str): Base64 encoded input image
507
+ drawing_type (str): Drawing type
508
+ params (Dict[str, Any]): Drawing parameters
509
+
510
+ Returns:
511
+ Dictionary with result image (base64)
512
+ """
513
+ try:
514
+ img = decode_image(image_base64)
515
+ draw = ImageDraw.Draw(img)
516
+ color = params.get("color", "red")
517
+
518
+ if drawing_type == "rectangle":
519
+ draw.rectangle(
520
+ [params["left"], params["top"], params["right"], params["bottom"]],
521
+ outline=color,
522
+ width=params.get("width", 2),
523
+ )
524
+ elif drawing_type == "circle":
525
+ x, y, r = params["x"], params["y"], params["radius"]
526
+ draw.ellipse(
527
+ (x - r, y - r, x + r, y + r),
528
+ outline=color,
529
+ width=params.get("width", 2),
530
+ )
531
+ elif drawing_type == "line":
532
+ draw.line(
533
+ (
534
+ params["start_x"],
535
+ params["start_y"],
536
+ params["end_x"],
537
+ params["end_y"],
538
+ ),
539
+ fill=color,
540
+ width=params.get("width", 2),
541
+ )
542
+ elif drawing_type == "text":
543
+ font_size = params.get("font_size", 20)
544
+ try:
545
+ font = ImageFont.truetype("arial.ttf", font_size)
546
+ except IOError:
547
+ font = ImageFont.load_default()
548
+ draw.text(
549
+ (params["x"], params["y"]),
550
+ params.get("text", "Text"),
551
+ fill=color,
552
+ font=font,
553
+ )
554
+ else:
555
+ return {"error": f"Unknown drawing type: {drawing_type}"}
556
+
557
+ result_path = save_image(img)
558
+ result_base64 = encode_image(result_path)
559
+ return {"result_image": result_base64}
560
+
561
+ except Exception as e:
562
+ return {"error": str(e)}
563
+
564
+
565
+ @tool
566
+ def generate_simple_image(
567
+ image_type: str,
568
+ width: int = 500,
569
+ height: int = 500,
570
+ params: Optional[Dict[str, Any]] = None,
571
+ ) -> Dict[str, Any]:
572
+ """
573
+ Generate a simple image (gradient, noise, pattern, chart).
574
+
575
+ Args:
576
+ image_type (str): Type of image
577
+ width (int), height (int)
578
+ params (Dict[str, Any], optional): Specific parameters
579
+
580
+ Returns:
581
+ Dictionary with generated image (base64)
582
+ """
583
+ try:
584
+ params = params or {}
585
+
586
+ if image_type == "gradient":
587
+ direction = params.get("direction", "horizontal")
588
+ start_color = params.get("start_color", (255, 0, 0))
589
+ end_color = params.get("end_color", (0, 0, 255))
590
+
591
+ img = Image.new("RGB", (width, height))
592
+ draw = ImageDraw.Draw(img)
593
+
594
+ if direction == "horizontal":
595
+ for x in range(width):
596
+ r = int(
597
+ start_color[0] + (end_color[0] - start_color[0]) * x / width
598
+ )
599
+ g = int(
600
+ start_color[1] + (end_color[1] - start_color[1]) * x / width
601
+ )
602
+ b = int(
603
+ start_color[2] + (end_color[2] - start_color[2]) * x / width
604
+ )
605
+ draw.line([(x, 0), (x, height)], fill=(r, g, b))
606
+ else:
607
+ for y in range(height):
608
+ r = int(
609
+ start_color[0] + (end_color[0] - start_color[0]) * y / height
610
+ )
611
+ g = int(
612
+ start_color[1] + (end_color[1] - start_color[1]) * y / height
613
+ )
614
+ b = int(
615
+ start_color[2] + (end_color[2] - start_color[2]) * y / height
616
+ )
617
+ draw.line([(0, y), (width, y)], fill=(r, g, b))
618
+
619
+ elif image_type == "noise":
620
+ noise_array = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
621
+ img = Image.fromarray(noise_array, "RGB")
622
+
623
+ else:
624
+ return {"error": f"Unsupported image_type {image_type}"}
625
+
626
+ result_path = save_image(img)
627
+ result_base64 = encode_image(result_path)
628
+ return {"generated_image": result_base64}
629
+
630
+ except Exception as e:
631
+ return {"error": str(e)}
632
+
633
+
634
+ @tool
635
+ def combine_images(
636
+ images_base64: List[str], operation: str, params: Optional[Dict[str, Any]] = None
637
+ ) -> Dict[str, Any]:
638
+ """
639
+ Combine multiple images (collage, stack, blend).
640
+
641
+ Args:
642
+ images_base64 (List[str]): List of base64 images
643
+ operation (str): Combination type
644
+ params (Dict[str, Any], optional)
645
+
646
+ Returns:
647
+ Dictionary with combined image (base64)
648
+ """
649
+ try:
650
+ images = [decode_image(b64) for b64 in images_base64]
651
+ params = params or {}
652
+
653
+ if operation == "stack":
654
+ direction = params.get("direction", "horizontal")
655
+ if direction == "horizontal":
656
+ total_width = sum(img.width for img in images)
657
+ max_height = max(img.height for img in images)
658
+ new_img = Image.new("RGB", (total_width, max_height))
659
+ x = 0
660
+ for img in images:
661
+ new_img.paste(img, (x, 0))
662
+ x += img.width
663
+ else:
664
+ max_width = max(img.width for img in images)
665
+ total_height = sum(img.height for img in images)
666
+ new_img = Image.new("RGB", (max_width, total_height))
667
+ y = 0
668
+ for img in images:
669
+ new_img.paste(img, (0, y))
670
+ y += img.height
671
+ else:
672
+ return {"error": f"Unsupported combination operation {operation}"}
673
+
674
+ result_path = save_image(new_img)
675
+ result_base64 = encode_image(result_path)
676
+ return {"combined_image": result_base64}
677
+
678
+ except Exception as e:
679
+ return {"error": str(e)}
680
 
681
 
682
  # load the system prompt from the file
683
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
684
  system_prompt = f.read()
685
+ print(system_prompt)
686
 
687
  # System message
688
  sys_msg = SystemMessage(content=system_prompt)
689
 
690
  # build a retriever
691
+ embeddings = HuggingFaceEmbeddings(
692
+ model_name="sentence-transformers/all-mpnet-base-v2"
693
+ ) # dim=768
694
  supabase: Client = create_client(
695
+ os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_ROLE_KEY")
696
+ )
697
  vector_store = SupabaseVectorStore(
698
  client=supabase,
699
+ embedding=embeddings,
700
+ table_name="documents2",
701
+ query_name="match_documents_2",
702
  )
703
  create_retriever_tool = create_retriever_tool(
704
  retriever=vector_store.as_retriever(),
 
707
  )
708
 
709
 
 
710
  tools = [
711
+ web_search,
712
+ wiki_search,
713
+ arxiv_search,
714
  multiply,
715
  add,
716
  subtract,
717
  divide,
718
  modulus,
719
+ power,
720
+ square_root,
721
+ save_and_read_file,
722
+ download_file_from_url,
723
+ extract_text_from_image,
724
+ analyze_csv_file,
725
+ analyze_excel_file,
726
+ execute_code_multilang,
727
+ analyze_image,
728
+ transform_image,
729
+ draw_on_image,
730
+ generate_simple_image,
731
+ combine_images,
732
  ]
733
 
734
+
735
  # Build graph function
736
  def build_graph(provider: str = "groq"):
737
  """Build the graph"""
738
  # Load environment variables from .env file
739
+ if provider == "groq":
 
 
 
740
  # Groq https://console.groq.com/docs/models
741
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
 
 
 
742
  elif provider == "huggingface":
743
+ # TODO: Add huggingface endpoint
744
  llm = ChatHuggingFace(
745
  llm=HuggingFaceEndpoint(
746
+ repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
747
+ task="text-generation", # for chat‐style use “text-generation”
748
+ max_new_tokens=1024,
749
+ do_sample=False,
750
+ repetition_penalty=1.03,
751
  temperature=0,
752
  ),
753
+ verbose=True,
754
  )
755
  else:
756
+ raise ValueError("Invalid provider. Choose 'groq' or 'huggingface'.")
757
  # Bind tools to LLM
758
  llm_with_tools = llm.bind_tools(tools)
759
 
 
765
  def retriever(state: MessagesState):
766
  """Retriever node"""
767
  similar_question = vector_store.similarity_search(state["messages"][0].content)
768
+
769
+ if similar_question: # Check if the list is not empty
770
+ example_msg = HumanMessage(
771
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
772
+ )
773
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
774
+ else:
775
+ # Handle the case when no similar questions are found
776
+ return {"messages": [sys_msg] + state["messages"]}
777
 
778
  builder = StateGraph(MessagesState)
779
  builder.add_node("retriever", retriever)
 
788
  builder.add_edge("tools", "assistant")
789
 
790
  # Compile graph
791
+ return builder.compile()
792
+
793
+
794
+ # test
795
+ if __name__ == "__main__":
796
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
797
+ graph = build_graph(provider="groq")
798
+ messages = [HumanMessage(content=question)]
799
+ messages = graph.invoke({"messages": messages})
800
+ for m in messages["messages"]:
801
+ m.pretty_print()