#graph_tool.py import base64 import io import json from typing import Dict, List, Literal, Tuple import matplotlib.pyplot as plt from langchain_core.tools import tool # Use the @tool decorator and specify the "content_and_artifact" response format. @tool(response_format="content_and_artifact") def generate_plot( data: Dict[str, float], plot_type: Literal["bar", "line", "pie"], title: str = "Generated Plot", labels: List[str] = None, x_label: str = "", y_label: str = "" ) -> Tuple: """ Generates a plot (bar, line, or pie) from a dictionary of data and returns it as a base64 encoded PNG image artifact. Args: data (Dict[str, float]): A dictionary where keys are labels and values are the numeric data to plot. plot_type (Literal["bar", "line", "pie"]): The type of plot to generate. title (str): The title for the plot. labels (List[str]): Optional list of labels to use for the x-axis or pie slices. If not provided, data keys are used. x_label (str): The label for the x-axis (for bar and line charts). y_label (str): The label for the y-axis (for bar and line charts). Returns: A tuple containing: - A string message confirming the plot was generated. - A dictionary artifact with the base64 encoded image string and its format. """ # --- Input Validation --- if not isinstance(data, dict) or not data: content = "Error: Data must be a non-empty dictionary." artifact = {"error": content} return content, artifact try: y_data = [float(val) for val in data.values()] except (ValueError, TypeError): content = "Error: All data values must be numeric." artifact = {"error": content} return content, artifact x_data = list(data.keys()) # --- Plot Generation --- try: fig, ax = plt.subplots(figsize=(10, 6)) if plot_type == 'bar': # Use provided labels if they match the data length, otherwise use data keys bar_labels = labels if labels and len(labels) == len(x_data) else x_data bars = ax.bar(bar_labels, y_data) ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.set_ylim(bottom=0) for bar, value in zip(bars, y_data): height = bar.get_height() ax.text(bar.get_x() + bar.get_width()/2., height, f'{value}', ha='center', va='bottom') elif plot_type == 'line': line_labels = labels if labels and len(labels) == len(x_data) else x_data ax.plot(line_labels, y_data, marker='o') ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.set_ylim(bottom=0) ax.grid(True, alpha=0.3) elif plot_type == 'pie': pie_labels = labels if labels and len(labels) == len(y_data) else list(data.keys()) ax.pie(y_data, labels=pie_labels, autopct='%1.1f%%', startangle=90) ax.axis('equal') else: content = f"Error: Invalid plot_type '{plot_type}'. Choose 'bar', 'line', or 'pie'." artifact = {"error": content} return content, artifact ax.set_title(title, fontsize=14, fontweight='bold') plt.tight_layout() # --- In-Memory Image Conversion --- buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150) plt.close(fig) buf.seek(0) img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') # --- Return Content and Artifact --- content = f"Successfully generated a {plot_type} plot titled '{title}'." artifact = { "base64_image": img_base64, "format": "png" } return content, artifact except Exception as e: plt.close('all') content = f"An unexpected error occurred while generating the plot: {str(e)}" artifact = {"error": str(e)} return content, artifact