Mimir / graph_tool.py
jdesiree's picture
Redesigned for LangGraph Agentic workflow
d4d436a verified
raw
history blame
4.02 kB
import base64
import io
import json
from typing import Dict, List, Literal, Tuple
import matplotlib.pyplot as plt
from langchain_core.tools import tool
# Use the @tool decorator and specify the "content_and_artifact" response format.
@tool(response_format="content_and_artifact")
def generate_plot(
data: Dict[str, float],
plot_type: Literal["bar", "line", "pie"],
title: str = "Generated Plot",
labels: List[str] = None,
x_label: str = "",
y_label: str = ""
) -> Tuple:
"""
Generates a plot (bar, line, or pie) from a dictionary of data and returns it
as a base64 encoded PNG image artifact.
Args:
data (Dict[str, float]): A dictionary where keys are labels and values are the numeric data to plot.
plot_type (Literal["bar", "line", "pie"]): The type of plot to generate.
title (str): The title for the plot.
labels (List[str]): Optional list of labels to use for the x-axis or pie slices. If not provided, data keys are used.
x_label (str): The label for the x-axis (for bar and line charts).
y_label (str): The label for the y-axis (for bar and line charts).
Returns:
A tuple containing:
- A string message confirming the plot was generated.
- A dictionary artifact with the base64 encoded image string and its format.
"""
# --- Input Validation ---
if not isinstance(data, dict) or not data:
content = "Error: Data must be a non-empty dictionary."
artifact = {"error": content}
return content, artifact
try:
y_data = [float(val) for val in data.values()]
except (ValueError, TypeError):
content = "Error: All data values must be numeric."
artifact = {"error": content}
return content, artifact
x_data = list(data.keys())
# --- Plot Generation ---
try:
fig, ax = plt.subplots(figsize=(10, 6))
if plot_type == 'bar':
# Use provided labels if they match the data length, otherwise use data keys
bar_labels = labels if labels and len(labels) == len(x_data) else x_data
bars = ax.bar(bar_labels, y_data)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_ylim(bottom=0)
for bar, value in zip(bars, y_data):
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height, f'{value}', ha='center', va='bottom')
elif plot_type == 'line':
line_labels = labels if labels and len(labels) == len(x_data) else x_data
ax.plot(line_labels, y_data, marker='o')
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_ylim(bottom=0)
ax.grid(True, alpha=0.3)
elif plot_type == 'pie':
pie_labels = labels if labels and len(labels) == len(y_data) else list(data.keys())
ax.pie(y_data, labels=pie_labels, autopct='%1.1f%%', startangle=90)
ax.axis('equal')
else:
content = f"Error: Invalid plot_type '{plot_type}'. Choose 'bar', 'line', or 'pie'."
artifact = {"error": content}
return content, artifact
ax.set_title(title, fontsize=14, fontweight='bold')
plt.tight_layout()
# --- In-Memory Image Conversion ---
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150)
plt.close(fig)
buf.seek(0)
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
# --- Return Content and Artifact ---
content = f"Successfully generated a {plot_type} plot titled '{title}'."
artifact = {
"base64_image": img_base64,
"format": "png"
}
return content, artifact
except Exception as e:
plt.close('all')
content = f"An unexpected error occurred while generating the plot: {str(e)}"
artifact = {"error": str(e)}
return content, artifact