File size: 6,525 Bytes
13a5236
 
 
 
 
 
 
 
 
 
 
a30bf5b
 
 
 
13a5236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e96dd08
13a5236
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from __future__ import annotations

import os
import pickle
from pathlib import Path

import numpy as np
import pandas as pd

# Pre-computed forecasts for the three demo datasets live here.

# Check local (inside backend/) first, then project root (parent.parent)
_HERE = Path(__file__).parent
CACHE_DIR = _HERE / "demo_cache" if (_HERE / "demo_cache").exists() else _HERE.parent / "demo_cache"
DEMO_DIR  = _HERE / "demo_data"  if (_HERE / "demo_data").exists()  else _HERE.parent / "demo_data"

_CACHE: dict[str, dict] = {}   # in-memory after first load


# ─── Public API ───────────────────────────────────────────────────────────────

def load_all() -> None:
    """
    Called once at FastAPI startup after the model is loaded.
    Reads all .pkl files from demo_cache/ into memory.
    Missing files are skipped silently β€” demo still works, just slower.
    """
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    for pkl_file in CACHE_DIR.glob("*.pkl"):
        key = pkl_file.stem
        try:
            with open(pkl_file, "rb") as f:
                _CACHE[key] = pickle.load(f)
        except Exception:
            pass


def get(key: str) -> dict | None:
    return _CACHE.get(key)


def has(key: str) -> bool:
    return key in _CACHE


# ─── Pre-compute script ───────────────────────────────────────────────────────

def build_all() -> None:
    """
    Generates and saves pre-computed forecasts for all demo datasets.
    Requires the model to already be loaded (call forecaster.load_model() first).
    """
    import forecaster
    from preprocessor import ingest, prepare_series
    from calibrator import calibrate
    from detector import build_detector
    from confidence import compute as confidence_score
    from decision import get_decision
    from baseline import select_and_run

    CACHE_DIR.mkdir(parents=True, exist_ok=True)

    demos = {
        "bakery":   ("bakery_sales.csv",         "date", "weekly_sales_inr"),
        "crop":     ("crop_prices_sample.csv",    "date", "wheat_price_inr_per_quintal"),
        "m5":       ("walmart_m5_sample.csv",     "date", "FOODS_1"),
    }

    for key, (filename, date_col, value_col) in demos.items():
        csv_path = DEMO_DIR / filename
        if not csv_path.exists():
            print(f"  skip {key} β€” {csv_path} not found")
            continue

        try:
            print(f"  building {key}...")

            with open(csv_path, "rb") as f:
                file_bytes = f.read()

            session_data = ingest(file_bytes, filename)
            session_id   = session_data["session_id"]
            prepared     = prepare_series(session_id, date_col, value_col)
            series       = prepared["series"]

            result = forecaster.run_forecast(series, horizon=4, frequency=prepared["frequency"])

            cal = calibrate(series, result["low"], result["high"])
            cal_low  = cal["calibrated_low"]
            cal_high = cal["calibrated_high"]

            hist_std = float(np.std(series))
            score, label = confidence_score(cal_low, cal_high, hist_std)

            last_val = float(series[-1])
            first_fc = float(result["median"][0])
            trend_pct = ((first_fc - last_val) / (last_val + 1e-9)) * 100

            decision = get_decision(
                trend_pct=trend_pct,
                confidence=score,
                cusum_alert="NONE",
                is_financial=prepared["is_financial"],
                is_intermittent=prepared["is_intermittent"],
            )

            detector  = build_detector(series)
            dates     = prepared["dates"]
            fc_dates  = _future_dates(dates[-1], 4, prepared["frequency"])

            payload = {
                "forecast": [
                    {
                        "date":   fc_dates[i],
                        "low":    float(cal_low[i]),
                        "median": float(result["median"][i]),
                        "high":   float(cal_high[i]),
                    }
                    for i in range(4)
                ],
                "baseline": [
                    {"date": fc_dates[i], "value": float(result["baseline"][i])}
                    for i in range(4)
                ],
                "baseline_type":    result["baseline_type"],
                "confidence_score": score,
                "confidence_label": label,
                "decision":         decision,
                "trend_pct":        round(trend_pct, 2),
                "fallback_bands":   cal["fallback"],
                "is_financial":     prepared["is_financial"],
                "is_intermittent":  prepared["is_intermittent"],
                "history_dates":    dates[-52:],
                "history_values":   [float(v) for v in series[-52:]],
                "series_name":      value_col.replace("_", " "),
                "frequency":        prepared["frequency"],
                "_detector":        detector,
                "_alpha":           cal["alpha"],
                "_hist_std":        hist_std,
                "_series":          series,
            }

            pkl_path = CACHE_DIR / f"{key}.pkl"
            with open(pkl_path, "wb") as f:
                pickle.dump(payload, f)

            print(f"  βœ“ {key} β†’ {pkl_path.name}  "
                  f"(conf={score}, trend={trend_pct:+.1f}%)")

        except Exception as e:
            print(f"  βœ— {key} failed: {e}")

    print("Cache build complete.")


def _future_dates(last_date_str: str, horizon: int, frequency: str) -> list[str]:
    last = pd.Timestamp(last_date_str)
    freq_map = {"hourly": "h", "daily": "D", "weekly": "W", "monthly": "MS", "quarterly": "QS", "annually": "YS"}
    offset   = freq_map.get(frequency, "W")
    dates    = pd.date_range(start=last, periods=horizon + 1, freq=offset)[1:]
    return [d.strftime("%Y-%m-%d") for d in dates]


# ─── CLI entry point ──────────────────────────────────────────────────────────

if __name__ == "__main__":
    import forecaster
    print("Loading Chronos-Bolt-Small...")
    forecaster.load_model()
    print("Model ready. Building demo cache...\n")
    build_all()