Gil Stetler commited on
Commit
67dad62
·
1 Parent(s): 6a68c5b

test on previous data with mse

Browse files
Files changed (1) hide show
  1. app.py +48 -30
app.py CHANGED
@@ -1,6 +1,4 @@
1
  # app.py
2
- import os
3
- import io
4
  import numpy as np
5
  import pandas as pd
6
  import torch
@@ -12,68 +10,88 @@ from chronos import ChronosPipeline
12
 
13
  MODEL_ID = "amazon/chronos-t5-large"
14
  PREDICTION_LENGTH = 12
15
- NUM_SAMPLES = 100 # adjust if you want tighter/faster
16
 
17
- # Choose a sensible dtype/device for Space hardware
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
20
 
21
- # Load once at startup (HF Spaces cache model weights between runs)
22
  pipe = ChronosPipeline.from_pretrained(
23
  MODEL_ID,
24
- device_map="auto", # uses GPU if available
25
  torch_dtype=dtype,
26
  )
27
 
28
- # Small helper to run the full demo pipeline
29
- def run_forecast():
30
- # 1) Load example data (univariate)
31
  df = pd.read_csv(
32
  "https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
33
  )
34
  y = df["#Passengers"].astype(float).to_numpy()
 
35
 
36
- # 2) Forecast with Chronos
37
- context = torch.tensor(y, dtype=torch.float32)
 
 
 
 
 
 
 
 
38
  fcst = pipe.predict(context, prediction_length=PREDICTION_LENGTH, num_samples=NUM_SAMPLES) # [1, S, H]
39
  samples = fcst[0].cpu().numpy() # (S, H)
40
 
41
- # 3) Summaries
42
  low, median, high = np.quantile(samples, [0.1, 0.5, 0.9], axis=0)
43
 
44
- # 4) Plot history + forecast
45
- fig = plt.figure(figsize=(8, 4))
46
- x_hist = np.arange(len(y))
47
- x_fcst = np.arange(len(y), len(y) + PREDICTION_LENGTH)
48
- plt.plot(x_hist, y, label="history")
49
- plt.plot(x_fcst, median, label="median")
 
 
 
 
 
 
50
  plt.fill_between(x_fcst, low, high, alpha=0.3, label="80% interval")
51
- plt.title("Chronos-T5-Large Forecast")
52
  plt.xlabel("time")
53
  plt.ylabel("#Passengers")
54
- plt.legend()
55
  plt.tight_layout()
56
 
57
- # Also return the raw curves if you want to inspect/download
58
  out_json = {
59
- "prediction_length": PREDICTION_LENGTH,
60
  "num_samples": int(NUM_SAMPLES),
 
61
  "median": median.tolist(),
62
  "p10": low.tolist(),
63
  "p90": high.tolist(),
 
64
  }
65
- return fig, out_json
66
 
67
- with gr.Blocks(title="Chronos-T5-Large AirPassengers Demo") as demo:
 
 
 
 
68
  gr.Markdown(
69
- "## Chronos-T5-Large (zero-shot forecasting)\n"
70
- "Click **Run forecast** to compute on the server (CPU/GPU of this Space)."
 
71
  )
72
- run_btn = gr.Button("Run forecast", variant="primary")
73
- plot = gr.Plot(label="Forecast")
74
- meta = gr.JSON(label="Forecast summary (median, p10, p90)")
 
75
 
76
- run_btn.click(fn=run_forecast, inputs=None, outputs=[plot, meta])
77
 
78
  if __name__ == "__main__":
79
  demo.launch()
 
1
  # app.py
 
 
2
  import numpy as np
3
  import pandas as pd
4
  import torch
 
10
 
11
  MODEL_ID = "amazon/chronos-t5-large"
12
  PREDICTION_LENGTH = 12
13
+ NUM_SAMPLES = 100 # increase for smoother quantiles (slower)
14
 
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
17
 
