File size: 2,414 Bytes
bd97fec
 
 
 
 
 
 
 
 
 
 
 
 
 
927e380
bd97fec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import re
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("Agg")
import seaborn as sns
sns.set_theme(style="whitegrid", palette="muted")
from typing import List
from pathlib import Path
from smolagents import tool

CHARTS_DIR = Path(__file__).parent / "charts"
CHARTS_DIR.mkdir(exist_ok=True)

LAST_CHART: dict = {"path": None}
AGENT_STEPS: list = []


@tool
def generate_chart(
    labels: List[str],
    values: List[float],
    chart_type: str,
    title: str,
    xlabel: str = "",
    ylabel: str = "",
) -> str:
    """
    Generate a bar, pie, or line chart from data and save it as a PNG file.
    Call this whenever the data contains a distribution, ranking, comparison,
    count breakdown, or proportion that would be clearer as a visual.

    Args:
        labels: List of category labels (strings).
        values: List of numeric values matching each label.
        chart_type: 'bar' for counts/rankings/comparisons, 'pie' for proportions, 'line' for trends.
        title: Title displayed on the chart.
        xlabel: X-axis label (bar / line only).
        ylabel: Y-axis label (bar / line only).

    Returns:
        The file path of the saved PNG image.
    """
    fig, ax = plt.subplots(figsize=(10, 5))

    if chart_type == "pie":
        ax.pie(values, labels=labels, autopct="%1.1f%%", startangle=140,
               colors=sns.color_palette("muted", len(labels)))
        ax.axis("equal")
    elif chart_type == "line":
        sns.lineplot(x=labels, y=values, marker="o", linewidth=2.5, ax=ax)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        plt.xticks(rotation=45, ha="right")
    else:
        sns.barplot(x=labels, y=values, hue=labels, palette="muted", legend=False, ax=ax)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        plt.xticks(rotation=45, ha="right")
        for p in ax.patches:
            ax.annotate(f"{p.get_height():.0f}",
                        (p.get_x() + p.get_width() / 2, p.get_height()),
                        ha="center", va="bottom", fontsize=9)

    ax.set_title(title, fontsize=13, fontweight="bold", pad=12)
    fig.tight_layout()

    safe_title = re.sub(r"[^\w\s-]", "", title[:40]).strip().replace(" ", "_")
    filename = str(CHARTS_DIR / f"chart_{safe_title}.png")
    plt.savefig(filename, dpi=150)
    plt.close()

    LAST_CHART["path"] = filename
    return f"Chart saved → {filename}"