Nam Fam commited on
Commit
9016439
·
1 Parent(s): 472e1d4

update files

Browse files
Files changed (3) hide show
  1. agents/sql_agent/graph.py +2 -2
  2. agents/sql_agent/nodes.py +178 -2
  3. app.py +9 -3
agents/sql_agent/graph.py CHANGED
@@ -37,8 +37,8 @@ def build_graph(visualize: bool = True) -> StateGraph:
37
  "detect_off_topic",
38
  lambda state: state['error'],
39
  path_map={
40
- # True: "generate_answer",
41
- True: "get_db_info",
42
  False: "get_db_info"
43
  }
44
  )
 
37
  "detect_off_topic",
38
  lambda state: state['error'],
39
  path_map={
40
+ True: "generate_answer",
41
+ # True: "get_db_info",
42
  False: "get_db_info"
43
  }
44
  )
agents/sql_agent/nodes.py CHANGED
@@ -332,6 +332,180 @@ def render_visualization(state: SQLAgentState) -> SQLAgentState:
332
  return state
333
 
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  def finalize_output(state: SQLAgentState) -> SQLAgentState:
336
  """
337
  Node hợp nhất kết quả cuối cùng (answer, visualization_output, error, ...).
@@ -372,7 +546,8 @@ def detect_off_topic(state: SQLAgentState) -> SQLAgentState:
372
  )
373
  metadata = {
374
  "topic": "Database Queries",
375
- "additional_context": "The database is about ecommerce products with tables: products, laptops, phones, tablets, promotions, category"
 
376
  }
377
 
378
  validation_result = validator.validate(question, metadata=metadata)
@@ -461,7 +636,8 @@ def generate_answer(state: SQLAgentState) -> SQLAgentState:
461
  state['error'] = state['error'] or "No results found."
462
  if state["off_topic"] == "OFF_TOPIC":
463
  state['error'] = "The question is off-topic."
464
- state["answer"] = "Sorry, I can't assist you with that request."
 
465
  state['step'] = 'generate_answer'
466
  return state
467
 
 
332
  return state
333
 
334
 
335
+ def render_visualization(state: SQLAgentState) -> SQLAgentState:
336
+ """
337
+ Render the visualization from formatted data.
338
+ Output: path to saved image file.
339
+ """
340
+ import matplotlib.pyplot as plt
341
+ import os
342
+ import uuid
343
+ from typing import Dict, Any, Optional
344
+
345
+ def save_fig(fig: plt.Figure) -> str:
346
+ """Save figure to file and return the file path."""
347
+ try:
348
+ output_dir = "output/plots"
349
+ os.makedirs(output_dir, exist_ok=True)
350
+ file_path = os.path.join(output_dir, f"visualization_{uuid.uuid4().hex[:8]}.png")
351
+ fig.savefig(file_path, format="png", bbox_inches="tight", dpi=100)
352
+ plt.close(fig)
353
+ return file_path
354
+ except Exception as e:
355
+ print(f"Error saving figure: {e}")
356
+ return ""
357
+
358
+ def validate_data(data: Dict[str, Any], required_keys: list) -> bool:
359
+ """Validate that data contains all required keys and has valid values."""
360
+ if not all(key in data for key in required_keys):
361
+ return False
362
+ # Check if there's actual data to plot
363
+ if "values" in data and not data["values"]:
364
+ return False
365
+ if "yValues" in data and not data["yValues"]:
366
+ return False
367
+ return True
368
+
369
+ def render_line(data: Dict[str, Any]) -> Optional[str]:
370
+ """Render line chart."""
371
+ required_keys = ["xValues", "yValues"]
372
+ if not validate_data(data, required_keys):
373
+ return None
374
+
375
+ try:
376
+ fig, ax = plt.subplots(figsize=(10, 6))
377
+ x = data["xValues"]
378
+ for series in data["yValues"]:
379
+ if len(x) == len(series["data"]):
380
+ ax.plot(x, series["data"], label=series.get("label", ""), marker='o')
381
+
382
+ ax.set_xlabel(data.get("xAxisLabel", "X"))
383
+ ax.set_ylabel(data.get("yAxisLabel", "Y"))
384
+ ax.set_title(data.get("title", ""))
385
+ if any(series.get("label") for series in data["yValues"]):
386
+ ax.legend()
387
+ plt.tight_layout()
388
+ return save_fig(fig)
389
+ except Exception as e:
390
+ print(f"Error rendering line chart: {e}")
391
+ return None
392
+
393
+ def render_bar(data: Dict[str, Any], horizontal: bool = False) -> Optional[str]:
394
+ """Render bar chart (vertical or horizontal)."""
395
+ required_keys = ["labels", "values"]
396
+ if not validate_data(data, required_keys) or not data["values"]:
397
+ return None
398
+
399
+ try:
400
+ fig, ax = plt.subplots(figsize=(10, 6))
401
+ labels = data["labels"]
402
+ n_series = len(data["values"])
403
+ width = 0.8 / max(1, n_series) # Prevent division by zero
404
+ x_indexes = list(range(len(labels)))
405
+
406
+ for i, series in enumerate(data["values"]):
407
+ if not series["data"]: # Skip empty series
408
+ continue
409
+
410
+ offset = (i - n_series / 2) * width + width / 2
411
+ if horizontal:
412
+ ax.barh(
413
+ [x + offset for x in x_indexes],
414
+ series["data"],
415
+ height=width,
416
+ label=series.get("label", f"Series {i+1}")
417
+ )
418
+ ax.set_yticks(x_indexes)
419
+ ax.set_yticklabels(labels)
420
+ ax.set_xlabel(data.get("xAxisLabel", "Value"))
421
+ ax.set_ylabel(data.get("yAxisLabel", "Category"))
422
+ else:
423
+ ax.bar(
424
+ [x + offset for x in x_indexes],
425
+ series["data"],
426
+ width=width,
427
+ label=series.get("label", f"Series {i+1}")
428
+ )
429
+ ax.set_xticks(x_indexes)
430
+ ax.set_xticklabels(labels, rotation=45, ha='right')
431
+ ax.set_xlabel(data.get("xAxisLabel", "Category"))
432
+ ax.set_ylabel(data.get("yAxisLabel", "Value"))
433
+
434
+ if any(series.get("label") for series in data["values"]):
435
+ ax.legend()
436
+ ax.set_title(data.get("title", ""))
437
+ plt.tight_layout()
438
+ return save_fig(fig)
439
+ except Exception as e:
440
+ print(f"Error rendering {'horizontal ' if horizontal else ''}bar chart: {e}")
441
+ return None
442
+
443
+ def render_scatter(data: Dict[str, Any]) -> Optional[str]:
444
+ """Render scatter plot."""
445
+ required_keys = ["series"]
446
+ if not validate_data(data, required_keys):
447
+ return None
448
+
449
+ try:
450
+ fig, ax = plt.subplots(figsize=(10, 6))
451
+ for series in data["series"]:
452
+ if not series.get("data"):
453
+ continue
454
+ xs = [point.get("x", 0) for point in series["data"]]
455
+ ys = [point.get("y", 0) for point in series["data"]]
456
+ if len(xs) == len(ys):
457
+ ax.scatter(
458
+ xs,
459
+ ys,
460
+ label=series.get("label"),
461
+ alpha=0.6,
462
+ edgecolors='w'
463
+ )
464
+
465
+ ax.set_xlabel(data.get("xAxisLabel", "X"))
466
+ ax.set_ylabel(data.get("yAxisLabel", "Y"))
467
+ ax.set_title(data.get("title", ""))
468
+ if any(series.get("label") for series in data["series"]):
469
+ ax.legend()
470
+ plt.tight_layout()
471
+ return save_fig(fig)
472
+ except Exception as e:
473
+ print(f"Error rendering scatter plot: {e}")
474
+ return None
475
+
476
+ # Main function logic
477
+ data = state.get("formatted_data_for_visualization")
478
+ visualization = state.get("visualization", "none")
479
+ state["visualization_output"] = None
480
+
481
+ if not data or visualization == "none":
482
+ return state
483
+
484
+ try:
485
+ renderers = {
486
+ "line": lambda: render_line(data),
487
+ "bar": lambda: render_bar(data, horizontal=False),
488
+ "horizontal_bar": lambda: render_bar(data, horizontal=True),
489
+ "scatter": lambda: render_scatter(data)
490
+ }
491
+
492
+ if visualization in renderers:
493
+ image_path = renderers[visualization]()
494
+ if image_path and os.path.exists(image_path):
495
+ state["visualization_output"] = image_path
496
+ else:
497
+ state["error"] = "Failed to generate visualization: No valid data to display"
498
+ else:
499
+ state["error"] = f"Unsupported visualization type: {visualization}"
500
+
501
+ except Exception as e:
502
+ state["error"] = f"Error in visualization: {str(e)}"
503
+ print(f"Visualization error: {e}")
504
+
505
+ state["step"] = "render_visualization"
506
+ return state
507
+
508
+
509
  def finalize_output(state: SQLAgentState) -> SQLAgentState:
510
  """
