rakesh-dvg commited on
Commit
38f7123
·
verified ·
1 Parent(s): 3ae99b4

Delete agent.py

Browse files
Files changed (1) hide show
  1. agent.py +0 -790
agent.py DELETED
@@ -1,790 +0,0 @@
1
- import os
2
- from dotenv import load_dotenv
3
- from typing import List, Dict, Any, Optional
4
- import tempfile
5
- import re
6
- import json
7
- import requests
8
- from urllib.parse import urlparse
9
- import pytesseract
10
- from PIL import Image, ImageDraw, ImageFont, ImageEnhance, ImageFilter
11
- import cmath
12
- import pandas as pd
13
- import uuid
14
- import numpy as np
15
- from code_interpreter import CodeInterpreter
16
-
17
- interpreter_instance = CodeInterpreter()
18
-
19
-
20
- # agent.py
21
- from supabase import create_client, Client
22
- import os
23
-
24
- def get_supabase_client() -> Client:
25
- supabase_url = os.getenv("SUPABASE_URL")
26
- supabase_key = os.getenv("SUPABASE_KEY")
27
-
28
- print("✅ agent.py - SUPABASE_URL:", repr(supabase_url))
29
- print("✅ agent.py - SUPABASE_KEY:", repr(supabase_key))
30
-
31
- return create_client(supabase_url, supabase_key)
32
-
33
-
34
-
35
- from image_processing import *
36
-
37
- """Langraph"""
38
- from langgraph.graph import START, StateGraph, MessagesState
39
- from langchain_community.tools.tavily_search import TavilySearchResults
40
- from langchain_community.document_loaders import WikipediaLoader
41
- from langchain_community.document_loaders import ArxivLoader
42
- from langgraph.prebuilt import ToolNode, tools_condition
43
- from langchain_google_genai import ChatGoogleGenerativeAI
44
- from langchain_groq import ChatGroq
45
- from langchain_huggingface import (
46
- ChatHuggingFace,
47
- HuggingFaceEndpoint,
48
- HuggingFaceEmbeddings,
49
- )
50
- from langchain_community.vectorstores import SupabaseVectorStore
51
- from langchain_core.messages import SystemMessage, HumanMessage
52
- from langchain_core.tools import tool
53
- from langchain.tools.retriever import create_retriever_tool
54
- from supabase.client import Client, create_client
55
-
56
- load_dotenv()
57
-
58
- ### =============== BROWSER TOOLS =============== ###
59
-
60
-
61
- @tool
62
- def wiki_search(query: str) -> str:
63
- """Search Wikipedia for a query and return maximum 2 results.
64
- Args:
65
- query: The search query."""
66
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
67
- formatted_search_docs = "\n\n---\n\n".join(
68
- [
69
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
70
- for doc in search_docs
71
- ]
72
- )
73
- return {"wiki_results": formatted_search_docs}
74
-
75
-
76
- @tool
77
- def web_search(query: str) -> str:
78
- """Search Tavily for a query and return maximum 3 results.
79
- Args:
80
- query: The search query."""
81
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
82
- formatted_search_docs = "\n\n---\n\n".join(
83
- [
84
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
85
- for doc in search_docs
86
- ]
87
- )
88
- return {"web_results": formatted_search_docs}
89
-
90
-
91
- @tool
92
- def arxiv_search(query: str) -> str:
93
- """Search Arxiv for a query and return maximum 3 result.
94
- Args:
95
- query: The search query."""
96
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
97
- formatted_search_docs = "\n\n---\n\n".join(
98
- [
99
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
100
- for doc in search_docs
101
- ]
102
- )
103
- return {"arxiv_results": formatted_search_docs}
104
-
105
-
106
- ### =============== CODE INTERPRETER TOOLS =============== ###
107
-
108
-
109
- @tool
110
- def execute_code_multilang(code: str, language: str = "python") -> str:
111
- """Execute code in multiple languages (Python, Bash, SQL, C, Java) and return results.
112
- Args:
113
- code (str): The source code to execute.
114
- language (str): The language of the code. Supported: "python", "bash", "sql", "c", "java".
115
- Returns:
116
- A string summarizing the execution results (stdout, stderr, errors, plots, dataframes if any).
117
- """
118
- supported_languages = ["python", "bash", "sql", "c", "java"]
119
- language = language.lower()
120
-
121
- if language not in supported_languages:
122
- return f"❌ Unsupported language: {language}. Supported languages are: {', '.join(supported_languages)}"
123
-
124
- result = interpreter_instance.execute_code(code, language=language)
125
-
126
- response = []
127
-
128
- if result["status"] == "success":
129
- response.append(f"✅ Code executed successfully in **{language.upper()}**")
130
-
131
- if result.get("stdout"):
132
- response.append(
133
- "\n**Standard Output:**\n```\n" + result["stdout"].strip() + "\n```"
134
- )
135
-
136
- if result.get("stderr"):
137
- response.append(
138
- "\n**Standard Error (if any):**\n```\n"
139
- + result["stderr"].strip()
140
- + "\n```"
141
- )
142
-
143
- if result.get("result") is not None:
144
- response.append(
145
- "\n**Execution Result:**\n```\n"
146
- + str(result["result"]).strip()
147
- + "\n```"
148
- )
149
-
150
- if result.get("dataframes"):
151
- for df_info in result["dataframes"]:
152
- response.append(
153
- f"\n**DataFrame `{df_info['name']}` (Shape: {df_info['shape']})**"
154
- )
155
- df_preview = pd.DataFrame(df_info["head"])
156
- response.append("First 5 rows:\n```\n" + str(df_preview) + "\n```")
157
-
158
- if result.get("plots"):
159
- response.append(
160
- f"\n**Generated {len(result['plots'])} plot(s)** (Image data returned separately)"
161
- )
162
-
163
- else:
164
- response.append(f"❌ Code execution failed in **{language.upper()}**")
165
- if result.get("stderr"):
166
- response.append(
167
- "\n**Error Log:**\n```\n" + result["stderr"].strip() + "\n```"
168
- )
169
-
170
- return "\n".join(response)
171
-
172
-
173
- ### =============== MATHEMATICAL TOOLS =============== ###
174
-
175
-
176
- @tool
177
- def multiply(a: float, b: float) -> float:
178
- """
179
- Multiplies two numbers.
180
- Args:
181
- a (float): the first number
182
- b (float): the second number
183
- """
184
- return a * b
185
-
186
-
187
- @tool
188
- def add(a: float, b: float) -> float:
189
- """
190
- Adds two numbers.
191
- Args:
192
- a (float): the first number
193
- b (float): the second number
194
- """
195
- return a + b
196
-
197
-
198
- @tool
199
- def subtract(a: float, b: float) -> int:
200
- """
201
- Subtracts two numbers.
202
- Args:
203
- a (float): the first number
204
- b (float): the second number
205
- """
206
- return a - b
207
-
208
-
209
- @tool
210
- def divide(a: float, b: float) -> float:
211
- """
212
- Divides two numbers.
213
- Args:
214
- a (float): the first float number
215
- b (float): the second float number
216
- """
217
- if b == 0:
218
- raise ValueError("Cannot divided by zero.")
219
- return a / b
220
-
221
-
222
- @tool
223
- def modulus(a: int, b: int) -> int:
224
- """
225
- Get the modulus of two numbers.
226
- Args:
227
- a (int): the first number
228
- b (int): the second number
229
- """
230
- return a % b
231
-
232
-
233
- @tool
234
- def power(a: float, b: float) -> float:
235
- """
236
- Get the power of two numbers.
237
- Args:
238
- a (float): the first number
239
- b (float): the second number
240
- """
241
- return a**b
242
-
243
-
244
- @tool
245
- def square_root(a: float) -> float | complex:
246
- """
247
- Get the square root of a number.
248
- Args:
249
- a (float): the number to get the square root of
250
- """
251
- if a >= 0:
252
- return a**0.5
253
- return cmath.sqrt(a)
254
-
255
-
256
- ### =============== DOCUMENT PROCESSING TOOLS =============== ###
257
-
258
-
259
- @tool
260
- def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
261
- """
262
- Save content to a file and return the path.
263
- Args:
264
- content (str): the content to save to the file
265
- filename (str, optional): the name of the file. If not provided, a random name file will be created.
266
- """
267
- temp_dir = tempfile.gettempdir()
268
- if filename is None:
269
- temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
270
- filepath = temp_file.name
271
- else:
272
- filepath = os.path.join(temp_dir, filename)
273
-
274
- with open(filepath, "w") as f:
275
- f.write(content)
276
-
277
- return f"File saved to {filepath}. You can read this file to process its contents."
278
-
279
-
280
- @tool
281
- def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
282
- """
283
- Download a file from a URL and save it to a temporary location.
284
- Args:
285
- url (str): the URL of the file to download.
286
- filename (str, optional): the name of the file. If not provided, a random name file will be created.
287
- """
288
- try:
289
- # Parse URL to get filename if not provided
290
- if not filename:
291
- path = urlparse(url).path
292
- filename = os.path.basename(path)
293
- if not filename:
294
- filename = f"downloaded_{uuid.uuid4().hex[:8]}"
295
-
296
- # Create temporary file
297
- temp_dir = tempfile.gettempdir()
298
- filepath = os.path.join(temp_dir, filename)
299
-
300
- # Download the file
301
- response = requests.get(url, stream=True)
302
- response.raise_for_status()
303
-
304
- # Save the file
305
- with open(filepath, "wb") as f:
306
- for chunk in response.iter_content(chunk_size=8192):
307
- f.write(chunk)
308
-
309
- return f"File downloaded to {filepath}. You can read this file to process its contents."
310
- except Exception as e:
311
- return f"Error downloading file: {str(e)}"
312
-
313
-
314
- @tool
315
- def extract_text_from_image(image_path: str) -> str:
316
- """
317
- Extract text from an image using OCR library pytesseract (if available).
318
- Args:
319
- image_path (str): the path to the image file.
320
- """
321
- try:
322
- # Open the image
323
- image = Image.open(image_path)
324
-
325
- # Extract text from the image
326
- text = pytesseract.image_to_string(image)
327
-
328
- return f"Extracted text from image:\n\n{text}"
329
- except Exception as e:
330
- return f"Error extracting text from image: {str(e)}"
331
-
332
-
333
- @tool
334
- def analyze_csv_file(file_path: str, query: str) -> str:
335
- """
336
- Analyze a CSV file using pandas and answer a question about it.
337
- Args:
338
- file_path (str): the path to the CSV file.
339
- query (str): Question about the data
340
- """
341
- try:
342
- # Read the CSV file
343
- df = pd.read_csv(file_path)
344
-
345
- # Run various analyses based on the query
346
- result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
347
- result += f"Columns: {', '.join(df.columns)}\n\n"
348
-
349
- # Add summary statistics
350
- result += "Summary statistics:\n"
351
- result += str(df.describe())
352
-
353
- return result
354
-
355
- except Exception as e:
356
- return f"Error analyzing CSV file: {str(e)}"
357
-
358
-
359
- @tool
360
- def analyze_excel_file(file_path: str, query: str) -> str:
361
- """
362
- Analyze an Excel file using pandas and answer a question about it.
363
- Args:
364
- file_path (str): the path to the Excel file.
365
- query (str): Question about the data
366
- """
367
- try:
368
- # Read the Excel file
369
- df = pd.read_excel(file_path)
370
-
371
- # Run various analyses based on the query
372
- result = (
373
- f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
374
- )
375
- result += f"Columns: {', '.join(df.columns)}\n\n"
376
-
377
- # Add summary statistics
378
- result += "Summary statistics:\n"
379
- result += str(df.describe())
380
-
381
- return result
382
-
383
- except Exception as e:
384
- return f"Error analyzing Excel file: {str(e)}"
385
-
386
-
387
- ### ============== IMAGE PROCESSING AND GENERATION TOOLS =============== ###
388
-
389
-
390
- @tool
391
- def analyze_image(image_base64: str) -> Dict[str, Any]:
392
- """
393
- Analyze basic properties of an image (size, mode, color analysis, thumbnail preview).
394
- Args:
395
- image_base64 (str): Base64 encoded image string
396
- Returns:
397
- Dictionary with analysis result
398
- """
399
- try:
400
- img = decode_image(image_base64)
401
- width, height = img.size
402
- mode = img.mode
403
-
404
- if mode in ("RGB", "RGBA"):
405
- arr = np.array(img)
406
- avg_colors = arr.mean(axis=(0, 1))
407
- dominant = ["Red", "Green", "Blue"][np.argmax(avg_colors[:3])]
408
- brightness = avg_colors.mean()
409
- color_analysis = {
410
- "average_rgb": avg_colors.tolist(),
411
- "brightness": brightness,
412
- "dominant_color": dominant,
413
- }
414
- else:
415
- color_analysis = {"note": f"No color analysis for mode {mode}"}
416
-
417
- thumbnail = img.copy()
418
- thumbnail.thumbnail((100, 100))
419
- thumb_path = save_image(thumbnail, "thumbnails")
420
- thumbnail_base64 = encode_image(thumb_path)
421
-
422
- return {
423
- "dimensions": (width, height),
424
- "mode": mode,
425
- "color_analysis": color_analysis,
426
- "thumbnail": thumbnail_base64,
427
- }
428
- except Exception as e:
429
- return {"error": str(e)}
430
-
431
-
432
- @tool
433
- def transform_image(
434
- image_base64: str, operation: str, params: Optional[Dict[str, Any]] = None
435
- ) -> Dict[str, Any]:
436
- """
437
- Apply transformations: resize, rotate, crop, flip, brightness, contrast, blur, sharpen, grayscale.
438
- Args:
439
- image_base64 (str): Base64 encoded input image
440
- operation (str): Transformation operation
441
- params (Dict[str, Any], optional): Parameters for the operation
442
- Returns:
443
- Dictionary with transformed image (base64)
444
- """
445
- try:
446
- img = decode_image(image_base64)
447
- params = params or {}
448
-
449
- if operation == "resize":
450
- img = img.resize(
451
- (
452
- params.get("width", img.width // 2),
453
- params.get("height", img.height // 2),
454
- )
455
- )
456
- elif operation == "rotate":
457
- img = img.rotate(params.get("angle", 90), expand=True)
458
- elif operation == "crop":
459
- img = img.crop(
460
- (
461
- params.get("left", 0),
462
- params.get("top", 0),
463
- params.get("right", img.width),
464
- params.get("bottom", img.height),
465
- )
466
- )
467
- elif operation == "flip":
468
- if params.get("direction", "horizontal") == "horizontal":
469
- img = img.transpose(Image.FLIP_LEFT_RIGHT)
470
- else:
471
- img = img.transpose(Image.FLIP_TOP_BOTTOM)
472
- elif operation == "adjust_brightness":
473
- img = ImageEnhance.Brightness(img).enhance(params.get("factor", 1.5))
474
- elif operation == "adjust_contrast":
475
- img = ImageEnhance.Contrast(img).enhance(params.get("factor", 1.5))
476
- elif operation == "blur":
477
- img = img.filter(ImageFilter.GaussianBlur(params.get("radius", 2)))
478
- elif operation == "sharpen":
479
- img = img.filter(ImageFilter.SHARPEN)
480
- elif operation == "grayscale":
481
- img = img.convert("L")
482
- else:
483
- return {"error": f"Unknown operation: {operation}"}
484
-
485
- result_path = save_image(img)
486
- result_base64 = encode_image(result_path)
487
- return {"transformed_image": result_base64}
488
-
489
- except Exception as e:
490
- return {"error": str(e)}
491
-
492
-
493
- @tool
494
- def draw_on_image(
495
- image_base64: str, drawing_type: str, params: Dict[str, Any]
496
- ) -> Dict[str, Any]:
497
- """
498
- Draw shapes (rectangle, circle, line) or text onto an image.
499
- Args:
500
- image_base64 (str): Base64 encoded input image
501
- drawing_type (str): Drawing type
502
- params (Dict[str, Any]): Drawing parameters
503
- Returns:
504
- Dictionary with result image (base64)
505
- """
506
- try:
507
- img = decode_image(image_base64)
508
- draw = ImageDraw.Draw(img)
509
- color = params.get("color", "red")
510
-
511
- if drawing_type == "rectangle":
512
- draw.rectangle(
513
- [params["left"], params["top"], params["right"], params["bottom"]],
514
- outline=color,
515
- width=params.get("width", 2),
516
- )
517
- elif drawing_type == "circle":
518
- x, y, r = params["x"], params["y"], params["radius"]
519
- draw.ellipse(
520
- (x - r, y - r, x + r, y + r),
521
- outline=color,
522
- width=params.get("width", 2),
523
- )
524
- elif drawing_type == "line":
525
- draw.line(
526
- (
527
- params["start_x"],
528
- params["start_y"],
529
- params["end_x"],
530
- params["end_y"],
531
- ),
532
- fill=color,
533
- width=params.get("width", 2),
534
- )
535
- elif drawing_type == "text":
536
- font_size = params.get("font_size", 20)
537
- try:
538
- font = ImageFont.truetype("arial.ttf", font_size)
539
- except IOError:
540
- font = ImageFont.load_default()
541
- draw.text(
542
- (params["x"], params["y"]),
543
- params.get("text", "Text"),
544
- fill=color,
545
- font=font,
546
- )
547
- else:
548
- return {"error": f"Unknown drawing type: {drawing_type}"}
549
-
550
- result_path = save_image(img)
551
- result_base64 = encode_image(result_path)
552
- return {"result_image": result_base64}
553
-
554
- except Exception as e:
555
- return {"error": str(e)}
556
-
557
-
558
- @tool
559
- def generate_simple_image(
560
- image_type: str,
561
- width: int = 500,
562
- height: int = 500,
563
- params: Optional[Dict[str, Any]] = None,
564
- ) -> Dict[str, Any]:
565
- """
566
- Generate a simple image (gradient, noise, pattern, chart).
567
- Args:
568
- image_type (str): Type of image
569
- width (int), height (int)
570
- params (Dict[str, Any], optional): Specific parameters
571
- Returns:
572
- Dictionary with generated image (base64)
573
- """
574
- try:
575
- params = params or {}
576
-
577
- if image_type == "gradient":
578
- direction = params.get("direction", "horizontal")
579
- start_color = params.get("start_color", (255, 0, 0))
580
- end_color = params.get("end_color", (0, 0, 255))
581
-
582
- img = Image.new("RGB", (width, height))
583
- draw = ImageDraw.Draw(img)
584
-
585
- if direction == "horizontal":
586
- for x in range(width):
587
- r = int(
588
- start_color[0] + (end_color[0] - start_color[0]) * x / width
589
- )
590
- g = int(
591
- start_color[1] + (end_color[1] - start_color[1]) * x / width
592
- )
593
- b = int(
594
- start_color[2] + (end_color[2] - start_color[2]) * x / width
595
- )
596
- draw.line([(x, 0), (x, height)], fill=(r, g, b))
597
- else:
598
- for y in range(height):
599
- r = int(
600
- start_color[0] + (end_color[0] - start_color[0]) * y / height
601
- )
602
- g = int(
603
- start_color[1] + (end_color[1] - start_color[1]) * y / height
604
- )
605
- b = int(
606
- start_color[2] + (end_color[2] - start_color[2]) * y / height
607
- )
608
- draw.line([(0, y), (width, y)], fill=(r, g, b))
609
-
610
- elif image_type == "noise":
611
- noise_array = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
612
- img = Image.fromarray(noise_array, "RGB")
613
-
614
- else:
615
- return {"error": f"Unsupported image_type {image_type}"}
616
-
617
- result_path = save_image(img)
618
- result_base64 = encode_image(result_path)
619
- return {"generated_image": result_base64}
620
-
621
- except Exception as e:
622
- return {"error": str(e)}
623
-
624
-
625
- @tool
626
- def combine_images(
627
- images_base64: List[str], operation: str, params: Optional[Dict[str, Any]] = None
628
- ) -> Dict[str, Any]:
629
- """
630
- Combine multiple images (collage, stack, blend).
631
- Args:
632
- images_base64 (List[str]): List of base64 images
633
- operation (str): Combination type
634
- params (Dict[str, Any], optional)
635
- Returns:
636
- Dictionary with combined image (base64)
637
- """
638
- try:
639
- images = [decode_image(b64) for b64 in images_base64]
640
- params = params or {}
641
-
642
- if operation == "stack":
643
- direction = params.get("direction", "horizontal")
644
- if direction == "horizontal":
645
- total_width = sum(img.width for img in images)
646
- max_height = max(img.height for img in images)
647
- new_img = Image.new("RGB", (total_width, max_height))
648
- x = 0
649
- for img in images:
650
- new_img.paste(img, (x, 0))
651
- x += img.width
652
- else:
653
- max_width = max(img.width for img in images)
654
- total_height = sum(img.height for img in images)
655
- new_img = Image.new("RGB", (max_width, total_height))
656
- y = 0
657
- for img in images:
658
- new_img.paste(img, (0, y))
659
- y += img.height
660
- else:
661
- return {"error": f"Unsupported combination operation {operation}"}
662
-
663
- result_path = save_image(new_img)
664
- result_base64 = encode_image(result_path)
665
- return {"combined_image": result_base64}
666
-
667
- except Exception as e:
668
- return {"error": str(e)}
669
-
670
-
671
- # load the system prompt from the file
672
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
673
- system_prompt = f.read()
674
- print(system_prompt)
675
-
676
- # System message
677
- sys_msg = SystemMessage(content=system_prompt)
678
-
679
- # build a retriever
680
- embeddings = HuggingFaceEmbeddings(
681
- model_name="sentence-transformers/all-mpnet-base-v2"
682
- ) # dim=768
683
- supabase: Client = create_client(
684
- os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_ROLE_KEY")
685
- )
686
- vector_store = SupabaseVectorStore(
687
- client=supabase,
688
- embedding=embeddings,
689
- table_name="documents2",
690
- query_name="match_documents_2",
691
- )
692
- create_retriever_tool = create_retriever_tool(
693
- retriever=vector_store.as_retriever(),
694
- name="Question Search",
695
- description="A tool to retrieve similar questions from a vector store.",
696
- )
697
-
698
-
699
- tools = [
700
- web_search,
701
- wiki_search,
702
- arxiv_search,
703
- multiply,
704
- add,
705
- subtract,
706
- divide,
707
- modulus,
708
- power,
709
- square_root,
710
- save_and_read_file,
711
- download_file_from_url,
712
- extract_text_from_image,
713
- analyze_csv_file,
714
- analyze_excel_file,
715
- execute_code_multilang,
716
- analyze_image,
717
- transform_image,
718
- draw_on_image,
719
- generate_simple_image,
720
- combine_images,
721
- ]
722
-
723
-
724
- # Build graph function
725
- def build_graph(provider: str = "groq"):
726
- """Build the graph"""
727
- # Load environment variables from .env file
728
- if provider == "groq":
729
- # Groq https://console.groq.com/docs/models
730
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
731
- elif provider == "huggingface":
732
- # TODO: Add huggingface endpoint
733
- llm = ChatHuggingFace(
734
- llm=HuggingFaceEndpoint(
735
- repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
736
- task="text-generation", # for chat‐style use “text-generation”
737
- max_new_tokens=1024,
738
- do_sample=False,
739
- repetition_penalty=1.03,
740
- temperature=0,
741
- ),
742
- verbose=True,
743
- )
744
- else:
745
- raise ValueError("Invalid provider. Choose 'groq' or 'huggingface'.")
746
- # Bind tools to LLM
747
- llm_with_tools = llm.bind_tools(tools)
748
-
749
- # Node
750
- def assistant(state: MessagesState):
751
- """Assistant node"""
752
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
753
-
754
- def retriever(state: MessagesState):
755
- """Retriever node"""
756
- similar_question = vector_store.similarity_search(state["messages"][0].content)
757
-
758
- if similar_question: # Check if the list is not empty
759
- example_msg = HumanMessage(
760
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
761
- )
762
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
763
- else:
764
- # Handle the case when no similar questions are found
765
- return {"messages": [sys_msg] + state["messages"]}
766
-
767
- builder = StateGraph(MessagesState)
768
- builder.add_node("retriever", retriever)
769
- builder.add_node("assistant", assistant)
770
- builder.add_node("tools", ToolNode(tools))
771
- builder.add_edge(START, "retriever")
772
- builder.add_edge("retriever", "assistant")
773
- builder.add_conditional_edges(
774
- "assistant",
775
- tools_condition,
776
- )
777
- builder.add_edge("tools", "assistant")
778
-
779
- # Compile graph
780
- return builder.compile()
781
-
782
-
783
- # test
784
- if __name__ == "__main__":
785
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
786
- graph = build_graph(provider="groq")
787
- messages = [HumanMessage(content=question)]
788
- messages = graph.invoke({"messages": messages})
789
- for m in messages["messages"]:
790
- m.pretty_print()