Spaces:
Running
Running
| # agents/forecast_agent.py | |
| """ | |
| Forecast Agent (robust) | |
| - Accepts a list of invoice states (dicts or InvoiceProcessingState models). | |
| - Produces monthly historical spend and a simple forecast (moving average). | |
| - Performs lightweight anomaly detection. | |
| - Returns a dict containing a Plotly chart and numeric summary. | |
| """ | |
| from typing import List, Dict, Any, Union | |
| from datetime import datetime | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import math | |
| import os | |
| # keep the type import only for hints; we do NOT require reconstructing models | |
| try: | |
| from state import InvoiceProcessingState | |
| except Exception: | |
| InvoiceProcessingState = None # type: ignore | |
| class ForecastAgent: | |
| def __init__(self): | |
| pass | |
| # ---- Internal: normalize input states -> DataFrame ---- | |
| def _normalize_states_to_df(self, states: List[Union[dict, object]]) -> pd.DataFrame: | |
| """ | |
| Accepts list of dicts or model instances. | |
| Produces a cleaned DataFrame with columns: | |
| ['file_name','invoice_date','due_date','total','vendor','risk_score','status'] | |
| """ | |
| rows = [] | |
| for s in states: | |
| try: | |
| # 1) obtain a plain dict representation without constructing pydantic models | |
| if isinstance(s, dict): | |
| raw = dict(s) | |
| else: | |
| # model-like object: try model_dump, to_dict, or __dict__ | |
| if hasattr(s, "model_dump"): | |
| raw = s.model_dump(exclude_none=False) | |
| elif hasattr(s, "dict"): | |
| raw = s.dict() | |
| else: | |
| # best effort: convert attributes to dict | |
| raw = { | |
| k: getattr(s, k) | |
| for k in dir(s) | |
| if not k.startswith("_") and not callable(getattr(s, k)) | |
| } | |
| # 2) sanitize well-known problematic fields that break pydantic elsewhere | |
| if "human_review_required" in raw and isinstance(raw["human_review_required"], str): | |
| v = raw["human_review_required"].strip().lower() | |
| raw["human_review_required"] = v in ("true", "yes", "1", "required") | |
| if "escalation_details" in raw and isinstance(raw["escalation_details"], dict): | |
| # convert to string summary so downstream code doesn't expect a dict | |
| try: | |
| raw["escalation_details"] = str(raw["escalation_details"]) | |
| except Exception: | |
| raw["escalation_details"] = "" | |
| # 3) pull invoice_data safely (may be None, dict, or model) | |
| inv = {} | |
| if raw.get("invoice_data") is None: | |
| inv = {} | |
| else: | |
| inv_raw = raw.get("invoice_data") | |
| if isinstance(inv_raw, dict): | |
| inv = dict(inv_raw) | |
| else: | |
| # model-like invoice_data | |
| if hasattr(inv_raw, "model_dump"): | |
| inv = inv_raw.model_dump(exclude_none=False) | |
| elif hasattr(inv_raw, "dict"): | |
| inv = inv_raw.dict() | |
| else: | |
| # fallback: read attributes | |
| inv = { | |
| k: getattr(inv_raw, k) | |
| for k in dir(inv_raw) | |
| if not k.startswith("_") and not callable(getattr(inv_raw, k)) | |
| } | |
| # 4) turnout the row items we care about | |
| total = inv.get("total") or inv.get("amount") or raw.get("total") or 0.0 | |
| # risk may be under risk_assessment.risk_score or top-level | |
| risk_src = raw.get("risk_assessment") or {} | |
| if isinstance(risk_src, dict): | |
| risk_score = risk_src.get("risk_score") or 0.0 | |
| else: | |
| # model-like risk_assessment | |
| if hasattr(risk_src, "model_dump"): | |
| try: | |
| risk_score = risk_src.model_dump().get("risk_score", 0.0) | |
| except Exception: | |
| risk_score = 0.0 | |
| else: | |
| risk_score = getattr(risk_src, "risk_score", 0.0) | |
| # dates: prefer due_date then invoice_date - they could be strings or datetimes | |
| due = inv.get("due_date") or inv.get("invoice_date") or raw.get("due_date") or raw.get("invoice_date") | |
| vendor = inv.get("customer_name") or inv.get("vendor_name") or raw.get("vendor") or raw.get("customer_name") or "Unknown" | |
| file_name = inv.get("file_name") or raw.get("file_name") or "unknown" | |
| rows.append( | |
| { | |
| "file_name": file_name, | |
| "due_date": due, | |
| "invoice_date": inv.get("invoice_date") or raw.get("invoice_date"), | |
| "total": total, | |
| "vendor": vendor, | |
| "risk_score": risk_score, | |
| "status": raw.get("overall_status") or inv.get("status") or "unknown", | |
| } | |
| ) | |
| except Exception: | |
| # skip malformed state | |
| continue | |
| df = pd.DataFrame(rows) | |
| if df.empty: | |
| return df | |
| # coerce and normalize | |
| df["due_date"] = pd.to_datetime(df["due_date"], errors="coerce") | |
| df["invoice_date"] = pd.to_datetime(df["invoice_date"], errors="coerce") | |
| # if due_date missing, fallback to invoice_date | |
| df["date"] = df["due_date"].fillna(df["invoice_date"]) | |
| df["total"] = pd.to_numeric(df["total"], errors="coerce").fillna(0.0) | |
| df["risk_score"] = pd.to_numeric(df["risk_score"], errors="coerce").fillna(0.0) | |
| df["vendor"] = df["vendor"].fillna("Unknown") | |
| return df | |
| # ---- Public: predict monthly cashflow and return a plotly chart ---- | |
| def predict_cashflow(self, states: List[Union[dict, object]], months: int = 6) -> Dict[str, Any]: | |
| """ | |
| Produces a monthly historical spend + simple forecast for `months` into the future. | |
| Returns: | |
| { | |
| "chart": plotly_figure, | |
| "average_monthly_spend": float, | |
| "total_forecast": float, | |
| "forecast_values": {month_str: float, ...}, | |
| "historical": pandas.Series, | |
| "forecast_start_month": str, | |
| "forecast_end_month": str | |
| } | |
| """ | |
| df = self._normalize_states_to_df(states) | |
| if df.empty or df["date"].dropna().empty: | |
| return {"message": "No data to forecast", "chart": None} | |
| # create monthly buckets (period start) | |
| df = df.dropna(subset=["date"]) | |
| df["month"] = df["date"].dt.to_period("M").dt.to_timestamp() | |
| monthly_hist = df.groupby("month")["total"].sum().sort_index() | |
| # compute average monthly spend from available historical months | |
| average_month = float(monthly_hist.mean()) if not monthly_hist.empty else 0.0 | |
| # build forecast months (next `months` starting from the next month after last historical) | |
| last_hist_month = monthly_hist.index.max() | |
| if pd.isnull(last_hist_month): | |
| start_month = pd.Timestamp.now().to_period("M").to_timestamp() | |
| else: | |
| # next month | |
| start_month = (last_hist_month + pd.offsets.MonthBegin(1)).normalize() | |
| forecast_index = pd.date_range(start=start_month, periods=months, freq="MS") | |
| # simple forecast: repeat the historical mean (interpretable and safe) | |
| forecast_vals = [average_month for _ in range(len(forecast_index))] | |
| # build plot dataframe (historical + forecast) | |
| hist_df = monthly_hist.reset_index().rename(columns={"month": "date", "total": "amount"}) | |
| hist_df["type"] = "Historical" | |
| fc_df = pd.DataFrame({"date": forecast_index, "amount": forecast_vals}) | |
| fc_df["type"] = "Forecast" | |
| plot_df = pd.concat([hist_df, fc_df], ignore_index=True).sort_values("date") | |
| # prepare a plotly figure with clear styling | |
| fig = go.Figure() | |
| # historical - solid line | |
| hist_plot = plot_df[plot_df["type"] == "Historical"] | |
| if not hist_plot.empty: | |
| fig.add_trace(go.Scatter( | |
| x=hist_plot["date"], | |
| y=hist_plot["amount"], | |
| mode="lines+markers", | |
| name="Historical Spend", | |
| line=dict(dash="solid"), | |
| )) | |
| # forecast - dashed line | |
| fc_plot = plot_df[plot_df["type"] == "Forecast"] | |
| if not fc_plot.empty: | |
| fig.add_trace(go.Scatter( | |
| x=fc_plot["date"], | |
| y=fc_plot["amount"], | |
| mode="lines+markers", | |
| name="Forecast", | |
| line=dict(dash="dash"), | |
| marker=dict(symbol="circle-open") | |
| )) | |
| fig.update_layout( | |
| title="Monthly Spend (Historical + Forecast)", | |
| xaxis_title="Month", | |
| yaxis_title="Total Spend (USD)", | |
| hovermode="x unified", | |
| template="plotly_dark", | |
| ) | |
| forecast_series = pd.Series(forecast_vals, index=[d.strftime("%Y-%m") for d in forecast_index]) | |
| total_forecast = float(forecast_series.sum()) | |
| result = { | |
| "chart": fig, | |
| "average_monthly_spend": round(average_month, 2), | |
| "total_forecast": round(total_forecast, 2), | |
| "forecast_values": forecast_series.to_dict(), | |
| "historical": monthly_hist, | |
| "forecast_start_month": forecast_index[0].strftime("%Y-%m"), | |
| "forecast_end_month": forecast_index[-1].strftime("%Y-%m"), | |
| } | |
| return result | |
| # ---- Public: detect anomalies on sanitized data ---- | |
| def detect_anomalies(self, states: List[Union[dict, object]]) -> pd.DataFrame: | |
| """ | |
| Returns DataFrame of anomalies: | |
| - total > 2 * mean(total) | |
| - OR risk_score >= 0.7 | |
| Columns returned: ['file_name','date','vendor','total','risk_score','anomaly_reason'] | |
| """ | |
| df = self._normalize_states_to_df(states) | |
| if df.empty: | |
| return pd.DataFrame() | |
| mean_spend = df["total"].mean() | |
| cond = (df["total"] > mean_spend * 2) | (df["risk_score"] >= 0.7) | |
| anomalies = df.loc[cond, ["file_name", "date", "vendor", "total", "risk_score"]].copy() | |
| if anomalies.empty: | |
| return pd.DataFrame() | |
| anomalies = anomalies.rename(columns={"date": "invoice_date"}) | |
| anomalies["anomaly_reason"] = anomalies.apply( | |
| lambda r: "High Spend" if r["total"] > mean_spend * 2 else "High Risk", | |
| axis=1, | |
| ) | |
| return anomalies.reset_index(drop=True) | |