511
  Node hợp nhất kết quả cuối cùng (answer, visualization_output, error, ...).
 
546
  )
547
  metadata = {
548
  "topic": "Database Queries",
549
+ "additional_context": "Only accept queries related to the data on Database/CSV"
550
+ # "additional_context": "The database is about ecommerce products with tables: products, laptops, phones, tablets, promotions, category"
551
  }
552
 
553
  validation_result = validator.validate(question, metadata=metadata)
 
636
  state['error'] = state['error'] or "No results found."
637
  if state["off_topic"] == "OFF_TOPIC":
638
  state['error'] = "The question is off-topic."
639
+ # state["answer"] = "Sorry, I can't assist you with that request."
640
+ state["answer"] = "Sorry, I can only help you with questions about the data! What information would you like to explore from the data?"
641
  state['step'] = 'generate_answer'
642
  return state
643
 
app.py CHANGED
@@ -464,11 +464,17 @@ else:
464
  st.write(step_details.get('answer', ''))
465
  chat_history.append({"role": "assistant", "content": step_details.get('answer', ''), "timestamp": datetime.now()})
466
  elif step_name == 'render_visualization':
467
- # with st.expander("Chart", expanded=False):
468
- st.image(step_details.get('visualization_output', ''))
 
 
 
 
 
 
469
  except Exception as e:
470
  err = f"SQL Agent Error: {e}"
471
- st.error(err)
472
  chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()})
473
  else:
474
  # Use DataFrame agent for selected CSV
 
464
  st.write(step_details.get('answer', ''))
465
  chat_history.append({"role": "assistant", "content": step_details.get('answer', ''), "timestamp": datetime.now()})
466
  elif step_name == 'render_visualization':
467
+ try:
468
+ visualization_output = step_details.get('visualization_output')
469
+ if visualization_output and os.path.exists(visualization_output):
470
+ st.image(visualization_output)
471
+ else:
472
+ print("No visualization was generated for this query.")
473
+ except Exception as e:
474
+ print(f"Could not display visualization: {str(e)}")
475
  except Exception as e:
476
  err = f"SQL Agent Error: {e}"
477
+ print(err)
478
  chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()})
479
  else:
480
  # Use DataFrame agent for selected CSV