Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| import pickle | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| # Pre-computed forecasts for the three demo datasets live here. | |
| # Check local (inside backend/) first, then project root (parent.parent) | |
| _HERE = Path(__file__).parent | |
| CACHE_DIR = _HERE / "demo_cache" if (_HERE / "demo_cache").exists() else _HERE.parent / "demo_cache" | |
| DEMO_DIR = _HERE / "demo_data" if (_HERE / "demo_data").exists() else _HERE.parent / "demo_data" | |
| _CACHE: dict[str, dict] = {} # in-memory after first load | |
| # βββ Public API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_all() -> None: | |
| """ | |
| Called once at FastAPI startup after the model is loaded. | |
| Reads all .pkl files from demo_cache/ into memory. | |
| Missing files are skipped silently β demo still works, just slower. | |
| """ | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| for pkl_file in CACHE_DIR.glob("*.pkl"): | |
| key = pkl_file.stem | |
| try: | |
| with open(pkl_file, "rb") as f: | |
| _CACHE[key] = pickle.load(f) | |
| except Exception: | |
| pass | |
| def get(key: str) -> dict | None: | |
| return _CACHE.get(key) | |
| def has(key: str) -> bool: | |
| return key in _CACHE | |
| # βββ Pre-compute script βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_all() -> None: | |
| """ | |
| Generates and saves pre-computed forecasts for all demo datasets. | |
| Requires the model to already be loaded (call forecaster.load_model() first). | |
| """ | |
| import forecaster | |
| from preprocessor import ingest, prepare_series | |
| from calibrator import calibrate | |
| from detector import build_detector | |
| from confidence import compute as confidence_score | |
| from decision import get_decision | |
| from baseline import select_and_run | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| demos = { | |
| "bakery": ("bakery_sales.csv", "date", "weekly_sales_inr"), | |
| "crop": ("crop_prices_sample.csv", "date", "wheat_price_inr_per_quintal"), | |
| "m5": ("walmart_m5_sample.csv", "date", "FOODS_1"), | |
| } | |
| for key, (filename, date_col, value_col) in demos.items(): | |
| csv_path = DEMO_DIR / filename | |
| if not csv_path.exists(): | |
| print(f" skip {key} β {csv_path} not found") | |
| continue | |
| try: | |
| print(f" building {key}...") | |
| with open(csv_path, "rb") as f: | |
| file_bytes = f.read() | |
| session_data = ingest(file_bytes, filename) | |
| session_id = session_data["session_id"] | |
| prepared = prepare_series(session_id, date_col, value_col) | |
| series = prepared["series"] | |
| result = forecaster.run_forecast(series, horizon=4, frequency=prepared["frequency"]) | |
| cal = calibrate(series, result["low"], result["high"]) | |
| cal_low = cal["calibrated_low"] | |
| cal_high = cal["calibrated_high"] | |
| hist_std = float(np.std(series)) | |
| score, label = confidence_score(cal_low, cal_high, hist_std) | |
| last_val = float(series[-1]) | |
| first_fc = float(result["median"][0]) | |
| trend_pct = ((first_fc - last_val) / (last_val + 1e-9)) * 100 | |
| decision = get_decision( | |
| trend_pct=trend_pct, | |
| confidence=score, | |
| cusum_alert="NONE", | |
| is_financial=prepared["is_financial"], | |
| is_intermittent=prepared["is_intermittent"], | |
| ) | |
| detector = build_detector(series) | |
| dates = prepared["dates"] | |
| fc_dates = _future_dates(dates[-1], 4, prepared["frequency"]) | |
| payload = { | |
| "forecast": [ | |
| { | |
| "date": fc_dates[i], | |
| "low": float(cal_low[i]), | |
| "median": float(result["median"][i]), | |
| "high": float(cal_high[i]), | |
| } | |
| for i in range(4) | |
| ], | |
| "baseline": [ | |
| {"date": fc_dates[i], "value": float(result["baseline"][i])} | |
| for i in range(4) | |
| ], | |
| "baseline_type": result["baseline_type"], | |
| "confidence_score": score, | |
| "confidence_label": label, | |
| "decision": decision, | |
| "trend_pct": round(trend_pct, 2), | |
| "fallback_bands": cal["fallback"], | |
| "is_financial": prepared["is_financial"], | |
| "is_intermittent": prepared["is_intermittent"], | |
| "history_dates": dates[-52:], | |
| "history_values": [float(v) for v in series[-52:]], | |
| "series_name": value_col.replace("_", " "), | |
| "frequency": prepared["frequency"], | |
| "_detector": detector, | |
| "_alpha": cal["alpha"], | |
| "_hist_std": hist_std, | |
| "_series": series, | |
| } | |
| pkl_path = CACHE_DIR / f"{key}.pkl" | |
| with open(pkl_path, "wb") as f: | |
| pickle.dump(payload, f) | |
| print(f" β {key} β {pkl_path.name} " | |
| f"(conf={score}, trend={trend_pct:+.1f}%)") | |
| except Exception as e: | |
| print(f" β {key} failed: {e}") | |
| print("Cache build complete.") | |
| def _future_dates(last_date_str: str, horizon: int, frequency: str) -> list[str]: | |
| last = pd.Timestamp(last_date_str) | |
| freq_map = {"hourly": "h", "daily": "D", "weekly": "W", "monthly": "MS", "quarterly": "QS", "annually": "YS"} | |
| offset = freq_map.get(frequency, "W") | |
| dates = pd.date_range(start=last, periods=horizon + 1, freq=offset)[1:] | |
| return [d.strftime("%Y-%m-%d") for d in dates] | |
| # βββ CLI entry point ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import forecaster | |
| print("Loading Chronos-Bolt-Small...") | |
| forecaster.load_model() | |
| print("Model ready. Building demo cache...\n") | |
| build_all() |