cuizhanming commited on
Commit
a491c35
·
1 Parent(s): 535a222
Files changed (3) hide show
  1. agent.py +116 -695
  2. app.py +6 -8
  3. metadata.jsonl +0 -0
agent.py CHANGED
@@ -1,36 +1,15 @@
 
1
  import os
2
  from dotenv import load_dotenv
3
- from typing import List, Dict, Any, Optional
4
- import tempfile
5
- import re
6
- import json
7
- import requests
8
- from urllib.parse import urlparse
9
- import pytesseract
10
- from PIL import Image, ImageDraw, ImageFont, ImageEnhance, ImageFilter
11
- import cmath
12
- import pandas as pd
13
- import uuid
14
- import numpy as np
15
- from code_interpreter import CodeInterpreter
16
-
17
- interpreter_instance = CodeInterpreter()
18
-
19
- from image_processing import *
20
-
21
- """Langraph"""
22
  from langgraph.graph import START, StateGraph, MessagesState
 
 
 
 
 
23
  from langchain_community.tools.tavily_search import TavilySearchResults
24
  from langchain_community.document_loaders import WikipediaLoader
25
  from langchain_community.document_loaders import ArxivLoader
26
- from langgraph.prebuilt import ToolNode, tools_condition
27
- from langchain_google_genai import ChatGoogleGenerativeAI
28
- from langchain_groq import ChatGroq
29
- from langchain_huggingface import (
30
- ChatHuggingFace,
31
- HuggingFaceEndpoint,
32
- HuggingFaceEmbeddings,
33
- )
34
  from langchain_community.vectorstores import SupabaseVectorStore
35
  from langchain_core.messages import SystemMessage, HumanMessage
36
  from langchain_core.tools import tool
@@ -39,666 +18,118 @@ from supabase.client import Client, create_client
39
 
40
  load_dotenv()
41
 
42
- ### =============== BROWSER TOOLS =============== ###
43
-
44
-
45
- @tool
46
- def wiki_search(query: str) -> str:
47
- """Search Wikipedia for a query and return maximum 2 results.
48
-
49
- Args:
50
- query: The search query."""
51
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
52
- formatted_search_docs = "\n\n---\n\n".join(
53
- [
54
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
55
- for doc in search_docs
56
- ]
57
- )
58
- return {"wiki_results": formatted_search_docs}
59
-
60
-
61
- @tool
62
- def web_search(query: str) -> str:
63
- """Search Tavily for a query and return maximum 3 results.
64
-
65
- Args:
66
- query: The search query."""
67
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
68
- formatted_search_docs = "\n\n---\n\n".join(
69
- [
70
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
71
- for doc in search_docs
72
- ]
73
- )
74
- return {"web_results": formatted_search_docs}
75
-
76
-
77
- @tool
78
- def arxiv_search(query: str) -> str:
79
- """Search Arxiv for a query and return maximum 3 result.
80
-
81
- Args:
82
- query: The search query."""
83
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
84
- formatted_search_docs = "\n\n---\n\n".join(
85
- [
86
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
87
- for doc in search_docs
88
- ]
89
- )
90
- return {"arxiv_results": formatted_search_docs}
91
-
92
-
93
- ### =============== CODE INTERPRETER TOOLS =============== ###
94
-
95
-
96
- @tool
97
- def execute_code_multilang(code: str, language: str = "python") -> str:
98
- """Execute code in multiple languages (Python, Bash, SQL, C, Java) and return results.
99
-
100
- Args:
101
- code (str): The source code to execute.
102
- language (str): The language of the code. Supported: "python", "bash", "sql", "c", "java".
103
-
104
- Returns:
105
- A string summarizing the execution results (stdout, stderr, errors, plots, dataframes if any).
106
- """
107
- supported_languages = ["python", "bash", "sql", "c", "java"]
108
- language = language.lower()
109
-
110
- if language not in supported_languages:
111
- return f"❌ Unsupported language: {language}. Supported languages are: {', '.join(supported_languages)}"
112
-
113
- result = interpreter_instance.execute_code(code, language=language)
114
-
115
- response = []
116
-
117
- if result["status"] == "success":
118
- response.append(f"✅ Code executed successfully in **{language.upper()}**")
119
-
120
- if result.get("stdout"):
121
- response.append(
122
- "\n**Standard Output:**\n```\n" + result["stdout"].strip() + "\n```"
123
- )
124
-
125
- if result.get("stderr"):
126
- response.append(
127
- "\n**Standard Error (if any):**\n```\n"
128
- + result["stderr"].strip()
129
- + "\n```"
130
- )
131
-
132
- if result.get("result") is not None:
133
- response.append(
134
- "\n**Execution Result:**\n```\n"
135
- + str(result["result"]).strip()
136
- + "\n```"
137
- )
138
-
139
- if result.get("dataframes"):
140
- for df_info in result["dataframes"]:
141
- response.append(
142
- f"\n**DataFrame `{df_info['name']}` (Shape: {df_info['shape']})**"
143
- )
144
- df_preview = pd.DataFrame(df_info["head"])
145
- response.append("First 5 rows:\n```\n" + str(df_preview) + "\n```")
146
-
147
- if result.get("plots"):
148
- response.append(
149
- f"\n**Generated {len(result['plots'])} plot(s)** (Image data returned separately)"
150
- )
151
-
152
- else:
153
- response.append(f"❌ Code execution failed in **{language.upper()}**")
154
- if result.get("stderr"):
155
- response.append(
156
- "\n**Error Log:**\n```\n" + result["stderr"].strip() + "\n```"
157
- )
158
-
159
- return "\n".join(response)
160
-
161
-
162
- ### =============== MATHEMATICAL TOOLS =============== ###
163
-
164
-
165
  @tool
