Tomkuijpers2232 commited on
Commit
18c23b1
·
verified ·
1 Parent(s): 26f874f

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +439 -2
agent.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  from dotenv import load_dotenv
3
- from typing import TypedDict, Annotated
4
  from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.graph.message import add_messages
6
  from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
@@ -70,7 +70,444 @@ tavily_search_tool = TavilySearch(
70
  topic="general",
71
  )
72
 
73
- tools = [multiply, add, subtract, divide, wikidata_search, tavily_search_tool]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  def build_graph():
76
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", api_key=os.getenv("GOOGLE_API_KEY"))
 
1
  import os
2
  from dotenv import load_dotenv
3
+ from typing import List, Dict, Any, Optional
4
  from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.graph.message import add_messages
6
  from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
 
70
  topic="general",
71
  )
72
 
73
+ @tool
74
+ def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
75
+ """
76
+ Save content to a file and return the path.
77
+ Args:
78
+ content (str): the content to save to the file
79
+ filename (str, optional): the name of the file. If not provided, a random name file will be created.
80
+ """
81
+ temp_dir = tempfile.gettempdir()
82
+ if filename is None:
83
+ temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
84
+ filepath = temp_file.name
85
+ else:
86
+ filepath = os.path.join(temp_dir, filename)
87
+
88
+ with open(filepath, "w") as f:
89
+ f.write(content)
90
+
91
+ return f"File saved to {filepath}. You can read this file to process its contents."
92
+
93
+
94
+ @tool
95
+ def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
96
+ """
97
+ Download a file from a URL and save it to a temporary location.
98
+ Args:
99
+ url (str): the URL of the file to download.
100
+ filename (str, optional): the name of the file. If not provided, a random name file will be created.
101
+ """
102
+ try:
103
+ # Parse URL to get filename if not provided
104
+ if not filename:
105
+ path = urlparse(url).path
106
+ filename = os.path.basename(path)
107
+ if not filename:
108
+ filename = f"downloaded_{uuid.uuid4().hex[:8]}"
109
+
110
+ # Create temporary file
111
+ temp_dir = tempfile.gettempdir()
112
+ filepath = os.path.join(temp_dir, filename)
113
+
114
+ # Download the file
115
+ response = requests.get(url, stream=True)
116
+ response.raise_for_status()
117
+
118
+ # Save the file
119
+ with open(filepath, "wb") as f:
120
+ for chunk in response.iter_content(chunk_size=8192):
121
+ f.write(chunk)
122
+
123
+ return f"File downloaded to {filepath}. You can read this file to process its contents."
124
+ except Exception as e:
125
+ return f"Error downloading file: {str(e)}"
126
+
127
+
128
+ @tool
129
+ def extract_text_from_image(image_path: str) -> str:
130
+ """
131
+ Extract text from an image using OCR library pytesseract (if available).
132
+ Args:
133
+ image_path (str): the path to the image file.
134
+ """
135
+ try:
136
+ # Open the image
137
+ image = Image.open(image_path)
138
+
139
+ # Extract text from the image
140
+ text = pytesseract.image_to_string(image)
141
+
142
+ return f"Extracted text from image:\n\n{text}"
143
+ except Exception as e:
144
+ return f"Error extracting text from image: {str(e)}"
145
+
146
+
147
+ @tool
148
+ def analyze_csv_file(file_path: str, query: str) -> str:
149
+ """
150
+ Analyze a CSV file using pandas and answer a question about it.
151
+ Args:
152
+ file_path (str): the path to the CSV file.
153
+ query (str): Question about the data
154
+ """
155
+ try:
156
+ # Read the CSV file
157
+ df = pd.read_csv(file_path)
158
+
159
+ # Run various analyses based on the query
160
+ result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
161
+ result += f"Columns: {', '.join(df.columns)}\n\n"
162
+
163
+ # Add summary statistics
164
+ result += "Summary statistics:\n"
165
+ result += str(df.describe())
166
+
167
+ return result
168
+
169
+ except Exception as e:
170
+ return f"Error analyzing CSV file: {str(e)}"
171
+
172
+
173
+ @tool
174
+ def analyze_excel_file(file_path: str, query: str) -> str:
175
+ """
176
+ Analyze an Excel file using pandas and answer a question about it.
177
+ Args:
178
+ file_path (str): the path to the Excel file.
179
+ query (str): Question about the data
180
+ """
181
+ try:
182
+ # Read the Excel file
183
+ df = pd.read_excel(file_path)
184
+
185
+ # Run various analyses based on the query
186
+ result = (
187
+ f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
188
+ )
189
+ result += f"Columns: {', '.join(df.columns)}\n\n"
190
+
191
+ # Add summary statistics
192
+ result += "Summary statistics:\n"
193
+ result += str(df.describe())
194
+
195
+ return result
196
+
197
+ except Exception as e:
198
+ return f"Error analyzing Excel file: {str(e)}"
199
+
200
+
201
+ ### ============== IMAGE PROCESSING AND GENERATION TOOLS =============== ###
202
+ import os
203
+ import io
204
+ import base64
205
+ import uuid
206
+ from PIL import Image, ImageDraw, ImageFont, ImageEnhance, ImageFilter
207
+
208
+ # Helper functions for image processing
209
+ def encode_image(image_path: str) -> str:
210
+ """Convert an image file to base64 string."""
211
+ with open(image_path, "rb") as image_file:
212
+ return base64.b64encode(image_file.read()).decode("utf-8")
213
+
214
+
215
+ def decode_image(base64_string: str) -> Image.Image:
216
+ """Convert a base64 string to a PIL Image."""
217
+ image_data = base64.b64decode(base64_string)
218
+ return Image.open(io.BytesIO(image_data))
219
+
220
+
221
+ def save_image(image: Image.Image, directory: str = "image_outputs") -> str:
222
+ """Save a PIL Image to disk and return the path."""
223
+ os.makedirs(directory, exist_ok=True)
224
+ image_id = str(uuid.uuid4())
225
+ image_path = os.path.join(directory, f"{image_id}.png")
226
+ image.save(image_path)
227
+ return image_path
228
+
229
+ @tool
230
+ def analyze_image(image_base64: str) -> Dict[str, Any]:
231
+ """
232
+ Analyze basic properties of an image (size, mode, color analysis, thumbnail preview).
233
+ Args:
234
+ image_base64 (str): Base64 encoded image string
235
+ Returns:
236
+ Dictionary with analysis result
237
+ """
238
+ try:
239
+ img = decode_image(image_base64)
240
+ width, height = img.size
241
+ mode = img.mode
242
+
243
+ if mode in ("RGB", "RGBA"):
244
+ arr = np.array(img)
245
+ avg_colors = arr.mean(axis=(0, 1))
246
+ dominant = ["Red", "Green", "Blue"][np.argmax(avg_colors[:3])]
247
+ brightness = avg_colors.mean()
248
+ color_analysis = {
249
+ "average_rgb": avg_colors.tolist(),
250
+ "brightness": brightness,
251
+ "dominant_color": dominant,
252
+ }
253
+ else:
254
+ color_analysis = {"note": f"No color analysis for mode {mode}"}
255
+
256
+ thumbnail = img.copy()
257
+ thumbnail.thumbnail((100, 100))
258
+ thumb_path = save_image(thumbnail, "thumbnails")
259
+ thumbnail_base64 = encode_image(thumb_path)
260
+
261
+ return {
262
+ "dimensions": (width, height),
263
+ "mode": mode,
264
+ "color_analysis": color_analysis,
265
+ "thumbnail": thumbnail_base64,
266
+ }
267
+ except Exception as e:
268
+ return {"error": str(e)}
269
+
270
+
271
+ @tool
272
+ def transform_image(
273
+ image_base64: str, operation: str, params: Optional[Dict[str, Any]] = None
274
+ ) -> Dict[str, Any]:
275
+ """
276
+ Apply transformations: resize, rotate, crop, flip, brightness, contrast, blur, sharpen, grayscale.
277
+ Args:
278
+ image_base64 (str): Base64 encoded input image
279
+ operation (str): Transformation operation
280
+ params (Dict[str, Any], optional): Parameters for the operation
281
+ Returns:
282
+ Dictionary with transformed image (base64)
283
+ """
284
+ try:
285
+ img = decode_image(image_base64)
286
+ params = params or {}
287
+
288
+ if operation == "resize":
289
+ img = img.resize(
290
+ (
291
+ params.get("width", img.width // 2),
292
+ params.get("height", img.height // 2),
293
+ )
294
+ )
295
+ elif operation == "rotate":
296
+ img = img.rotate(params.get("angle", 90), expand=True)
297
+ elif operation == "crop":
298
+ img = img.crop(
299
+ (
300
+ params.get("left", 0),
301
+ params.get("top", 0),
302
+ params.get("right", img.width),
303
+ params.get("bottom", img.height),
304
+ )
305
+ )
306
+ elif operation == "flip":
307
+ if params.get("direction", "horizontal") == "horizontal":
308
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
309
+ else:
310
+ img = img.transpose(Image.FLIP_TOP_BOTTOM)
311
+ elif operation == "adjust_brightness":
312
+ img = ImageEnhance.Brightness(img).enhance(params.get("factor", 1.5))
313
+ elif operation == "adjust_contrast":
314
+ img = ImageEnhance.Contrast(img).enhance(params.get("factor", 1.5))
315
+ elif operation == "blur":
316
+ img = img.filter(ImageFilter.GaussianBlur(params.get("radius", 2)))
317
+ elif operation == "sharpen":
318
+ img = img.filter(ImageFilter.SHARPEN)
319
+ elif operation == "grayscale":
320
+ img = img.convert("L")
321
+ else:
322
+ return {"error": f"Unknown operation: {operation}"}
323
+
324
+ result_path = save_image(img)
325
+ result_base64 = encode_image(result_path)
326
+ return {"transformed_image": result_base64}
327
+
328
+ except Exception as e:
329
+ return {"error": str(e)}
330
+
331
+
332
+ @tool
333
+ def draw_on_image(
334
+ image_base64: str, drawing_type: str, params: Dict[str, Any]
335
+ ) -> Dict[str, Any]:
336
+ """
337
+ Draw shapes (rectangle, circle, line) or text onto an image.
338
+ Args:
339
+ image_base64 (str): Base64 encoded input image
340
+ drawing_type (str): Drawing type
341
+ params (Dict[str, Any]): Drawing parameters
342
+ Returns:
343
+ Dictionary with result image (base64)
344
+ """
345
+ try:
346
+ img = decode_image(image_base64)
347
+ draw = ImageDraw.Draw(img)
348
+ color = params.get("color", "red")
349
+
350
+ if drawing_type == "rectangle":
351
+ draw.rectangle(
352
+ [params["left"], params["top"], params["right"], params["bottom"]],
353
+ outline=color,
354
+ width=params.get("width", 2),
355
+ )
356
+ elif drawing_type == "circle":
357
+ x, y, r = params["x"], params["y"], params["radius"]
358
+ draw.ellipse(
359
+ (x - r, y - r, x + r, y + r),
360
+ outline=color,
361
+ width=params.get("width", 2),
362
+ )
363
+ elif drawing_type == "line":
364
+ draw.line(
365
+ (
366
+ params["start_x"],
367
+ params["start_y"],
368
+ params["end_x"],
369
+ params["end_y"],
370
+ ),
371
+ fill=color,
372
+ width=params.get("width", 2),
373
+ )
374
+ elif drawing_type == "text":
375
+ font_size = params.get("font_size", 20)
376
+ try:
377
+ font = ImageFont.truetype("arial.ttf", font_size)
378
+ except IOError:
379
+ font = ImageFont.load_default()
380
+ draw.text(
381
+ (params["x"], params["y"]),
382
+ params.get("text", "Text"),
383
+ fill=color,
384
+ font=font,
385
+ )
386
+ else:
387
+ return {"error": f"Unknown drawing type: {drawing_type}"}
388
+
389
+ result_path = save_image(img)
390
+ result_base64 = encode_image(result_path)
391
+ return {"result_image": result_base64}
392
+
393
+ except Exception as e:
394
+ return {"error": str(e)}
395
+
396
+
397
+ @tool
398
+ def generate_simple_image(
399
+ image_type: str,
400
+ width: int = 500,
401
+ height: int = 500,
402
+ params: Optional[Dict[str, Any]] = None,
403
+ ) -> Dict[str, Any]:
404
+ """
405
+ Generate a simple image (gradient, noise, pattern, chart).
406
+ Args:
407
+ image_type (str): Type of image
408
+ width (int), height (int)
409
+ params (Dict[str, Any], optional): Specific parameters
410
+ Returns:
411
+ Dictionary with generated image (base64)
412
+ """
413
+ try:
414
+ params = params or {}
415
+
416
+ if image_type == "gradient":
417
+ direction = params.get("direction", "horizontal")
418
+ start_color = params.get("start_color", (255, 0, 0))
419
+ end_color = params.get("end_color", (0, 0, 255))
420
+
421
+ img = Image.new("RGB", (width, height))
422
+ draw = ImageDraw.Draw(img)
423
+
424
+ if direction == "horizontal":
425
+ for x in range(width):
426
+ r = int(
427
+ start_color[0] + (end_color[0] - start_color[0]) * x / width
428
+ )
429
+ g = int(
430
+ start_color[1] + (end_color[1] - start_color[1]) * x / width
431
+ )
432
+ b = int(
433
+ start_color[2] + (end_color[2] - start_color[2]) * x / width
434
+ )
435
+ draw.line([(x, 0), (x, height)], fill=(r, g, b))
436
+ else:
437
+ for y in range(height):
438
+ r = int(
439
+ start_color[0] + (end_color[0] - start_color[0]) * y / height
440
+ )
441
+ g = int(
442
+ start_color[1] + (end_color[1] - start_color[1]) * y / height
443
+ )
444
+ b = int(
445
+ start_color[2] + (end_color[2] - start_color[2]) * y / height
446
+ )
447
+ draw.line([(0, y), (width, y)], fill=(r, g, b))
448
+
449
+ elif image_type == "noise":
450
+ noise_array = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
451
+ img = Image.fromarray(noise_array, "RGB")
452
+
453
+ else:
454
+ return {"error": f"Unsupported image_type {image_type}"}
455
+
456
+ result_path = save_image(img)
457
+ result_base64 = encode_image(result_path)
458
+ return {"generated_image": result_base64}
459
+
460
+ except Exception as e:
461
+ return {"error": str(e)}
462
+
463
+
464
+ @tool
465
+ def combine_images(
466
+ images_base64: List[str], operation: str, params: Optional[Dict[str, Any]] = None
467
+ ) -> Dict[str, Any]:
468
+ """
469
+ Combine multiple images (collage, stack, blend).
470
+ Args:
471
+ images_base64 (List[str]): List of base64 images
472
+ operation (str): Combination type
473
+ params (Dict[str, Any], optional)
474
+ Returns:
475
+ Dictionary with combined image (base64)
476
+ """
477
+ try:
478
+ images = [decode_image(b64) for b64 in images_base64]
479
+ params = params or {}
480
+
481
+ if operation == "stack":
482
+ direction = params.get("direction", "horizontal")
483
+ if direction == "horizontal":
484
+ total_width = sum(img.width for img in images)
485
+ max_height = max(img.height for img in images)
486
+ new_img = Image.new("RGB", (total_width, max_height))
487
+ x = 0
488
+ for img in images:
489
+ new_img.paste(img, (x, 0))
490
+ x += img.width
491
+ else:
492
+ max_width = max(img.width for img in images)
493
+ total_height = sum(img.height for img in images)
494
+ new_img = Image.new("RGB", (max_width, total_height))
495
+ y = 0
496
+ for img in images:
497
+ new_img.paste(img, (0, y))
498
+ y += img.height
499
+ else:
500
+ return {"error": f"Unsupported combination operation {operation}"}
501
+
502
+ result_path = save_image(new_img)
503
+ result_base64 = encode_image(result_path)
504
+ return {"combined_image": result_base64}
505
+
506
+ except Exception as e:
507
+ return {"error": str(e)}
508
+
509
+
510
+ tools = [multiply, add, subtract, divide, wikidata_search, tavily_search_tool,combine_images, analyze_image, transform_image, draw_on_image, generate_simple_image, analyze_csv_file, analyze_excel_file, save_and_read_file, download_file_from_image, extract_text_from_image]
511
 
512
  def build_graph():
513
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", api_key=os.getenv("GOOGLE_API_KEY"))