File size: 4,034 Bytes
4846644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#graph_tool.py

import base64
import io
import json
from typing import Dict, List, Literal, Tuple

import matplotlib.pyplot as plt
from langchain_core.tools import tool

# Use the @tool decorator and specify the "content_and_artifact" response format.
@tool(response_format="content_and_artifact")
def generate_plot(
    data: Dict[str, float],
    plot_type: Literal["bar", "line", "pie"],
    title: str = "Generated Plot",
    labels: List[str] = None,
    x_label: str = "",
    y_label: str = ""
) -> Tuple:
    """
    Generates a plot (bar, line, or pie) from a dictionary of data and returns it
    as a base64 encoded PNG image artifact.

    Args:
        data (Dict[str, float]): A dictionary where keys are labels and values are the numeric data to plot.
        plot_type (Literal["bar", "line", "pie"]): The type of plot to generate.
        title (str): The title for the plot.
        labels (List[str]): Optional list of labels to use for the x-axis or pie slices. If not provided, data keys are used.
        x_label (str): The label for the x-axis (for bar and line charts).
        y_label (str): The label for the y-axis (for bar and line charts).

    Returns:
        A tuple containing:
        - A string message confirming the plot was generated.
        - A dictionary artifact with the base64 encoded image string and its format.
    """
    # --- Input Validation ---
    if not isinstance(data, dict) or not data:
        content = "Error: Data must be a non-empty dictionary."
        artifact = {"error": content}
        return content, artifact
    
    try:
        y_data = [float(val) for val in data.values()]
    except (ValueError, TypeError):
        content = "Error: All data values must be numeric."
        artifact = {"error": content}
        return content, artifact

    x_data = list(data.keys())
    
    # --- Plot Generation ---
    try:
        fig, ax = plt.subplots(figsize=(10, 6))

        if plot_type == 'bar':
            # Use provided labels if they match the data length, otherwise use data keys
            bar_labels = labels if labels and len(labels) == len(x_data) else x_data
            bars = ax.bar(bar_labels, y_data)
            ax.set_xlabel(x_label)
            ax.set_ylabel(y_label)
            ax.set_ylim(bottom=0)
            for bar, value in zip(bars, y_data):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height, f'{value}', ha='center', va='bottom')

        elif plot_type == 'line':
            line_labels = labels if labels and len(labels) == len(x_data) else x_data
            ax.plot(line_labels, y_data, marker='o')
            ax.set_xlabel(x_label)
            ax.set_ylabel(y_label)
            ax.set_ylim(bottom=0)
            ax.grid(True, alpha=0.3)

        elif plot_type == 'pie':
            pie_labels = labels if labels and len(labels) == len(y_data) else list(data.keys())
            ax.pie(y_data, labels=pie_labels, autopct='%1.1f%%', startangle=90)
            ax.axis('equal')

        else:
            content = f"Error: Invalid plot_type '{plot_type}'. Choose 'bar', 'line', or 'pie'."
            artifact = {"error": content}
            return content, artifact

        ax.set_title(title, fontsize=14, fontweight='bold')
        plt.tight_layout()

        # --- In-Memory Image Conversion ---
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=150)
        plt.close(fig)
        buf.seek(0)
        
        img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')

        # --- Return Content and Artifact ---
        content = f"Successfully generated a {plot_type} plot titled '{title}'."
        artifact = {
            "base64_image": img_base64,
            "format": "png"
        }
        return content, artifact

    except Exception as e:
        plt.close('all')
        content = f"An unexpected error occurred while generating the plot: {str(e)}"
        artifact = {"error": str(e)}
        return content, artifact