166
- def multiply(a: float, b: float) -> float:
167
- """
168
- Multiplies two numbers.
169
-
170
  Args:
171
- a (float): the first number
172
- b (float): the second number
173
  """
174
  return a * b
175
 
176
-
177
  @tool
178
- def add(a: float, b: float) -> float:
179
- """
180
- Adds two numbers.
181
-
182
  Args:
183
- a (float): the first number
184
- b (float): the second number
185
  """
186
  return a + b
187
 
188
-
189
  @tool
190
- def subtract(a: float, b: float) -> int:
191
- """
192
- Subtracts two numbers.
193
-
194
  Args:
195
- a (float): the first number
196
- b (float): the second number
197
  """
198
  return a - b
199
 
200
-
201
  @tool
202
- def divide(a: float, b: float) -> float:
203
- """
204
- Divides two numbers.
205
-
206
  Args:
207
- a (float): the first float number
208
- b (float): the second float number
209
  """
210
  if b == 0:
211
- raise ValueError("Cannot divided by zero.")
212
  return a / b
213
 
214
-
215
  @tool
216
  def modulus(a: int, b: int) -> int:
217
- """
218
- Get the modulus of two numbers.
219
-
220
  Args:
221
- a (int): the first number
222
- b (int): the second number
223
  """
224
  return a % b
225
 
