"""Optional matplotlib export for plot actions (declarations → real PNG files). The environment normally only *records* plot intent for grading. When ``AUTODATALAB_PLOT_DIR`` is set, the server can call :func:`save_plot_to_png` after a ``plot`` action. Install: ``pip install -e ".[plot]"`` (adds matplotlib). """ from __future__ import annotations import argparse import sys from pathlib import Path from typing import Literal, Optional import pandas as pd PlotKind = Literal["scatter", "bar", "histogram"] def _point_label_column(df: pd.DataFrame, x: str, y: str) -> Optional[str]: """Prefer a label column for annotating scatter points when not used as an axis.""" for c in ("Name", "name", "Product", "product"): if c in df.columns and c not in (x, y): return c return None def _series_numeric_or_datetime(s: pd.Series) -> pd.Series: """Use numeric values when possible; otherwise parse datetimes (e.g. ``OrderDate`` strings).""" num = pd.to_numeric(s, errors="coerce") if num.notna().any(): return num dt = pd.to_datetime(s, errors="coerce", utc=False) if dt.notna().any(): return dt return num def save_plot_to_png( df: pd.DataFrame, plot_type: Optional[str], x: Optional[str], y: Optional[str], out_path: Path, *, title: str = "", ) -> None: """Render a simple figure from the current table and write *out_path* (``.png``).""" import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt out_path = Path(out_path) out_path.parent.mkdir(parents=True, exist_ok=True) pt = (plot_type or "scatter").lower() w, h = (8.5, 5.2) if pt == "scatter" else (7.6, 4.8) fig, ax = plt.subplots(figsize=(w, h)) title = title or f"{pt}: {x!r} vs {y!r}" ax.set_facecolor("#f8fafc") fig.patch.set_facecolor("white") ax.grid(axis="y", color="#cbd5e1", linewidth=0.8, alpha=0.55) for spine in ("top", "right"): ax.spines[spine].set_visible(False) ax.spines["left"].set_color("#94a3b8") ax.spines["bottom"].set_color("#94a3b8") if pt == "scatter": if not x or not y or x not in df.columns or y not in df.columns: raise ValueError(f"scatter requires valid x,y columns; got x={x!r} y={y!r}") xs = _series_numeric_or_datetime(df[x]) ys = pd.to_numeric(df[y], errors="coerce") view = pd.DataFrame({"_x": xs, "_y": ys}).dropna().copy() if pd.api.types.is_datetime64_any_dtype(view["_x"]): view = view.groupby("_x", as_index=False)["_y"].sum().sort_values("_x") ax.plot(view["_x"], view["_y"], color="#93c5fd", linewidth=1.4, zorder=1) ax.scatter(view["_x"], view["_y"], s=42, color="#2563eb", edgecolors="white", linewidths=0.6, alpha=0.92, zorder=2) ax.set_xlabel(x) ax.set_ylabel(y) label_col = _point_label_column(df, x, y) if label_col is not None: for i in range(len(df)): lab = df[label_col].iloc[i] if pd.isna(lab) or (pd.isna(xs.iloc[i]) and pd.isna(ys.iloc[i])): continue if pd.isna(xs.iloc[i]) or pd.isna(ys.iloc[i]): continue ax.annotate( str(lab), (xs.iloc[i], float(ys.iloc[i])), fontsize=7, alpha=0.78, xytext=(4, 4), textcoords="offset points", zorder=3, ) ax.set_title(f"{title} (labels: {label_col})") else: ax.set_title(title) elif pt == "bar": if not x or x not in df.columns: raise ValueError(f"bar requires valid column x={x!r}") if y and y in df.columns: # Category vs sales / revenue: aggregate numeric y per category on x vals = pd.to_numeric(df[y], errors="coerce") g = df.assign(_y=vals).groupby(x, dropna=False, sort=True)["_y"].sum() g = g.dropna(how="all") g = g.sort_values(ascending=False).head(20) g.plot(kind="bar", ax=ax, color="#2563eb", edgecolor="#1e3a8a", width=0.72) ax.set_ylabel(y) else: s = df[x].value_counts().head(20) s.plot(kind="bar", ax=ax, color="#2563eb", edgecolor="#1e3a8a", width=0.72) ax.set_xlabel(x) ax.tick_params(axis="x", rotation=25) elif pt == "histogram": col = x or y if not col or col not in df.columns: raise ValueError(f"histogram requires a column; got x={x!r} y={y!r}") ax.hist( pd.to_numeric(df[col], errors="coerce").dropna(), bins=min(20, max(5, len(df))), color="#2563eb", edgecolor="white", linewidth=0.8, ) ax.set_xlabel(col) else: raise ValueError(f"unsupported plot_type: {plot_type!r}") if pt != "scatter": ax.set_title(title) fig.tight_layout() fig.savefig(out_path, dpi=120, bbox_inches="tight") plt.close(fig) def main(argv: Optional[list[str]] = None) -> int: """CLI: render a CSV + plot spec to PNG (for agent pipelines / debugging).""" p = argparse.ArgumentParser( description="Render a plot from a CSV file (optional artifact export for AutoDataLab)." ) p.add_argument("csv", type=Path, help="Path to CSV (same shape as env working table)") p.add_argument("plot_type", choices=("scatter", "bar", "histogram")) p.add_argument("x", help="X column (or primary column for histogram)") p.add_argument("y", nargs="?", default=None, help="Y column (scatter only)") p.add_argument("-o", "--output", type=Path, default=Path("plot_out.png")) args = p.parse_args(argv) df = pd.read_csv(args.csv) try: save_plot_to_png(df, args.plot_type, args.x, args.y, args.output) except ImportError: print("matplotlib is required: pip install matplotlib", file=sys.stderr) return 1 except Exception as e: print(f"error: {e}", file=sys.stderr) return 1 print(args.output) return 0 def _entry() -> None: raise SystemExit(main()) if __name__ == "__main__": raise SystemExit(main())