Spaces:
Sleeping
Sleeping
File size: 4,034 Bytes
4846644 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | #graph_tool.py
import base64
import io
import json
from typing import Dict, List, Literal, Tuple
import matplotlib.pyplot as plt
from langchain_core.tools import tool
# Use the @tool decorator and specify the "content_and_artifact" response format.
@tool(response_format="content_and_artifact")
def generate_plot(
data: Dict[str, float],
plot_type: Literal["bar", "line", "pie"],
title: str = "Generated Plot",
labels: List[str] = None,
x_label: str = "",
y_label: str = ""
) -> Tuple:
"""
Generates a plot (bar, line, or pie) from a dictionary of data and returns it
as a base64 encoded PNG image artifact.
Args:
data (Dict[str, float]): A dictionary where keys are labels and values are the numeric data to plot.
plot_type (Literal["bar", "line", "pie"]): The type of plot to generate.
title (str): The title for the plot.
labels (List[str]): Optional list of labels to use for the x-axis or pie slices. If not provided, data keys are used.
x_label (str): The label for the x-axis (for bar and line charts).
y_label (str): The label for the y-axis (for bar and line charts).
Returns:
A tuple containing:
- A string message confirming the plot was generated.
- A dictionary artifact with the base64 encoded image string and its format.
"""
# --- Input Validation ---
if not isinstance(data, dict) or not data:
content = "Error: Data must be a non-empty dictionary."
artifact = {"error": content}
return content, artifact
try:
y_data = [float(val) for val in data.values()]
except (ValueError, TypeError):
content = "Error: All data values must be numeric."
artifact = {"error": content}
return content, artifact
x_data = list(data.keys())
# --- Plot Generation ---
try:
fig, ax = plt.subplots(figsize=(10, 6))
if plot_type == 'bar':
# Use provided labels if they match the data length, otherwise use data keys
bar_labels = labels if labels and len(labels) == len(x_data) else x_data
bars = ax.bar(bar_labels, y_data)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_ylim(bottom=0)
for bar, value in zip(bars, y_data):
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height, f'{value}', ha='center', va='bottom')
elif plot_type == 'line':
line_labels = labels if labels and len(labels) == len(x_data) else x_data
ax.plot(line_labels, y_data, marker='o')
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_ylim(bottom=0)
ax.grid(True, alpha=0.3)
elif plot_type == 'pie':
pie_labels = labels if labels and len(labels) == len(y_data) else list(data.keys())
ax.pie(y_data, labels=pie_labels, autopct='%1.1f%%', startangle=90)
ax.axis('equal')
else:
content = f"Error: Invalid plot_type '{plot_type}'. Choose 'bar', 'line', or 'pie'."
artifact = {"error": content}
return content, artifact
ax.set_title(title, fontsize=14, fontweight='bold')
plt.tight_layout()
# --- In-Memory Image Conversion ---
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150)
plt.close(fig)
buf.seek(0)
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
# --- Return Content and Artifact ---
content = f"Successfully generated a {plot_type} plot titled '{title}'."
artifact = {
"base64_image": img_base64,
"format": "png"
}
return content, artifact
except Exception as e:
plt.close('all')
content = f"An unexpected error occurred while generating the plot: {str(e)}"
artifact = {"error": str(e)}
return content, artifact |