Spaces:
Running
on
T4
Running
on
T4
preset forecasting data used only when default forecast length is set, tested, updated README.md
Browse files
README.md
CHANGED
|
@@ -7,4 +7,63 @@ sdk: docker
|
|
| 7 |
app_port: 7860
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
app_port: 7860
|
| 8 |
---
|
| 9 |
|
| 10 |
+
# TiRex – Zero‑Shot Time Series Forecasting App
|
| 11 |
+
|
| 12 |
+
A Gradio‑based interactive web app to perform zero‑shot time series forecasting using the TiRex model. Upload your own CSV/XLSX/Parquet files or choose from built‑in presets, filter series by name, and visualize quantile forecasts over your chosen horizon.
|
| 13 |
+
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
## 🔍 Features
|
| 17 |
+
|
| 18 |
+
- **Zero‑Shot Forecasting**: Powered by the [`NX-AI/TiRex`](https://huggingface.co/NX-AI/TiRex) model.
|
| 19 |
+
- **Custom Data Upload**: Accepts CSV, XLSX, and Parquet.
|
| 20 |
+
- **Preset Datasets**: Includes `loop.csv`, `air_passangers.csv`, and `ett2.csv` for quick demos.
|
| 21 |
+
- **Interactive Filtering**: Search, check/uncheck, and plot only the series you care about.
|
| 22 |
+
- **Quantile Forecasts**: Displays historical data, median forecast line, and 10–90% quantile shading.
|
| 23 |
+
- **Configurable Horizon**: Slider to set forecast length (1–512 steps).
|
| 24 |
+
- **Automatic Defaults**: Detects best forecast‐length defaults for presets.
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## 📊 Data Format
|
| 29 |
+
|
| 30 |
+
### With Named Series
|
| 31 |
+
```csv
|
| 32 |
+
AAPL,120.5,121.0,119.8,122.1,123.5,...
|
| 33 |
+
AMZN,3300.0,3310.5,3295.2,3305.8,3315.1,...
|
| 34 |
+
GOOGL,2800.1,2795.3,2810.7,2805.2,2820.4,...
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
### Without Named Series
|
| 38 |
+
```csv
|
| 39 |
+
120.5,121.0,119.8,122.1,123.5,...
|
| 40 |
+
3300.0,3310.5,3295.2,3305.8,3315.1,...
|
| 41 |
+
2800.1,2795.3,2810.7,2805.2,2820.4,...
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
### Key Rules:
|
| 45 |
+
- **One row per time series**
|
| 46 |
+
- **Consistent naming**: Either all rows have names (first column) or none do
|
| 47 |
+
- **Numeric data**: All values after the optional name column must be numeric
|
| 48 |
+
- **Minimum length**: Time series must have at least `forecast_length + 10` data points
|
| 49 |
+
- **Maximum constraints**: Up to 30 time series and 2048 time steps per series
|
| 50 |
+
|
| 51 |
+
## 🔧 Configuration
|
| 52 |
+
|
| 53 |
+
### Forecast Length
|
| 54 |
+
- **Default**: 64 steps
|
| 55 |
+
- **Range**: 1-512 steps
|
| 56 |
+
- **Auto-adjustment**: Preset datasets have optimized forecast lengths:
|
| 57 |
+
- `loop.csv` and `ett2.csv`: 256 steps
|
| 58 |
+
- `air_passangers.csv`: 48 steps
|
| 59 |
+
|
| 60 |
+
### Model Settings
|
| 61 |
+
- **Device**: CUDA (T4 GPU)
|
| 62 |
+
- **Quantiles**: 10%, 50% (median), 90% prediction intervals
|
| 63 |
+
|
| 64 |
+
## 📈 Output Features
|
| 65 |
+
|
| 66 |
+
- **Historical data**: Blue line showing input time series
|
| 67 |
+
- **Median forecast**: Orange line for point predictions
|
| 68 |
+
- **Uncertainty bands**: Gray shaded area showing 10%-90%
|
| 69 |
+
|
app.py
CHANGED
|
@@ -18,13 +18,13 @@ torch.manual_seed(42)
|
|
| 18 |
model: ForecastModel = load_model("NX-AI/TiRex",device='cuda')
|
| 19 |
|
| 20 |
def model_forecast(input_data, forecast_length=256, file_name=None):
|
| 21 |
-
if os.path.basename(file_name) == "loop.csv":
|
| 22 |
_forecast_tensor = torch.load("data/loop_forecast_256.pt")
|
| 23 |
return _forecast_tensor[:,:forecast_length,:]
|
| 24 |
-
elif os.path.basename(file_name) == "ett2.csv":
|
| 25 |
_forecast_tensor = torch.load("data/ett2_forecast_256.pt")
|
| 26 |
return _forecast_tensor[:,:forecast_length,:]
|
| 27 |
-
elif os.path.basename(file_name) == "air_passangers.csv":
|
| 28 |
_forecast_tensor = torch.load("data/air_passengers_forecast_48.pt")
|
| 29 |
return _forecast_tensor[:,:forecast_length,:]
|
| 30 |
else:
|
|
@@ -413,5 +413,5 @@ with gr.Blocks(fill_width=True,theme=gr.themes.Ocean()) as demo:
|
|
| 413 |
'''
|
| 414 |
gradio app.py
|
| 415 |
ssh -L 7860:localhost:7860 nikita_blago@oracle-gpu-controller -t \
|
| 416 |
-
ssh -L 7860:localhost:7860 compute-permanent-node-
|
| 417 |
'''
|
|
|
|
| 18 |
model: ForecastModel = load_model("NX-AI/TiRex",device='cuda')
|
| 19 |
|
| 20 |
def model_forecast(input_data, forecast_length=256, file_name=None):
|
| 21 |
+
if os.path.basename(file_name) == "loop.csv" and forecast_length==256:
|
| 22 |
_forecast_tensor = torch.load("data/loop_forecast_256.pt")
|
| 23 |
return _forecast_tensor[:,:forecast_length,:]
|
| 24 |
+
elif os.path.basename(file_name) == "ett2.csv" and forecast_length==256:
|
| 25 |
_forecast_tensor = torch.load("data/ett2_forecast_256.pt")
|
| 26 |
return _forecast_tensor[:,:forecast_length,:]
|
| 27 |
+
elif os.path.basename(file_name) == "air_passangers.csv"and forecast_length==48:
|
| 28 |
_forecast_tensor = torch.load("data/air_passengers_forecast_48.pt")
|
| 29 |
return _forecast_tensor[:,:forecast_length,:]
|
| 30 |
else:
|
|
|
|
| 413 |
'''
|
| 414 |
gradio app.py
|
| 415 |
ssh -L 7860:localhost:7860 nikita_blago@oracle-gpu-controller -t \
|
| 416 |
+
ssh -L 7860:localhost:7860 compute-permanent-node-195
|
| 417 |
'''
|