jdesiree commited on
Commit
4846644
·
verified ·
1 Parent(s): 8ac50bc

Upload graph_tool.py

Browse files
Files changed (1) hide show
  1. graph_tool.py +109 -0
graph_tool.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #graph_tool.py
2
+
3
+ import base64
4
+ import io
5
+ import json
6
+ from typing import Dict, List, Literal, Tuple
7
+
8
+ import matplotlib.pyplot as plt
9
+ from langchain_core.tools import tool
10
+
11
+ # Use the @tool decorator and specify the "content_and_artifact" response format.
12
+ @tool(response_format="content_and_artifact")
13
+ def generate_plot(
14
+ data: Dict[str, float],
15
+ plot_type: Literal["bar", "line", "pie"],
16
+ title: str = "Generated Plot",
17
+ labels: List[str] = None,
18
+ x_label: str = "",
19
+ y_label: str = ""
20
+ ) -> Tuple:
21
+ """
22
+ Generates a plot (bar, line, or pie) from a dictionary of data and returns it
23
+ as a base64 encoded PNG image artifact.
24
+
25
+ Args:
26
+ data (Dict[str, float]): A dictionary where keys are labels and values are the numeric data to plot.
27
+ plot_type (Literal["bar", "line", "pie"]): The type of plot to generate.
28
+ title (str): The title for the plot.
29
+ labels (List[str]): Optional list of labels to use for the x-axis or pie slices. If not provided, data keys are used.
30
+ x_label (str): The label for the x-axis (for bar and line charts).
31
+ y_label (str): The label for the y-axis (for bar and line charts).
32
+
33
+ Returns:
34
+ A tuple containing:
35
+ - A string message confirming the plot was generated.
36
+ - A dictionary artifact with the base64 encoded image string and its format.
37
+ """
38
+ # --- Input Validation ---
39
+ if not isinstance(data, dict) or not data:
40
+ content = "Error: Data must be a non-empty dictionary."
41
+ artifact = {"error": content}
42
+ return content, artifact
43
+
44
+ try:
45
+ y_data = [float(val) for val in data.values()]
46
+ except (ValueError, TypeError):
47
+ content = "Error: All data values must be numeric."
48
+ artifact = {"error": content}
49
+ return content, artifact
50
+
51
+ x_data = list(data.keys())
52
+
53
+ # --- Plot Generation ---
54
+ try:
55
+ fig, ax = plt.subplots(figsize=(10, 6))
56
+
57
+ if plot_type == 'bar':
58
+ # Use provided labels if they match the data length, otherwise use data keys
59
+ bar_labels = labels if labels and len(labels) == len(x_data) else x_data
60
+ bars = ax.bar(bar_labels, y_data)
61
+ ax.set_xlabel(x_label)
62
+ ax.set_ylabel(y_label)
63
+ ax.set_ylim(bottom=0)
64
+ for bar, value in zip(bars, y_data):
65
+ height = bar.get_height()
66
+ ax.text(bar.get_x() + bar.get_width()/2., height, f'{value}', ha='center', va='bottom')
67
+
68
+ elif plot_type == 'line':
69
+ line_labels = labels if labels and len(labels) == len(x_data) else x_data
70
+ ax.plot(line_labels, y_data, marker='o')
71
+ ax.set_xlabel(x_label)
72
+ ax.set_ylabel(y_label)
73
+ ax.set_ylim(bottom=0)
74
+ ax.grid(True, alpha=0.3)
75
+
76
+ elif plot_type == 'pie':
77
+ pie_labels = labels if labels and len(labels) == len(y_data) else list(data.keys())
78
+ ax.pie(y_data, labels=pie_labels, autopct='%1.1f%%', startangle=90)
79
+ ax.axis('equal')
80
+
81
+ else:
82
+ content = f"Error: Invalid plot_type '{plot_type}'. Choose 'bar', 'line', or 'pie'."
83
+ artifact = {"error": content}
84
+ return content, artifact
85
+
86
+ ax.set_title(title, fontsize=14, fontweight='bold')
87
+ plt.tight_layout()
88
+
89
+ # --- In-Memory Image Conversion ---
90
+ buf = io.BytesIO()
91
+ plt.savefig(buf, format='png', dpi=150)
92
+ plt.close(fig)
93
+ buf.seek(0)
94
+
95
+ img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
96
+
97
+ # --- Return Content and Artifact ---
98
+ content = f"Successfully generated a {plot_type} plot titled '{title}'."
99
+ artifact = {
100
+ "base64_image": img_base64,
101
+ "format": "png"
102
+ }
103
+ return content, artifact
104
+
105
+ except Exception as e:
106
+ plt.close('all')
107
+ content = f"An unexpected error occurred while generating the plot: {str(e)}"
108
+ artifact = {"error": str(e)}
109
+ return content, artifact