pulseai-api / cache.py
r-bansal's picture
fix: include demo data and cache for HF Spaces deployment
a30bf5b
from __future__ import annotations
import os
import pickle
from pathlib import Path
import numpy as np
import pandas as pd
# Pre-computed forecasts for the three demo datasets live here.
# Check local (inside backend/) first, then project root (parent.parent)
_HERE = Path(__file__).parent
CACHE_DIR = _HERE / "demo_cache" if (_HERE / "demo_cache").exists() else _HERE.parent / "demo_cache"
DEMO_DIR = _HERE / "demo_data" if (_HERE / "demo_data").exists() else _HERE.parent / "demo_data"
_CACHE: dict[str, dict] = {} # in-memory after first load
# ─── Public API ───────────────────────────────────────────────────────────────
def load_all() -> None:
"""
Called once at FastAPI startup after the model is loaded.
Reads all .pkl files from demo_cache/ into memory.
Missing files are skipped silently β€” demo still works, just slower.
"""
CACHE_DIR.mkdir(parents=True, exist_ok=True)
for pkl_file in CACHE_DIR.glob("*.pkl"):
key = pkl_file.stem
try:
with open(pkl_file, "rb") as f:
_CACHE[key] = pickle.load(f)
except Exception:
pass
def get(key: str) -> dict | None:
return _CACHE.get(key)
def has(key: str) -> bool:
return key in _CACHE
# ─── Pre-compute script ───────────────────────────────────────────────────────
def build_all() -> None:
"""
Generates and saves pre-computed forecasts for all demo datasets.
Requires the model to already be loaded (call forecaster.load_model() first).
"""
import forecaster
from preprocessor import ingest, prepare_series
from calibrator import calibrate
from detector import build_detector
from confidence import compute as confidence_score
from decision import get_decision
from baseline import select_and_run
CACHE_DIR.mkdir(parents=True, exist_ok=True)
demos = {
"bakery": ("bakery_sales.csv", "date", "weekly_sales_inr"),
"crop": ("crop_prices_sample.csv", "date", "wheat_price_inr_per_quintal"),
"m5": ("walmart_m5_sample.csv", "date", "FOODS_1"),
}
for key, (filename, date_col, value_col) in demos.items():
csv_path = DEMO_DIR / filename
if not csv_path.exists():
print(f" skip {key} β€” {csv_path} not found")
continue
try:
print(f" building {key}...")
with open(csv_path, "rb") as f:
file_bytes = f.read()
session_data = ingest(file_bytes, filename)
session_id = session_data["session_id"]
prepared = prepare_series(session_id, date_col, value_col)
series = prepared["series"]
result = forecaster.run_forecast(series, horizon=4, frequency=prepared["frequency"])
cal = calibrate(series, result["low"], result["high"])
cal_low = cal["calibrated_low"]
cal_high = cal["calibrated_high"]
hist_std = float(np.std(series))
score, label = confidence_score(cal_low, cal_high, hist_std)
last_val = float(series[-1])
first_fc = float(result["median"][0])
trend_pct = ((first_fc - last_val) / (last_val + 1e-9)) * 100
decision = get_decision(
trend_pct=trend_pct,
confidence=score,
cusum_alert="NONE",
is_financial=prepared["is_financial"],
is_intermittent=prepared["is_intermittent"],
)
detector = build_detector(series)
dates = prepared["dates"]
fc_dates = _future_dates(dates[-1], 4, prepared["frequency"])
payload = {
"forecast": [
{
"date": fc_dates[i],
"low": float(cal_low[i]),
"median": float(result["median"][i]),
"high": float(cal_high[i]),
}
for i in range(4)
],
"baseline": [
{"date": fc_dates[i], "value": float(result["baseline"][i])}
for i in range(4)
],
"baseline_type": result["baseline_type"],
"confidence_score": score,
"confidence_label": label,
"decision": decision,
"trend_pct": round(trend_pct, 2),
"fallback_bands": cal["fallback"],
"is_financial": prepared["is_financial"],
"is_intermittent": prepared["is_intermittent"],
"history_dates": dates[-52:],
"history_values": [float(v) for v in series[-52:]],
"series_name": value_col.replace("_", " "),
"frequency": prepared["frequency"],
"_detector": detector,
"_alpha": cal["alpha"],
"_hist_std": hist_std,
"_series": series,
}
pkl_path = CACHE_DIR / f"{key}.pkl"
with open(pkl_path, "wb") as f:
pickle.dump(payload, f)
print(f" βœ“ {key} β†’ {pkl_path.name} "
f"(conf={score}, trend={trend_pct:+.1f}%)")
except Exception as e:
print(f" βœ— {key} failed: {e}")
print("Cache build complete.")
def _future_dates(last_date_str: str, horizon: int, frequency: str) -> list[str]:
last = pd.Timestamp(last_date_str)
freq_map = {"hourly": "h", "daily": "D", "weekly": "W", "monthly": "MS", "quarterly": "QS", "annually": "YS"}
offset = freq_map.get(frequency, "W")
dates = pd.date_range(start=last, periods=horizon + 1, freq=offset)[1:]
return [d.strftime("%Y-%m-%d") for d in dates]
# ─── CLI entry point ──────────────────────────────────────────────────────────
if __name__ == "__main__":
import forecaster
print("Loading Chronos-Bolt-Small...")
forecaster.load_model()
print("Model ready. Building demo cache...\n")
build_all()