Spaces:
Sleeping
Sleeping
File size: 4,018 Bytes
27e06fc d4d436a b2f206d d4d436a 3fa11a9 d4d436a 3fa11a9 d4d436a 27e06fc d4d436a 27e06fc d4d436a 3fa11a9 d4d436a 2ebfa62 d4d436a 2ebfa62 d4d436a 2ebfa62 27e06fc d4d436a 27e06fc d4d436a 27e06fc 0a6f3b4 27e06fc d4d436a 27e06fc d4d436a 27e06fc 0a6f3b4 27e06fc d4d436a 27e06fc d4d436a 27e06fc d4d436a 27e06fc d4d436a 27e06fc d4d436a b2f206d d4d436a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | import base64
import io
import json
from typing import Dict, List, Literal, Tuple
import matplotlib.pyplot as plt
from langchain_core.tools import tool
# Use the @tool decorator and specify the "content_and_artifact" response format.
@tool(response_format="content_and_artifact")
def generate_plot(
data: Dict[str, float],
plot_type: Literal["bar", "line", "pie"],
title: str = "Generated Plot",
labels: List[str] = None,
x_label: str = "",
y_label: str = ""
) -> Tuple:
"""
Generates a plot (bar, line, or pie) from a dictionary of data and returns it
as a base64 encoded PNG image artifact.
Args:
data (Dict[str, float]): A dictionary where keys are labels and values are the numeric data to plot.
plot_type (Literal["bar", "line", "pie"]): The type of plot to generate.
title (str): The title for the plot.
labels (List[str]): Optional list of labels to use for the x-axis or pie slices. If not provided, data keys are used.
x_label (str): The label for the x-axis (for bar and line charts).
y_label (str): The label for the y-axis (for bar and line charts).
Returns:
A tuple containing:
- A string message confirming the plot was generated.
- A dictionary artifact with the base64 encoded image string and its format.
"""
# --- Input Validation ---
if not isinstance(data, dict) or not data:
content = "Error: Data must be a non-empty dictionary."
artifact = {"error": content}
return content, artifact
try:
y_data = [float(val) for val in data.values()]
except (ValueError, TypeError):
content = "Error: All data values must be numeric."
artifact = {"error": content}
return content, artifact
x_data = list(data.keys())
# --- Plot Generation ---
try:
fig, ax = plt.subplots(figsize=(10, 6))
if plot_type == 'bar':
# Use provided labels if they match the data length, otherwise use data keys
bar_labels = labels if labels and len(labels) == len(x_data) else x_data
bars = ax.bar(bar_labels, y_data)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_ylim(bottom=0)
for bar, value in zip(bars, y_data):
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height, f'{value}', ha='center', va='bottom')
elif plot_type == 'line':
line_labels = labels if labels and len(labels) == len(x_data) else x_data
ax.plot(line_labels, y_data, marker='o')
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_ylim(bottom=0)
ax.grid(True, alpha=0.3)
elif plot_type == 'pie':
pie_labels = labels if labels and len(labels) == len(y_data) else list(data.keys())
ax.pie(y_data, labels=pie_labels, autopct='%1.1f%%', startangle=90)
ax.axis('equal')
else:
content = f"Error: Invalid plot_type '{plot_type}'. Choose 'bar', 'line', or 'pie'."
artifact = {"error": content}
return content, artifact
ax.set_title(title, fontsize=14, fontweight='bold')
plt.tight_layout()
# --- In-Memory Image Conversion ---
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150)
plt.close(fig)
buf.seek(0)
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
# --- Return Content and Artifact ---
content = f"Successfully generated a {plot_type} plot titled '{title}'."
artifact = {
"base64_image": img_base64,
"format": "png"
}
return content, artifact
except Exception as e:
plt.close('all')
content = f"An unexpected error occurred while generating the plot: {str(e)}"
artifact = {"error": str(e)}
return content, artifact |