226
-
227
- @tool
228
- def power(a: float, b: float) -> float:
229
- """
230
- Get the power of two numbers.
231
-
232
- Args:
233
- a (float): the first number
234
- b (float): the second number
235
- """
236
- return a**b
237
-
238
-
239
- @tool
240
- def square_root(a: float) -> float | complex:
241
- """
242
- Get the square root of a number.
243
-
244
- Args:
245
- a (float): the number to get the square root of
246
- """
247
- if a >= 0:
248
- return a**0.5
249
- return cmath.sqrt(a)
250
-
251
-
252
- ### =============== DOCUMENT PROCESSING TOOLS =============== ###
253
-
254
-
255
- @tool
256
- def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
257
- """
258
- Save content to a file and return the path.
259
-
260
- Args:
261
- content (str): the content to save to the file
262
- filename (str, optional): the name of the file. If not provided, a random name file will be created.
263
- """
264
- temp_dir = tempfile.gettempdir()
265
- if filename is None:
266
- temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
267
- filepath = temp_file.name
268
- else:
269
- filepath = os.path.join(temp_dir, filename)
270
-
271
- with open(filepath, "w") as f:
272
- f.write(content)
273
-
274
- return f"File saved to {filepath}. You can read this file to process its contents."
275
-
276
-
277
- @tool
278
- def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
279
- """
280
- Download a file from a URL and save it to a temporary location.
281
-
282
- Args:
283
- url (str): the URL of the file to download.
284
- filename (str, optional): the name of the file. If not provided, a random name file will be created.
285
- """
286
- try:
287
- # Parse URL to get filename if not provided
288
- if not filename:
289
- path = urlparse(url).path
290
- filename = os.path.basename(path)
291
- if not filename:
292
- filename = f"downloaded_{uuid.uuid4().hex[:8]}"
293
-
294
- # Create temporary file
295
- temp_dir = tempfile.gettempdir()
296
- filepath = os.path.join(temp_dir, filename)
297
-
298
- # Download the file
299
- response = requests.get(url, stream=True)
300
- response.raise_for_status()
301
-
302
- # Save the file
303
- with open(filepath, "wb") as f:
304
- for chunk in response.iter_content(chunk_size=8192):
305
- f.write(chunk)
306
-
307
- return f"File downloaded to {filepath}. You can read this file to process its contents."
308
- except Exception as e:
309
- return f"Error downloading file: {str(e)}"
310
-
311
-
312
- @tool
313
- def extract_text_from_image(image_path: str) -> str:
314
- """
315
- Extract text from an image using OCR library pytesseract (if available).
316
-
317
- Args:
318
- image_path (str): the path to the image file.
319
- """
320
- try:
321
- # Open the image
322
- image = Image.open(image_path)
323
-
324
- # Extract text from the image
325
- text = pytesseract.image_to_string(image)
326
-
327
- return f"Extracted text from image:\n\n{text}"
328
- except Exception as e:
329
- return f"Error extracting text from image: {str(e)}"
330
-
331
-
332
- @tool
333
- def analyze_csv_file(file_path: str, query: str) -> str:
334
- """
335
- Analyze a CSV file using pandas and answer a question about it.
336
-
337
- Args:
338
- file_path (str): the path to the CSV file.
339
- query (str): Question about the data
340
- """
341
- try:
342
- # Read the CSV file
343
- df = pd.read_csv(file_path)
344
-
345
- # Run various analyses based on the query
346
- result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
347
- result += f"Columns: {', '.join(df.columns)}\n\n"
348
-
349
- # Add summary statistics
350
- result += "Summary statistics:\n"
351
- result += str(df.describe())
352
-
353
- return result
354
-
355
- except Exception as e:
356
- return f"Error analyzing CSV file: {str(e)}"
357
-
358
-
359
- @tool
360
- def analyze_excel_file(file_path: str, query: str) -> str:
361
- """
362
- Analyze an Excel file using pandas and answer a question about it.
363
-
364
- Args:
365
- file_path (str): the path to the Excel file.
366
- query (str): Question about the data
367
- """
368
- try:
369
- # Read the Excel file
370
- df = pd.read_excel(file_path)
371
-
372
- # Run various analyses based on the query
373
- result = (
374
- f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
375
- )
376
- result += f"Columns: {', '.join(df.columns)}\n\n"
377
-
378
- # Add summary statistics
379
- result += "Summary statistics:\n"
380
- result += str(df.describe())
381
-
382
- return result
383
-
384
- except Exception as e:
385
- return f"Error analyzing Excel file: {str(e)}"
386
-
387
-
388
- ### ============== IMAGE PROCESSING AND GENERATION TOOLS =============== ###
389
-
390
-
391
- @tool
392
- def analyze_image(image_base64: str) -> Dict[str, Any]:
393
- """
394
- Analyze basic properties of an image (size, mode, color analysis, thumbnail preview).
395
-
396
- Args:
397
- image_base64 (str): Base64 encoded image string
398
-
399
- Returns:
400
- Dictionary with analysis result
401
- """
402
- try:
403
- img = decode_image(image_base64)
404
- width, height = img.size
405
- mode = img.mode
406
-
407
- if mode in ("RGB", "RGBA"):
408
- arr = np.array(img)
409
- avg_colors = arr.mean(axis=(0, 1))
410
- dominant = ["Red", "Green", "Blue"][np.argmax(avg_colors[:3])]
411
- brightness = avg_colors.mean()
412
- color_analysis = {
413
- "average_rgb": avg_colors.tolist(),
414
- "brightness": brightness,
415
- "dominant_color": dominant,
416
- }
417
- else:
418
- color_analysis = {"note": f"No color analysis for mode {mode}"}
419
-
420
- thumbnail = img.copy()
421
- thumbnail.thumbnail((100, 100))
422
- thumb_path = save_image(thumbnail, "thumbnails")
423
- thumbnail_base64 = encode_image(thumb_path)
424
-
425
- return {
426
- "dimensions": (width, height),
427
- "mode": mode,
428
- "color_analysis": color_analysis,
429
- "thumbnail": thumbnail_base64,
430
- }
431
- except Exception as e:
432
- return {"error": str(e)}
433
-
434
-
435
  @tool