18
+ # Load once at startup (HF Spaces cache between runs)
19
  pipe = ChronosPipeline.from_pretrained(
20
  MODEL_ID,
21
+ device_map="auto",
22
  torch_dtype=dtype,
23
  )
24
 
25
+ def run_forecast_and_evaluate():
26
+ # 1) Load univariate example data
 
27
  df = pd.read_csv(
28
  "https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
29
  )
30
  y = df["#Passengers"].astype(float).to_numpy()
31
+ n = len(y)
32
 
33
+ if n <= PREDICTION_LENGTH + 5:
34
+ raise gr.Error("Time series too short for a holdout evaluation.")
35
+
36
+ # 2) Train/forecast split:
37
+ # Use all but the last PREDICTION_LENGTH points as context (train),
38
+ # and compare forecast to the real last PREDICTION_LENGTH points (test).
39
+ y_train = y[: n - PREDICTION_LENGTH]
40
+ y_test = y[n - PREDICTION_LENGTH :]
41
+
42
+ context = torch.tensor(y_train, dtype=torch.float32)
43
  fcst = pipe.predict(context, prediction_length=PREDICTION_LENGTH, num_samples=NUM_SAMPLES) # [1, S, H]
44
  samples = fcst[0].cpu().numpy() # (S, H)
45
 
46
+ # 3) Summaries & metrics
47
  low, median, high = np.quantile(samples, [0.1, 0.5, 0.9], axis=0)
48
 
49
+ # "mean standard error" is ambiguous; commonly MSE + RMSE are reported:
50
+ mse = float(np.mean((median - y_test) ** 2))
51
+ rmse = float(np.sqrt(mse))
52
+
53
+ # 4) Plot: full history + forecast horizon vs ground truth
54
+ fig = plt.figure(figsize=(9, 4))
55
+ x_hist = np.arange(len(y_train))
56
+ x_fcst = np.arange(len(y_train), len(y_train) + PREDICTION_LENGTH)
57
+
58
+ plt.plot(x_hist, y_train, label="history")
59
+ plt.plot(x_fcst, y_test, label="actual (holdout)")
60
+ plt.plot(x_fcst, median, linestyle="--", label="forecast (median)")
61
  plt.fill_between(x_fcst, low, high, alpha=0.3, label="80% interval")
62
+ plt.title("Chronos-T5-Large • Holdout Evaluation")
63
  plt.xlabel("time")
64
  plt.ylabel("#Passengers")
65
+ plt.legend(loc="best")
66
  plt.tight_layout()
67
 
68
+ # JSON payload for inspection/download
69
  out_json = {
70
+ "prediction_length": int(PREDICTION_LENGTH),
71
  "num_samples": int(NUM_SAMPLES),
72
+ "metrics": {"MSE": mse, "RMSE": rmse},
73
  "median": median.tolist(),
74
  "p10": low.tolist(),
75
  "p90": high.tolist(),
76
+ "actual": y_test.tolist(),
77
  }
 
78
 
79
+ # Metrics text to display prominently
80
+ metrics_md = f"**MSE:** {mse:.3f}  **RMSE:** {rmse:.3f}"
81
+ return fig, out_json, metrics_md
82
+
83
+ with gr.Blocks(title="Chronos-T5-Large • Holdout Demo") as demo:
84
  gr.Markdown(
85
+ "## Chronos-T5-Large (zero-shot forecasting) — Holdout Evaluation\n"
86
+ "Click **Run** to forecast the last 12 months from AirPassengers and compare to the true values.\n"
87
+ "Computation runs on this Space's server hardware."
88
  )
89
+ run_btn = gr.Button("Run", variant="primary")
90
+ plot = gr.Plot(label="Forecast vs Actual (holdout)")
91
+ meta = gr.JSON(label="Data & Metrics")
92
+ metrics = gr.Markdown(label="Metrics")
93
 
94
+ run_btn.click(run_forecast_and_evaluate, inputs=None, outputs=[plot, meta, metrics])
95
 
96
  if __name__ == "__main__":
97
  demo.launch()