File size: 3,334 Bytes
8dcf472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
Chart Generator Tool — creates financial comparison charts from JSON data.
Used by the Financial Analyst agent.
"""
from __future__ import annotations

import json
import os
from pathlib import Path

from crewai.tools import BaseTool
from loguru import logger
from pydantic import Field


class ChartGeneratorTool(BaseTool):
    name: str = "chart_generator"
    description: str = (
        "Generate a bar chart or line chart from financial data. "
        "Input JSON string with keys: "
        "'chart_type' (bar|line), 'title' (str), 'labels' (list[str]), "
        "'values' (list[float]), 'output_dir' (str, optional). "
        "Returns the file path of the saved PNG chart."
    )

    output_dir: str = Field(default="./outputs")

    def _run(self, input_json: str) -> str:
        try:
            data = json.loads(input_json)
        except json.JSONDecodeError as exc:
            return f"Error: invalid JSON input — {exc}"

        try:
            import matplotlib
            matplotlib.use("Agg")  # non-interactive backend — safe for servers
            import matplotlib.pyplot as plt

            chart_type = data.get("chart_type", "bar")
            title = data.get("title", "Financial Chart")
            labels = data.get("labels", [])
            values = data.get("values", [])
            base = Path(self.output_dir).resolve()
            # Always use the configured base dir; ignore any LLM-supplied path override
            out_dir = base
            out_dir.mkdir(parents=True, exist_ok=True)

            if not labels or not values:
                return "Error: 'labels' and 'values' are required and must be non-empty."
            if len(labels) != len(values):
                return f"Error: labels ({len(labels)}) and values ({len(values)}) length mismatch."

            fig, ax = plt.subplots(figsize=(10, 5))
            if chart_type == "line":
                ax.plot(labels, values, marker="o", linewidth=2, color="#2563EB")
                ax.fill_between(range(len(labels)), values, alpha=0.1, color="#2563EB")
            else:
                bars = ax.bar(labels, values, color="#2563EB", edgecolor="white", linewidth=0.8)
                for bar, val in zip(bars, values):
                    ax.text(
                        bar.get_x() + bar.get_width() / 2,
                        bar.get_height(),
                        f"{val:,.0f}",
                        ha="center",
                        va="bottom",
                        fontsize=8,
                    )

            ax.set_title(title, fontsize=14, fontweight="bold", pad=12)
            ax.set_xticks(range(len(labels)))
            ax.set_xticklabels(labels, rotation=15, ha="right")
            ax.spines["top"].set_visible(False)
            ax.spines["right"].set_visible(False)
            plt.tight_layout()

            safe_title = "".join(c if c.isalnum() else "_" for c in title)[:40]
            out_path = out_dir / f"{safe_title}.png"
            fig.savefig(str(out_path), dpi=150, bbox_inches="tight")
            plt.close(fig)

            logger.info(f"Chart saved: {out_path}")
            return f"Chart saved to: {out_path}"

        except Exception as exc:
            logger.error(f"ChartGeneratorTool error: {exc}")
            return f"Error generating chart: {exc}"