File size: 4,018 Bytes
27e06fc
d4d436a
b2f206d
d4d436a
 
 
 
3fa11a9
d4d436a
 
 
 
 
 
 
 
 
 
3fa11a9
d4d436a
 
 
27e06fc
d4d436a
 
 
 
 
 
 
27e06fc
d4d436a
 
 
3fa11a9
d4d436a
 
 
 
 
2ebfa62
 
d4d436a
 
 
 
 
 
 
2ebfa62
d4d436a
2ebfa62
27e06fc
d4d436a
27e06fc
d4d436a
 
 
27e06fc
 
0a6f3b4
27e06fc
 
d4d436a
 
27e06fc
d4d436a
 
27e06fc
 
0a6f3b4
27e06fc
d4d436a
27e06fc
d4d436a
 
 
 
27e06fc
d4d436a
 
 
 
 
27e06fc
d4d436a
 
 
 
 
 
27e06fc
 
d4d436a
 
 
 
 
 
 
 
 
b2f206d
d4d436a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import base64
import io
import json
from typing import Dict, List, Literal, Tuple

import matplotlib.pyplot as plt
from langchain_core.tools import tool

# Use the @tool decorator and specify the "content_and_artifact" response format.
@tool(response_format="content_and_artifact")
def generate_plot(
    data: Dict[str, float],
    plot_type: Literal["bar", "line", "pie"],
    title: str = "Generated Plot",
    labels: List[str] = None,
    x_label: str = "",
    y_label: str = ""
) -> Tuple:
    """
    Generates a plot (bar, line, or pie) from a dictionary of data and returns it
    as a base64 encoded PNG image artifact.

    Args:
        data (Dict[str, float]): A dictionary where keys are labels and values are the numeric data to plot.
        plot_type (Literal["bar", "line", "pie"]): The type of plot to generate.
        title (str): The title for the plot.
        labels (List[str]): Optional list of labels to use for the x-axis or pie slices. If not provided, data keys are used.
        x_label (str): The label for the x-axis (for bar and line charts).
        y_label (str): The label for the y-axis (for bar and line charts).

    Returns:
        A tuple containing:
        - A string message confirming the plot was generated.
        - A dictionary artifact with the base64 encoded image string and its format.
    """
    # --- Input Validation ---
    if not isinstance(data, dict) or not data:
        content = "Error: Data must be a non-empty dictionary."
        artifact = {"error": content}
        return content, artifact
    
    try:
        y_data = [float(val) for val in data.values()]
    except (ValueError, TypeError):
        content = "Error: All data values must be numeric."
        artifact = {"error": content}
        return content, artifact

    x_data = list(data.keys())
    
    # --- Plot Generation ---
    try:
        fig, ax = plt.subplots(figsize=(10, 6))

        if plot_type == 'bar':
            # Use provided labels if they match the data length, otherwise use data keys
            bar_labels = labels if labels and len(labels) == len(x_data) else x_data
            bars = ax.bar(bar_labels, y_data)
            ax.set_xlabel(x_label)
            ax.set_ylabel(y_label)
            ax.set_ylim(bottom=0)
            for bar, value in zip(bars, y_data):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height, f'{value}', ha='center', va='bottom')

        elif plot_type == 'line':
            line_labels = labels if labels and len(labels) == len(x_data) else x_data
            ax.plot(line_labels, y_data, marker='o')
            ax.set_xlabel(x_label)
            ax.set_ylabel(y_label)
            ax.set_ylim(bottom=0)
            ax.grid(True, alpha=0.3)

        elif plot_type == 'pie':
            pie_labels = labels if labels and len(labels) == len(y_data) else list(data.keys())
            ax.pie(y_data, labels=pie_labels, autopct='%1.1f%%', startangle=90)
            ax.axis('equal')

        else:
            content = f"Error: Invalid plot_type '{plot_type}'. Choose 'bar', 'line', or 'pie'."
            artifact = {"error": content}
            return content, artifact

        ax.set_title(title, fontsize=14, fontweight='bold')
        plt.tight_layout()

        # --- In-Memory Image Conversion ---
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=150)
        plt.close(fig)
        buf.seek(0)
        
        img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')

        # --- Return Content and Artifact ---
        content = f"Successfully generated a {plot_type} plot titled '{title}'."
        artifact = {
            "base64_image": img_base64,
            "format": "png"
        }
        return content, artifact

    except Exception as e:
        plt.close('all')
        content = f"An unexpected error occurred while generating the plot: {str(e)}"
        artifact = {"error": str(e)}
        return content, artifact