dealflow-ai / src /tools /chart_generator.py
PeterBot22's picture
feat: DealFlow AI MVP β€” 3-agent CrewAI due diligence system on HF Spaces
8dcf472 verified
"""
Chart Generator Tool β€” creates financial comparison charts from JSON data.
Used by the Financial Analyst agent.
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from crewai.tools import BaseTool
from loguru import logger
from pydantic import Field
class ChartGeneratorTool(BaseTool):
name: str = "chart_generator"
description: str = (
"Generate a bar chart or line chart from financial data. "
"Input JSON string with keys: "
"'chart_type' (bar|line), 'title' (str), 'labels' (list[str]), "
"'values' (list[float]), 'output_dir' (str, optional). "
"Returns the file path of the saved PNG chart."
)
output_dir: str = Field(default="./outputs")
def _run(self, input_json: str) -> str:
try:
data = json.loads(input_json)
except json.JSONDecodeError as exc:
return f"Error: invalid JSON input β€” {exc}"
try:
import matplotlib
matplotlib.use("Agg") # non-interactive backend β€” safe for servers
import matplotlib.pyplot as plt
chart_type = data.get("chart_type", "bar")
title = data.get("title", "Financial Chart")
labels = data.get("labels", [])
values = data.get("values", [])
base = Path(self.output_dir).resolve()
# Always use the configured base dir; ignore any LLM-supplied path override
out_dir = base
out_dir.mkdir(parents=True, exist_ok=True)
if not labels or not values:
return "Error: 'labels' and 'values' are required and must be non-empty."
if len(labels) != len(values):
return f"Error: labels ({len(labels)}) and values ({len(values)}) length mismatch."
fig, ax = plt.subplots(figsize=(10, 5))
if chart_type == "line":
ax.plot(labels, values, marker="o", linewidth=2, color="#2563EB")
ax.fill_between(range(len(labels)), values, alpha=0.1, color="#2563EB")
else:
bars = ax.bar(labels, values, color="#2563EB", edgecolor="white", linewidth=0.8)
for bar, val in zip(bars, values):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height(),
f"{val:,.0f}",
ha="center",
va="bottom",
fontsize=8,
)
ax.set_title(title, fontsize=14, fontweight="bold", pad=12)
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=15, ha="right")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.tight_layout()
safe_title = "".join(c if c.isalnum() else "_" for c in title)[:40]
out_path = out_dir / f"{safe_title}.png"
fig.savefig(str(out_path), dpi=150, bbox_inches="tight")
plt.close(fig)
logger.info(f"Chart saved: {out_path}")
return f"Chart saved to: {out_path}"
except Exception as exc:
logger.error(f"ChartGeneratorTool error: {exc}")
return f"Error generating chart: {exc}"