InvoiceAgenticAI / agents /forecast_agent.py
PARTHASAKHAPAUL
Restructure project and sync local as source of truth
9107a5d
# agents/forecast_agent.py
"""
Forecast Agent (robust)
- Accepts a list of invoice states (dicts or InvoiceProcessingState models).
- Produces monthly historical spend and a simple forecast (moving average).
- Performs lightweight anomaly detection.
- Returns a dict containing a Plotly chart and numeric summary.
"""
from typing import List, Dict, Any, Union
from datetime import datetime
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import math
import os
# keep the type import only for hints; we do NOT require reconstructing models
try:
from state import InvoiceProcessingState
except Exception:
InvoiceProcessingState = None # type: ignore
class ForecastAgent:
def __init__(self):
pass
# ---- Internal: normalize input states -> DataFrame ----
def _normalize_states_to_df(self, states: List[Union[dict, object]]) -> pd.DataFrame:
"""
Accepts list of dicts or model instances.
Produces a cleaned DataFrame with columns:
['file_name','invoice_date','due_date','total','vendor','risk_score','status']
"""
rows = []
for s in states:
try:
# 1) obtain a plain dict representation without constructing pydantic models
if isinstance(s, dict):
raw = dict(s)
else:
# model-like object: try model_dump, to_dict, or __dict__
if hasattr(s, "model_dump"):
raw = s.model_dump(exclude_none=False)
elif hasattr(s, "dict"):
raw = s.dict()
else:
# best effort: convert attributes to dict
raw = {
k: getattr(s, k)
for k in dir(s)
if not k.startswith("_") and not callable(getattr(s, k))
}
# 2) sanitize well-known problematic fields that break pydantic elsewhere
if "human_review_required" in raw and isinstance(raw["human_review_required"], str):
v = raw["human_review_required"].strip().lower()
raw["human_review_required"] = v in ("true", "yes", "1", "required")
if "escalation_details" in raw and isinstance(raw["escalation_details"], dict):
# convert to string summary so downstream code doesn't expect a dict
try:
raw["escalation_details"] = str(raw["escalation_details"])
except Exception:
raw["escalation_details"] = ""
# 3) pull invoice_data safely (may be None, dict, or model)
inv = {}
if raw.get("invoice_data") is None:
inv = {}
else:
inv_raw = raw.get("invoice_data")
if isinstance(inv_raw, dict):
inv = dict(inv_raw)
else:
# model-like invoice_data
if hasattr(inv_raw, "model_dump"):
inv = inv_raw.model_dump(exclude_none=False)
elif hasattr(inv_raw, "dict"):
inv = inv_raw.dict()
else:
# fallback: read attributes
inv = {
k: getattr(inv_raw, k)
for k in dir(inv_raw)
if not k.startswith("_") and not callable(getattr(inv_raw, k))
}
# 4) turnout the row items we care about
total = inv.get("total") or inv.get("amount") or raw.get("total") or 0.0
# risk may be under risk_assessment.risk_score or top-level
risk_src = raw.get("risk_assessment") or {}
if isinstance(risk_src, dict):
risk_score = risk_src.get("risk_score") or 0.0
else:
# model-like risk_assessment
if hasattr(risk_src, "model_dump"):
try:
risk_score = risk_src.model_dump().get("risk_score", 0.0)
except Exception:
risk_score = 0.0
else:
risk_score = getattr(risk_src, "risk_score", 0.0)
# dates: prefer due_date then invoice_date - they could be strings or datetimes
due = inv.get("due_date") or inv.get("invoice_date") or raw.get("due_date") or raw.get("invoice_date")
vendor = inv.get("customer_name") or inv.get("vendor_name") or raw.get("vendor") or raw.get("customer_name") or "Unknown"
file_name = inv.get("file_name") or raw.get("file_name") or "unknown"
rows.append(
{
"file_name": file_name,
"due_date": due,
"invoice_date": inv.get("invoice_date") or raw.get("invoice_date"),
"total": total,
"vendor": vendor,
"risk_score": risk_score,
"status": raw.get("overall_status") or inv.get("status") or "unknown",
}
)
except Exception:
# skip malformed state
continue
df = pd.DataFrame(rows)
if df.empty:
return df
# coerce and normalize
df["due_date"] = pd.to_datetime(df["due_date"], errors="coerce")
df["invoice_date"] = pd.to_datetime(df["invoice_date"], errors="coerce")
# if due_date missing, fallback to invoice_date
df["date"] = df["due_date"].fillna(df["invoice_date"])
df["total"] = pd.to_numeric(df["total"], errors="coerce").fillna(0.0)
df["risk_score"] = pd.to_numeric(df["risk_score"], errors="coerce").fillna(0.0)
df["vendor"] = df["vendor"].fillna("Unknown")
return df
# ---- Public: predict monthly cashflow and return a plotly chart ----
def predict_cashflow(self, states: List[Union[dict, object]], months: int = 6) -> Dict[str, Any]:
"""
Produces a monthly historical spend + simple forecast for `months` into the future.
Returns:
{
"chart": plotly_figure,
"average_monthly_spend": float,
"total_forecast": float,
"forecast_values": {month_str: float, ...},
"historical": pandas.Series,
"forecast_start_month": str,
"forecast_end_month": str
}
"""
df = self._normalize_states_to_df(states)
if df.empty or df["date"].dropna().empty:
return {"message": "No data to forecast", "chart": None}
# create monthly buckets (period start)
df = df.dropna(subset=["date"])
df["month"] = df["date"].dt.to_period("M").dt.to_timestamp()
monthly_hist = df.groupby("month")["total"].sum().sort_index()
# compute average monthly spend from available historical months
average_month = float(monthly_hist.mean()) if not monthly_hist.empty else 0.0
# build forecast months (next `months` starting from the next month after last historical)
last_hist_month = monthly_hist.index.max()
if pd.isnull(last_hist_month):
start_month = pd.Timestamp.now().to_period("M").to_timestamp()
else:
# next month
start_month = (last_hist_month + pd.offsets.MonthBegin(1)).normalize()
forecast_index = pd.date_range(start=start_month, periods=months, freq="MS")
# simple forecast: repeat the historical mean (interpretable and safe)
forecast_vals = [average_month for _ in range(len(forecast_index))]
# build plot dataframe (historical + forecast)
hist_df = monthly_hist.reset_index().rename(columns={"month": "date", "total": "amount"})
hist_df["type"] = "Historical"
fc_df = pd.DataFrame({"date": forecast_index, "amount": forecast_vals})
fc_df["type"] = "Forecast"
plot_df = pd.concat([hist_df, fc_df], ignore_index=True).sort_values("date")
# prepare a plotly figure with clear styling
fig = go.Figure()
# historical - solid line
hist_plot = plot_df[plot_df["type"] == "Historical"]
if not hist_plot.empty:
fig.add_trace(go.Scatter(
x=hist_plot["date"],
y=hist_plot["amount"],
mode="lines+markers",
name="Historical Spend",
line=dict(dash="solid"),
))
# forecast - dashed line
fc_plot = plot_df[plot_df["type"] == "Forecast"]
if not fc_plot.empty:
fig.add_trace(go.Scatter(
x=fc_plot["date"],
y=fc_plot["amount"],
mode="lines+markers",
name="Forecast",
line=dict(dash="dash"),
marker=dict(symbol="circle-open")
))
fig.update_layout(
title="Monthly Spend (Historical + Forecast)",
xaxis_title="Month",
yaxis_title="Total Spend (USD)",
hovermode="x unified",
template="plotly_dark",
)
forecast_series = pd.Series(forecast_vals, index=[d.strftime("%Y-%m") for d in forecast_index])
total_forecast = float(forecast_series.sum())
result = {
"chart": fig,
"average_monthly_spend": round(average_month, 2),
"total_forecast": round(total_forecast, 2),
"forecast_values": forecast_series.to_dict(),
"historical": monthly_hist,
"forecast_start_month": forecast_index[0].strftime("%Y-%m"),
"forecast_end_month": forecast_index[-1].strftime("%Y-%m"),
}
return result
# ---- Public: detect anomalies on sanitized data ----
def detect_anomalies(self, states: List[Union[dict, object]]) -> pd.DataFrame:
"""
Returns DataFrame of anomalies:
- total > 2 * mean(total)
- OR risk_score >= 0.7
Columns returned: ['file_name','date','vendor','total','risk_score','anomaly_reason']
"""
df = self._normalize_states_to_df(states)
if df.empty:
return pd.DataFrame()
mean_spend = df["total"].mean()
cond = (df["total"] > mean_spend * 2) | (df["risk_score"] >= 0.7)
anomalies = df.loc[cond, ["file_name", "date", "vendor", "total", "risk_score"]].copy()
if anomalies.empty:
return pd.DataFrame()
anomalies = anomalies.rename(columns={"date": "invoice_date"})
anomalies["anomaly_reason"] = anomalies.apply(
lambda r: "High Spend" if r["total"] > mean_spend * 2 else "High Risk",
axis=1,
)
return anomalies.reset_index(drop=True)