Spaces:
Sleeping
Sleeping
Gil Stetler
commited on
Commit
·
67dad62
1
Parent(s):
6a68c5b
test on previous data with mse
Browse files
app.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
| 1 |
# app.py
|
| 2 |
-
import os
|
| 3 |
-
import io
|
| 4 |
import numpy as np
|
| 5 |
import pandas as pd
|
| 6 |
import torch
|
|
@@ -12,68 +10,88 @@ from chronos import ChronosPipeline
|
|
| 12 |
|
| 13 |
MODEL_ID = "amazon/chronos-t5-large"
|
| 14 |
PREDICTION_LENGTH = 12
|
| 15 |
-
NUM_SAMPLES = 100 #
|
| 16 |
|
| 17 |
-
# Choose a sensible dtype/device for Space hardware
|
| 18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 20 |
|
| 21 |
-
# Load once at startup (HF Spaces cache
|
| 22 |
pipe = ChronosPipeline.from_pretrained(
|
| 23 |
MODEL_ID,
|
| 24 |
-
device_map="auto",
|
| 25 |
torch_dtype=dtype,
|
| 26 |
)
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
# 1) Load example data (univariate)
|
| 31 |
df = pd.read_csv(
|
| 32 |
"https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
|
| 33 |
)
|
| 34 |
y = df["#Passengers"].astype(float).to_numpy()
|
|
|
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
fcst = pipe.predict(context, prediction_length=PREDICTION_LENGTH, num_samples=NUM_SAMPLES) # [1, S, H]
|
| 39 |
samples = fcst[0].cpu().numpy() # (S, H)
|
| 40 |
|
| 41 |
-
# 3) Summaries
|
| 42 |
low, median, high = np.quantile(samples, [0.1, 0.5, 0.9], axis=0)
|
| 43 |
|
| 44 |
-
#
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
plt.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
plt.fill_between(x_fcst, low, high, alpha=0.3, label="80% interval")
|
| 51 |
-
plt.title("Chronos-T5-Large
|
| 52 |
plt.xlabel("time")
|
| 53 |
plt.ylabel("#Passengers")
|
| 54 |
-
plt.legend()
|
| 55 |
plt.tight_layout()
|
| 56 |
|
| 57 |
-
#
|
| 58 |
out_json = {
|
| 59 |
-
"prediction_length": PREDICTION_LENGTH,
|
| 60 |
"num_samples": int(NUM_SAMPLES),
|
|
|
|
| 61 |
"median": median.tolist(),
|
| 62 |
"p10": low.tolist(),
|
| 63 |
"p90": high.tolist(),
|
|
|
|
| 64 |
}
|
| 65 |
-
return fig, out_json
|
| 66 |
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
gr.Markdown(
|
| 69 |
-
"## Chronos-T5-Large (zero-shot forecasting)\n"
|
| 70 |
-
"Click **Run
|
|
|
|
| 71 |
)
|
| 72 |
-
run_btn = gr.Button("Run
|
| 73 |
-
plot = gr.Plot(label="Forecast")
|
| 74 |
-
meta = gr.JSON(label="
|
|
|
|
| 75 |
|
| 76 |
-
run_btn.click(
|
| 77 |
|
| 78 |
if __name__ == "__main__":
|
| 79 |
demo.launch()
|
|
|
|
| 1 |
# app.py
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
import torch
|
|
|
|
| 10 |
|
| 11 |
MODEL_ID = "amazon/chronos-t5-large"
|
| 12 |
PREDICTION_LENGTH = 12
|
| 13 |
+
NUM_SAMPLES = 100 # increase for smoother quantiles (slower)
|
| 14 |
|
|
|
|
| 15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 17 |
|
| 18 |
+
# Load once at startup (HF Spaces cache between runs)
|
| 19 |
pipe = ChronosPipeline.from_pretrained(
|
| 20 |
MODEL_ID,
|
| 21 |
+
device_map="auto",
|
| 22 |
torch_dtype=dtype,
|
| 23 |
)
|
| 24 |
|
| 25 |
+
def run_forecast_and_evaluate():
|
| 26 |
+
# 1) Load univariate example data
|
|
|
|
| 27 |
df = pd.read_csv(
|
| 28 |
"https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
|
| 29 |
)
|
| 30 |
y = df["#Passengers"].astype(float).to_numpy()
|
| 31 |
+
n = len(y)
|
| 32 |
|
| 33 |
+
if n <= PREDICTION_LENGTH + 5:
|
| 34 |
+
raise gr.Error("Time series too short for a holdout evaluation.")
|
| 35 |
+
|
| 36 |
+
# 2) Train/forecast split:
|
| 37 |
+
# Use all but the last PREDICTION_LENGTH points as context (train),
|
| 38 |
+
# and compare forecast to the real last PREDICTION_LENGTH points (test).
|
| 39 |
+
y_train = y[: n - PREDICTION_LENGTH]
|
| 40 |
+
y_test = y[n - PREDICTION_LENGTH :]
|
| 41 |
+
|
| 42 |
+
context = torch.tensor(y_train, dtype=torch.float32)
|
| 43 |
fcst = pipe.predict(context, prediction_length=PREDICTION_LENGTH, num_samples=NUM_SAMPLES) # [1, S, H]
|
| 44 |
samples = fcst[0].cpu().numpy() # (S, H)
|
| 45 |
|
| 46 |
+
# 3) Summaries & metrics
|
| 47 |
low, median, high = np.quantile(samples, [0.1, 0.5, 0.9], axis=0)
|
| 48 |
|
| 49 |
+
# "mean standard error" is ambiguous; commonly MSE + RMSE are reported:
|
| 50 |
+
mse = float(np.mean((median - y_test) ** 2))
|
| 51 |
+
rmse = float(np.sqrt(mse))
|
| 52 |
+
|
| 53 |
+
# 4) Plot: full history + forecast horizon vs ground truth
|
| 54 |
+
fig = plt.figure(figsize=(9, 4))
|
| 55 |
+
x_hist = np.arange(len(y_train))
|
| 56 |
+
x_fcst = np.arange(len(y_train), len(y_train) + PREDICTION_LENGTH)
|
| 57 |
+
|
| 58 |
+
plt.plot(x_hist, y_train, label="history")
|
| 59 |
+
plt.plot(x_fcst, y_test, label="actual (holdout)")
|
| 60 |
+
plt.plot(x_fcst, median, linestyle="--", label="forecast (median)")
|
| 61 |
plt.fill_between(x_fcst, low, high, alpha=0.3, label="80% interval")
|
| 62 |
+
plt.title("Chronos-T5-Large • Holdout Evaluation")
|
| 63 |
plt.xlabel("time")
|
| 64 |
plt.ylabel("#Passengers")
|
| 65 |
+
plt.legend(loc="best")
|
| 66 |
plt.tight_layout()
|
| 67 |
|
| 68 |
+
# JSON payload for inspection/download
|
| 69 |
out_json = {
|
| 70 |
+
"prediction_length": int(PREDICTION_LENGTH),
|
| 71 |
"num_samples": int(NUM_SAMPLES),
|
| 72 |
+
"metrics": {"MSE": mse, "RMSE": rmse},
|
| 73 |
"median": median.tolist(),
|
| 74 |
"p10": low.tolist(),
|
| 75 |
"p90": high.tolist(),
|
| 76 |
+
"actual": y_test.tolist(),
|
| 77 |
}
|
|
|
|
| 78 |
|
| 79 |
+
# Metrics text to display prominently
|
| 80 |
+
metrics_md = f"**MSE:** {mse:.3f} **RMSE:** {rmse:.3f}"
|
| 81 |
+
return fig, out_json, metrics_md
|
| 82 |
+
|
| 83 |
+
with gr.Blocks(title="Chronos-T5-Large • Holdout Demo") as demo:
|
| 84 |
gr.Markdown(
|
| 85 |
+
"## Chronos-T5-Large (zero-shot forecasting) — Holdout Evaluation\n"
|
| 86 |
+
"Click **Run** to forecast the last 12 months from AirPassengers and compare to the true values.\n"
|
| 87 |
+
"Computation runs on this Space's server hardware."
|
| 88 |
)
|
| 89 |
+
run_btn = gr.Button("Run", variant="primary")
|
| 90 |
+
plot = gr.Plot(label="Forecast vs Actual (holdout)")
|
| 91 |
+
meta = gr.JSON(label="Data & Metrics")
|
| 92 |
+
metrics = gr.Markdown(label="Metrics")
|
| 93 |
|
| 94 |
+
run_btn.click(run_forecast_and_evaluate, inputs=None, outputs=[plot, meta, metrics])
|
| 95 |
|
| 96 |
if __name__ == "__main__":
|
| 97 |
demo.launch()
|