Gil Stetler commited on
Commit
6a68c5b
·
1 Parent(s): 937945c

chronos-t5-large test

Browse files
Files changed (2) hide show
  1. app.py +76 -4
  2. requirements.txt +6 -0
app.py CHANGED
@@ -1,7 +1,79 @@
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import io
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
  import gradio as gr
8
+ import matplotlib
9
+ matplotlib.use("Agg") # headless backend for Spaces
10
+ import matplotlib.pyplot as plt
11
+ from chronos import ChronosPipeline
12
 
13
+ MODEL_ID = "amazon/chronos-t5-large"
14
+ PREDICTION_LENGTH = 12
15
+ NUM_SAMPLES = 100 # adjust if you want tighter/faster
16
 
17
+ # Choose a sensible dtype/device for Space hardware
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
20
+
21
+ # Load once at startup (HF Spaces cache model weights between runs)
22
+ pipe = ChronosPipeline.from_pretrained(
23
+ MODEL_ID,
24
+ device_map="auto", # uses GPU if available
25
+ torch_dtype=dtype,
26
+ )
27
+
28
+ # Small helper to run the full demo pipeline
29
+ def run_forecast():
30
+ # 1) Load example data (univariate)
31
+ df = pd.read_csv(
32
+ "https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
33
+ )
34
+ y = df["#Passengers"].astype(float).to_numpy()
35
+
36
+ # 2) Forecast with Chronos
37
+ context = torch.tensor(y, dtype=torch.float32)
38
+ fcst = pipe.predict(context, prediction_length=PREDICTION_LENGTH, num_samples=NUM_SAMPLES) # [1, S, H]
39
+ samples = fcst[0].cpu().numpy() # (S, H)
40
+
41
+ # 3) Summaries
42
+ low, median, high = np.quantile(samples, [0.1, 0.5, 0.9], axis=0)
43
+
44
+ # 4) Plot history + forecast
45
+ fig = plt.figure(figsize=(8, 4))
46
+ x_hist = np.arange(len(y))
47
+ x_fcst = np.arange(len(y), len(y) + PREDICTION_LENGTH)
48
+ plt.plot(x_hist, y, label="history")
49
+ plt.plot(x_fcst, median, label="median")
50
+ plt.fill_between(x_fcst, low, high, alpha=0.3, label="80% interval")
51
+ plt.title("Chronos-T5-Large Forecast")
52
+ plt.xlabel("time")
53
+ plt.ylabel("#Passengers")
54
+ plt.legend()
55
+ plt.tight_layout()
56
+
57
+ # Also return the raw curves if you want to inspect/download
58
+ out_json = {
59
+ "prediction_length": PREDICTION_LENGTH,
60
+ "num_samples": int(NUM_SAMPLES),
61
+ "median": median.tolist(),
62
+ "p10": low.tolist(),
63
+ "p90": high.tolist(),
64
+ }
65
+ return fig, out_json
66
+
67
+ with gr.Blocks(title="Chronos-T5-Large • AirPassengers Demo") as demo:
68
+ gr.Markdown(
69
+ "## Chronos-T5-Large (zero-shot forecasting)\n"
70
+ "Click **Run forecast** to compute on the server (CPU/GPU of this Space)."
71
+ )
72
+ run_btn = gr.Button("Run forecast", variant="primary")
73
+ plot = gr.Plot(label="Forecast")
74
+ meta = gr.JSON(label="Forecast summary (median, p10, p90)")
75
+
76
+ run_btn.click(fn=run_forecast, inputs=None, outputs=[plot, meta])
77
+
78
+ if __name__ == "__main__":
79
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.0
2
+ chronos-forecasting>=1.5
3
+ torch>=2.2
4
+ pandas>=2.0
5
+ numpy>=1.26
6
+ matplotlib>=3.8