436
- def transform_image(
437
- image_base64: str, operation: str, params: Optional[Dict[str, Any]] = None
438
- ) -> Dict[str, Any]:
439
- """
440
- Apply transformations: resize, rotate, crop, flip, brightness, contrast, blur, sharpen, grayscale.
441
-
442
- Args:
443
- image_base64 (str): Base64 encoded input image
444
- operation (str): Transformation operation
445
- params (Dict[str, Any], optional): Parameters for the operation
446
-
447
- Returns:
448
- Dictionary with transformed image (base64)
449
- """
450
- try:
451
- img = decode_image(image_base64)
452
- params = params or {}
453
-
454
- if operation == "resize":
455
- img = img.resize(
456
- (
457
- params.get("width", img.width // 2),
458
- params.get("height", img.height // 2),
459
- )
460
- )
461
- elif operation == "rotate":
462
- img = img.rotate(params.get("angle", 90), expand=True)
463
- elif operation == "crop":
464
- img = img.crop(
465
- (
466
- params.get("left", 0),
467
- params.get("top", 0),
468
- params.get("right", img.width),
469
- params.get("bottom", img.height),
470
- )
471
- )
472
- elif operation == "flip":
473
- if params.get("direction", "horizontal") == "horizontal":
474
- img = img.transpose(Image.FLIP_LEFT_RIGHT)
475
- else:
476
- img = img.transpose(Image.FLIP_TOP_BOTTOM)
477
- elif operation == "adjust_brightness":
478
- img = ImageEnhance.Brightness(img).enhance(params.get("factor", 1.5))
479
- elif operation == "adjust_contrast":
480
- img = ImageEnhance.Contrast(img).enhance(params.get("factor", 1.5))
481
- elif operation == "blur":
482
- img = img.filter(ImageFilter.GaussianBlur(params.get("radius", 2)))
483
- elif operation == "sharpen":
484
- img = img.filter(ImageFilter.SHARPEN)
485
- elif operation == "grayscale":
486
- img = img.convert("L")
487
- else:
488
- return {"error": f"Unknown operation: {operation}"}
489
-
490
- result_path = save_image(img)
491
- result_base64 = encode_image(result_path)
492
- return {"transformed_image": result_base64}
493
-
494
- except Exception as e:
495
- return {"error": str(e)}
496
-
497
-
498
- @tool
499
- def draw_on_image(
500
- image_base64: str, drawing_type: str, params: Dict[str, Any]
501
- ) -> Dict[str, Any]:
502
- """
503
- Draw shapes (rectangle, circle, line) or text onto an image.
504
-
505
  Args:
506
- image_base64 (str): Base64 encoded input image
507
- drawing_type (str): Drawing type
508
- params (Dict[str, Any]): Drawing parameters
509
-
510
- Returns:
511
- Dictionary with result image (base64)
512
- """
513
- try:
514
- img = decode_image(image_base64)
515
- draw = ImageDraw.Draw(img)
516
- color = params.get("color", "red")
517
-
518
- if drawing_type == "rectangle":
519
- draw.rectangle(
520
- [params["left"], params["top"], params["right"], params["bottom"]],
521
- outline=color,
522
- width=params.get("width", 2),
523
- )
524
- elif drawing_type == "circle":
525
- x, y, r = params["x"], params["y"], params["radius"]
526
- draw.ellipse(
527
- (x - r, y - r, x + r, y + r),
528
- outline=color,
529
- width=params.get("width", 2),
530
- )
531
- elif drawing_type == "line":
532
- draw.line(
533
- (
534
- params["start_x"],
535
- params["start_y"],
536
- params["end_x"],
537
- params["end_y"],
538
- ),
539
- fill=color,
540
- width=params.get("width", 2),
541
- )
542
- elif drawing_type == "text":
543
- font_size = params.get("font_size", 20)
544
- try:
545
- font = ImageFont.truetype("arial.ttf", font_size)
546
- except IOError:
547
- font = ImageFont.load_default()
548
- draw.text(
549
- (params["x"], params["y"]),
550
- params.get("text", "Text"),
551
- fill=color,
552
- font=font,
553
- )
554
- else:
555
- return {"error": f"Unknown drawing type: {drawing_type}"}
556
-
557
- result_path = save_image(img)
558
- result_base64 = encode_image(result_path)
559
- return {"result_image": result_base64}
560
-
561
- except Exception as e:
562
- return {"error": str(e)}
563
-
564
 
565
  @tool
566
- def generate_simple_image(
567
- image_type: str,
568
- width: int = 500,
569
- height: int = 500,
570
- params: Optional[Dict[str, Any]] = None,
571
- ) -> Dict[str, Any]:
572
- """
573
- Generate a simple image (gradient, noise, pattern, chart).
574
-
575
  Args:
576
- image_type (str): Type of image
577
- width (int), height (int)
578
- params (Dict[str, Any], optional): Specific parameters
579
-
580
- Returns:
581
- Dictionary with generated image (base64)
582
- """
583
- try:
584
- params = params or {}
585
-
586
- if image_type == "gradient":
587
- direction = params.get("direction", "horizontal")
588
- start_color = params.get("start_color", (255, 0, 0))
589
- end_color = params.get("end_color", (0, 0, 255))
590
-
591
- img = Image.new("RGB", (width, height))
592
- draw = ImageDraw.Draw(img)
593
-
594
- if direction == "horizontal":
595
- for x in range(width):
596
- r = int(
597
- start_color[0] + (end_color[0] - start_color[0]) * x / width
598
- )
599
- g = int(
600
- start_color[1] + (end_color[1] - start_color[1]) * x / width
601
- )
602
- b = int(
603
- start_color[2] + (end_color[2] - start_color[2]) * x / width
604
- )
605
- draw.line([(x, 0), (x, height)], fill=(r, g, b))
606
- else:
607
- for y in range(height):
608
- r = int(
609
- start_color[0] + (end_color[0] - start_color[0]) * y / height
610
- )
611
- g = int(
612
- start_color[1] + (end_color[1] - start_color[1]) * y / height
613
- )
614
- b = int(
615
- start_color[2] + (end_color[2] - start_color[2]) * y / height
616
- )
617
- draw.line([(0, y), (width, y)], fill=(r, g, b))
618
-
619
- elif image_type == "noise":
620
- noise_array = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
621
- img = Image.fromarray(noise_array, "RGB")
622
-
623
- else:
624
- return {"error": f"Unsupported image_type {image_type}"}
625
-
626
- result_path = save_image(img)
627
- result_base64 = encode_image(result_path)
628
- return {"generated_image": result_base64}
629
-
630
- except Exception as e:
631
- return {"error": str(e)}
632
-
633
 
634
  @tool
635
- def combine_images(
636
- images_base64: List[str], operation: str, params: Optional[Dict[str, Any]] = None
637
- ) -> Dict[str, Any]:
638
- """
639
- Combine multiple images (collage, stack, blend).
640
-
641
  Args:
642
- images_base64 (List[str]): List of base64 images
643
- operation (str): Combination type
644
- params (Dict[str, Any], optional)
645
-
646
- Returns:
647
- Dictionary with combined image (base64)
648
- """
649
- try:
650
- images = [decode_image(b64) for b64 in images_base64]
651
- params = params or {}
652
-
653
- if operation == "stack":
654
- direction = params.get("direction", "horizontal")
655
- if direction == "horizontal":
656
- total_width = sum(img.width for img in images)
657
- max_height = max(img.height for img in images)
658
- new_img = Image.new("RGB", (total_width, max_height))
659
- x = 0
660
- for img in images:
661
- new_img.paste(img, (x, 0))
662
- x += img.width
663
- else:
664
- max_width = max(img.width for img in images)
665
- total_height = sum(img.height for img in images)
666
- new_img = Image.new("RGB", (max_width, total_height))
667
- y = 0
668
- for img in images:
669
- new_img.paste(img, (0, y))
670
- y += img.height
671
- else:
672
- return {"error": f"Unsupported combination operation {operation}"}
673
-
674
- result_path = save_image(new_img)
675
- result_base64 = encode_image(result_path)
676
- return {"combined_image": result_base64}
677
 
678
- except Exception as e:
679
- return {"error": str(e)}
680
 
681
 
682
  # load the system prompt from the file
683
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
684
  system_prompt = f.read()
685
- print(system_prompt)
686
 
687
  # System message
688
  sys_msg = SystemMessage(content=system_prompt)
689
 
690
  # build a retriever
691
- embeddings = HuggingFaceEmbeddings(
692
- model_name="sentence-transformers/all-mpnet-base-v2"
693
- ) # dim=768
694
  supabase: Client = create_client(
695
- os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_ROLE_KEY")
696
- )
697
  vector_store = SupabaseVectorStore(
698
  client=supabase,
699
- embedding=embeddings,
700
- table_name="documents2",
701
- query_name="match_documents_2",
702
  )
703
  create_retriever_tool = create_retriever_tool(
704
  retriever=vector_store.as_retriever(),
@@ -707,54 +138,38 @@ create_retriever_tool = create_retriever_tool(
707
  )
708
 
709
 
 
710
  tools = [
711
- web_search,
712
- wiki_search,
713
- arxiv_search,
714
  multiply,
715
  add,
716
  subtract,
717
  divide,
718
  modulus,
719
- power,
720
- square_root,
721
- save_and_read_file,
722
- download_file_from_url,
723
- extract_text_from_image,
724
- analyze_csv_file,
725
- analyze_excel_file,
726
- execute_code_multilang,
727
- analyze_image,
728
- transform_image,
729
- draw_on_image,
730
- generate_simple_image,
731
- combine_images,
732
  ]
733
 
734
-
735
  # Build graph function
736
- def build_graph(provider: str = "groq"):
737
  """Build the graph"""
738
  # Load environment variables from .env file
739
- if provider == "groq":
 
 
 
740
  # Groq https://console.groq.com/docs/models
741
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
742
  elif provider == "huggingface":
743
  # TODO: Add huggingface endpoint
744
  llm = ChatHuggingFace(
745
  llm=HuggingFaceEndpoint(
746
- url="https://api-inference.huggingface.co/models/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
747
- repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
748
- task="text-generation", # for chat‐style use “text-generation”
749
- max_new_tokens=1024,
750
- do_sample=False,
751
- repetition_penalty=1.03,
752
  temperature=0,
753
  ),
754
- verbose=True,
755
  )
756
  else:
757
- raise ValueError("Invalid provider. Choose 'groq' or 'huggingface'.")
758
  # Bind tools to LLM
759
  llm_with_tools = llm.bind_tools(tools)
760
 
@@ -762,41 +177,47 @@ def build_graph(provider: str = "groq"):
762
  def assistant(state: MessagesState):
763
  """Assistant node"""
764
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
 
 
 
 
 
 
765
 
766
  def retriever(state: MessagesState):
767
- """Retriever node"""
768
- similar_question = vector_store.similarity_search(state["messages"][0].content)
769
 
770
- if similar_question: # Check if the list is not empty
771
- example_msg = HumanMessage(
772
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
773
- )
774
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
775
  else:
776
- # Handle the case when no similar questions are found
777
- return {"messages": [sys_msg] + state["messages"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
778
 
779
  builder = StateGraph(MessagesState)
780
  builder.add_node("retriever", retriever)
781
- builder.add_node("assistant", assistant)
782
- builder.add_node("tools", ToolNode(tools))
783
- builder.add_edge(START, "retriever")
784
- builder.add_edge("retriever", "assistant")
785
- builder.add_conditional_edges(
786
- "assistant",
787
- tools_condition,
788
- )
789
- builder.add_edge("tools", "assistant")
790
-
791
- # Compile graph
792
- return builder.compile()
793
 
 
 
 
794
 
795
- # test
796
- if __name__ == "__main__":
797
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
798
- graph = build_graph(provider="groq")
799
- messages = [HumanMessage(content=question)]
800
- messages = graph.invoke({"messages": messages})
801
- for m in messages["messages"]:
802
- m.pretty_print()
 
1
+ """LangGraph Agent"""
2
  import os
3
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
12
  from langchain_community.document_loaders import ArxivLoader
 
 
 
 
 
 
 
 
13
  from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
 
18
 
19
  load_dotenv()
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  @tool
22
+ def multiply(a: int, b: int) -> int:
23
+ """Multiply two numbers.
 
 
24
  Args:
25
+ a: first int
26
+ b: second int
27
  """
28
  return a * b
29
 
 
30
  @tool
31
+ def add(a: int, b: int) -> int:
32
+ """Add two numbers.
33
+
 
34
  Args:
35
+ a: first int
36
+ b: second int
37
  """
38
  return a + b
39
 
 
40
  @tool
41
+ def subtract(a: int, b: int) -> int:
42
+ """Subtract two numbers.
43
+
 
44
  Args:
45
+ a: first int
46
+ b: second int
47
  """
48
  return a - b
49
 
 
50
  @tool
51
+ def divide(a: int, b: int) -> int:
52
+ """Divide two numbers.
53
+
 
54
  Args:
55
+ a: first int
56
+ b: second int
57
  """
58
  if b == 0:
59
+ raise ValueError("Cannot divide by zero.")
60
  return a / b
61
 
 
62
  @tool
63
  def modulus(a: int, b: int) -> int:
64
+ """Get the modulus of two numbers.
65
+
 
66
  Args:
67
+ a: first int
68
+ b: second int
69
  """
70
  return a % b
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  @tool
73
+ def wiki_search(query: str) -> str:
74
+ """Search Wikipedia for a query and return maximum 2 results.
75
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  Args:
77
+ query: The search query."""
78
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
+ formatted_search_docs = "\n\n---\n\n".join(
80
+ [
81
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
+ for doc in search_docs
83
+ ])
84
+ return {"wiki_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  @tool
87
+ def web_search(query: str) -> str:
88
+ """Search Tavily for a query and return maximum 3 results.
89
+
 
 
 
 
 
 
90
  Args:
91
+ query: The search query."""
92
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
93
+ formatted_search_docs = "\n\n---\n\n".join(
94
+ [
95
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
+ for doc in search_docs
97
+ ])
98
+ return {"web_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  @tool
101
+ def arvix_search(query: str) -> str:
102
+ """Search Arxiv for a query and return maximum 3 result.
103
+
 
 
 
104
  Args:
105
+ query: The search query."""
106
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
+ formatted_search_docs = "\n\n---\n\n".join(
108
+ [
109
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
+ for doc in search_docs
111
+ ])
112
+ return {"arvix_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
 
 
114
 
115
 
116
  # load the system prompt from the file
117
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
  system_prompt = f.read()
 
119
 
120
  # System message
121
  sys_msg = SystemMessage(content=system_prompt)
122
 
123
  # build a retriever
124
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
 
 
125
  supabase: Client = create_client(
126
+ os.environ.get("SUPABASE_URL"),
127
+ os.environ.get("SUPABASE_SERVICE_KEY"))
128
  vector_store = SupabaseVectorStore(
129
  client=supabase,
130
+ embedding= embeddings,
131
+ table_name="documents",
132
+ query_name="match_documents_langchain",
133
  )
134
  create_retriever_tool = create_retriever_tool(
135
  retriever=vector_store.as_retriever(),
 
138
  )
139
 
140
 
141
+
142
  tools = [
 
 
 
143
  multiply,
144
  add,
145
  subtract,
146
  divide,
147
  modulus,
148
+ wiki_search,
149
+ web_search,
150
+ arvix_search,
 
 
 
 
 
 
 
 
 
 
151
  ]
152
 
 
153
  # Build graph function
154
+ def build_graph(provider: str = "google"):
155
  """Build the graph"""
156
  # Load environment variables from .env file
157
+ if provider == "google":
158
+ # Google Gemini
159
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
160
+ elif provider == "groq":
161
  # Groq https://console.groq.com/docs/models
162
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
163
  elif provider == "huggingface":
164
  # TODO: Add huggingface endpoint
165
  llm = ChatHuggingFace(
166
  llm=HuggingFaceEndpoint(
167
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
 
 
 
 
 
168
  temperature=0,
169
  ),
 
170
  )
171
  else:
172
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
173
  # Bind tools to LLM
174
  llm_with_tools = llm.bind_tools(tools)
175
 
 
177
  def assistant(state: MessagesState):
178
  """Assistant node"""
179
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
+
181
+ # def retriever(state: MessagesState):
182
+ # """Retriever node"""
183
+ # similar_question = vector_store.similarity_search(state["messages"][0].content)
184
+ #example_msg = HumanMessage(
185
+ # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
+ # )
187
+ # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
+
189
+ from langchain_core.messages import AIMessage
190
 
191
  def retriever(state: MessagesState):
192
+ query = state["messages"][-1].content
193
+ similar_doc = vector_store.similarity_search(query, k=1)[0]
194
 
195
+ content = similar_doc.page_content
196
+ if "Final answer :" in content:
197
+ answer = content.split("Final answer :")[-1].strip()
 
 
198
  else:
199
+ answer = content.strip()
200
+
201
+ return {"messages": [AIMessage(content=answer)]}
202
+
203
+ # builder = StateGraph(MessagesState)
204
+ #builder.add_node("retriever", retriever)
205
+ #builder.add_node("assistant", assistant)
206
+ #builder.add_node("tools", ToolNode(tools))
207
+ #builder.add_edge(START, "retriever")
208
+ #builder.add_edge("retriever", "assistant")
209
+ #builder.add_conditional_edges(
210
+ # "assistant",
211
+ # tools_condition,
212
+ #)
213
+ #builder.add_edge("tools", "assistant")
214
 
215
  builder = StateGraph(MessagesState)
216
  builder.add_node("retriever", retriever)
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+ # Retriever ist Start und Endpunkt
219
+ builder.set_entry_point("retriever")
220
+ builder.set_finish_point("retriever")
221
 
222
+ # Compile graph
223
+ return builder.compile()
 
 
 
 
 
 
app.py CHANGED
@@ -4,7 +4,6 @@ import inspect
4
  import gradio as gr
5
  import requests
6
  import pandas as pd
7
- import time
8
  from langchain_core.messages import HumanMessage
9
  from agent import build_graph
10
 
@@ -26,11 +25,11 @@ class BasicAgent:
26
 
27
  def __call__(self, question: str) -> str:
28
  print(f"Agent received question (first 50 chars): {question[:50]}...")
29
- # Wrap the question in a HumanMessage from langchain_core
30
  messages = [HumanMessage(content=question)]
31
- messages = self.graph.invoke({"messages": messages})
32
- answer = messages['messages'][-1].content
33
- return answer[14:]
 
34
 
35
 
36
  def run_and_submit_all( profile: gr.OAuthProfile | None):
@@ -93,9 +92,6 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
93
  if not task_id or question_text is None:
94
  print(f"Skipping item with missing task_id or question: {item}")
95
  continue
96
-
97
- # time.sleep(10)
98
-
99
  try:
100
  submitted_answer = agent(question_text)
101
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
@@ -163,9 +159,11 @@ with gr.Blocks() as demo:
163
  gr.Markdown(
164
  """
165
  **Instructions:**
 
166
  1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ...
167
  2. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
168
  3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score.
 
169
  ---
170
  **Disclaimers:**
171
  Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions).
 
4
  import gradio as gr
5
  import requests
6
  import pandas as pd
 
7
  from langchain_core.messages import HumanMessage
8
  from agent import build_graph
9
 
 
25
 
26
  def __call__(self, question: str) -> str:
27
  print(f"Agent received question (first 50 chars): {question[:50]}...")
 
28
  messages = [HumanMessage(content=question)]
29
+ result = self.graph.invoke({"messages": messages})
30
+ answer = result['messages'][-1].content
31
+ return answer # kein [14:] mehr nötig!
32
+
33
 
34
 
35
  def run_and_submit_all( profile: gr.OAuthProfile | None):
 
92
  if not task_id or question_text is None:
93
  print(f"Skipping item with missing task_id or question: {item}")
94
  continue
 
 
 
95
  try:
96
  submitted_answer = agent(question_text)
97
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
 
159
  gr.Markdown(
160
  """
161
  **Instructions:**
162
+
163
  1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ...
164
  2. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
165
  3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score.
166
+
167
  ---
168
  **Disclaimers:**
169
  Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions).
metadata.jsonl ADDED
The diff for this file is too large to render. See raw diff