zqiao11 commited on
Commit
0b97f6a
·
0 Parent(s):

Initial release

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ scale-hf-logo.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ auto_evals/
2
+ venv/
3
+ __pycache__/
4
+ .env
5
+ .ipynb_checkpoints
6
+ *ipynb
7
+ .idea
8
+ .vscode/
9
+
10
+ eval-queue/
11
+ eval-results/
12
+ eval-queue-bk/
13
+ eval-results-bk/
14
+ logs/
15
+ utils.py
16
+ css_html_js.py
17
+ formatting.py
18
+ run_local.sh
.pre-commit-config.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ default_language_version:
16
+ python: python3
17
+
18
+ ci:
19
+ autofix_prs: true
20
+ autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
21
+ autoupdate_schedule: quarterly
22
+
23
+ repos:
24
+ - repo: https://github.com/pre-commit/pre-commit-hooks
25
+ rev: v4.3.0
26
+ hooks:
27
+ - id: check-yaml
28
+ - id: check-case-conflict
29
+ - id: detect-private-key
30
+ - id: check-added-large-files
31
+ args: ['--maxkb=1000']
32
+ - id: requirements-txt-fixer
33
+ - id: end-of-file-fixer
34
+ - id: trailing-whitespace
35
+
36
+ - repo: https://github.com/PyCQA/isort
37
+ rev: 5.12.0
38
+ hooks:
39
+ - id: isort
40
+ name: Format imports
41
+
42
+ - repo: https://github.com/psf/black
43
+ rev: 22.12.0
44
+ hooks:
45
+ - id: black
46
+ name: Format code
47
+ additional_dependencies: ['click==8.0.2']
48
+
49
+ - repo: https://github.com/charliermarsh/ruff-pre-commit
50
+ # Ruff version.
51
+ rev: 'v0.0.267'
52
+ hooks:
53
+ - id: ruff
Dockerfile ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install git for pip install from GitHub
6
+ RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
7
+
8
+ # Copy requirements first for better caching
9
+ COPY requirements.txt .
10
+
11
+ # Install Python dependencies
12
+ RUN pip install --no-cache-dir --upgrade pip && \
13
+ pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copy application code
16
+ COPY . .
17
+
18
+ # Create startup script that installs timebench at runtime (when secrets are available)
19
+ RUN echo '#!/bin/bash\n\
20
+ if [ -n "$GITHUB_TOKEN" ]; then\n\
21
+ echo "Installing timebench from private GitHub repo..."\n\
22
+ pip install --no-cache-dir git+https://oauth2:${GITHUB_TOKEN}@github.com/zqiao11/TIME.git\n\
23
+ else\n\
24
+ echo "Installing timebench from public GitHub repo..."\n\
25
+ pip install --no-cache-dir git+https://github.com/zqiao11/TIME.git\n\
26
+ fi\n\
27
+ exec python app.py\n' > /app/start.sh && chmod +x /app/start.sh
28
+
29
+ # Expose Gradio default port
30
+ EXPOSE 7860
31
+
32
+ # Set environment variables
33
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
34
+ ENV GRADIO_SERVER_PORT="7860"
35
+
36
+ # Run the startup script
37
+ CMD ["/app/start.sh"]
Makefile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: style format
2
+
3
+
4
+ style:
5
+ python -m black --line-length 119 .
6
+ python -m isort .
7
+ ruff check --fix .
8
+
9
+
10
+ quality:
11
+ python -m black --check --line-length 119 .
12
+ python -m isort --check-only .
13
+ ruff check .
README.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: TIME Benchmark Leaderboard
3
+ emoji: 🥇
4
+ colorFrom: green
5
+ colorTo: indigo
6
+ sdk: docker
7
+ pinned: true
8
+ license: apache-2.0
9
+ short_description: 'TIME: A Benchmark for Time Series Forecasting'
10
+ ---
11
+
12
+ # TIME Benchmark Leaderboard
13
+
14
+ A unified benchmark for time series probabilistic forecasting with multiple granularity evaluation.
15
+
16
+ ## Features
17
+
18
+ - **Overall Performance**: Aggregated metrics across all datasets and horizons
19
+ - **Dataset-level Analysis**: Performance breakdown by individual datasets
20
+ - **Window-level Visualization**: Detailed test window analysis with prediction visualization
21
+
22
+ ## Configuration
23
+
24
+ ### Environment Variables
25
+
26
+ The app reads data from HuggingFace Hub. Configure the following environment variables:
27
+
28
+ | Variable | Description | Default |
29
+ |----------|-------------|---------|
30
+ | `HF_TOKEN` | HuggingFace API token (required for private datasets) | None |
31
+ | `HF_REPO_ID` | Dataset repository ID | `TIME-benchmark/TIME-1.0` |
32
+ | `USE_HF_HUB` | Use HF Hub (`true`) or local files (`false`) | `true` |
33
+ | `HF_CACHE_DIR` | Custom cache directory for downloads | `~/.cache/huggingface` |
34
+
35
+ ### For HuggingFace Space Deployment
36
+
37
+ #### 快速部署(推荐)
38
+
39
+ ```bash
40
+ # 1. 复制 timebench 模块到 leaderboard_app
41
+ cd /home/eee/qzz/TIME
42
+ cp -r src/timebench leaderboard_app/
43
+
44
+ # 2. 进入 leaderboard_app 目录
45
+ cd leaderboard_app
46
+
47
+ # 3. 运行部署脚本
48
+ chmod +x deploy.sh
49
+ ./deploy.sh YOUR_USERNAME YOUR_SPACE_NAME
50
+ ```
51
+
52
+ #### 手动部署
53
+
54
+ 详细步骤请参考 [DEPLOY.md](DEPLOY.md)
55
+
56
+ **重要**: 部署前需要:
57
+ 1. 创建 HuggingFace Space: https://huggingface.co/new-space
58
+ 2. 在 Space Settings → Repository secrets 中添加 `HF_TOKEN`
59
+ 3. 确保数据已上传到 `TIME-benchmark/TIME-1.0` Dataset
60
+
61
+ ### For Local Development
62
+
63
+ Set `USE_HF_HUB=false` to use local data:
64
+
65
+ ```bash
66
+ export USE_HF_HUB=false
67
+ python app.py
68
+ ```
69
+
70
+ ## Installation
71
+
72
+ ```bash
73
+ pip install -r requirements.txt
74
+ python app.py
75
+ ```
76
+
77
+ ## Data Structure
78
+
79
+ The app expects the following data structure in the HuggingFace Dataset:
80
+
81
+ ```
82
+ HF_REPO/
83
+ ├── data/
84
+ │ └── hf_dataset/ # Time series datasets
85
+ │ ├── ECDC_COVID/
86
+ │ ├── Australia_Solar/
87
+ │ └── ...
88
+ ├── output/
89
+ │ └── results/ # Model evaluation results
90
+ │ ├── moirai_small/
91
+ │ ├── chronos_base/
92
+ │ └── ...
93
+ └── config/
94
+ └── datasets.yaml # Dataset configurations
95
+ ```
96
+
97
+ ## Ethical Considerations
98
+
99
+ This release is for research purposes only in support of an academic paper. Our models, datasets, and code are not specifically designed or evaluated for all downstream purposes. We strongly recommend users evaluate and address potential concerns related to accuracy, safety, and fairness before deploying this model.
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ # Add project root and src directory to Python path to enable imports from timebench
5
+ # Get the directory containing this file (leaderboard_app/)
6
+ current_dir = os.path.dirname(os.path.abspath(__file__))
7
+
8
+ # Try multiple paths for timebench import:
9
+ # 1. Current directory (if timebench was copied to leaderboard_app/)
10
+ # 2. Parent directory's src (for local development: TIME/src/)
11
+ # 3. Parent's parent's src (if running from leaderboard_app/)
12
+
13
+ # Add current directory first (for Space deployment)
14
+ if current_dir not in sys.path:
15
+ sys.path.insert(0, current_dir)
16
+
17
+ # Add parent directory's src (for local development)
18
+ project_root = os.path.dirname(current_dir)
19
+ if project_root not in sys.path:
20
+ sys.path.insert(0, project_root)
21
+
22
+ src_dir = os.path.join(project_root, "src")
23
+ if src_dir not in sys.path and os.path.exists(src_dir):
24
+ sys.path.insert(0, src_dir)
25
+
26
+ import gradio as gr
27
+ from src.display.css_html_js import custom_css
28
+ from src.about import TITLE, INTRODUCTION_TEXT
29
+ from src.tab import init_overall_tab, init_per_window_tab, init_per_dataset_tab, init_per_pattern_tab
30
+
31
+ # Custom head content for responsive design
32
+ custom_head = """
33
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=5.0, user-scalable=yes">
34
+ <style>
35
+ /* 响应式设计:让页面自动适配不同屏幕尺寸 */
36
+ html {
37
+ width: 100%;
38
+ max-width: 100%;
39
+ }
40
+ body {
41
+ width: 100%;
42
+ max-width: 100%;
43
+ }
44
+ </style>
45
+ """
46
+
47
+ with gr.Blocks(css=custom_css, head=custom_head) as demo:
48
+ gr.HTML(TITLE)
49
+ gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
50
+
51
+ with gr.Tabs(elem_id="custom-tabs") as tabs:
52
+ with gr.Tab("🏅 Overall Performance", id=0):
53
+ init_overall_tab()
54
+
55
+ with gr.Tab("🏅 Per Dataset", id=1):
56
+ init_per_dataset_tab(demo)
57
+
58
+ with gr.Tab("🏅 Per Test Window", id=3):
59
+ init_per_window_tab(demo)
60
+
61
+ with gr.Tab("🏅 Per Pattern", id=4):
62
+ init_per_pattern_tab(demo)
63
+
64
+ # with gr.Tab("📂 Archive", id=5):
65
+ # init_archive_tab(demo)
66
+
67
+
68
+ if __name__ == "__main__":
69
+ demo.launch()
pyproject.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.ruff]
2
+ # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
3
+ select = ["E", "F"]
4
+ ignore = ["E501"] # line too long (black is taking care of this)
5
+ line-length = 119
6
+ fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
7
+
8
+ [tool.isort]
9
+ profile = "black"
10
+ line_length = 119
11
+
12
+ [tool.black]
13
+ line-length = 119
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # TIME Leaderboard Dependencies
3
+ # =============================================================================
4
+
5
+ # Install timebench from public GitHub repo (always fetch latest from main)
6
+ timebench @ git+https://github.com/zqiao11/TIME.git@main
7
+
8
+ # Core dependencies - pinned to match local working environment
9
+ gradio==5.50.0
10
+ gradio_leaderboard==0.0.14
11
+ gradio_client==1.14.0
12
+ huggingface-hub==0.36.0
13
+ datasets==2.17.1
14
+ APScheduler
15
+ matplotlib
16
+ numpy==1.26.4
17
+ plotly==6.5.0
18
+ pandas==2.3.3
19
+ python-dateutil
20
+ python-dotenv
21
+ tqdm
22
+ pyarrow
23
+ pyyaml
24
+ scipy==1.11.4
25
+
26
+ # Note: gluonts and other timebench dependencies are automatically installed
27
+ # via the timebench package
src/about.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+ from pathlib import Path
4
+ import pandas as pd
5
+
6
+ # Import HuggingFace Hub configuration
7
+ from src.hf_config import get_results_root, get_config_root, get_features_root, initialize_data
8
+
9
+ from src.utils import (
10
+ get_all_datasets_results, get_all_domains_and_freq, get_all_variates_results,
11
+ get_dataset_choices, get_dataset_display_map, compute_ranks,
12
+ load_features, load_all_features, binarize_features
13
+ )
14
+
15
+
16
+ # =============================================================================
17
+ # Initialize data from HuggingFace Hub (or local for development)
18
+ # =============================================================================
19
+ print("🚀 Starting TIME Leaderboard initialization...")
20
+
21
+ # Download/cache results and config from HuggingFace Hub
22
+ RESULTS_ROOT, CONFIG_ROOT = initialize_data()
23
+
24
+ # Get features root (local or HF)
25
+ FEATURES_ROOT = get_features_root()
26
+
27
+ # Get list of all models from results directory
28
+ ALL_MODELS = []
29
+ if RESULTS_ROOT.exists():
30
+ ALL_MODELS = [p.name for p in RESULTS_ROOT.iterdir() if p.is_dir()]
31
+ print(f"📊 Found {len(ALL_MODELS)} models: {ALL_MODELS}")
32
+
33
+ # ---------------------------------------------------
34
+ # Get dataset choices from TIME results (with smart display names)
35
+ DATASET_CHOICES, DATASET_DISPLAY_TO_ID, DATASET_ID_TO_DISPLAY = get_dataset_choices(str(RESULTS_ROOT))
36
+ print(f"📁 Found {len(DATASET_CHOICES)} dataset configurations")
37
+
38
+ # === Load data once at startup ===
39
+ DATASETS_DF = get_all_datasets_results(root_dir=str(RESULTS_ROOT))
40
+ if not DATASETS_DF.empty:
41
+ # Use dataset_id (dataset/freq) for ranking to correctly handle multi-freq datasets
42
+ DATASETS_DF = compute_ranks(DATASETS_DF, groupby_cols=['dataset_id', "horizon"]) # Rows: 每一行是1个独立的实验 num_model x num_dataset_id x num_horizons
43
+ print(f"✅ Loaded {len(DATASETS_DF)} dataset results")
44
+
45
+ # === Load variate-level results for pattern-based leaderboard ===
46
+ print("📊 Loading variate-level results...")
47
+ VARIATES_DF = get_all_variates_results(root_dir=str(RESULTS_ROOT))
48
+ if not VARIATES_DF.empty:
49
+ # Compute ranks per (dataset_id, series_name, variate_name, horizon)
50
+ VARIATES_DF = compute_ranks(VARIATES_DF, groupby_cols=['dataset_id', 'series_name', 'variate_name', 'horizon'])
51
+ print(f"✅ Loaded {len(VARIATES_DF)} variate-level results")
52
+ else:
53
+ print("⚠️ No variate-level results found")
54
+
55
+ # === Load features for pattern-based filtering ===
56
+ print("📊 Loading features...")
57
+ FEATURES_DF = load_all_features(features_root=str(FEATURES_ROOT), split="test")
58
+ if not FEATURES_DF.empty:
59
+ print(f"✅ Loaded {len(FEATURES_DF)} variate features")
60
+ else:
61
+ print("⚠️ No features found")
62
+
63
+ # Columns to exclude from binarization
64
+ BINARIZE_EXCLUDE = [
65
+ 'dataset_id', 'series_name', 'variate_name', 'unique_id',
66
+ 'mean', 'std', 'length',
67
+ 'period1', 'period2', 'period3',
68
+ 'p_strength1', 'p_strength2', 'p_strength3',
69
+ 'missing_rate',
70
+ # Meta features are already 0/1, handle separately
71
+ 'is_random_walk', 'has_spike_presence',
72
+ ]
73
+
74
+ # Binarize numeric features by median
75
+ FEATURES_BOOL_DF = pd.DataFrame()
76
+ if not FEATURES_DF.empty:
77
+ FEATURES_BOOL_DF = binarize_features(FEATURES_DF, exclude=BINARIZE_EXCLUDE)
78
+ print(f"✅ Binarized features for {len(FEATURES_BOOL_DF)} variates")
79
+
80
+
81
+ if not DATASETS_DF.empty:
82
+ OVERALL_TABLE_COLUMNS = ["model", "MASE", "CRPS", "MASE_rank", "CRPS_rank"]
83
+ else:
84
+ OVERALL_TABLE_COLUMNS = ["model", "MASE", "CRPS"]
85
+
86
+
87
+ ALL_HORIZONS = ['short', 'medium', 'long']
88
+
89
+ # Pattern mapping: UI pattern name -> feature column name
90
+ PATTERN_MAP = {
91
+ # Trend patterns
92
+ "T_strength": "trend_strength",
93
+ "T_linearity": "linearity",
94
+ "T_curvature": "curvature",
95
+ # Seasonal patterns
96
+ "S_strength": "seasonal_strength",
97
+ "S_complexity": "seasonal_entropy",
98
+ "S_corr": "seasonal_corr",
99
+ # Residual patterns
100
+ "R_diff1_ACF1": "e_diff1_acf1",
101
+ "R_ACF1": "e_acf1",
102
+ # Meta patterns
103
+ "stationarity": "is_random_walk", # Note: stationarity = NOT is_random_walk
104
+ "outlier_presence": "has_spike_presence",
105
+ "complexity": "x_entropy", # High entropy = low predictability/high noise
106
+ }
107
+ # ---------------------------------------------------
108
+
109
+
110
+ # Your leaderboard name
111
+ TITLE = """<h1 align="center" id="space-title"> It's TIME</h1>"""
112
+
113
+ # What does your leaderboard evaluate?
114
+ INTRODUCTION_TEXT = """
115
+ TIME introduces a unified benchmark for time series probabilistic forecasting that supports evaluation at **multiple granularities**, ranging from overall performance across datasets to dataset-level, variate-level, and even individual test windows (with visualization). Beyond conventional analysis, the benchmark enables **pattern-driven, cross-dataset benchmarking** by grouping variates with similar temporal features, where patterns are defined based on groups of tsfeatures that capture properties such as trend, seasonality, and stationarity, offering a more systematic understanding of model behavior. For data and results, please refer to 🤗 [dataset](https://huggingface.co/datasets/TIME-benchmark/TIME-1.0/tree/main).
116
+ """
117
+ # An integrated archive further enriches the platform by providing structural tsfeatures and statistical descriptors of all variates,
118
+ # ensuring both comprehensive evaluation and transparent interpretability across diverse forecasting scenarios
119
+ print("✅ TIME Leaderboard initialization complete!")
src/display.egg-info/PKG-INFO ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: display
3
+ Version: 0.0.0
src/display.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ src/about.py
4
+ src/hf_config.py
5
+ src/leaderboard.py
6
+ src/tab.py
7
+ src/utils.py
8
+ src/display/css_html_js.py
9
+ src/display/formatting.py
10
+ src/display/utils.py
11
+ src/display.egg-info/PKG-INFO
12
+ src/display.egg-info/SOURCES.txt
13
+ src/display.egg-info/dependency_links.txt
14
+ src/display.egg-info/top_level.txt
src/display.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/display.egg-info/top_level.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ about
2
+ display
3
+ hf_config
4
+ leaderboard
5
+ tab
6
+ utils
src/display/css_html_js.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ custom_css = """
2
+
3
+ /* ========== 响应式布局 ========== */
4
+ /* 移除固定宽度,让Gradio自动适配不同屏幕尺寸 */
5
+ .gradio-container {
6
+ width: 100% !important;
7
+ max-width: 100% !important;
8
+ }
9
+
10
+ /* 主体内容区域自适应 */
11
+ .main, .contain {
12
+ width: 100% !important;
13
+ max-width: 100% !important;
14
+ }
15
+
16
+ /* Tab 内容区域自适应 */
17
+ .tabitem {
18
+ width: 100% !important;
19
+ max-width: 100% !important;
20
+ }
21
+
22
+ /* Plot 组件自适应,但保持最小可读宽度 */
23
+ .js-plotly-plot, .plotly {
24
+ width: 100% !important;
25
+ max-width: 100% !important;
26
+ min-width: 300px !important; /* 保持最小可读宽度 */
27
+ }
28
+
29
+ /* ========== 原有样式 ========== */
30
+ .markdown-text {
31
+ font-size: 20px !important;
32
+ }
33
+
34
+ /* 只影响 Tabs 按钮 */
35
+ #custom-tabs [role="tab"] {
36
+ font-size: 20px;
37
+ }
38
+
39
+ /* ✅ 只影响表格 */
40
+ .custom-table table thead th {
41
+ font-size: 16px;
42
+ font-weight: 600; /* 想要普通就改成 400 */
43
+ text-align: center;
44
+ }
45
+
46
+ .custom-table table tbody td {
47
+ font-size: 14px;
48
+ }
49
+
50
+ /* 响应式表格布局 */
51
+ .custom-table table {
52
+ table-layout: auto; /* 使用自动布局,让表格自适应 */
53
+ width: 100%; /* 占满容器 */
54
+ min-width: 100%; /* 确保至少占满容器 */
55
+ }
56
+
57
+ /* 表格容器允许横向滚动(当内容过宽时) */
58
+ .custom-table {
59
+ overflow-x: auto; /* 当表格内容过宽时,允许横向滚动 */
60
+ width: 100%;
61
+ }
62
+
63
+ /* 为不同列设置合适的宽度(使用相对单位,更灵活) */
64
+ .custom-table table th:nth-child(1),
65
+ .custom-table table td:nth-child(1) {
66
+ min-width: 150px; /* model 列最小宽度 */
67
+ max-width: 250px; /* 最大宽度限制 */
68
+ }
69
+
70
+ /* 指标列(MASE, CRPS, MAE, MSE) */
71
+ .custom-table table th:nth-child(2),
72
+ .custom-table table td:nth-child(2),
73
+ .custom-table table th:nth-child(3),
74
+ .custom-table table td:nth-child(3),
75
+ .custom-table table th:nth-child(4),
76
+ .custom-table table td:nth-child(4),
77
+ .custom-table table th:nth-child(5),
78
+ .custom-table table td:nth-child(5) {
79
+ min-width: 80px; /* 原始指标列最小宽度 */
80
+ max-width: 120px;
81
+ }
82
+
83
+ /* 归一化指标列(MASE_norm, CRPS_norm, MAE_norm, MSE_norm) */
84
+ .custom-table table th:nth-child(6),
85
+ .custom-table table td:nth-child(6),
86
+ .custom-table table th:nth-child(7),
87
+ .custom-table table td:nth-child(7),
88
+ .custom-table table th:nth-child(8),
89
+ .custom-table table td:nth-child(8),
90
+ .custom-table table th:nth-child(9),
91
+ .custom-table table td:nth-child(9) {
92
+ min-width: 100px; /* 归一化指标列最小宽度 */
93
+ max-width: 150px;
94
+ }
95
+
96
+ /* 排名列(MASE_rank, CRPS_rank) */
97
+ .custom-table table th:nth-child(10),
98
+ .custom-table table td:nth-child(10),
99
+ .custom-table table th:nth-child(11),
100
+ .custom-table table td:nth-child(11) {
101
+ min-width: 80px; /* 排名列最小宽度 */
102
+ max-width: 120px;
103
+ }
104
+
105
+
106
+ #archive-table table thead th { font-size: 14px; font-weight: 400}
107
+ #archive-table table {
108
+ table-layout: fixed; /* 强制固定布局 */
109
+ width: 100%; /* 占满容器 */
110
+ }
111
+
112
+ #archive-table table th:nth-child(1),
113
+ #archive-table table td:nth-child(1) {
114
+ width: 160px !important; /* dataset */
115
+ }
116
+
117
+ #archive-table table th:nth-child(2),
118
+ #archive-table table td:nth-child(2) {
119
+ width: 100px !important; /* variate_name */
120
+ }
121
+
122
+ #archive-table table th:nth-child(3),
123
+ #archive-table table td:nth-child(3) {
124
+ width: 60px !important; /* freq */
125
+ }
126
+
127
+ #archive-table table th:nth-child(4),
128
+ #archive-table table td:nth-child(4) {
129
+ width: 100px !important; /* domain */
130
+ }
131
+
132
+ /* 后面的特征列 */
133
+ #archive-table table th:nth-child(n+5),
134
+ #archive-table table td:nth-child(n+5) {
135
+ width: 120px !important;
136
+ }
137
+
138
+
139
+
140
+
141
+ #citation-button span {
142
+ font-size: 14px !important;
143
+ }
144
+
145
+ #citation-button textarea {
146
+ font-size: 16px !important;
147
+ }
148
+
149
+ #citation-button > label > button {
150
+ margin: 6px;
151
+ transform: scale(1.3);
152
+ }
153
+
154
+ #search-bar-table-box > div:first-child {
155
+ background: none;
156
+ border: none;
157
+ }
158
+
159
+ #search-bar {
160
+ padding: 0px;
161
+ }
162
+
163
+ """
164
+
165
+ # ToDO: markdown-text不好使...
166
+ # archive-table table thead th { font-size: 14px; font-weight: 400}
167
+
168
+ # /* 让表格遵守列宽、并能横向滚动 */
169
+ # #
src/display/formatting.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def model_hyperlink(model_link, code_link, model_name):
2
+ if model_link == "":
3
+ return model_name
4
+ # return f'<a target="_blank">{model_name}</a>'
5
+ # return f'<a target="_blank" href="{link}" rel="noopener noreferrer">{model_name}</a>'
6
+ else:
7
+ model_url = f'<a target="_blank" href="{model_link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
8
+ if code_link == "":
9
+ return model_url
10
+ else:
11
+ code_url = f'<a target="_blank" href="{code_link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">code</a>'
12
+ return f"{model_url} ({code_url})"
13
+ # return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a> | ' \
14
+ # f'<a target="_blank" href="https://www.google.com" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}_link2</a>'
15
+
16
+
17
+ def make_clickable_model(model_name):
18
+ link = f"https://huggingface.co/{model_name}"
19
+ return model_hyperlink(link, model_name)
20
+
21
+
22
+ def styled_error(error):
23
+ return f"<p style='color: red; font-size: 20px; text-align: center;'>{error}</p>"
24
+
25
+
26
+ def styled_warning(warn):
27
+ return f"<p style='color: orange; font-size: 20px; text-align: center;'>{warn}</p>"
28
+
29
+
30
+ def styled_message(message):
31
+ return f"<p style='color: green; font-size: 20px; text-align: center;'>{message}</p>"
32
+
33
+
34
+ def has_no_nan_values(df, columns):
35
+ return df[columns].notna().all(axis=1)
36
+
37
+
38
+ def has_nan_values(df, columns):
39
+ return df[columns].isna().any(axis=1)
src/display/utils.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, make_dataclass
2
+ from enum import Enum
3
+
4
+ import pandas as pd
5
+
6
+ from src.about import Tasks
7
+
8
+ def fields(raw_class):
9
+ return [v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"]
10
+
11
+
12
+ # These classes are for user facing column names,
13
+ # to avoid having to change them all around the code
14
+ # when a modif is needed
15
+ @dataclass
16
+ class ColumnContent:
17
+ name: str
18
+ type: str
19
+ displayed_by_default: bool
20
+ hidden: bool = False
21
+ never_hidden: bool = False
22
+
23
+ ## Leaderboard columns
24
+ archive_info_dict = []
25
+
26
+ archive_info_dict.append(["dataset", ColumnContent, ColumnContent("dataset", "markdown", True, never_hidden=True)])
27
+ archive_info_dict.append(["unique_id", ColumnContent, ColumnContent("unique_id", "str", True, never_hidden=True)])
28
+ archive_info_dict.append(["freq", ColumnContent, ColumnContent("freq", "str", True, never_hidden=True)])
29
+ archive_info_dict.append(["domain", ColumnContent, ColumnContent("domain", "str", True, never_hidden=True)])
30
+ # Raw features
31
+ archive_info_dict.append(["x_acf1", ColumnContent, ColumnContent("x_acf1", "number", False, False)])
32
+ archive_info_dict.append(["x_acf10", ColumnContent, ColumnContent("x_acf10", "number", False, False)])
33
+ archive_info_dict.append(["lumpiness", ColumnContent, ColumnContent("lumpiness", "number", False, False)])
34
+ archive_info_dict.append(["stability", ColumnContent, ColumnContent("stability", "number", False, False)])
35
+ archive_info_dict.append(["hurst", ColumnContent, ColumnContent("hurst", "number", False, False)])
36
+ archive_info_dict.append(["entropy", ColumnContent, ColumnContent("entropy", "number", False, False)])
37
+ # Trend features
38
+ archive_info_dict.append(["trend", ColumnContent, ColumnContent("trend_strength", "number", False, False)])
39
+ archive_info_dict.append(["trend_crossing_point_ratio", ColumnContent, ColumnContent("trend_xpoint_ratio", "number", False, False)])
40
+ archive_info_dict.append(["trend_stability", ColumnContent, ColumnContent("trend_stability", "number", False, False)])
41
+ archive_info_dict.append(["trend_lumpiness", ColumnContent, ColumnContent("trend_lumpiness", "number", False, False)])
42
+ archive_info_dict.append(["trend_hurst", ColumnContent, ColumnContent("trend_hurst", "number", False, False)])
43
+ archive_info_dict.append(["trend_entropy", ColumnContent, ColumnContent("trend_entropy", "number", False, False)])
44
+ # Seasonal features
45
+ archive_info_dict.append(["e_acf1", ColumnContent, ColumnContent("e_acf1", "number", False, False)])
46
+ archive_info_dict.append(["e_acf10", ColumnContent, ColumnContent("e_acf10", "number", False, False)])
47
+ archive_info_dict.append(["e_entropy", ColumnContent, ColumnContent("e_entropy", "number", False, False)])
48
+ archive_info_dict.append(["e_hurst", ColumnContent, ColumnContent("e_hurst", "number", False, False)])
49
+ archive_info_dict.append(["e_lumpiness", ColumnContent, ColumnContent("e_lumpiness", "number", False, False)])
50
+ archive_info_dict.append(["e_outlier_ratio", ColumnContent, ColumnContent("e_outlier_ratio", "number", False, False)])
51
+ # Remainder features
52
+ archive_info_dict.append(["seasonal_strength", ColumnContent, ColumnContent("seasonal_strength", "number", False, False)])
53
+ archive_info_dict.append(["seasonality_corr", ColumnContent, ColumnContent("seasonality_corr", "number", False, False)])
54
+ archive_info_dict.append(["seasonal_stability", ColumnContent, ColumnContent("seasonal_stability", "number", False, False)])
55
+ archive_info_dict.append(["seasonal_lumpiness", ColumnContent, ColumnContent("seasonal_lumpiness", "number", False, False)])
56
+ archive_info_dict.append(["seasonal_hurst", ColumnContent, ColumnContent("seasonal_hurst", "number", False, False)])
57
+ archive_info_dict.append(["seasonal_entropy", ColumnContent, ColumnContent("seasonal_entropy", "number", False, False)])
58
+ # Statistics
59
+ archive_info_dict.append(["mean", ColumnContent, ColumnContent("mean", "number", False, False)])
60
+ archive_info_dict.append(["std", ColumnContent, ColumnContent("std", "number", False, False)])
61
+ archive_info_dict.append(["missing_rate", ColumnContent, ColumnContent("missing_rate", "number", False, False)])
62
+ archive_info_dict.append(["length", ColumnContent, ColumnContent("length", "number", False, False)])
63
+ archive_info_dict.append(["period1", ColumnContent, ColumnContent("period1", "number", False, False)])
64
+ archive_info_dict.append(["period2", ColumnContent, ColumnContent("period2", "number", False, False)])
65
+ archive_info_dict.append(["period3", ColumnContent, ColumnContent("period3", "number", False, False)])
66
+ archive_info_dict.append(["p_strength1", ColumnContent, ColumnContent("p_strength1", "number", False, False)])
67
+ archive_info_dict.append(["p_strength2", ColumnContent, ColumnContent("p_strength2", "number", False, False)])
68
+ archive_info_dict.append(["p_strength3", ColumnContent, ColumnContent("p_strength3", "number", False, False)])
69
+
70
+ ArchiveInfoColumn = make_dataclass("ArchiveInfoColumn", archive_info_dict, frozen=True)
71
+
72
+
73
+
74
+
75
+
76
+
77
+
78
+
79
+
80
+ model_info_dict = []
81
+ # Init column for the model properties
82
+ model_info_dict.append(["model_type_symbol", ColumnContent, ColumnContent("T", "str", True, never_hidden=True)])
83
+ model_info_dict.append(["model", ColumnContent, ColumnContent("model", "markdown", True, never_hidden=True)])
84
+ # Model information
85
+ model_info_dict.append(["model_type", ColumnContent, ColumnContent("Type", "str", False, True)])
86
+ model_info_dict.append(["precision", ColumnContent, ColumnContent("Precision", "str", False, True)])
87
+ model_info_dict.append(["license", ColumnContent, ColumnContent("Hub License", "str", False, True)])
88
+ model_info_dict.append(["params", ColumnContent, ColumnContent("#Params (B)", "number", False, True)])
89
+ model_info_dict.append(["likes", ColumnContent, ColumnContent("Hub ❤️", "number", False, True)])
90
+ model_info_dict.append(["still_on_hub", ColumnContent, ColumnContent("Available on the hub", "bool", False)])
91
+ model_info_dict.append(["org", ColumnContent, ColumnContent("Organization", "str", True, hidden=False)])
92
+ model_info_dict.append(["testdata_leakage", ColumnContent, ColumnContent("TestData Leakage", "str", True, hidden=False)])
93
+
94
+ # We use make dataclass to dynamically fill the scores from Tasks
95
+ ModelInfoColumn = make_dataclass("ModelInfoColumn", model_info_dict, frozen=True)
96
+
97
+ ## For the queue columns in the submission tab
98
+ @dataclass(frozen=True)
99
+ class EvalQueueColumn: # Queue column
100
+ model = ColumnContent("model", "markdown", True)
101
+ revision = ColumnContent("revision", "str", True)
102
+ private = ColumnContent("private", "bool", True)
103
+ precision = ColumnContent("precision", "str", True)
104
+ weight_type = ColumnContent("weight_type", "str", "Original")
105
+ status = ColumnContent("status", "str", True)
106
+
107
+ ## All the model information that we might need
108
+ @dataclass
109
+ class ModelDetails:
110
+ name: str
111
+ display_name: str = ""
112
+ symbol: str = "" # emoji
113
+
114
+
115
+ class ModelType(Enum):
116
+ PT = ModelDetails(name="🟢 pretrained", symbol="🟢")
117
+ ZT = ModelDetails(name="🔴 zero-shot", symbol="🔴")
118
+ FT = ModelDetails(name="🟣 fine-tuned", symbol="🟣")
119
+ AG = ModelDetails(name="🟡 agentic", symbol="🟡")
120
+ DL = ModelDetails(name="🔷 deep-learning", symbol="🔷")
121
+ ST = ModelDetails(name="🔶 statistical", symbol="🔶")
122
+
123
+
124
+ Unknown = ModelDetails(name="", symbol="?")
125
+
126
+ def to_str(self, separator=" "):
127
+ return f"{self.value.symbol}{separator}{self.value.name}"
128
+
129
+ @staticmethod
130
+ def from_str(type):
131
+ if "fine-tuned" in type or "🔶" in type:
132
+ return ModelType.FT
133
+ if "pretrained" in type or "🟢" in type:
134
+ return ModelType.PT
135
+ if "zero-shot" in type or "🔴" in type:
136
+ return ModelType.ZT
137
+ if "agentic" in type or "🟡" in type:
138
+ return ModelType.AG
139
+ if "deep-learning" in type or "🟦" in type:
140
+ return ModelType.DL
141
+ if "statistical" in type or "🟣" in type:
142
+ return ModelType.ST
143
+ return ModelType.Unknown
144
+
145
+ class WeightType(Enum):
146
+ Adapter = ModelDetails("Adapter")
147
+ Original = ModelDetails("Original")
148
+ Delta = ModelDetails("Delta")
149
+
150
+ class Precision(Enum):
151
+ float16 = ModelDetails("float16")
152
+ bfloat16 = ModelDetails("bfloat16")
153
+ Unknown = ModelDetails("?")
154
+
155
+ def from_str(precision):
156
+ if precision in ["torch.float16", "float16"]:
157
+ return Precision.float16
158
+ if precision in ["torch.bfloat16", "bfloat16"]:
159
+ return Precision.bfloat16
160
+ return Precision.Unknown
161
+
162
+ # Column selection
163
+ MODEL_INFO_COLS = [c.name for c in fields(ModelInfoColumn) if not c.hidden]
164
+
165
+ EVAL_COLS = [c.name for c in fields(EvalQueueColumn)]
166
+ EVAL_TYPES = [c.type for c in fields(EvalQueueColumn)]
167
+
168
+ BENCHMARK_COLS = [t.value.col_name for t in Tasks]
169
+
src/hf_config.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Hub Configuration and Helper Functions
3
+
4
+ This module provides configuration and utilities for loading data from HuggingFace Hub.
5
+ The data is cached locally after first download, so subsequent accesses are fast.
6
+ """
7
+
8
+ import os
9
+ from functools import lru_cache
10
+ from pathlib import Path
11
+
12
+ from huggingface_hub import snapshot_download
13
+
14
+ # =============================================================================
15
+ # Configuration
16
+ # =============================================================================
17
+
18
+ # HuggingFace Dataset repository ID
19
+ HF_REPO_ID = os.environ.get("HF_REPO_ID", "TIME-benchmark/TIME-1.0")
20
+
21
+ # HuggingFace token (set via environment variable for security)
22
+ # In HuggingFace Space, set this in Settings -> Repository secrets
23
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
24
+
25
+ # Whether to use HuggingFace Hub (True) or local files (False)
26
+ # Set to False for local development with local data
27
+ USE_HF_HUB = os.environ.get("USE_HF_HUB", "true").lower() == "true"
28
+
29
+ # Local cache directory for HF Hub downloads
30
+ HF_CACHE_DIR = os.environ.get("HF_CACHE_DIR", None) # None uses default ~/.cache/huggingface
31
+
32
+ # Local data paths (used when USE_HF_HUB=false)
33
+ # Set these environment variables to specify custom local paths
34
+ LOCAL_RESULTS_PATH = os.environ.get("LOCAL_RESULTS_PATH", None) # Path to output/results
35
+ LOCAL_FEATURES_PATH = os.environ.get("LOCAL_FEATURES_PATH", None) # Path to output/features
36
+ LOCAL_CONFIG_PATH = os.environ.get("LOCAL_CONFIG_PATH", None) # Path to config directory
37
+ LOCAL_DATASETS_PATH = os.environ.get("LOCAL_DATASETS_PATH", None) # Path to data/hf_dataset
38
+
39
+ # =============================================================================
40
+ # Helper Functions
41
+ # =============================================================================
42
+
43
+ @lru_cache(maxsize=1)
44
+ def download_results_snapshot() -> Path:
45
+ """
46
+ Download the results directory from HuggingFace Hub.
47
+ Uses caching - only downloads once, then returns cached path.
48
+
49
+ Returns:
50
+ Path: Local path to the downloaded results directory
51
+ """
52
+ if not USE_HF_HUB:
53
+ # Return local path for development
54
+ # Priority: 1) LOCAL_RESULTS_PATH env var, 2) ../output/results, 3) /home/eee/qzz/TIME/output/results
55
+ if LOCAL_RESULTS_PATH:
56
+ local_path = Path(LOCAL_RESULTS_PATH)
57
+ else:
58
+ local_path = Path("../output/results")
59
+ if not local_path.exists():
60
+ local_path = Path("/home/eee/qzz/TIME/output/results")
61
+ if not local_path.exists():
62
+ print(f"⚠️ Warning: Local results path does not exist: {local_path}")
63
+ return local_path
64
+
65
+ print(f"📥 Downloading results from HuggingFace Hub: {HF_REPO_ID}")
66
+
67
+ local_dir = snapshot_download(
68
+ repo_id=HF_REPO_ID,
69
+ repo_type="dataset",
70
+ token=HF_TOKEN,
71
+ allow_patterns=["output/results/**"],
72
+ cache_dir=HF_CACHE_DIR,
73
+ )
74
+
75
+ results_path = Path(local_dir) / "output" / "results"
76
+ print(f"✅ Results cached at: {results_path}")
77
+ return results_path
78
+
79
+
80
+ @lru_cache(maxsize=1)
81
+ def download_datasets_snapshot() -> Path:
82
+ """
83
+ Download the hf_dataset directory from HuggingFace Hub.
84
+ Uses caching - only downloads once, then returns cached path.
85
+
86
+ Returns:
87
+ Path: Local path to the downloaded hf_dataset directory
88
+ """
89
+ if not USE_HF_HUB:
90
+ # Return local path for development
91
+ # Priority: 1) LOCAL_DATASETS_PATH env var, 2) ../data/hf_dataset, 3) /home/eee/qzz/TIME/data/hf_dataset
92
+ if LOCAL_DATASETS_PATH:
93
+ local_path = Path(LOCAL_DATASETS_PATH)
94
+ else:
95
+ local_path = Path("../data/hf_dataset")
96
+ if not local_path.exists():
97
+ local_path = Path("/home/eee/qzz/TIME/data/hf_dataset")
98
+ if not local_path.exists():
99
+ print(f"⚠️ Warning: Local datasets path does not exist: {local_path}")
100
+ return local_path
101
+
102
+ print(f"📥 Downloading datasets from HuggingFace Hub: {HF_REPO_ID}")
103
+
104
+ local_dir = snapshot_download(
105
+ repo_id=HF_REPO_ID,
106
+ repo_type="dataset",
107
+ token=HF_TOKEN,
108
+ allow_patterns=["data/hf_dataset/**"],
109
+ cache_dir=HF_CACHE_DIR,
110
+ )
111
+
112
+ datasets_path = Path(local_dir) / "data" / "hf_dataset"
113
+ print(f"✅ Datasets cached at: {datasets_path}")
114
+ return datasets_path
115
+
116
+
117
+ def download_config_snapshot() -> Path:
118
+ """
119
+ Get the config directory from the installed timebench package.
120
+
121
+ The config (datasets.yaml) is bundled with the timebench package,
122
+ so no download is needed - we just use the installed package's config.
123
+
124
+ Returns:
125
+ Path: Local path to the config directory
126
+ """
127
+ # Try to get config from installed timebench package
128
+ try:
129
+ from timebench.evaluation.data import DEFAULT_CONFIG_PATH
130
+ config_path = DEFAULT_CONFIG_PATH.parent # Get the config directory
131
+ if config_path.exists():
132
+ # print(f"📁 Using config from timebench package: {config_path}")
133
+ return config_path
134
+ except ImportError:
135
+ print(f"❌ ImportError: {ImportError}, using local config")
136
+ pass
137
+
138
+ # Fallback: Local development path
139
+ # Priority: 1) LOCAL_CONFIG_PATH env var, 2) ../config, 3) /home/eee/qzz/TIME/config
140
+ if LOCAL_CONFIG_PATH:
141
+ local_path = Path(LOCAL_CONFIG_PATH)
142
+ else:
143
+ local_path = Path("../config")
144
+ if not local_path.exists():
145
+ local_path = Path("/home/eee/qzz/TIME/config")
146
+
147
+ if local_path.exists():
148
+ print(f"📁 Using local config: {local_path}")
149
+ return local_path
150
+
151
+ raise FileNotFoundError(
152
+ "Config directory not found. Please ensure timebench is installed, "
153
+ "set USE_HF_HUB=false for local development, "
154
+ "or set LOCAL_CONFIG_PATH environment variable to point to your config directory."
155
+ )
156
+
157
+
158
+ def get_results_root() -> Path:
159
+ """Get the root path for results (handles both HF Hub and local)."""
160
+ return download_results_snapshot()
161
+
162
+
163
+ def get_datasets_root() -> Path:
164
+ """Get the root path for hf_dataset (handles both HF Hub and local)."""
165
+ return download_datasets_snapshot()
166
+
167
+
168
+ def get_config_root() -> Path:
169
+ """Get the root path for config (handles both HF Hub and local)."""
170
+ return download_config_snapshot()
171
+
172
+
173
+ def get_features_root() -> Path:
174
+ """
175
+ Get the root path for features (handles both HF Hub and local).
176
+
177
+ Features are stored at output/features/{dataset}/{freq}/test.csv
178
+
179
+ Returns:
180
+ Path: Local path to the features directory
181
+ """
182
+ if not USE_HF_HUB:
183
+ # Return local path for development
184
+ # Priority: 1) LOCAL_FEATURES_PATH env var, 2) ../output/features, 3) /home/eee/qzz/TIME/output/features
185
+ if LOCAL_FEATURES_PATH:
186
+ local_path = Path(LOCAL_FEATURES_PATH)
187
+ else:
188
+ local_path = Path("../output/features")
189
+ if not local_path.exists():
190
+ local_path = Path("/home/eee/qzz/TIME/output/features")
191
+ if not local_path.exists():
192
+ print(f"⚠️ Warning: Local features path does not exist: {local_path}")
193
+ return local_path
194
+
195
+ # For HF Hub, features are in the same repo as results
196
+ print(f"📥 Downloading features from HuggingFace Hub: {HF_REPO_ID}")
197
+
198
+ local_dir = snapshot_download(
199
+ repo_id=HF_REPO_ID,
200
+ repo_type="dataset",
201
+ token=HF_TOKEN,
202
+ allow_patterns=["output/features/**"],
203
+ cache_dir=HF_CACHE_DIR,
204
+ )
205
+
206
+ features_path = Path(local_dir) / "output" / "features"
207
+ print(f"✅ Features cached at: {features_path}")
208
+ return features_path
209
+
210
+
211
+ def clear_cache():
212
+ """Clear the LRU cache to force re-download on next access."""
213
+ download_results_snapshot.cache_clear()
214
+ download_datasets_snapshot.cache_clear()
215
+
216
+
217
+ # =============================================================================
218
+ # Initialization - Download data at module import
219
+ # =============================================================================
220
+
221
+ def initialize_data():
222
+ """
223
+ Initialize by downloading all necessary data.
224
+ Call this at app startup to pre-download data.
225
+ """
226
+ print("🚀 Initializing TIME Leaderboard data...")
227
+
228
+ # Download results (required for leaderboard)
229
+ results_root = get_results_root()
230
+ print(f" Results: {results_root}")
231
+
232
+ # Download config (required for dataset settings)
233
+ config_root = get_config_root()
234
+ print(f" Config: {config_root}")
235
+
236
+ # Note: Datasets are downloaded on-demand when visualization is needed
237
+ # to reduce initial load time
238
+
239
+ print("✅ Initialization complete!")
240
+ return results_root, config_root
241
+
src/leaderboard.py ADDED
@@ -0,0 +1,1085 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import pandas as pd
4
+ import numpy as np
5
+ import json
6
+ import gradio as gr
7
+ from typing import List, Tuple, Optional
8
+ from scipy import stats
9
+ from src.about import DATASETS_DF, OVERALL_TABLE_COLUMNS, ALL_MODELS, RESULTS_ROOT, DATASET_DISPLAY_TO_ID
10
+ from src.hf_config import get_datasets_root, get_config_root
11
+ from src.utils import normalize_by_seasonal_naive
12
+ # FEATURES_DF, FEATURES_BOOL_DF, VARIATES_DF VARIATE_COLUMNS
13
+ import ast
14
+ from timebench.evaluation.data import Dataset, get_dataset_settings, load_dataset_config
15
+ from pathlib import Path
16
+
17
+ def resolve_dataset_id(display_name: str) -> str:
18
+ """
19
+ Convert a display name to dataset_id.
20
+
21
+ Args:
22
+ display_name: Either a display_name from UI or a dataset_id directly
23
+
24
+ Returns:
25
+ dataset_id in format "dataset/freq"
26
+ """
27
+ # If it's in the mapping, use the mapping
28
+ if display_name in DATASET_DISPLAY_TO_ID:
29
+ return DATASET_DISPLAY_TO_ID[display_name]
30
+ # Otherwise assume it's already a dataset_id
31
+ return display_name
32
+
33
+
34
+ def find_dataset_term_path(results_root, model_name, display_name):
35
+ """
36
+ Find the dataset_term path that matches the display_name or dataset_id.
37
+ Returns the path string (e.g., "Traffic/15T") which is dataset/freq, or None.
38
+ Path structure: results/{model_name}/{dataset}/{freq}/{horizon}/
39
+
40
+ Args:
41
+ results_root: Root directory for results
42
+ model_name: Model name
43
+ display_name: Display name from UI (could be "Traffic" or "Traffic/15T")
44
+ or dataset_id directly
45
+ """
46
+ model_dir = os.path.join(results_root, model_name)
47
+ if not os.path.exists(model_dir):
48
+ return None
49
+
50
+ # Resolve display_name to dataset_id
51
+ dataset_id = resolve_dataset_id(display_name)
52
+
53
+ # Check if dataset_id is in format "dataset/freq"
54
+ if "/" in dataset_id:
55
+ # Direct lookup: dataset_id is already dataset/freq
56
+ dataset_name, freq = dataset_id.split("/", 1)
57
+ freq_path = os.path.join(model_dir, dataset_name, freq)
58
+ if os.path.isdir(freq_path):
59
+ # Verify it has horizon directories
60
+ for horizon in ["short", "medium", "long"]:
61
+ config_path = os.path.join(freq_path, horizon, "config.json")
62
+ if os.path.exists(config_path):
63
+ return dataset_id
64
+ return None
65
+
66
+ # Legacy fallback: dataset_name only (find first freq)
67
+ dataset_name = dataset_id
68
+ for dataset_dir_name in os.listdir(model_dir):
69
+ dataset_dir = os.path.join(model_dir, dataset_dir_name)
70
+ if not os.path.isdir(dataset_dir):
71
+ continue
72
+
73
+ if dataset_dir_name == dataset_name:
74
+ # Check freq subdirectories
75
+ for freq_dir in os.listdir(dataset_dir):
76
+ freq_path = os.path.join(dataset_dir, freq_dir)
77
+ if not os.path.isdir(freq_path):
78
+ continue
79
+
80
+ for horizon in ["short", "medium", "long"]:
81
+ config_path = os.path.join(freq_path, horizon, "config.json")
82
+ if os.path.exists(config_path):
83
+ return f"{dataset_dir_name}/{freq_dir}"
84
+
85
+ return None
86
+
87
+ def load_test_windows(display_name, horizon, model_name="moirai_small", series=None, variate=None, window_id=None, parse_series=False):
88
+ """
89
+ Load test window results from TIME NPZ files.
90
+
91
+ Args:
92
+ display_name: Dataset display name from UI (will be converted to dataset_id)
93
+ horizon: Horizon name (short, medium, long)
94
+ model_name: Model name
95
+ series: Optional series name (string) or index (int/string) to filter
96
+ variate: Optional variate name (string) or index (int/string) to filter
97
+ window_id: Optional window_id to filter
98
+ parse_series: If True, include label and quantile predictions as lists
99
+
100
+ Returns:
101
+ pd.DataFrame with columns: series_name, variate_name, window_id, MASE, CRPS, MAE, MSE,
102
+ and optionally label, quantile[...] if parse_series=True
103
+ """
104
+ results_root = str(RESULTS_ROOT)
105
+
106
+ # Find the dataset_term directory (handles display_name -> dataset_id conversion)
107
+ dataset_term = find_dataset_term_path(results_root, model_name, display_name)
108
+ if dataset_term is None:
109
+ return None
110
+
111
+ horizon_dir = os.path.join(results_root, model_name, dataset_term, horizon)
112
+ metrics_path = os.path.join(horizon_dir, "metrics.npz")
113
+
114
+ # Load data
115
+ metrics = np.load(metrics_path)
116
+
117
+ # Get array shapes
118
+ mase_arr = metrics["MASE"] # (num_series, num_windows, num_variates)
119
+ num_series, num_windows, num_variates = mase_arr.shape
120
+
121
+ # Load Dataset to get actual names
122
+ series_names = None
123
+ variate_names = None
124
+ try:
125
+ # Use HF config to get dataset root (handles both local and HF Hub)
126
+ hf_dataset_root = str(get_datasets_root())
127
+
128
+ if os.path.exists(hf_dataset_root):
129
+ # Use HF config to get config root
130
+ config_root = get_config_root()
131
+ config_path = config_root / "datasets.yaml"
132
+
133
+ config = load_dataset_config(config_path) if config_path.exists() else {}
134
+ settings = get_dataset_settings(dataset_term, horizon, config)
135
+
136
+ prediction_length = settings.get("prediction_length")
137
+ test_length = settings.get("test_length")
138
+
139
+ # Load dataset with storage_path parameter
140
+ dataset_obj = Dataset(
141
+ name=dataset_term,
142
+ term=horizon,
143
+ prediction_length=prediction_length,
144
+ test_length=test_length,
145
+ storage_path=hf_dataset_root,
146
+ )
147
+
148
+ # Get series names (item_id)
149
+ if "item_id" in dataset_obj.hf_dataset.column_names:
150
+ series_names = dataset_obj.hf_dataset["item_id"]
151
+ else:
152
+ series_names = [dataset_obj.hf_dataset[i].get("item_id", f"item_{i}")
153
+ for i in range(len(dataset_obj.hf_dataset))]
154
+
155
+ # Get variate names
156
+ variate_names = dataset_obj.get_variate_names()
157
+ if variate_names is None:
158
+ # Univariate mode: variate names are same as series names
159
+ variate_names = series_names
160
+ except Exception as e:
161
+ print(f"Error loading Dataset for names: {e}")
162
+
163
+ # Create name to index mappings
164
+ series_name_to_idx = {}
165
+ variate_name_to_idx = {}
166
+ if series_names is not None:
167
+ series_name_to_idx = {name: idx for idx, name in enumerate(series_names)}
168
+ if variate_names is not None:
169
+ variate_name_to_idx = {name: idx for idx, name in enumerate(variate_names)}
170
+
171
+ # Convert series and variate names to indices if they are names
172
+ series_idx_filter = None
173
+ if series is not None:
174
+ if series in series_name_to_idx:
175
+ series_idx_filter = series_name_to_idx[series]
176
+ else:
177
+ # Try as index
178
+ try:
179
+ series_idx_filter = int(series)
180
+ except ValueError:
181
+ pass
182
+
183
+ variate_idx_filter = None
184
+ if variate is not None:
185
+ if variate in variate_name_to_idx:
186
+ variate_idx_filter = variate_name_to_idx[variate]
187
+ else:
188
+ # Try as index
189
+ try:
190
+ variate_idx_filter = int(variate)
191
+ except ValueError:
192
+ pass
193
+
194
+ # Build DataFrame row by row
195
+ rows = []
196
+ for series_idx in range(num_series):
197
+ # Filter by series if specified
198
+ if series_idx_filter is not None:
199
+ if series_idx != series_idx_filter:
200
+ continue
201
+
202
+ series_name = series_names[series_idx] if series_names is not None else str(series_idx)
203
+
204
+ for window_idx in range(num_windows):
205
+ # Filter by window_id if specified
206
+ if window_id is not None:
207
+ if window_idx != int(window_id):
208
+ continue
209
+
210
+ for variate_idx in range(num_variates):
211
+ # Filter by variate if specified
212
+ if variate_idx_filter is not None:
213
+ if variate_idx != variate_idx_filter:
214
+ continue
215
+
216
+ variate_name = variate_names[variate_idx] if variate_names is not None else str(variate_idx)
217
+
218
+ row = {
219
+ "series_name": series_name,
220
+ "variate_name": variate_name,
221
+ "window_id": window_idx,
222
+ "MASE": float(mase_arr[series_idx, window_idx, variate_idx]),
223
+ "CRPS": float(metrics["CRPS"][series_idx, window_idx, variate_idx]),
224
+ "MAE": float(metrics["MAE"][series_idx, window_idx, variate_idx]),
225
+ "MSE": float(metrics["MSE"][series_idx, window_idx, variate_idx]),
226
+ "model": model_name,
227
+ "series_idx": series_idx, # Keep for reference
228
+ "variate_idx": variate_idx, # Keep for reference
229
+ }
230
+
231
+ # Add label and quantiles if requested
232
+ if parse_series:
233
+ # Get ground truth from predictions (we need to load it separately)
234
+ # For now, we'll need to compute it or load from a separate source
235
+ # TIME doesn't save ground_truth in predictions.npz, so we'll skip for now
236
+ # TODO: Need to handle ground truth loading
237
+ pass
238
+
239
+ rows.append(row)
240
+
241
+ if not rows:
242
+ return None
243
+
244
+ df = pd.DataFrame(rows)
245
+
246
+ # If parse_series, we would need ground truth data
247
+ # For now, return without series data
248
+ return df
249
+
250
+ def get_overall_leaderboard(df_datasets: pd.DataFrame, metric: str = "MASE") -> pd.DataFrame:
251
+ """
252
+ Compute overall leaderboard across datasets by normalizing metrics by Seasonal Naive
253
+ and aggregating with geometric mean.
254
+
255
+ Args:
256
+ df_datasets (pd.DataFrame): Dataset-level results, must include
257
+ ["model", "dataset_id", "horizon", "MASE", "CRPS", "MASE_rank", "CRPS_rank"].
258
+ metric (str): Metric to use for sorting. Defaults to "MASE".
259
+
260
+ Returns:
261
+ pd.DataFrame: Leaderboard with:
262
+ - MASE, CRPS: Geometric mean of Seasonal Naive-normalized values
263
+ - MASE_rank, CRPS_rank: Average rank across configurations (from original data)
264
+ Sorted by the chosen metric.
265
+ """
266
+ if df_datasets.empty:
267
+ return pd.DataFrame(columns=OVERALL_TABLE_COLUMNS)
268
+
269
+ if metric not in df_datasets.columns:
270
+ return pd.DataFrame(columns=OVERALL_TABLE_COLUMNS)
271
+
272
+ # Step 1: Normalize MASE and CRPS by Seasonal Naive per (dataset_id, horizon)
273
+ df_normalized = normalize_by_seasonal_naive(
274
+ df_datasets,
275
+ baseline_model="seasonal_naive",
276
+ metrics=["MASE", "CRPS"],
277
+ groupby_cols=["dataset_id", "horizon"],
278
+ )
279
+
280
+ if df_normalized.empty:
281
+ # Fall back to original behavior if normalization fails
282
+ print("[get_overall_leaderboard] Warning: normalization failed, using arithmetic mean")
283
+ leaderboard = (
284
+ df_datasets.groupby(["model"])
285
+ .mean(numeric_only=True)
286
+ .reset_index()
287
+ )
288
+ # Rename columns: MASE -> MASE (norm.), CRPS -> CRPS (norm.)
289
+ if "MASE" in leaderboard.columns:
290
+ leaderboard = leaderboard.rename(columns={"MASE": "MASE (norm.)"})
291
+ if "CRPS" in leaderboard.columns:
292
+ leaderboard = leaderboard.rename(columns={"CRPS": "CRPS (norm.)"})
293
+
294
+ # Adjust metric name for sorting
295
+ sort_metric = metric
296
+ if metric == "MASE":
297
+ sort_metric = "MASE (norm.)"
298
+ elif metric == "CRPS":
299
+ sort_metric = "CRPS (norm.)"
300
+
301
+ if sort_metric in leaderboard.columns:
302
+ leaderboard = leaderboard.sort_values(by=sort_metric, ascending=True).reset_index(drop=True)
303
+ else:
304
+ leaderboard = leaderboard.sort_values(by=metric, ascending=True).reset_index(drop=True)
305
+
306
+ # Define column order
307
+ col_order = ["model", "MASE (norm.)", "CRPS (norm.)", "MASE_rank", "CRPS_rank"]
308
+ col_order = [col for col in col_order if col in leaderboard.columns]
309
+ leaderboard = leaderboard[col_order]
310
+ leaderboard = leaderboard.round(3)
311
+ return leaderboard
312
+
313
+ # Step 2: Aggregate normalized MASE and CRPS with geometric mean
314
+ # Filter out NaN values for geometric mean computation
315
+ def gmean_with_nan(x):
316
+ """Compute geometric mean, ignoring NaN values."""
317
+ valid = x.dropna()
318
+ if len(valid) == 0:
319
+ return np.nan
320
+ return stats.gmean(valid)
321
+
322
+ normalized_metrics = (
323
+ df_normalized.groupby("model")[["MASE", "CRPS"]]
324
+ .agg(gmean_with_nan)
325
+ .reset_index()
326
+ )
327
+
328
+ # Rename columns: MASE -> MASE (norm.), CRPS -> CRPS (norm.)
329
+ normalized_metrics = normalized_metrics.rename(columns={
330
+ "MASE": "MASE (norm.)",
331
+ "CRPS": "CRPS (norm.)"
332
+ })
333
+
334
+ # Step 3: Compute average ranks from original data (pre-normalized)
335
+ # Ranks should be computed on original metrics, which is already done in about.py
336
+ if "MASE_rank" in df_datasets.columns and "CRPS_rank" in df_datasets.columns:
337
+ # Use the same configurations that were used in normalization
338
+ # (only those with Seasonal Naive baseline)
339
+ df_with_baseline = df_datasets[
340
+ df_datasets.set_index(["dataset_id", "horizon"]).index.isin(
341
+ df_normalized.set_index(["dataset_id", "horizon"]).index.unique()
342
+ )
343
+ ]
344
+ avg_ranks = (
345
+ df_with_baseline.groupby("model")[["MASE_rank", "CRPS_rank"]]
346
+ .mean()
347
+ .reset_index()
348
+ )
349
+ # Merge normalized metrics with average ranks
350
+ leaderboard = normalized_metrics.merge(avg_ranks, on="model", how="left")
351
+ else:
352
+ leaderboard = normalized_metrics
353
+
354
+ # Step 4: Sort by chosen metric (adjust metric name if needed)
355
+ sort_metric = metric
356
+ if metric == "MASE":
357
+ sort_metric = "MASE (norm.)"
358
+ elif metric == "CRPS":
359
+ sort_metric = "CRPS (norm.)"
360
+
361
+ if sort_metric in leaderboard.columns:
362
+ leaderboard = leaderboard.sort_values(by=sort_metric, ascending=True).reset_index(drop=True)
363
+ else:
364
+ # Fallback: sort by first available metric
365
+ leaderboard = leaderboard.sort_values(by=leaderboard.columns[1], ascending=True).reset_index(drop=True)
366
+
367
+ # Step 5: Select and order columns
368
+ col_order = ["model", "MASE (norm.)", "CRPS (norm.)", "MASE_rank", "CRPS_rank"]
369
+ col_order = [col for col in col_order if col in leaderboard.columns]
370
+ leaderboard = leaderboard[col_order]
371
+ leaderboard = leaderboard.round(3)
372
+
373
+ return leaderboard
374
+
375
+
376
+ def get_dataset_leaderboard(
377
+ display_name: str,
378
+ horizons: List[str],
379
+ metric: str = "MASE"
380
+ ) -> Tuple[str, pd.DataFrame]:
381
+ """
382
+ Return leaderboard for a specific dataset, averaged over the specified horizons.
383
+
384
+ Returns both original metrics and Seasonal Naive-normalized metrics in a single table.
385
+
386
+ Args:
387
+ display_name (str): The dataset display name selected by the user (from UI dropdown).
388
+ Will be converted to dataset_id for filtering.
389
+ horizons (List[str]): List of horizons to include (e.g., ["short", "medium"]).
390
+ If None, all horizons are used.
391
+ metric (str): The metric used for sorting. Defaults to "MASE".
392
+
393
+ Returns:
394
+ tuple:
395
+ str: A message string to display in the UI ("" if no error).
396
+ pd.DataFrame: Dataframe containing leaderboard with columns:
397
+ - model
398
+ - MASE, CRPS, MAE, MSE (original, arithmetic mean)
399
+ - MASE_norm, CRPS_norm, MAE_norm, MSE_norm (normalized, geometric mean)
400
+ - MASE_rank, CRPS_rank (average of per-task ranks)
401
+ """
402
+ if DATASETS_DF.empty:
403
+ return "No dataset results are available. Please check your results folder.", pd.DataFrame(columns=["model"])
404
+
405
+ # Convert display_name to dataset_id for filtering
406
+ dataset_id = resolve_dataset_id(display_name)
407
+
408
+ # Filter by dataset_id
409
+ df_filtered = DATASETS_DF[DATASETS_DF["dataset_id"] == dataset_id].copy()
410
+ if df_filtered.empty:
411
+ return f"No results found for dataset '{display_name}'.", pd.DataFrame(columns=["model"])
412
+
413
+ # Filter by horizon
414
+ if horizons is None or len(horizons) == 0:
415
+ horizons = df_filtered["horizon"].unique().tolist()
416
+ df_filtered = df_filtered[df_filtered["horizon"].isin(horizons)]
417
+ if df_filtered.empty:
418
+ return f"No results found for dataset '{display_name}' with horizons {horizons}.", pd.DataFrame(columns=["model"])
419
+
420
+ # Get dataset information (series count, variate count, freq)
421
+ dataset_info_msg = ""
422
+ try:
423
+ # Parse dataset_id to get freq
424
+ if "/" in dataset_id:
425
+ _, freq = dataset_id.split("/", 1)
426
+ else:
427
+ freq = "unknown"
428
+
429
+ # Load Dataset to get series and variate counts
430
+ hf_dataset_root = str(get_datasets_root())
431
+ config_root = get_config_root()
432
+ config_path = config_root / "datasets.yaml"
433
+
434
+ if os.path.exists(hf_dataset_root) and config_path.exists():
435
+ config = load_dataset_config(config_path)
436
+ settings = get_dataset_settings(dataset_id, horizons[0] if horizons else "short", config)
437
+
438
+ dataset_obj = Dataset(
439
+ name=dataset_id,
440
+ term=horizons[0] if horizons else "short",
441
+ prediction_length=settings.get("prediction_length"),
442
+ test_length=settings.get("test_length"),
443
+ storage_path=hf_dataset_root,
444
+ )
445
+
446
+ # Get series count
447
+ if "item_id" in dataset_obj.hf_dataset.column_names:
448
+ series_names = dataset_obj.hf_dataset["item_id"]
449
+ # Convert to list if it's an array/Series to avoid ambiguity in boolean check
450
+ if isinstance(series_names, (np.ndarray, pd.Series)):
451
+ series_names = list(series_names)
452
+ elif not isinstance(series_names, list):
453
+ series_names = list(series_names) if hasattr(series_names, '__iter__') else [series_names]
454
+ # Use len() check instead of boolean check to avoid ambiguity
455
+ if len(series_names) > 0:
456
+ num_series = len(set(series_names))
457
+ else:
458
+ num_series = len(dataset_obj.hf_dataset)
459
+ else:
460
+ num_series = len(dataset_obj.hf_dataset)
461
+
462
+ # Get variate count
463
+ variate_names = dataset_obj.get_variate_names()
464
+ if variate_names is not None:
465
+ num_variates = len(variate_names)
466
+ else:
467
+ # UTS: each series is one variate
468
+ num_variates = 1
469
+
470
+ dataset_info_msg = f"📊 Dataset Info: {num_series} series, {num_variates} variates, freq={freq}"
471
+ except Exception as e:
472
+ print(f"Error getting dataset info: {e}")
473
+ # If we can't get info, try to extract freq from dataset_id
474
+ if "/" in dataset_id:
475
+ _, freq = dataset_id.split("/", 1)
476
+ dataset_info_msg = f"📊 Dataset Info: freq={freq}"
477
+
478
+ metrics_list = ["MASE", "CRPS", "MAE", "MSE"]
479
+
480
+ # === Step 1: Compute original metrics (arithmetic mean) ===
481
+ original_df = (
482
+ df_filtered.groupby("model")[metrics_list]
483
+ .mean()
484
+ .reset_index()
485
+ )
486
+
487
+ # === Step 2: Compute normalized metrics (geometric mean of Seasonal Naive-normalized) ===
488
+ df_normalized = normalize_by_seasonal_naive(
489
+ df_filtered,
490
+ baseline_model="seasonal_naive",
491
+ metrics=metrics_list,
492
+ groupby_cols=["dataset_id", "horizon"],
493
+ )
494
+
495
+ # Helper function for geometric mean with NaN handling
496
+ def gmean_with_nan(x):
497
+ valid = x.dropna()
498
+ if len(valid) == 0:
499
+ return np.nan
500
+ return stats.gmean(valid)
501
+
502
+ if not df_normalized.empty:
503
+ normalized_df = (
504
+ df_normalized.groupby("model")[metrics_list]
505
+ .agg(gmean_with_nan)
506
+ .reset_index()
507
+ )
508
+ # Rename columns to * (norm.)
509
+ normalized_df = normalized_df.rename(columns={
510
+ "MASE": "MASE (norm.)",
511
+ "CRPS": "CRPS (norm.)",
512
+ "MAE": "MAE (norm.)",
513
+ "MSE": "MSE (norm.)",
514
+ })
515
+ else:
516
+ # If normalization fails, create empty normalized columns
517
+ normalized_df = original_df[["model"]].copy()
518
+ for col in ["MASE (norm.)", "CRPS (norm.)", "MAE (norm.)", "MSE (norm.)"]:
519
+ normalized_df[col] = np.nan
520
+
521
+ # Rename original columns to * (raw)
522
+ original_df = original_df.rename(columns={
523
+ "MASE": "MASE (raw)",
524
+ "CRPS": "CRPS (raw)",
525
+ "MAE": "MAE (raw)",
526
+ "MSE": "MSE (raw)",
527
+ })
528
+
529
+ # === Step 3: Compute average ranks from pre-computed per-task ranks ===
530
+ if "MASE_rank" in df_filtered.columns and "CRPS_rank" in df_filtered.columns:
531
+ # Use only configurations that have Seasonal Naive baseline (for consistency)
532
+ if not df_normalized.empty:
533
+ df_with_baseline = df_filtered[
534
+ df_filtered.set_index(["dataset_id", "horizon"]).index.isin(
535
+ df_normalized.set_index(["dataset_id", "horizon"]).index.unique()
536
+ )
537
+ ]
538
+ else:
539
+ df_with_baseline = df_filtered
540
+
541
+ ranks_df = (
542
+ df_with_baseline.groupby("model")[["MASE_rank", "CRPS_rank"]]
543
+ .mean()
544
+ .reset_index()
545
+ )
546
+ else:
547
+ ranks_df = original_df[["model"]].copy()
548
+ ranks_df["MASE_rank"] = np.nan
549
+ ranks_df["CRPS_rank"] = np.nan
550
+
551
+ # === Step 4: Combine all into one DataFrame ===
552
+ agg_df = original_df.merge(normalized_df, on="model", how="left")
553
+ agg_df = agg_df.merge(ranks_df, on="model", how="left")
554
+
555
+ # Sort by MASE (norm.) as requested
556
+ if "MASE (norm.)" in agg_df.columns:
557
+ agg_df = agg_df.sort_values(by="MASE (norm.)", ascending=True).reset_index(drop=True)
558
+ elif "MASE (raw)" in agg_df.columns:
559
+ # Fallback to MASE (raw) if normalized version not available
560
+ agg_df = agg_df.sort_values(by="MASE (raw)", ascending=True).reset_index(drop=True)
561
+ else:
562
+ # Final fallback: sort by first available metric column
563
+ if len(agg_df.columns) > 1:
564
+ agg_df = agg_df.sort_values(by=agg_df.columns[1], ascending=True).reset_index(drop=True)
565
+
566
+ # Define column order: model, * (norm.), * (raw), *_rank
567
+ cols_order = ["model",
568
+ "MASE (norm.)", "CRPS (norm.)", "MAE (norm.)", "MSE (norm.)",
569
+ "MASE (raw)", "CRPS (raw)", "MAE (raw)", "MSE (raw)",
570
+ "MASE_rank", "CRPS_rank"]
571
+ cols_to_return = [col for col in cols_order if col in agg_df.columns]
572
+
573
+ agg_df = agg_df[cols_to_return].round(3)
574
+
575
+ return dataset_info_msg, agg_df
576
+
577
+
578
+ def get_dataset_multilevel_leaderboard(display_name, series, variate, horizons, metric: str = "MASE"):
579
+ """
580
+ Get leaderboard based on dataset, series, and variate selections.
581
+
582
+ Logic:
583
+ 0. If only dataset selected (series="---", variate="---"): return dataset-level results
584
+ with both original and normalized metrics
585
+ 1. If series/variate selected: return only original metrics (MASE, CRPS, MAE, MSE)
586
+
587
+ Args:
588
+ display_name: Dataset display name from UI (will be converted to dataset_id)
589
+ series: Series name or "---" if not selected
590
+ variate: Variate name or "---" if not selected
591
+ horizons: List of horizons to include
592
+ metric: Metric for sorting
593
+
594
+ Returns:
595
+ tuple: (message, DataFrame)
596
+ """
597
+ # Case 0: Only dataset selected - return both original and normalized metrics
598
+ if (series is None or series == "---" or series == "") and (variate is None or variate == "---" or variate == ""):
599
+ return get_dataset_leaderboard(display_name, horizons, metric)
600
+
601
+ # Case 1: Series/Variate selected - return only original metrics
602
+ # Determine if dataset is UTS or MTS by checking if variate dropdown is enabled
603
+ results_root = str(RESULTS_ROOT)
604
+ model_name = ALL_MODELS[0]
605
+
606
+ dataset_term = find_dataset_term_path(results_root, model_name, display_name)
607
+ if dataset_term is None:
608
+ return f"Dataset '{display_name}' not found.", pd.DataFrame(columns=["model", "MASE", "CRPS", "MAE", "MSE"])
609
+
610
+ # Check if dataset is UTS or MTS
611
+ is_uts = False
612
+ try:
613
+ hf_dataset_root = str(get_datasets_root())
614
+
615
+ if os.path.exists(hf_dataset_root):
616
+ config_root = get_config_root()
617
+ config_path = config_root / "datasets.yaml"
618
+
619
+ config = load_dataset_config(config_path) if config_path.exists() else {}
620
+ settings = get_dataset_settings(dataset_term, horizons[0] if horizons else "short", config)
621
+
622
+ dataset_obj = Dataset(
623
+ name=dataset_term,
624
+ term=horizons[0] if horizons else "short",
625
+ prediction_length=settings.get("prediction_length"),
626
+ test_length=settings.get("test_length"),
627
+ storage_path=hf_dataset_root,
628
+ )
629
+
630
+ variate_names = dataset_obj.get_variate_names()
631
+ is_uts = (variate_names is None)
632
+ except Exception as e:
633
+ print(f"Error checking UTS/MTS: {e}")
634
+
635
+ # Collect data from all models and horizons
636
+ df_all = []
637
+ for model in ALL_MODELS:
638
+ for horizon in horizons:
639
+ series_filter = None if (series == "---" or series == "") else series
640
+ variate_filter = None if (variate == "---" or variate == "") else variate
641
+
642
+ if is_uts:
643
+ variate_filter = None
644
+
645
+ df_model = load_test_windows(
646
+ display_name, horizon, model,
647
+ series=series_filter,
648
+ variate=variate_filter,
649
+ window_id=None
650
+ )
651
+
652
+ if df_model is not None and not df_model.empty:
653
+ df_all.append(df_model)
654
+
655
+ if not df_all:
656
+ return f"⚠️ No results found for the selected filters.", pd.DataFrame(columns=["model", "MASE", "CRPS", "MAE", "MSE"])
657
+
658
+ # Combine all data
659
+ df_combined = pd.concat(df_all, ignore_index=True)
660
+
661
+ metrics_list = ["MASE", "CRPS", "MAE", "MSE"]
662
+
663
+ # Simple arithmetic mean across all windows (no normalization for series/variate level)
664
+ leaderboard = (
665
+ df_combined.groupby("model")[metrics_list]
666
+ .mean()
667
+ .reset_index()
668
+ )
669
+
670
+ leaderboard = leaderboard.round(3)
671
+
672
+ if metric not in leaderboard.columns:
673
+ return f"Metric '{metric}' not found.", pd.DataFrame(columns=["model"])
674
+
675
+ # Sort by metric
676
+ leaderboard = leaderboard.sort_values(by=metric, ascending=True).reset_index(drop=True)
677
+
678
+ return "", leaderboard
679
+
680
+
681
+
682
+ def get_window_leaderboard(display_name, series, variate, window_id, horizon, metric: str = "MASE"):
683
+ """
684
+ Get leaderboard for a specific test window.
685
+
686
+ Args:
687
+ display_name: Dataset display name from UI (will be converted to dataset_id)
688
+ series: Series name or index
689
+ variate: Variate name or index
690
+ window_id: Window index
691
+ horizon: Horizon name
692
+ metric: Metric for sorting
693
+ """
694
+ df_all = []
695
+ for model in ALL_MODELS:
696
+ df_model = load_test_windows(display_name, horizon, model, series=series, variate=variate, window_id=window_id)
697
+ if df_model is not None and not df_model.empty:
698
+ df_all.append(df_model)
699
+
700
+ if not df_all:
701
+ # Return empty DataFrame with expected columns if no data found
702
+ return pd.DataFrame(columns=["model", "MASE", "CRPS", "MAE", "MSE"])
703
+
704
+ df_all = pd.concat(df_all, ignore_index=True)
705
+
706
+ # metrics DataFrame
707
+ metrics_cols = ["model", "MASE", "CRPS", "MAE", "MSE"]
708
+ leaderboard = df_all[metrics_cols].reset_index(drop=True)
709
+
710
+ if metric not in leaderboard.columns:
711
+ return pd.DataFrame(columns=["model"])
712
+
713
+ # Round numeric columns to 3 decimal places
714
+ numeric_cols = leaderboard.select_dtypes(include=[np.number]).columns
715
+ for col in numeric_cols:
716
+ leaderboard[col] = leaderboard[col].round(3)
717
+
718
+ return leaderboard
719
+
720
+
721
+ def get_pattern_leaderboard(
722
+ pattern_filters: dict[str, int],
723
+ selected_horizons: list[str],
724
+ ) -> tuple[str, pd.DataFrame]:
725
+ """
726
+ Filter variates by selected patterns and compute average metrics per model.
727
+
728
+ Uses FEATURES_BOOL_DF (binarized features) to filter variates,
729
+ then joins with VARIATES_DF to get metrics.
730
+
731
+ Args:
732
+ pattern_filters: Dict mapping pattern names to required values.
733
+ - {feature_name: required_value} where required_value is 0 or 1.
734
+ - Features with "Any" selection are not included in the dict.
735
+ - Example: {"T_strength": 1, "S_strength": 0} means:
736
+ - T_strength must be 1 (has the feature)
737
+ - S_strength must be 0 (does not have the feature)
738
+ selected_horizons: List of horizons to include (e.g., ["short", "medium"])
739
+
740
+ Returns:
741
+ tuple: (message, leaderboard_df)
742
+ - message: Status message with matching count
743
+ - leaderboard_df: DataFrame with model metrics, sorted by MASE
744
+ """
745
+ from src.about import VARIATES_DF, FEATURES_DF, FEATURES_BOOL_DF, PATTERN_MAP
746
+
747
+ # Check if data is available
748
+ if VARIATES_DF.empty:
749
+ return "⚠️ No variate-level results available.", pd.DataFrame(columns=["model", "MASE", "CRPS"])
750
+
751
+ if FEATURES_DF.empty or FEATURES_BOOL_DF.empty:
752
+ return "⚠️ No features data available.", pd.DataFrame(columns=["model", "MASE", "CRPS"])
753
+
754
+ if not selected_horizons:
755
+ return "ℹ️ Please select at least one horizon.", pd.DataFrame(columns=["model", "MASE", "CRPS"])
756
+
757
+ # === Step 1. Apply pattern filters ===
758
+ # Start with all variates
759
+ mask = pd.Series(True, index=FEATURES_BOOL_DF.index)
760
+
761
+ # If no pattern filters (all "Any"), use all variates
762
+ if pattern_filters:
763
+ for pattern, required_value in pattern_filters.items():
764
+ # Map UI pattern name to feature column name
765
+ feature_col = PATTERN_MAP.get(pattern, pattern)
766
+
767
+ if feature_col not in FEATURES_BOOL_DF.columns:
768
+ return f"⚠️ Pattern '{pattern}' (column '{feature_col}') not found in features.", pd.DataFrame(columns=["model", "MASE", "CRPS"])
769
+
770
+ # Special handling for "stationarity" pattern
771
+ # stationarity = NOT is_random_walk
772
+ # When user selects "Has stationarity" (required_value=1), we want is_random_walk == 0
773
+ # When user selects "Not stationarity" (required_value=0), we want is_random_walk == 1
774
+ if pattern == "stationarity":
775
+ mask &= (FEATURES_BOOL_DF[feature_col] == (1 - required_value))
776
+ else:
777
+ mask &= (FEATURES_BOOL_DF[feature_col] == required_value)
778
+
779
+ # Get matching variates
780
+ matched_features = FEATURES_DF[mask].copy()
781
+
782
+ if matched_features.empty:
783
+ # Build debug info for empty results
784
+ debug_info = []
785
+ for pattern, required_value in pattern_filters.items():
786
+ feature_col = PATTERN_MAP.get(pattern, pattern)
787
+ if feature_col in FEATURES_BOOL_DF.columns:
788
+ value_counts = FEATURES_BOOL_DF[feature_col].value_counts().to_dict()
789
+ debug_info.append(f"{pattern} ({feature_col}): {value_counts}")
790
+ debug_msg = "; ".join(debug_info) if debug_info else "No debug info available"
791
+ return f"⚠️ No variates match the selected patterns.\n📊 Feature distribution: {debug_msg}", pd.DataFrame(columns=["model", "MASE", "CRPS"])
792
+
793
+ # === Step 2. Join with VARIATES_DF to get metrics ===
794
+ # Join strategy:
795
+ # - For multivariate (is_uts=False): use full join on (dataset_id, series_name, variate_name)
796
+ # - For univariate (is_uts=True): determine which FEATURES_DF field matches VARIATES_DF series_name
797
+
798
+ # Check if join columns exist
799
+ base_join_cols = ["dataset_id", "series_name", "variate_name"]
800
+ for col in base_join_cols:
801
+ if col not in matched_features.columns:
802
+ return f"⚠️ Column '{col}' not found in features.", pd.DataFrame(columns=["model", "MASE", "CRPS"])
803
+ if col not in VARIATES_DF.columns:
804
+ return f"⚠️ Column '{col}' not found in variates results.", pd.DataFrame(columns=["model", "MASE", "CRPS"])
805
+
806
+ # Select only join columns from features (to avoid column conflicts)
807
+ features_keys = matched_features[base_join_cols].drop_duplicates()
808
+
809
+ # Group by dataset_id and is_uts, then perform appropriate join
810
+ merged_list = []
811
+
812
+ for dataset_id in features_keys["dataset_id"].unique():
813
+ # Get dataset-specific data
814
+ dataset_features = features_keys[features_keys["dataset_id"] == dataset_id]
815
+ dataset_variates = VARIATES_DF[VARIATES_DF["dataset_id"] == dataset_id]
816
+
817
+ if dataset_variates.empty or dataset_features.empty:
818
+ continue
819
+
820
+ # Check is_uts for this dataset (should be consistent across all rows)
821
+ is_uts_values = dataset_variates["is_uts"].unique()
822
+ if len(is_uts_values) > 1:
823
+ print(f"⚠️ Warning: {dataset_id} has inconsistent is_uts values: {is_uts_values}")
824
+ is_uts = is_uts_values[0] if len(is_uts_values) > 0 else False
825
+
826
+ # Initialize dataset_merged to ensure it's always defined
827
+ dataset_merged = pd.DataFrame()
828
+
829
+ if not is_uts:
830
+ # Multivariate: use full join on (dataset_id, series_name, variate_name)
831
+ join_cols = ["dataset_id", "series_name", "variate_name"]
832
+ dataset_features_keys = dataset_features[join_cols].drop_duplicates()
833
+ dataset_merged = dataset_variates.merge(
834
+ dataset_features_keys,
835
+ on=join_cols,
836
+ how="inner"
837
+ )
838
+ else:
839
+ # Univariate: determine which FEATURES_DF field matches VARIATES_DF series_name
840
+ # Get unique series_name values from VARIATES_DF for this dataset
841
+ variates_series_names = set(dataset_variates["series_name"].unique())
842
+
843
+ # Check which FEATURES_DF field matches VARIATES_DF series_name
844
+ features_series_names = set(dataset_features["series_name"].unique())
845
+ features_variate_names = set(dataset_features["variate_name"].unique())
846
+
847
+ features_series_match = len(variates_series_names & features_series_names)
848
+ features_variate_match = len(variates_series_names & features_variate_names)
849
+
850
+ if features_series_match > features_variate_match:
851
+ # FEATURES_DF series_name matches VARIATES_DF series_name
852
+ join_cols = ["dataset_id", "series_name"]
853
+ dataset_features_keys = dataset_features[join_cols].drop_duplicates()
854
+ dataset_merged = dataset_variates.merge(
855
+ dataset_features_keys,
856
+ on=join_cols,
857
+ how="inner"
858
+ )
859
+ elif features_variate_match > features_series_match:
860
+ # FEATURES_DF variate_name matches VARIATES_DF series_name
861
+ # Create mapping: use FEATURES_DF variate_name to match VARIATES_DF series_name
862
+ dataset_features_keys = dataset_features[["dataset_id", "variate_name"]].drop_duplicates()
863
+ # Rename variate_name to series_name for join
864
+ dataset_features_keys = dataset_features_keys.rename(columns={"variate_name": "series_name"})
865
+ dataset_merged = dataset_variates.merge(
866
+ dataset_features_keys,
867
+ on=["dataset_id", "series_name"],
868
+ how="inner"
869
+ )
870
+ else:
871
+ # Both match equally or neither matches - try series_name first
872
+ if features_series_match > 0:
873
+ join_cols = ["dataset_id", "series_name"]
874
+ dataset_features_keys = dataset_features[join_cols].drop_duplicates()
875
+ dataset_merged = dataset_variates.merge(
876
+ dataset_features_keys,
877
+ on=join_cols,
878
+ how="inner"
879
+ )
880
+ else:
881
+ # No match found, skip this dataset
882
+ print(f"⚠️ Warning: {dataset_id} (UTS) - no matching field found between FEATURES_DF and VARIATES_DF series_name")
883
+ print(f" VARIATES_DF series_names: {sorted(list(variates_series_names))[:5]}")
884
+ print(f" FEATURES_DF series_names: {sorted(list(features_series_names))[:5]}")
885
+ print(f" FEATURES_DF variate_names: {sorted(list(features_variate_names))[:5]}")
886
+ continue
887
+
888
+ if not dataset_merged.empty:
889
+ merged_list.append(dataset_merged)
890
+
891
+ # Combine all merged results
892
+ if merged_list:
893
+ merged = pd.concat(merged_list, ignore_index=True)
894
+ else:
895
+ merged = pd.DataFrame(columns=VARIATES_DF.columns)
896
+
897
+ # === Step 3. Apply horizon filter ===
898
+ merged = merged[merged["horizon"].isin(selected_horizons)]
899
+
900
+ if merged.empty:
901
+ return f"⚠️ No results for selected horizons: {selected_horizons}", pd.DataFrame(columns=["model", "MASE", "CRPS"])
902
+
903
+ # === Step 4. Aggregate by model ===
904
+ metric_cols = ["MASE", "CRPS"]
905
+ available_metrics = [col for col in metric_cols if col in merged.columns]
906
+
907
+ # 4a. Original metrics: arithmetic mean across all matching variates and horizons
908
+ original_df = (
909
+ merged.groupby("model")[available_metrics]
910
+ .mean()
911
+ .reset_index()
912
+ )
913
+
914
+ # 4b. Normalized metrics: normalize by Seasonal Naive at (dataset_id, series_name, variate_name, horizon) level
915
+ # then aggregate with geometric mean
916
+ df_normalized = normalize_by_seasonal_naive(
917
+ merged,
918
+ baseline_model="seasonal_naive",
919
+ metrics=available_metrics,
920
+ groupby_cols=["dataset_id", "series_name", "variate_name", "horizon"],
921
+ )
922
+
923
+ # Helper function for geometric mean with NaN handling
924
+ def gmean_with_nan(x):
925
+ valid = x.dropna()
926
+ if len(valid) == 0:
927
+ return np.nan
928
+ return stats.gmean(valid)
929
+
930
+ if not df_normalized.empty:
931
+ normalized_df = (
932
+ df_normalized.groupby("model")[available_metrics]
933
+ .agg(gmean_with_nan)
934
+ .reset_index()
935
+ )
936
+ # Rename columns to *_norm
937
+ rename_map = {col: f"{col}_norm" for col in available_metrics}
938
+ normalized_df = normalized_df.rename(columns=rename_map)
939
+ else:
940
+ # If normalization fails, create empty normalized columns
941
+ normalized_df = original_df[["model"]].copy()
942
+ for col in available_metrics:
943
+ normalized_df[f"{col}_norm"] = np.nan
944
+
945
+ # Combine original and normalized metrics
946
+ leaderboard = original_df.merge(normalized_df, on="model", how="left")
947
+
948
+ # Rename columns for better clarity
949
+ rename_map = {}
950
+ if "MASE" in leaderboard.columns:
951
+ rename_map["MASE"] = "MASE (raw)"
952
+ if "CRPS" in leaderboard.columns:
953
+ rename_map["CRPS"] = "CRPS (raw)"
954
+ if "MASE_norm" in leaderboard.columns:
955
+ rename_map["MASE_norm"] = "MASE (norm.)"
956
+ if "CRPS_norm" in leaderboard.columns:
957
+ rename_map["CRPS_norm"] = "CRPS (norm.)"
958
+
959
+ if rename_map:
960
+ leaderboard = leaderboard.rename(columns=rename_map)
961
+
962
+ # Sort by MASE (norm.) if available, otherwise by MASE (raw)
963
+ if "MASE (norm.)" in leaderboard.columns:
964
+ leaderboard = leaderboard.sort_values(by="MASE (norm.)", ascending=True).reset_index(drop=True)
965
+ elif "MASE (raw)" in leaderboard.columns:
966
+ leaderboard = leaderboard.sort_values(by="MASE (raw)", ascending=True).reset_index(drop=True)
967
+
968
+ # Round numeric columns to 3 decimal places
969
+ numeric_cols = leaderboard.select_dtypes(include=[np.number]).columns
970
+ for col in numeric_cols:
971
+ # Round to 3 decimal places and ensure proper formatting
972
+ leaderboard[col] = leaderboard[col].round(3)
973
+ # Convert to float64 to ensure consistent display
974
+ leaderboard[col] = leaderboard[col].astype('float64')
975
+
976
+ # Reorder columns: model, MASE (norm.), CRPS (norm.), MASE (raw), CRPS (raw)
977
+ col_order = ["model", "MASE (norm.)", "CRPS (norm.)", "MASE (raw)", "CRPS (raw)"]
978
+ col_order = [col for col in col_order if col in leaderboard.columns]
979
+ leaderboard = leaderboard[col_order]
980
+
981
+ # === Step 5. Build message ===
982
+ num_variates = len(features_keys)
983
+ num_results = len(merged)
984
+ num_models = leaderboard["model"].nunique()
985
+
986
+ # Count by dataset
987
+ dataset_counts = features_keys["dataset_id"].value_counts().to_dict()
988
+ dataset_msg = ", ".join([f"{ds}: {cnt}" for ds, cnt in list(dataset_counts.items())[:5]])
989
+ if len(dataset_counts) > 5:
990
+ dataset_msg += f", ... ({len(dataset_counts)} datasets total)"
991
+
992
+ # Build pattern description
993
+ if pattern_filters:
994
+ pattern_desc = ", ".join([
995
+ f"{p}={v}" for p, v in pattern_filters.items()
996
+ ])
997
+ filter_msg = f"🔍 Filters: {pattern_desc}"
998
+ else:
999
+ filter_msg = "🔍 Filters: All N/A (no filtering, all variates included)"
1000
+
1001
+ msg = (
1002
+ f"✨ {num_variates} variates matched across {len(dataset_counts)} datasets.\n"
1003
+ f"📊 {num_results} results from {num_models} models.\n"
1004
+ f"{filter_msg}\n"
1005
+ f"📁 Datasets: {dataset_msg}"
1006
+ )
1007
+
1008
+ return msg, leaderboard
1009
+
1010
+
1011
+ # def get_archive_results(dataset_name: str, selected_patterns: list[str], variate_name: str):
1012
+ # """
1013
+ # Return variates filtered by dataset, patterns, and variate name.
1014
+ # """
1015
+ # if FEATURES_BOOL_DF.empty:
1016
+ # return pd.DataFrame(columns=["dataset", "variate_name"])
1017
+
1018
+ # df = FEATURES_DF.copy()
1019
+ # df_bool = FEATURES_BOOL_DF.copy()
1020
+
1021
+ # # Dataset filter
1022
+ # if dataset_name and dataset_name != "All":
1023
+ # df = df[df["dataset"] == dataset_name]
1024
+
1025
+ # # Pattern filter
1026
+ # if selected_patterns:
1027
+ # mask = pd.Series(True, index=df_bool.index)
1028
+ # for pattern in selected_patterns:
1029
+ # if pattern in df_bool.columns:
1030
+ # mask &= df_bool[pattern] == 1
1031
+ # df = df[mask]
1032
+
1033
+ # # Variate filter
1034
+ # if variate_name and variate_name != "All":
1035
+ # df = df[df["variate_name"] == variate_name]
1036
+
1037
+ # if df.empty:
1038
+ # return pd.DataFrame(columns=FEATURES_BOOL_DF.columns) # ToDO: columns换成完整的
1039
+
1040
+ # # --- Add freq & domain from YAML ---
1041
+ # freq_map, domain_map = {}, {}
1042
+ # for ds in df["dataset"].unique():
1043
+ # yaml_path = os.path.join("conf", "data", f"{ds}.yaml")
1044
+ # if os.path.exists(yaml_path):
1045
+ # with open(yaml_path, "r") as f:
1046
+ # meta = yaml.safe_load(f)
1047
+ # freq_map[ds] = meta.get("freq", None)
1048
+ # domain_map[ds] = meta.get("domain", None)
1049
+ # else:
1050
+ # freq_map[ds] = None
1051
+ # domain_map[ds] = None
1052
+
1053
+ # df["freq"] = df["dataset"].map(freq_map)
1054
+ # df["domain"] = df["dataset"].map(domain_map)
1055
+
1056
+ # # Select useful columns
1057
+ # base_cols = ["dataset", "variate_name", "freq", "domain"]
1058
+ # tsfeature_cols = [c for c in df.columns if c not in base_cols+['unitroot_pp', 'unitroot_kpss']]
1059
+
1060
+ # # === Rename feature columns ===
1061
+ # renamed_cols = {}
1062
+ # for col in tsfeature_cols:
1063
+ # if col == 'trend':
1064
+ # renamed_cols['trend'] = "T_strength"
1065
+ # elif col.startswith("trend"):
1066
+ # renamed_cols[col] = col.replace("trend", "T", 1)
1067
+ # elif col.startswith("seasonality"):
1068
+ # renamed_cols[col] = col.replace("seasonality", "S", 1)
1069
+ # elif col.startswith("seasonal"):
1070
+ # renamed_cols[col] = col.replace("seasonal", "S", 1)
1071
+ # df = df.rename(columns=renamed_cols)
1072
+
1073
+ # # === Apply new column names ===
1074
+ # tsfeature_cols = [renamed_cols.get(c, c) for c in tsfeature_cols]
1075
+
1076
+ # global_cols = ["x_acf1", "x_acf10", "lumpiness", "stability", "hurst", "entropy"]
1077
+ # t_cols = [c for c in tsfeature_cols if c.startswith("T_")]
1078
+ # s_cols = [c for c in tsfeature_cols if c.startswith("S_")]
1079
+ # e_cols = [c for c in tsfeature_cols if c.startswith("e_")]
1080
+ # stats_cols = [c for c in tsfeature_cols if c not in global_cols+ t_cols + s_cols + e_cols]
1081
+
1082
+ # ordered_cols = base_cols + global_cols + t_cols + s_cols + e_cols + stats_cols
1083
+
1084
+
1085
+ # return df[ordered_cols].round(3)
src/tab.py ADDED
@@ -0,0 +1,1370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import plotly.graph_objects as go
4
+
5
+
6
+
7
+ # Add project root and src directory to Python path to enable imports from timebench
8
+ # Get the directory containing this file (leaderboard_app/src/)
9
+ current_dir = os.path.dirname(os.path.abspath(__file__))
10
+ # Get leaderboard_app directory
11
+ leaderboard_app_dir = os.path.dirname(current_dir)
12
+
13
+ # Try multiple paths for timebench import:
14
+ # 1. Current leaderboard_app directory (if timebench was copied to leaderboard_app/)
15
+ # 2. Parent directory's src (for local development: TIME/src/)
16
+
17
+ # Add current leaderboard_app directory first (for Space deployment)
18
+ if leaderboard_app_dir not in sys.path:
19
+ sys.path.insert(0, leaderboard_app_dir)
20
+
21
+ # Get project root directory (TIME/) - for local development
22
+ project_root = os.path.dirname(leaderboard_app_dir)
23
+ if project_root not in sys.path:
24
+ sys.path.insert(0, project_root)
25
+
26
+ src_dir = os.path.join(project_root, "src")
27
+ if src_dir not in sys.path and os.path.exists(src_dir):
28
+ sys.path.insert(0, src_dir)
29
+
30
+ import json
31
+ import gradio as gr
32
+ from src.about import DATASET_CHOICES, ALL_MODELS, RESULTS_ROOT, FEATURES_DF, FEATURES_BOOL_DF, PATTERN_MAP
33
+ from src.leaderboard import (get_overall_leaderboard, get_dataset_multilevel_leaderboard,
34
+ get_window_leaderboard, get_pattern_leaderboard, resolve_dataset_id)
35
+ from src.about import DATASETS_DF, ALL_HORIZONS
36
+ from src.hf_config import get_datasets_root, get_config_root
37
+ import numpy as np
38
+ import pandas as pd
39
+ from pathlib import Path
40
+ import ast
41
+ import matplotlib
42
+ matplotlib.use('Agg') # Use non-interactive backend for Gradio
43
+ import yaml
44
+ import tempfile
45
+
46
+ from timebench.evaluation.data import Dataset, get_dataset_settings, load_dataset_config
47
+ from src.leaderboard import find_dataset_term_path
48
+
49
+
50
+ def export_dataframe_to_csv(df, filename_prefix="leaderboard"):
51
+ """Export a DataFrame to a temporary CSV file and return the path for download.
52
+
53
+ Args:
54
+ df: pandas DataFrame to export
55
+ filename_prefix: prefix for the temporary file name
56
+
57
+ Returns:
58
+ str: path to the temporary CSV file, or None if df is empty
59
+ """
60
+ if df is None or (hasattr(df, 'empty') and df.empty):
61
+ return None
62
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, prefix=f"{filename_prefix}_") as f:
63
+ df.to_csv(f, index=False)
64
+ return f.name
65
+
66
+
67
+ # def update_variate_choices(dataset_name: str, selected_patterns: list[str]):
68
+ # """
69
+ # Dynamically update the variate dropdown choices based on dataset + patterns.
70
+ # """
71
+ # if dataset_name == "All":
72
+ # return gr.Dropdown(choices=["All"], value="All", interactive=False)
73
+
74
+ # # Filter features by dataset
75
+ # df = FEATURES_BOOL_DF[FEATURES_BOOL_DF["dataset"] == dataset_name]
76
+
77
+ # # Apply pattern filters if provided
78
+ # if selected_patterns:
79
+ # mask = pd.Series(True, index=df.index)
80
+ # for pattern in selected_patterns:
81
+ # if pattern in df.columns:
82
+ # mask &= df[pattern] == 1
83
+ # df = df[mask]
84
+
85
+ # variates = sorted(df["variate_name"].unique().tolist())
86
+ # if not variates:
87
+ # return gr.Dropdown(choices=["All"], value="All", interactive=False)
88
+
89
+ # return gr.Dropdown(choices=["All"] + variates, value="All", interactive=True)
90
+
91
+
92
+ # # 更新 Variate 选择框
93
+ # def update_variate_choices_groups(dataset_name, t, s, r, g):
94
+ # selected_patterns = (t or []) + (s or []) + (r or []) + (g or [])
95
+ # return update_variate_choices(dataset_name, selected_patterns)
96
+
97
+
98
+ ########################## Dataset Tab ##########################
99
+ def update_series_and_variate(display_name):
100
+ """
101
+ 根据 dataset display_name 更新 series 和 variate 的下拉选项
102
+ 用于合并后的 Dataset tab
103
+
104
+ Args:
105
+ display_name: Dataset display name from UI dropdown (will be resolved to dataset_id)
106
+ """
107
+ # Use first available model to get data
108
+ model_name = ALL_MODELS[0]
109
+ # Find dataset_term (handles display_name -> dataset_id conversion)
110
+ results_root = str(RESULTS_ROOT)
111
+ dataset_term = find_dataset_term_path(results_root, model_name, display_name)
112
+
113
+ if dataset_term is None:
114
+ print(f"Error: dataset_term is None for display_name={display_name}, model_name={model_name}")
115
+ return (
116
+ gr.Dropdown(choices=["---"], value="---", label="Select Series", interactive=True),
117
+ gr.Dropdown(choices=["---"], value="---", label="Select Variate", interactive=True),
118
+ )
119
+
120
+ # Load Dataset to get actual series and variate names
121
+ # Use HF config to get dataset root (handles both local and HF Hub)
122
+ hf_dataset_root = str(get_datasets_root())
123
+
124
+ # Use HF config to get config root
125
+ config_root = get_config_root()
126
+ config_path = config_root / "datasets.yaml"
127
+
128
+ # horizon不影响series和variate的值,因此直接用short
129
+ config = load_dataset_config(config_path)
130
+ settings = get_dataset_settings(dataset_term, "short", config)
131
+
132
+ prediction_length = settings.get("prediction_length")
133
+ test_length = settings.get("test_length")
134
+
135
+ dataset_obj = Dataset(
136
+ name=dataset_term,
137
+ term="short",
138
+ prediction_length=prediction_length,
139
+ test_length=test_length,
140
+ storage_path=hf_dataset_root, # Pass storage path directly
141
+ )
142
+
143
+ # Get series names
144
+ if "item_id" in dataset_obj.hf_dataset.column_names:
145
+ series_names = dataset_obj.hf_dataset["item_id"]
146
+ else:
147
+ series_names = [dataset_obj.hf_dataset[i].get("item_id", f"item_{i}")
148
+ for i in range(len(dataset_obj.hf_dataset))]
149
+
150
+ series_list = ["---"] + [str(name) for name in series_names]
151
+
152
+ # Get variate names
153
+ variate_names = dataset_obj.get_variate_names()
154
+
155
+ if variate_names is None:
156
+ # UTS mode: variate dropdown should be disabled
157
+ return (
158
+ gr.Dropdown(choices=series_list, value="---", label="Select Series", interactive=True),
159
+ gr.Dropdown(choices=["---"], value="---", label="Select Variate", interactive=False),
160
+ )
161
+ else:
162
+ # MTS mode: both dropdowns are enabled
163
+ variates_list = ["---"] + [str(name) for name in variate_names]
164
+ return (
165
+ gr.Dropdown(choices=series_list, value="---", label="Select Series", interactive=True),
166
+ gr.Dropdown(choices=variates_list, value="---", label="Select Variate", interactive=True),
167
+ )
168
+
169
+
170
+ ########################## Window Tab ##########################
171
+ def get_available_horizons(display_name):
172
+ """
173
+ 获取数据集可用的horizons
174
+
175
+ Args:
176
+ display_name: Dataset display name from UI dropdown
177
+
178
+ Returns:
179
+ list: 可用的horizon列表,例如 ["short", "medium", "long"] 或 ["short"]
180
+ """
181
+ if DATASETS_DF.empty:
182
+ return ALL_HORIZONS
183
+
184
+ # Resolve display_name to dataset_id
185
+ dataset_id = resolve_dataset_id(display_name)
186
+
187
+ # Filter by dataset_id
188
+ df_filtered = DATASETS_DF[DATASETS_DF["dataset_id"] == dataset_id]
189
+
190
+ if df_filtered.empty:
191
+ # If not found, return all horizons as fallback
192
+ return ALL_HORIZONS
193
+
194
+ # Get unique horizons for this dataset
195
+ available_horizons = df_filtered["horizon"].unique().tolist()
196
+
197
+ # Sort to maintain order: short, medium, long
198
+ available_horizons = [h for h in ALL_HORIZONS if h in available_horizons]
199
+
200
+ return available_horizons if available_horizons else ["short"]
201
+
202
+
203
+ def update_horizon_choices(display_name):
204
+ """
205
+ 根据数据集更新horizon Radio组件的choices和value
206
+
207
+ Args:
208
+ display_name: Dataset display name from UI dropdown
209
+
210
+ Returns:
211
+ tuple: (choices, value) 用于更新Radio组件
212
+ """
213
+ available_horizons = get_available_horizons(display_name)
214
+
215
+ # 如果当前选择的horizon不在可用列表中,则选择第一个可用的
216
+ current_value = "short" if "short" in available_horizons else (available_horizons[0] if available_horizons else "short")
217
+
218
+ # 创建choices列表,只包含可用的horizons
219
+ choices = [h for h in ALL_HORIZONS if h in available_horizons]
220
+
221
+ return gr.Radio(choices=choices, value=current_value)
222
+
223
+
224
+ def update_series_variate_and_window(display_name, horizon):
225
+ """
226
+ 根据 dataset display_name 和 horizon 更新 series, variate, window 的下拉选项
227
+ 使用 Dataset 加载实际的 series 和 variate 名称
228
+
229
+ Args:
230
+ display_name: Dataset display name from UI dropdown (will be resolved to dataset_id)
231
+ horizon: Horizon name (short, medium, long)
232
+ """
233
+ # Use first available model to get data
234
+ model_name = ALL_MODELS[0]
235
+
236
+ # Find dataset_term (handles display_name -> dataset_id conversion)
237
+ results_root = str(RESULTS_ROOT)
238
+ dataset_term = find_dataset_term_path(results_root, model_name, display_name)
239
+
240
+ if dataset_term is None:
241
+ print(f"Error: dataset_term is None for display_name={display_name}, horizon={horizon}, model_name={model_name}")
242
+ return (
243
+ gr.Dropdown(choices=[], value=None, label="Select Series", interactive=False),
244
+ gr.Dropdown(choices=[], value=None, label="Select Variate", interactive=False),
245
+ gr.Dropdown(choices=[], value=None, label="Select Testing Window", interactive=False),
246
+ )
247
+
248
+ # Parse dataset_name and freq from dataset_term (format: "dataset_name/freq")
249
+ dataset_name, freq = dataset_term.split("/", 1)
250
+
251
+ # Load Dataset to get actual series and variate names
252
+ # Use HF config to get dataset root (handles both local and HF Hub)
253
+ hf_dataset_root = str(get_datasets_root())
254
+
255
+ # Use HF config to get config root
256
+ config_root = get_config_root()
257
+ config_path = config_root / "datasets.yaml"
258
+
259
+ config = load_dataset_config(config_path) if config_path.exists() else {}
260
+ settings = get_dataset_settings(dataset_term, horizon, config)
261
+
262
+ prediction_length = settings.get("prediction_length")
263
+ test_length = settings.get("test_length")
264
+
265
+ # Load dataset
266
+ dataset_obj = Dataset(
267
+ name=dataset_term,
268
+ term=horizon,
269
+ prediction_length=prediction_length,
270
+ test_length=test_length,
271
+ storage_path=hf_dataset_root, # Pass storage path directly
272
+ )
273
+
274
+ # Get series names (item_id) from hf_dataset
275
+ if "item_id" in dataset_obj.hf_dataset.column_names:
276
+ series_names = dataset_obj.hf_dataset["item_id"]
277
+ else:
278
+ # Fallback: get from iterating
279
+ series_names = [dataset_obj.hf_dataset[i].get("item_id", f"item_{i}")
280
+ for i in range(len(dataset_obj.hf_dataset))]
281
+
282
+ # Get variate names
283
+ variate_names = dataset_obj.get_variate_names()
284
+
285
+ # Get window count
286
+ num_windows = dataset_obj.windows
287
+ windows = [str(i) for i in range(num_windows)]
288
+
289
+ # Convert to lists and maintain order (no sorting)
290
+ series_list = [str(name) for name in series_names]
291
+
292
+ # Handle UTS (Univariate Time Series) vs MTS (Multivariate Time Series)
293
+ if variate_names is None:
294
+ # UTS mode: each series is a single variate, so variate is always 0
295
+ return (
296
+ gr.Dropdown(choices=series_list, value=series_list[0], label="Select Series", interactive=True),
297
+ gr.Dropdown(choices=["0"], value="0", label="Select Variate", interactive=False),
298
+ gr.Dropdown(choices=windows, value=windows[0], label="Select Testing Window", interactive=True),
299
+ )
300
+ else:
301
+ # MTS mode: multiple variates per series
302
+ variates_list = [str(name) for name in variate_names]
303
+ return (
304
+ gr.Dropdown(choices=series_list, value=series_list[0], label="Select Series", interactive=True),
305
+ gr.Dropdown(choices=variates_list, value=variates_list[0], label="Select Variate", interactive=True),
306
+ gr.Dropdown(choices=windows, value=windows[0], label="Select Testing Window", interactive=True),
307
+ )
308
+
309
+
310
+ def plot_window_series(display_name, series, variate, window_id, horizon, selected_quantiles, model):
311
+ """
312
+ Plot time series predictions for a specific window using Plotly for interactive visualization.
313
+ Now includes full time series visualization with test window highlighted.
314
+ Accepts series and variate names (strings) and converts them to indices.
315
+
316
+ Args:
317
+ display_name: Dataset display name from UI dropdown (will be resolved to dataset_id)
318
+ series: Series name
319
+ variate: Variate name
320
+ window_id: Window index
321
+ horizon: Horizon name
322
+ selected_quantiles: List of quantile strings to plot
323
+ model: Model name
324
+
325
+ Returns:
326
+ tuple: (fig, info_message) where fig is Plotly figure and info_message contains prediction details
327
+ """
328
+ print(f"🔍 plot_window_series called: display_name={display_name}, series={series}, variate={variate}, window_id={window_id}, horizon={horizon}, model={model}")
329
+
330
+ if display_name is None or series is None or variate is None or window_id is None:
331
+ print("❌ Missing parameters")
332
+ fig = go.Figure()
333
+ fig.update_layout(title="Please select all parameters")
334
+ return fig, ""
335
+
336
+ results_root = str(RESULTS_ROOT)
337
+ print(f"📁 results_root: {results_root}")
338
+ dataset_term = find_dataset_term_path(results_root, model, display_name)
339
+ print(f"📁 dataset_term: {dataset_term}")
340
+ if dataset_term is None:
341
+ print("❌ Dataset not found")
342
+ fig = go.Figure()
343
+ fig.update_layout(title="Dataset not found")
344
+ return fig, ""
345
+
346
+ predictions_path = os.path.join(results_root, model, dataset_term, horizon, "predictions.npz")
347
+ print(f"📁 predictions_path: {predictions_path}, exists: {os.path.exists(predictions_path)}")
348
+
349
+ if not os.path.exists(predictions_path):
350
+ print("❌ Predictions file not found")
351
+ fig = go.Figure()
352
+ fig.update_layout(title="Predictions file not found")
353
+ return fig, ""
354
+
355
+
356
+ predictions = np.load(predictions_path)
357
+ # Load pre-computed quantiles (new format only)
358
+ predictions_quantiles = predictions["predictions_quantiles"] # (num_series, num_windows, 9, num_variates, prediction_length)
359
+ quantile_levels = predictions["quantile_levels"] # [0.1, 0.2, ..., 0.9]
360
+
361
+ # Load prediction scale factor from config.json (for float16 overflow prevention)
362
+ model_config_path = os.path.join(results_root, model, dataset_term, horizon, "config.json")
363
+ prediction_scale_factor = 1.0
364
+ if os.path.exists(model_config_path):
365
+ with open(model_config_path, "r") as f:
366
+ model_config = json.load(f)
367
+ prediction_scale_factor = model_config.get("prediction_scale_factor", 1.0)
368
+ if prediction_scale_factor != 1.0:
369
+ print(f"📊 Applying inverse scale factor: {prediction_scale_factor}")
370
+ predictions_quantiles = predictions_quantiles.astype(np.float32) * prediction_scale_factor
371
+
372
+ # Convert series and variate names to indices
373
+ series_idx = None
374
+ variate_idx = None
375
+ dataset_obj = None
376
+
377
+ # Load Dataset to get name-to-index mappings and full time series
378
+ # Use HF config to get dataset root (handles both local and HF Hub)
379
+ hf_dataset_root = str(get_datasets_root())
380
+ print(f"📁 hf_dataset_root: {hf_dataset_root}, exists: {os.path.exists(hf_dataset_root)}")
381
+
382
+ # Use HF config to get config root
383
+ config_root = get_config_root()
384
+ config_path_yaml = config_root / "datasets.yaml"
385
+ print(f"📁 config_path_yaml: {config_path_yaml}, exists: {config_path_yaml.exists()}")
386
+
387
+ config = load_dataset_config(config_path_yaml) if config_path_yaml.exists() else {}
388
+ settings = get_dataset_settings(dataset_term, horizon, config)
389
+ print(f"⚙️ settings: {settings}")
390
+
391
+ prediction_length = settings.get("prediction_length")
392
+ test_length = settings.get("test_length")
393
+
394
+ print(f"📥 Loading Dataset: name={dataset_term}, term={horizon}, storage_path={hf_dataset_root}")
395
+ dataset_obj = Dataset(
396
+ name=dataset_term,
397
+ term=horizon,
398
+ prediction_length=prediction_length,
399
+ test_length=test_length,
400
+ storage_path=hf_dataset_root, # Pass storage path directly
401
+ )
402
+ print(f"✅ Dataset loaded: {len(dataset_obj.hf_dataset)} series")
403
+
404
+ # Get frequency from dataset
405
+ dataset_freq = dataset_obj.freq
406
+ print(f"📅 Dataset frequency: {dataset_freq}")
407
+
408
+ # Get series names and create mapping
409
+ if "item_id" in dataset_obj.hf_dataset.column_names:
410
+ series_names = dataset_obj.hf_dataset["item_id"]
411
+ else:
412
+ series_names = [dataset_obj.hf_dataset[i].get("item_id", f"item_{i}")
413
+ for i in range(len(dataset_obj.hf_dataset))]
414
+
415
+ print(f"📋 series_names: {list(series_names)}")
416
+ series_name_to_idx = {name: idx for idx, name in enumerate(series_names)}
417
+ if series in series_name_to_idx:
418
+ series_idx = series_name_to_idx[series]
419
+ print(f"✅ Found series '{series}' at index {series_idx}")
420
+ else:
421
+ series_idx = int(series)
422
+ print(f"⚠️ Series '{series}' not found in names, using int index {series_idx}")
423
+
424
+ # Get variate names and create mapping
425
+ variate_names = dataset_obj.get_variate_names()
426
+ print(f"📋 variate_names: {variate_names}")
427
+ if variate_names is not None:
428
+ # MTS mode: multiple variates per series
429
+ variate_name_to_idx = {name: idx for idx, name in enumerate(variate_names)}
430
+ if variate in variate_name_to_idx:
431
+ variate_idx = variate_name_to_idx[variate]
432
+ print(f"✅ Found variate '{variate}' at index {variate_idx}")
433
+ else:
434
+ variate_idx = int(variate)
435
+ print(f"⚠️ Variate '{variate}' not found in names, using int index {variate_idx}")
436
+ else:
437
+ # UTS mode: each series is a single variate, so variate_idx is always 0
438
+ variate_idx = 0
439
+ print(f"ℹ️ UTS mode, variate_idx=0")
440
+
441
+ if series_idx is None:
442
+ series_idx = int(series)
443
+ if variate_idx is None:
444
+ # For UTS mode, variate_idx should be 0
445
+ try:
446
+ variate_idx = int(variate) if variate is not None else 0
447
+ except (ValueError, TypeError):
448
+ variate_idx = 0
449
+
450
+ window_idx = int(window_id)
451
+
452
+ # Get pre-computed quantiles for this specific series, window, and variate
453
+ quantiles_data = predictions_quantiles[series_idx, window_idx, :, variate_idx, :] # (9, prediction_length)
454
+ prediction_length = quantiles_data.shape[1]
455
+ # Create mapping from quantile level string to index
456
+ quantile_level_to_idx = {f"{q:.1f}": i for i, q in enumerate(quantile_levels)}
457
+
458
+ # Load full time series data
459
+ full_series = None
460
+ train_end_idx = None
461
+ test_window_start_idx = None
462
+ test_window_end_idx = None
463
+
464
+ # Get full target time series for this series
465
+ print(f"📊 Getting target for series_idx={series_idx}, variate_idx={variate_idx}")
466
+ full_target = dataset_obj.hf_dataset[series_idx]["target"]
467
+ print(f"📊 full_target shape: {full_target.shape}, dtype: {full_target.dtype}")
468
+ print(f"📊 full_target first 10 values (all variates): {full_target[:, :10] if full_target.ndim > 1 else full_target[:10]}")
469
+
470
+ # Get start timestamp for this series and create timestamp array
471
+ series_start = dataset_obj.hf_dataset[series_idx]["start"]
472
+ print(f"📅 Series start timestamp: {series_start}, type: {type(series_start)}")
473
+
474
+ # Handle numpy array containing datetime64 (common when reading from HF dataset)
475
+ if isinstance(series_start, np.ndarray):
476
+ # Extract scalar from array
477
+ series_start = series_start.item() if series_start.ndim == 0 else series_start[0]
478
+ print(f"📅 Extracted scalar: {series_start}, type: {type(series_start)}")
479
+
480
+ # Convert numpy datetime64 to pandas Timestamp
481
+ if isinstance(series_start, (np.datetime64, str)):
482
+ series_start = pd.Timestamp(series_start)
483
+
484
+ # Calculate series length for timestamp creation
485
+ if full_target.ndim > 1:
486
+ ts_length = full_target.shape[1]
487
+ else:
488
+ ts_length = len(full_target)
489
+
490
+ # Create timestamp array for the entire series
491
+ try:
492
+ timestamps = pd.date_range(start=series_start, periods=ts_length, freq=dataset_freq)
493
+ print(f"📅 Created timestamp array: {timestamps[0]} to {timestamps[-1]}")
494
+ except Exception as e:
495
+ print(f"⚠️ Failed to create timestamps: {e}, falling back to indices")
496
+ timestamps = None
497
+
498
+ # Handle multivariate case: extract specific variate
499
+ if full_target.ndim > 1:
500
+ full_series = full_target[variate_idx, :] # Shape: (series_length,)
501
+ else:
502
+ full_series = full_target # Shape: (series_length,)
503
+ print(f"📊 full_series shape: {full_series.shape}, min: {full_series.min()}, max: {full_series.max()}, has_nan: {np.isnan(full_series).any()}")
504
+
505
+ # Calculate train/test split point
506
+ # Test data starts at: series_length - test_length
507
+ series_length = len(full_series)
508
+ train_end_idx = series_length - test_length
509
+
510
+ # Calculate current test window position
511
+ test_window_start_idx = train_end_idx + window_idx * prediction_length
512
+ test_window_end_idx = test_window_start_idx + prediction_length
513
+
514
+ # Create Plotly figure
515
+ fig = go.Figure()
516
+
517
+ # Quantile colors - from light to dark
518
+ quantile_colors = {
519
+ "0.1": "#c6dbef", "0.9": "#c6dbef", # lightest
520
+ "0.2": "#6baed6", "0.8": "#6baed6", # light
521
+ "0.3": "#4292c6", "0.7": "#4292c6", # medium
522
+ "0.4": "#2171b5", "0.6": "#2171b5", # dark
523
+ "0.5": "#08306b", # darkest (median)
524
+ }
525
+
526
+ # Calculate prediction time steps (overlay on the test window)
527
+ if test_window_start_idx is not None:
528
+ pred_time_steps = np.arange(test_window_start_idx, test_window_end_idx)
529
+ else:
530
+ pred_time_steps = np.arange(prediction_length)
531
+
532
+ # Plot full time series if available
533
+ time_steps = np.arange(len(full_series))
534
+
535
+ # Use timestamps for x-axis if available
536
+ if timestamps is not None:
537
+ x_full = timestamps
538
+ x_pred = timestamps[pred_time_steps] if test_window_start_idx is not None else timestamps[:prediction_length]
539
+ x_window = timestamps[test_window_start_idx:test_window_end_idx] if test_window_start_idx is not None else None
540
+ else:
541
+ x_full = time_steps
542
+ x_pred = pred_time_steps
543
+ x_window = np.arange(test_window_start_idx, test_window_end_idx) if test_window_start_idx is not None else None
544
+
545
+ # Plot full series in light gray
546
+ fig.add_trace(go.Scatter(
547
+ x=x_full,
548
+ y=full_series,
549
+ mode='lines',
550
+ name='Full Time Series',
551
+ line=dict(color='gray', width=1),
552
+ opacity=0.6,
553
+ hovertemplate='Time: %{x}<br>Value: %{y:.4f}<extra></extra>'
554
+ ))
555
+
556
+ # Add shapes for regions (training, test, current window)
557
+ if train_end_idx is not None:
558
+ # Training region - use timestamps if available
559
+ x0_train = timestamps[0] if timestamps is not None else 0
560
+ x1_train = timestamps[train_end_idx] if timestamps is not None else train_end_idx
561
+ fig.add_shape(
562
+ type="rect",
563
+ x0=x0_train, x1=x1_train,
564
+ y0=0, y1=1, yref="paper",
565
+ fillcolor="blue", opacity=0.1,
566
+ layer="below", line_width=0,
567
+ )
568
+ # Test region
569
+ test_region_end = len(full_series)
570
+ x0_test = timestamps[train_end_idx] if timestamps is not None else train_end_idx
571
+ x1_test = timestamps[test_region_end-1] if timestamps is not None else test_region_end-1
572
+ fig.add_shape(
573
+ type="rect",
574
+ x0=x0_test, x1=x1_test,
575
+ y0=0, y1=1, yref="paper",
576
+ fillcolor="orange", opacity=0.15,
577
+ layer="below", line_width=0,
578
+ )
579
+
580
+ # Highlight current test window
581
+ if test_window_start_idx is not None and test_window_end_idx is not None:
582
+ # Use timestamps for window highlight if available
583
+ x0_window = timestamps[test_window_start_idx] if timestamps is not None else test_window_start_idx
584
+ x1_window = timestamps[test_window_end_idx-1] if timestamps is not None else test_window_end_idx-1
585
+ fig.add_shape(
586
+ type="rect",
587
+ x0=x0_window, x1=x1_window,
588
+ y0=0, y1=1, yref="paper",
589
+ fillcolor="red", opacity=0.2,
590
+ layer="below", line_width=0,
591
+ )
592
+
593
+ # Plot the test window portion of full series
594
+ window_series = full_series[test_window_start_idx:test_window_end_idx]
595
+ fig.add_trace(go.Scatter(
596
+ x=x_window,
597
+ y=window_series,
598
+ mode='lines',
599
+ name='Ground Truth (Window)',
600
+ line=dict(color='red', width=2),
601
+ opacity=0.8,
602
+ hovertemplate='Time: %{x}<br>Value: %{y:.4f}<extra></extra>'
603
+ ))
604
+
605
+ # Quantile pairs mapping: UI selection -> (low, high) quantile values
606
+ quantile_pair_map = {
607
+ "0.1-0.9": ("0.1", "0.9"),
608
+ "0.2-0.8": ("0.2", "0.8"),
609
+ "0.3-0.7": ("0.3", "0.7"),
610
+ "0.4-0.6": ("0.4", "0.6"),
611
+ }
612
+
613
+ # Helper function to get pre-computed quantile values
614
+ def get_quantile_values(q_str):
615
+ return quantiles_data[quantile_level_to_idx[q_str], :]
616
+
617
+ # Plot quantile pairs with fill (based on paired selection)
618
+ for pair_str, (q_low_str, q_high_str) in quantile_pair_map.items():
619
+ if pair_str in selected_quantiles:
620
+ quantile_low = get_quantile_values(q_low_str)
621
+ quantile_high = get_quantile_values(q_high_str)
622
+ color = quantile_colors.get(q_low_str, "#2171b5")
623
+
624
+ # Add filled area between quantiles
625
+ fig.add_trace(go.Scatter(
626
+ x=list(x_pred) + list(x_pred[::-1]),
627
+ y=list(quantile_high) + list(quantile_low[::-1]),
628
+ fill='toself',
629
+ fillcolor=color,
630
+ line=dict(color='rgba(255,255,255,0)'),
631
+ hoverinfo="skip",
632
+ showlegend=True,
633
+ name=f'Q{q_low_str}-Q{q_high_str}',
634
+ opacity=0.3
635
+ ))
636
+
637
+ # Add lower quantile line
638
+ fig.add_trace(go.Scatter(
639
+ x=x_pred,
640
+ y=quantile_low,
641
+ mode='lines',
642
+ name=f'Q{q_low_str}',
643
+ line=dict(color=color, width=1),
644
+ opacity=0.7,
645
+ showlegend=False,
646
+ hovertemplate=f'Time: %{{x}}<br>Q{q_low_str}: %{{y:.4f}}<extra></extra>'
647
+ ))
648
+
649
+ # Add upper quantile line
650
+ fig.add_trace(go.Scatter(
651
+ x=x_pred,
652
+ y=quantile_high,
653
+ mode='lines',
654
+ name=f'Q{q_high_str}',
655
+ line=dict(color=color, width=1),
656
+ opacity=0.7,
657
+ showlegend=False,
658
+ hovertemplate=f'Time: %{{x}}<br>Q{q_high_str}: %{{y:.4f}}<extra></extra>'
659
+ ))
660
+
661
+ # Plot median (0.5) if selected
662
+ if "0.5" in selected_quantiles:
663
+ quantile_values = get_quantile_values("0.5")
664
+ color = quantile_colors.get("0.5", "#08306b")
665
+
666
+ fig.add_trace(go.Scatter(
667
+ x=x_pred,
668
+ y=quantile_values,
669
+ mode='lines+markers',
670
+ name='Median (Q0.5)',
671
+ line=dict(color=color, width=3),
672
+ marker=dict(size=5, symbol='circle'),
673
+ opacity=0.8,
674
+ hovertemplate='Time: %{x}<br>Q0.5: %{y:.4f}<extra></extra>'
675
+ ))
676
+
677
+ # Update layout - use autosize for responsive width
678
+ x_axis_title = "Timestamp" if timestamps is not None else "Time Step"
679
+ fig.update_layout(
680
+ title=None,
681
+ xaxis_title=x_axis_title,
682
+ yaxis_title="Value",
683
+ hovermode='x unified',
684
+ autosize=True, # 使用自动宽度,让图表响应容器大小
685
+ height=400,
686
+ margin=dict(l=60, r=40, t=60, b=60), # 设置合理的边距
687
+ legend=dict(
688
+ orientation="h",
689
+ yanchor="bottom",
690
+ y=1.02,
691
+ xanchor="right",
692
+ x=1,
693
+ font=dict(size=14)
694
+ ),
695
+ plot_bgcolor='white',
696
+ xaxis=dict(showgrid=True, gridcolor='lightgray', gridwidth=1),
697
+ yaxis=dict(showgrid=True, gridcolor='lightgray', gridwidth=1)
698
+ )
699
+
700
+
701
+ # Create info message for prediction window
702
+ if timestamps is not None and test_window_start_idx is not None and test_window_end_idx is not None:
703
+ pred_start_ts = timestamps[test_window_start_idx]
704
+ pred_end_ts = timestamps[test_window_end_idx - 1] # -1 because end index is exclusive
705
+ # Format with weekday name
706
+ start_str = f"{pred_start_ts.strftime('%Y-%m-%d %H:%M:%S')} ({pred_start_ts.day_name()})"
707
+ end_str = f"{pred_end_ts.strftime('%Y-%m-%d %H:%M:%S')} ({pred_end_ts.day_name()})"
708
+ base_info = (
709
+ f"📊 Prediction Length: {prediction_length}\n"
710
+ f"📅 Prediction Range: {start_str} → {end_str}\n"
711
+ f"🔄 Dataset Frequency: {dataset_freq}"
712
+ )
713
+ else:
714
+ base_info = (
715
+ f"📊 Prediction Length: {prediction_length}\n"
716
+ f"📅 Prediction Range: index {test_window_start_idx} → {test_window_end_idx - 1}\n"
717
+ f"🔄 Dataset Frequency: {dataset_freq if 'dataset_freq' in dir() else 'N/A'}"
718
+ )
719
+
720
+ # Get features information for the selected variate
721
+ # Pattern names from init_per_pattern_tab
722
+ pattern_names = [
723
+ "T_strength", "T_linearity",
724
+ "S_strength", "S_corr",
725
+ "R_ACF1",
726
+ "stationarity", "complexity"
727
+ ]
728
+
729
+ features_info = ""
730
+ if not FEATURES_DF.empty and not FEATURES_BOOL_DF.empty:
731
+ # Find matching row in features dataframes
732
+ # Try to match by dataset_id, series_name, variate_name
733
+ feature_row_orig = None
734
+ feature_row_bool = None
735
+
736
+ # Match by dataset_id first
737
+ features_subset_orig = FEATURES_DF[FEATURES_DF["dataset_id"] == dataset_term]
738
+ features_subset_bool = FEATURES_BOOL_DF[FEATURES_BOOL_DF["dataset_id"] == dataset_term]
739
+
740
+ print(f"🔍 Features lookup: dataset_term={dataset_term}, series={series}, variate={variate}")
741
+ print(f"🔍 Features subset size: orig={len(features_subset_orig)}, bool={len(features_subset_bool)}")
742
+
743
+ # Try matching by series_name and variate_name (for MTS)
744
+ if not features_subset_orig.empty:
745
+ # Check if series_name matches
746
+ if "series_name" in features_subset_orig.columns:
747
+ series_match_orig = features_subset_orig["series_name"] == series
748
+ if series_match_orig.any():
749
+ series_matched = features_subset_orig[series_match_orig]
750
+ print(f"🔍 Found {len(series_matched)} rows with series_name={series}")
751
+ # Check if variate_name matches
752
+ if "variate_name" in series_matched.columns:
753
+ # For UTS, variate might be "0" or 0, try both
754
+ variate_str = str(variate)
755
+ variate_match_orig = (series_matched["variate_name"] == variate_str) | (series_matched["variate_name"] == variate)
756
+ if variate_match_orig.any():
757
+ feature_row_orig = series_matched[variate_match_orig].iloc[0]
758
+ print(f"✅ Found feature row by series_name + variate_name")
759
+ # Find corresponding row in bool dataframe
760
+ if not features_subset_bool.empty and "series_name" in features_subset_bool.columns and "variate_name" in features_subset_bool.columns:
761
+ series_match_bool = features_subset_bool["series_name"] == series
762
+ variate_match_bool = (features_subset_bool["variate_name"] == variate_str) | (features_subset_bool["variate_name"] == variate)
763
+ bool_matched = features_subset_bool[series_match_bool & variate_match_bool]
764
+ if not bool_matched.empty:
765
+ feature_row_bool = bool_matched.iloc[0]
766
+
767
+ # If not found, try matching by series_name only (for UTS cases where variate_name might not match)
768
+ if feature_row_orig is None and not features_subset_orig.empty:
769
+ if "series_name" in features_subset_orig.columns:
770
+ series_match_orig = features_subset_orig["series_name"] == series
771
+ if series_match_orig.any():
772
+ # For UTS, there might be only one row per series
773
+ series_matched = features_subset_orig[series_match_orig]
774
+ if len(series_matched) == 1:
775
+ feature_row_orig = series_matched.iloc[0]
776
+ print(f"✅ Found feature row by series_name only (UTS)")
777
+ # Find corresponding row in bool dataframe
778
+ if not features_subset_bool.empty and "series_name" in features_subset_bool.columns:
779
+ series_match_bool = features_subset_bool["series_name"] == series
780
+ bool_matched = features_subset_bool[series_match_bool]
781
+ if len(bool_matched) == 1:
782
+ feature_row_bool = bool_matched.iloc[0]
783
+
784
+ # If still not found, try matching by variate_name only (for UTS cases where variate_name == series)
785
+ if feature_row_orig is None and not features_subset_orig.empty:
786
+ if "variate_name" in features_subset_orig.columns:
787
+ variate_match_orig = features_subset_orig["variate_name"] == series # For UTS, series might be the variate_name
788
+ if variate_match_orig.any():
789
+ feature_row_orig = features_subset_orig[variate_match_orig].iloc[0]
790
+ print(f"✅ Found feature row by variate_name (series as variate_name)")
791
+ # Find corresponding row in bool dataframe
792
+ if not features_subset_bool.empty and "variate_name" in features_subset_bool.columns:
793
+ variate_match_bool = features_subset_bool["variate_name"] == series
794
+ if variate_match_bool.any():
795
+ feature_row_bool = features_subset_bool[variate_match_bool].iloc[0]
796
+
797
+ if feature_row_orig is None:
798
+ print(f"⚠️ Could not find features for dataset_term={dataset_term}, series={series}, variate={variate}")
799
+ if not features_subset_orig.empty:
800
+ print(f" Available series_names: {features_subset_orig['series_name'].unique()[:10] if 'series_name' in features_subset_orig.columns else 'N/A'}")
801
+ print(f" Available variate_names: {features_subset_orig['variate_name'].unique()[:10] if 'variate_name' in features_subset_orig.columns else 'N/A'}")
802
+
803
+ if feature_row_orig is not None:
804
+ # Build features display
805
+ features_orig_items = []
806
+ features_bool_items = []
807
+
808
+ for pattern_name in pattern_names:
809
+ # Map pattern name to feature column name
810
+ feature_col = PATTERN_MAP.get(pattern_name, pattern_name)
811
+
812
+ # Get original value (skip stationarity as it's derived from is_random_walk)
813
+ if pattern_name != "stationarity":
814
+ if feature_col in feature_row_orig.index:
815
+ orig_value = feature_row_orig[feature_col]
816
+ if pd.notna(orig_value):
817
+ features_orig_items.append(f"{pattern_name}: {orig_value:.3f}")
818
+
819
+ # Get binary value
820
+ if feature_row_bool is not None and feature_col in feature_row_bool.index:
821
+ bool_value = feature_row_bool[feature_col]
822
+ if pd.notna(bool_value):
823
+ # Special handling for stationarity (it's inverted)
824
+ if pattern_name == "stationarity":
825
+ # stationarity = NOT is_random_walk, so display the inverted value
826
+ display_bool = 1 - int(bool_value)
827
+ else:
828
+ display_bool = int(bool_value)
829
+ features_bool_items.append(f"{pattern_name}: {display_bool}")
830
+
831
+ if features_orig_items or features_bool_items:
832
+ features_info = "\n\n 📝 Features of variate:\n"
833
+ if features_orig_items:
834
+ features_info += "- Original Values: " + ", ".join(features_orig_items) + "\n"
835
+ if features_bool_items:
836
+ features_info += "- Binary Values (0/1): " + ", ".join(features_bool_items)
837
+
838
+ info_message = base_info + features_info
839
+
840
+ print(f"📝 Info message: {info_message}")
841
+ return fig, info_message
842
+
843
+
844
+ def init_overall_tab():
845
+ gr.Markdown(
846
+ """
847
+ This tab presents each model's overall performance aggregated across all tasks. A **task** is defined as a specific **(dataset, horizon)** pair. For each task, the result is obtained by averaging the metrics across all its variates.
848
+ - **MASE (norm.), CRPS (norm.)**: task-level results are normalized by Seasonal Naive and aggregated by geometric mean.
849
+ - **MASE_rank, CRPS_rank**: for each task, models are ranked by the metric; the average rank across all tasks is then reported.
850
+ """,
851
+ elem_classes="markdown-text"
852
+ )
853
+
854
+ overall_table = gr.DataFrame(
855
+ value=get_overall_leaderboard(DATASETS_DF, metric="MASE"),
856
+ elem_classes="custom-table",
857
+ interactive=False
858
+ )
859
+
860
+ # CSV Export
861
+ def export_overall_csv():
862
+ df = get_overall_leaderboard(DATASETS_DF, metric="MASE")
863
+ return export_dataframe_to_csv(df, filename_prefix="overall_leaderboard")
864
+
865
+ with gr.Row():
866
+ export_btn = gr.Button("📥 Export CSV", size="sm")
867
+ export_file = gr.File(label="Download CSV", visible=False)
868
+
869
+ export_btn.click(
870
+ fn=export_overall_csv,
871
+ inputs=[],
872
+ outputs=[export_file]
873
+ ).then(
874
+ fn=lambda: gr.File(visible=True),
875
+ inputs=[],
876
+ outputs=[export_file]
877
+ )
878
+
879
+
880
+ def init_per_dataset_tab(demo):
881
+ gr.Markdown(
882
+ """
883
+ This tab provides flexible analysis at dataset, series, and variate levels.
884
+
885
+ - **Dataset only**: Shows both Seasonal Naive-normalized metrics (task-level) and original non-normalized metrics, plus average ranks
886
+ - **Series/Variate selected**: Shows only original metrics.
887
+ - **Horizons**: Select one or more horizons to aggregate results
888
+ """,
889
+ elem_classes="markdown-text"
890
+ )
891
+
892
+ with gr.Row():
893
+ with gr.Column(scale=1):
894
+ horizons = gr.CheckboxGroup(
895
+ choices=ALL_HORIZONS,
896
+ value=ALL_HORIZONS,
897
+ label="Horizons"
898
+ )
899
+
900
+ dataset_dropdown = gr.Dropdown(
901
+ choices=DATASET_CHOICES,
902
+ value=DATASET_CHOICES[0],
903
+ label="Dataset",
904
+ interactive=True
905
+ )
906
+
907
+ # Initialize series and variate dropdowns
908
+ initial_dataset = DATASET_CHOICES[0]
909
+ series_dropdown, variate_dropdown = update_series_and_variate(
910
+ initial_dataset
911
+ )
912
+
913
+ msg = gr.Textbox(label="Message", interactive=False)
914
+ table = gr.DataFrame(elem_classes="custom-table", interactive=False)
915
+
916
+ # Update series and variate dropdowns when dataset changes
917
+ dataset_dropdown.change(
918
+ fn=update_series_and_variate,
919
+ inputs=[dataset_dropdown],
920
+ outputs=[series_dropdown, variate_dropdown],
921
+ )
922
+
923
+ # Update leaderboard when any selection changes
924
+ for comp in [dataset_dropdown, series_dropdown, variate_dropdown, horizons]:
925
+ comp.change(
926
+ fn=get_dataset_multilevel_leaderboard,
927
+ inputs=[dataset_dropdown, series_dropdown, variate_dropdown, horizons],
928
+ outputs=[msg, table]
929
+ )
930
+
931
+ # Load on startup
932
+ demo.load(
933
+ fn=get_dataset_multilevel_leaderboard,
934
+ inputs=[dataset_dropdown, series_dropdown, variate_dropdown, horizons],
935
+ outputs=[msg, table]
936
+ )
937
+
938
+ # CSV Export
939
+ def export_dataset_csv(dataset, series, variate, horizons_val):
940
+ _, df = get_dataset_multilevel_leaderboard(dataset, series, variate, horizons_val)
941
+ # Sanitize dataset name for filename (replace / with _)
942
+ safe_dataset_name = dataset.replace("/", "_") if dataset else "unknown"
943
+ return export_dataframe_to_csv(df, filename_prefix=f"dataset_{safe_dataset_name}")
944
+
945
+ with gr.Row():
946
+ export_btn = gr.Button("📥 Export CSV", size="sm")
947
+ export_file = gr.File(label="Download CSV", visible=False)
948
+
949
+ export_btn.click(
950
+ fn=export_dataset_csv,
951
+ inputs=[dataset_dropdown, series_dropdown, variate_dropdown, horizons],
952
+ outputs=[export_file]
953
+ ).then(
954
+ fn=lambda: gr.File(visible=True),
955
+ inputs=[],
956
+ outputs=[export_file]
957
+ )
958
+
959
+
960
+ def init_per_window_tab(demo):
961
+ gr.Markdown(
962
+ """
963
+ This tab enables detailed analysis of model performance at the level of individual testing windows. By selecting a dataset, variate, horizon, and test window, users can examine window-level metrics (MASE, CRPS, MAE, MSE) at fine granularity and visualize the predicted quantiles of a model along with the ground-truth.
964
+ - **Interactive Visualization**: Zoom, pan, autoscale and download the plot.
965
+ - 🟦 Train Split 🟨 Test Split 🟥 Prediction Window
966
+ """
967
+ )
968
+
969
+ QUANTILE_PAIR_CHOICES = ["0.1-0.9", "0.2-0.8", "0.3-0.7", "0.4-0.6", "0.5"]
970
+ initial_quantiles = ["0.5"]
971
+
972
+ with gr.Row():
973
+ with gr.Column(scale=1):
974
+ # Initialize horizon choices based on first dataset
975
+ initial_dataset = DATASET_CHOICES[0] if DATASET_CHOICES else None
976
+ initial_horizons = get_available_horizons(initial_dataset) if initial_dataset else ALL_HORIZONS
977
+ horizons = gr.Radio(
978
+ choices=initial_horizons,
979
+ value="short" if "short" in initial_horizons else (initial_horizons[0] if initial_horizons else "short"),
980
+ label="Horizons"
981
+ )
982
+
983
+ # Dropdown for dataset selection
984
+ dataset_dropdown = gr.Dropdown(
985
+ choices=DATASET_CHOICES,
986
+ value=DATASET_CHOICES[0] if DATASET_CHOICES else None, # 默认选第一个
987
+ label="Dataset",
988
+ interactive=True
989
+ )
990
+
991
+ # Initialize series, variate, window dropdowns using function
992
+ series_dropdown, variate_dropdown, window_dropdown = update_series_variate_and_window(
993
+ dataset_dropdown.value, horizons.value
994
+ )
995
+
996
+ with gr.Column(scale=2):
997
+ with gr.Row():
998
+ with gr.Column(scale=2):
999
+ quantiles = gr.CheckboxGroup(
1000
+ choices=QUANTILE_PAIR_CHOICES,
1001
+ value=initial_quantiles,
1002
+ label="Select Quantiles for Visualization"
1003
+ )
1004
+ with gr.Column(scale=1):
1005
+ model = gr.Dropdown(
1006
+ choices=ALL_MODELS,
1007
+ value=ALL_MODELS[0],
1008
+ label="Select Model for Visualization",
1009
+ interactive=True
1010
+ )
1011
+ ts_visualization = gr.Plot()
1012
+ # Message box for prediction window info
1013
+ prediction_info = gr.Textbox(
1014
+ label="Info",
1015
+ interactive=False,
1016
+ lines=3
1017
+ )
1018
+
1019
+ table_window = gr.DataFrame(elem_classes="custom-table", interactive=False)
1020
+
1021
+ # When dataset changes: first update horizon choices, then update dropdowns
1022
+ dataset_dropdown.change(
1023
+ fn=update_horizon_choices,
1024
+ inputs=[dataset_dropdown],
1025
+ outputs=[horizons],
1026
+ ).then(
1027
+ fn=update_series_variate_and_window,
1028
+ inputs=[dataset_dropdown, horizons],
1029
+ outputs=[series_dropdown, variate_dropdown, window_dropdown],
1030
+ ).then(
1031
+ # After dropdowns are updated, refresh the visualization and table
1032
+ fn=plot_window_series,
1033
+ inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons, quantiles, model],
1034
+ outputs=[ts_visualization, prediction_info]
1035
+ ).then(
1036
+ fn=get_window_leaderboard,
1037
+ inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons],
1038
+ outputs=table_window
1039
+ )
1040
+
1041
+ # When horizon changes: update dropdowns, then refresh visualization
1042
+ horizons.change(
1043
+ fn=update_series_variate_and_window,
1044
+ inputs=[dataset_dropdown, horizons],
1045
+ outputs=[series_dropdown, variate_dropdown, window_dropdown],
1046
+ ).then(
1047
+ fn=plot_window_series,
1048
+ inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons, quantiles, model],
1049
+ outputs=[ts_visualization, prediction_info]
1050
+ ).then(
1051
+ fn=get_window_leaderboard,
1052
+ inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons],
1053
+ outputs=table_window
1054
+ )
1055
+
1056
+ # For series, variate, window changes - update visualization and table
1057
+ for comp in [series_dropdown, variate_dropdown, window_dropdown]:
1058
+ comp.change(
1059
+ fn=get_window_leaderboard,
1060
+ inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons],
1061
+ outputs=table_window
1062
+ )
1063
+ comp.change(
1064
+ fn=plot_window_series,
1065
+ inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons, quantiles, model],
1066
+ outputs=[ts_visualization, prediction_info]
1067
+ )
1068
+
1069
+ # For quantiles and model changes - only update visualization (no table change needed)
1070
+ for comp in [quantiles, model]:
1071
+ comp.change(
1072
+ fn=plot_window_series,
1073
+ inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons, quantiles, model],
1074
+ outputs=[ts_visualization, prediction_info]
1075
+ )
1076
+
1077
+ # Load initial visualization and table on page load
1078
+ demo.load(
1079
+ fn=plot_window_series,
1080
+ inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons, quantiles, model],
1081
+ outputs=[ts_visualization, prediction_info]
1082
+ )
1083
+ demo.load(
1084
+ fn=get_window_leaderboard,
1085
+ inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons],
1086
+ outputs=table_window
1087
+ )
1088
+
1089
+ # CSV Export
1090
+ def export_window_csv(dataset, series, variate, window, horizon):
1091
+ df = get_window_leaderboard(dataset, series, variate, window, horizon)
1092
+ return export_dataframe_to_csv(df, filename_prefix="window_leaderboard")
1093
+
1094
+ with gr.Row():
1095
+ export_btn = gr.Button("📥 Export CSV", size="sm")
1096
+ export_file = gr.File(label="Download CSV", visible=False)
1097
+
1098
+ export_btn.click(
1099
+ fn=export_window_csv,
1100
+ inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons],
1101
+ outputs=[export_file]
1102
+ ).then(
1103
+ fn=lambda: gr.File(visible=True),
1104
+ inputs=[],
1105
+ outputs=[export_file]
1106
+ )
1107
+
1108
+
1109
+ def init_per_pattern_tab(demo):
1110
+ gr.Markdown(
1111
+ """
1112
+ This tab allows you to explore model performance based on **selected patterns**.
1113
+
1114
+ Select patterns to filter variates that exhibit those characteristics, then view aggregated model performance.
1115
+ Each pattern is a **boolean indicator** derived from time series features (binarized by **median** threshold for continuous features).
1116
+
1117
+ - **Patterns are intersected**: A variate must exhibit ALL selected patterns to be included.
1118
+ - **MASE (norm.), CRPS (norm.)**: variate-level results are normalized by Seasonal Naive and aggregated by geometric mean across all matching variates.
1119
+ - **MASE (raw), CRPS (raw)**: arithmetic mean across all matching variates.
1120
+ """,
1121
+ elem_classes="markdown-text"
1122
+ )
1123
+
1124
+ # Define pattern choices for Radio components
1125
+ PATTERN_CHOICES = ["N/A", "=1", "=0"]
1126
+
1127
+ with gr.Row(): # TSFeatures
1128
+ with gr.Column(scale=1):
1129
+ with gr.Group():
1130
+ gr.Markdown("### 📈 Trend Features")
1131
+ T_strength = gr.Radio(
1132
+ choices=PATTERN_CHOICES, value="N/A", label="T_strength"
1133
+ )
1134
+ T_linearity = gr.Radio(
1135
+ choices=PATTERN_CHOICES, value="N/A", label="T_linearity"
1136
+ )
1137
+ with gr.Column(scale=1):
1138
+ with gr.Group():
1139
+ gr.Markdown("### 🔄 Seasonal Features")
1140
+ S_strength = gr.Radio(
1141
+ choices=PATTERN_CHOICES, value="N/A", label="S_strength"
1142
+ )
1143
+ S_corr = gr.Radio(
1144
+ choices=PATTERN_CHOICES, value="N/A", label="S_corr"
1145
+ )
1146
+ with gr.Column(scale=1):
1147
+ with gr.Group():
1148
+ gr.Markdown("### 🎯 Residual Features")
1149
+ R_ACF1 = gr.Radio(
1150
+ choices=PATTERN_CHOICES, value="N/A", label="R_ACF1"
1151
+ )
1152
+ with gr.Column(scale=1):
1153
+ with gr.Group():
1154
+ gr.Markdown("### ⚙️ Global Features")
1155
+ stationarity = gr.Radio(
1156
+ choices=PATTERN_CHOICES, value="N/A", label="stationarity"
1157
+ )
1158
+ complexity = gr.Radio(
1159
+ choices=PATTERN_CHOICES, value="N/A", label="complexity"
1160
+ )
1161
+
1162
+ # List of all pattern Radio components and their names
1163
+ pattern_radios = [
1164
+ T_strength, T_linearity,
1165
+ S_strength, S_corr,
1166
+ R_ACF1,
1167
+ stationarity, complexity
1168
+ ]
1169
+ pattern_names = [
1170
+ "T_strength", "T_linearity",
1171
+ "S_strength", "S_corr",
1172
+ "R_ACF1",
1173
+ "stationarity", "complexity"
1174
+ ]
1175
+
1176
+ with gr.Row():
1177
+ with gr.Column(scale=1):
1178
+ horizons = gr.CheckboxGroup(
1179
+ choices=ALL_HORIZONS,
1180
+ value=ALL_HORIZONS,
1181
+ label="Horizons"
1182
+ )
1183
+ with gr.Column(scale=2):
1184
+ msg_pattern = gr.Textbox(label="Status", interactive=False, lines=4)
1185
+
1186
+ table_variates = gr.DataFrame(elem_classes="custom-table", interactive=False)
1187
+
1188
+ def merge_patterns(*radio_values):
1189
+ """Convert Radio values to pattern filter dict.
1190
+
1191
+ Args:
1192
+ *radio_values: Values from all Radio components in order of pattern_names
1193
+
1194
+ Returns:
1195
+ dict: {feature_name: required_value} where required_value is 0 or 1.
1196
+ Features with "N/A" are not included in the dict.
1197
+ """
1198
+ result = {}
1199
+ for name, value in zip(pattern_names, radio_values):
1200
+ if value == "=1":
1201
+ result[name] = 1
1202
+ elif value == "=0":
1203
+ result[name] = 0
1204
+ # "N/A" -> don't include in dict (no filter on this feature)
1205
+ return result
1206
+
1207
+ def update_leaderboard(*args):
1208
+ """Callback to update the pattern leaderboard.
1209
+
1210
+ Args:
1211
+ *args: All Radio values followed by horizons (last argument)
1212
+ """
1213
+ # Last argument is horizons, rest are pattern radio values
1214
+ horizons_val = args[-1]
1215
+ radio_values = args[:-1]
1216
+ pattern_filters = merge_patterns(*radio_values)
1217
+ return get_pattern_leaderboard(pattern_filters, horizons_val)
1218
+
1219
+ # Bind change events for all pattern radios and horizons
1220
+ all_inputs = pattern_radios + [horizons]
1221
+ for comp in all_inputs:
1222
+ comp.change(
1223
+ fn=update_leaderboard,
1224
+ inputs=all_inputs,
1225
+ outputs=[msg_pattern, table_variates]
1226
+ )
1227
+
1228
+ # Load initial state
1229
+ demo.load(
1230
+ fn=update_leaderboard,
1231
+ inputs=all_inputs,
1232
+ outputs=[msg_pattern, table_variates]
1233
+ )
1234
+
1235
+ # CSV Export
1236
+ def export_pattern_csv(*args):
1237
+ # Last argument is horizons, rest are pattern radio values
1238
+ horizons_val = args[-1]
1239
+ radio_values = args[:-1]
1240
+ pattern_filters = merge_patterns(*radio_values)
1241
+ _, df = get_pattern_leaderboard(pattern_filters, horizons_val)
1242
+ return export_dataframe_to_csv(df, filename_prefix="pattern_leaderboard")
1243
+
1244
+ with gr.Row():
1245
+ export_btn = gr.Button("📥 Export CSV", size="sm")
1246
+ export_file = gr.File(label="Download CSV", visible=False)
1247
+
1248
+ export_btn.click(
1249
+ fn=export_pattern_csv,
1250
+ inputs=all_inputs,
1251
+ outputs=[export_file]
1252
+ ).then(
1253
+ fn=lambda: gr.File(visible=True),
1254
+ inputs=[],
1255
+ outputs=[export_file]
1256
+ )
1257
+
1258
+
1259
+ # # ToDO: Now the archive is using different features from the ones in per_pattern tab
1260
+ # def init_archive_tab(demo):
1261
+ # gr.Markdown(
1262
+ # """
1263
+ # This tab provides an interactive archive of the features of time series variates across datasets. You can explore the archive by specifying a dataset, domain, and frequency, and filter variates with the selected structural patterns. Each pattern is a **boolean indicator** showing whether a variate exhibits the pattern, with thresholds derived from the distribution of feature values across the entire dataset. Pattern filters are applied as an **intersection** (a variate must exhibit all selected patterns). Domain and frequency filters are applied as a **union** (a variate may belong to any selected category). The resulting table displays all variates that satisfy the chosen filters, together with their dataset, frequency, domain, and computed feature values. This view makes it possible to identify and group variates that share similar feature profiles.
1264
+ # """
1265
+ # )
1266
+
1267
+ # with gr.Row():
1268
+ # with gr.Column(scale=1):
1269
+ # dataset_dropdown = gr.Dropdown(
1270
+ # choices=["All"] + sorted(FEATURES_BOOL_DF["dataset"].unique().tolist()),
1271
+ # value="All",
1272
+ # label="Select Dataset"
1273
+ # )
1274
+
1275
+ # variate_dropdown = gr.Dropdown(
1276
+ # choices=["All"],
1277
+ # value="All",
1278
+ # label="Select Variate",
1279
+ # interactive=False
1280
+ # )
1281
+
1282
+ # domains = gr.CheckboxGroup(
1283
+ # choices=ALL_DOMAINS,
1284
+ # value=ALL_DOMAINS, # default all checked
1285
+ # label="Domains"
1286
+ # )
1287
+
1288
+ # freqs = gr.CheckboxGroup(
1289
+ # choices=ALL_FREQS,
1290
+ # value=ALL_FREQS, # 默认全选
1291
+ # label="Frequencies"
1292
+ # )
1293
+
1294
+ # with gr.Column(scale=2):
1295
+ # trend_group = gr.CheckboxGroup(
1296
+ # choices=["trend", "trend_stability", "trend_lumpiness", "trend_hurst", "trend_entropy"],
1297
+ # label="Trend Patterns"
1298
+ # )
1299
+
1300
+ # season_group = gr.CheckboxGroup(
1301
+ # choices=["seasonal_strength", "seasonality_corr", "seasonal_stability",
1302
+ # "seasonal_lumpiness", "seasonal_hurst", "seasonal_entropy"],
1303
+ # label="Seasonality Patterns"
1304
+ # )
1305
+
1306
+ # remainder_group = gr.CheckboxGroup(
1307
+ # choices=["e_acf1", "e_acf10",
1308
+ # "e_entropy", "e_hurst", "e_lumpiness", "e_outlier_ratio"],
1309
+ # label="Remainder Patterns"
1310
+ # )
1311
+
1312
+ # global_group = gr.CheckboxGroup(
1313
+ # choices=["x_acf1", "x_acf10", "lumpiness", "stability", "hurst", "entropy"],
1314
+ # label="Global Patterns"
1315
+ # )
1316
+
1317
+ # msg_box = gr.Textbox(
1318
+ # label="Message",
1319
+ # interactive=False
1320
+ # )
1321
+
1322
+ # archive_leaderboard = gr.DataFrame(
1323
+ # elem_classes="custom-table",
1324
+ # elem_id="archive-table",
1325
+ # max_height=600,
1326
+ # interactive=False
1327
+ # )
1328
+
1329
+ # # 绑定事件
1330
+ # domains.change(
1331
+ # fn=update_dataset_choices,
1332
+ # inputs=[domains, freqs],
1333
+ # outputs=dataset_dropdown
1334
+ # )
1335
+
1336
+ # freqs.change(
1337
+ # fn=update_dataset_choices,
1338
+ # inputs=[domains, freqs],
1339
+ # outputs=dataset_dropdown
1340
+ # )
1341
+
1342
+ # # Change DF
1343
+ # for comp in [dataset_dropdown, trend_group, season_group, remainder_group, global_group]:
1344
+ # comp.change(
1345
+ # fn=update_variate_choices_groups,
1346
+ # inputs=[dataset_dropdown, trend_group, season_group, remainder_group, global_group],
1347
+ # outputs=variate_dropdown
1348
+ # )
1349
+ # comp.change(
1350
+ # fn=collect_patterns,
1351
+ # inputs=[dataset_dropdown, trend_group, season_group, remainder_group, global_group,
1352
+ # variate_dropdown, domains, freqs],
1353
+ # outputs=[msg_box, archive_leaderboard]
1354
+ # )
1355
+
1356
+ # for comp in [variate_dropdown, domains, freqs]:
1357
+ # comp.change(
1358
+ # fn=collect_patterns,
1359
+ # inputs=[dataset_dropdown, trend_group, season_group, remainder_group, global_group,
1360
+ # variate_dropdown, domains, freqs],
1361
+ # outputs=[msg_box, archive_leaderboard]
1362
+ # )
1363
+
1364
+ # # Initial Load
1365
+ # demo.load(
1366
+ # fn=collect_patterns,
1367
+ # inputs=[dataset_dropdown, trend_group, season_group, remainder_group, global_group,
1368
+ # variate_dropdown, domains, freqs],
1369
+ # outputs=[msg_box, archive_leaderboard]
1370
+ # )
src/utils.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import json
5
+ from typing import List, Tuple, Optional
6
+ import yaml
7
+ from pathlib import Path
8
+ from scipy import stats
9
+
10
+ from timebench.evaluation.data import Dataset, get_dataset_settings, load_dataset_config
11
+ from src.hf_config import get_datasets_root, get_config_root
12
+
13
+
14
+ def load_time_results(root_dir, model_name, dataset_with_freq, horizon):
15
+ """
16
+ Load TIME results from NPZ files for a specific model, dataset, and horizon.
17
+
18
+ Args:
19
+ root_dir: Root directory containing TIME results (e.g., "output/results")
20
+ model_name: Model name (e.g., "moirai_small")
21
+ dataset_with_freq: Dataset and freq combined (e.g., "Water_Quality_Darwin/15T")
22
+ horizon: Horizon name (e.g., "short", "medium", "long")
23
+
24
+ Returns:
25
+ tuple: (metrics_dict, predictions_dict, config_dict) or (None, None, None) if not found
26
+ """
27
+ horizon_dir = os.path.join(root_dir, model_name, dataset_with_freq, horizon)
28
+ metrics_path = os.path.join(horizon_dir, "metrics.npz")
29
+ predictions_path = os.path.join(horizon_dir, "predictions.npz")
30
+ config_path = os.path.join(horizon_dir, "config.json")
31
+
32
+ if not os.path.exists(metrics_path) or not os.path.exists(predictions_path):
33
+ return None, None, None
34
+
35
+ metrics = np.load(metrics_path)
36
+ predictions = np.load(predictions_path)
37
+
38
+ metrics_dict = {k: metrics[k] for k in metrics.files}
39
+ predictions_dict = {k: predictions[k] for k in predictions.files}
40
+
41
+ config_dict = {}
42
+ if os.path.exists(config_path):
43
+ with open(config_path, "r") as f:
44
+ config_dict = json.load(f)
45
+
46
+ return metrics_dict, predictions_dict, config_dict
47
+
48
+
49
+ def get_all_datasets_results(root_dir="output/results"):
50
+ """
51
+ Load dataset-level leaderboard by reading TIME NPZ files and aggregating.
52
+
53
+ Args:
54
+ root_dir (str): Path to the TIME results root directory (e.g., "output/results").
55
+
56
+ Returns:
57
+ pd.DataFrame: DataFrame containing dataset-level results with columns
58
+ ["model", "dataset", "freq", "dataset_id", "horizon", "MASE", "CRPS", "MAE", "MSE"].
59
+ - dataset: Original dataset name (e.g., "Traffic")
60
+ - freq: Frequency string (e.g., "15T", "1H")
61
+ - dataset_id: Unique identifier as "dataset/freq" (e.g., "Traffic/15T")
62
+ Number of Rows: num_model x num_dataset_freq_combinations x num_horizons
63
+ """
64
+ rows = []
65
+
66
+ if not os.path.exists(root_dir):
67
+ print(f"Error: root_dir={root_dir} does not exist")
68
+ return pd.DataFrame(columns=["model", "dataset", "freq", "dataset_id", "horizon", "MASE", "CRPS", "MAE", "MSE"])
69
+
70
+ for model in os.listdir(root_dir):
71
+ model_dir = os.path.join(root_dir, model)
72
+ if not os.path.isdir(model_dir):
73
+ continue
74
+
75
+ for dataset in os.listdir(model_dir):
76
+ dataset_dir = os.path.join(model_dir, dataset)
77
+ if not os.path.isdir(dataset_dir):
78
+ continue
79
+
80
+ # Nested structure: model/dataset/freq/horizon/
81
+ for freq_dir in os.listdir(dataset_dir):
82
+ freq_path = os.path.join(dataset_dir, freq_dir)
83
+ if not os.path.isdir(freq_path):
84
+ continue
85
+
86
+ for horizon in ["short", "medium", "long"]:
87
+ dataset_with_freq = f"{dataset}/{freq_dir}"
88
+ metrics_dict, _, config_dict = load_time_results(root_dir, model, dataset_with_freq, horizon)
89
+ if metrics_dict is None:
90
+ continue
91
+
92
+ # Aggregate metrics
93
+ mase = np.nanmean(metrics_dict.get("MASE", np.array([])))
94
+ crps = np.nanmean(metrics_dict.get("CRPS", np.array([])))
95
+ mae = np.nanmean(metrics_dict.get("MAE", np.array([])))
96
+ mse = np.nanmean(metrics_dict.get("MSE", np.array([])))
97
+
98
+ rows.append({
99
+ "model": model,
100
+ "dataset": dataset,
101
+ "freq": freq_dir,
102
+ "dataset_id": dataset_with_freq, # Unique identifier: dataset/freq
103
+ "horizon": horizon,
104
+ "MASE": mase,
105
+ "CRPS": crps,
106
+ "MAE": mae,
107
+ "MSE": mse,
108
+ })
109
+
110
+ if rows:
111
+ return pd.DataFrame(rows)
112
+ else:
113
+ return pd.DataFrame(columns=["model", "dataset", "freq", "dataset_id", "horizon", "MASE", "CRPS", "MAE", "MSE"])
114
+
115
+
116
+ def get_dataset_display_map(datasets_df: pd.DataFrame) -> Tuple[dict, dict]:
117
+ """
118
+ Generate smart display name mapping for datasets.
119
+
120
+ For datasets with only one freq: display as "dataset" (e.g., "Australia_Solar")
121
+ For datasets with multiple freqs: display as "dataset/freq" (e.g., "Traffic/15T")
122
+
123
+ Args:
124
+ datasets_df: DataFrame with 'dataset', 'freq', 'dataset_id' columns
125
+
126
+ Returns:
127
+ Tuple of:
128
+ - id_to_display: dict mapping dataset_id -> display_name
129
+ - display_to_id: dict mapping display_name -> dataset_id
130
+ """
131
+ if datasets_df.empty:
132
+ return {}, {}
133
+
134
+ # Count unique freqs per dataset
135
+ freq_counts = datasets_df.groupby('dataset')['freq'].nunique()
136
+
137
+ # Build mappings
138
+ id_to_display = {}
139
+ display_to_id = {}
140
+
141
+ unique_configs = datasets_df[['dataset', 'freq', 'dataset_id']].drop_duplicates()
142
+
143
+ for _, row in unique_configs.iterrows():
144
+ dataset_id = row['dataset_id']
145
+ dataset_name = row['dataset']
146
+
147
+ if freq_counts[dataset_name] > 1:
148
+ # Multiple freqs: display as dataset/freq
149
+ display_name = dataset_id
150
+ else:
151
+ # Single freq: display as dataset only
152
+ display_name = dataset_name
153
+
154
+ id_to_display[dataset_id] = display_name
155
+ display_to_id[display_name] = dataset_id
156
+
157
+ return id_to_display, display_to_id
158
+
159
+
160
+ def get_all_variates_results(root_dir: str = "output/results") -> pd.DataFrame:
161
+ """
162
+ Collect all variate-individual-level results from TIME NPZ files.
163
+
164
+ Each (series, variate) combination is treated as an independent variate individual.
165
+ Metrics are aggregated only across windows (not across series).
166
+ Uses actual series_names and variate_names from Dataset objects.
167
+
168
+ Args:
169
+ root_dir (str): Path to the TIME results root directory (e.g., "output/results").
170
+
171
+ Returns:
172
+ pd.DataFrame: DataFrame with columns:
173
+ ["dataset_id", "series_name", "variate_name", "is_uts", "model", "horizon", "MASE", "CRPS", "MAE", "MSE"]
174
+ Number of Rows: num_models x num_datasets x num_horizons x num_series x num_variates
175
+ """
176
+ rows = []
177
+
178
+ if not os.path.exists(root_dir):
179
+ print(f"[get_all_variates_results] root_dir={root_dir} does not exist")
180
+ return pd.DataFrame(columns=["dataset_id", "series_name", "variate_name", "is_uts", "model", "horizon", "MASE", "CRPS", "MAE", "MSE"])
181
+
182
+ # Cache for dataset info (series_names, variate_names) to avoid repeated loading
183
+ dataset_info_cache = {}
184
+
185
+ for model in os.listdir(root_dir):
186
+ model_dir = os.path.join(root_dir, model)
187
+ if not os.path.isdir(model_dir):
188
+ continue
189
+
190
+ for dataset in os.listdir(model_dir):
191
+ dataset_dir = os.path.join(model_dir, dataset)
192
+ if not os.path.isdir(dataset_dir):
193
+ continue
194
+
195
+ # Nested structure: model/dataset/freq/horizon/
196
+ for freq_dir in os.listdir(dataset_dir):
197
+ freq_path = os.path.join(dataset_dir, freq_dir)
198
+ if not os.path.isdir(freq_path):
199
+ continue
200
+
201
+ dataset_id = f"{dataset}/{freq_dir}"
202
+
203
+ # Get series_names and variate_names (use cache)
204
+ if dataset_id not in dataset_info_cache:
205
+ series_names = None
206
+ variate_names = None
207
+ is_uts = False
208
+ try:
209
+ hf_dataset_root = str(get_datasets_root())
210
+ if os.path.exists(hf_dataset_root):
211
+ config_root = get_config_root()
212
+ config_path = config_root / "datasets.yaml"
213
+ config = load_dataset_config(config_path) if config_path.exists() else {}
214
+ settings = get_dataset_settings(dataset_id, "short", config)
215
+
216
+ dataset_obj = Dataset(
217
+ name=dataset_id,
218
+ term="short",
219
+ prediction_length=settings.get("prediction_length"),
220
+ test_length=settings.get("test_length"),
221
+ storage_path=hf_dataset_root,
222
+ )
223
+
224
+ # Get series names
225
+ if "item_id" in dataset_obj.hf_dataset.column_names:
226
+ series_names = list(dataset_obj.hf_dataset["item_id"])
227
+ else:
228
+ series_names = [f"item_{i}" for i in range(len(dataset_obj.hf_dataset))]
229
+
230
+ # Get variate names
231
+ variate_names = dataset_obj.get_variate_names()
232
+ if variate_names is None:
233
+ # UTS mode: variate_names = series_names, and is_uts = True
234
+ is_uts = True
235
+ variate_names = series_names
236
+ else:
237
+ variate_names = list(variate_names)
238
+ except Exception as e:
239
+ print(f"[get_all_variates_results] Error loading Dataset info for {dataset_id}: {e}")
240
+
241
+ dataset_info_cache[dataset_id] = {
242
+ "series_names": series_names,
243
+ "variate_names": variate_names,
244
+ "is_uts": is_uts,
245
+ }
246
+
247
+ info = dataset_info_cache[dataset_id]
248
+ series_names = info["series_names"]
249
+ variate_names = info["variate_names"]
250
+ is_uts = info["is_uts"]
251
+
252
+ for horizon in ["short", "medium", "long"]:
253
+ metrics_dict, _, _ = load_time_results(root_dir, model, dataset_id, horizon)
254
+ if metrics_dict is None:
255
+ continue
256
+
257
+ # Get metrics arrays: shape = (num_series, num_windows, num_variates)
258
+ mase_arr = metrics_dict.get("MASE", np.array([]))
259
+ crps_arr = metrics_dict.get("CRPS", np.array([]))
260
+ mae_arr = metrics_dict.get("MAE", np.array([]))
261
+ mse_arr = metrics_dict.get("MSE", np.array([]))
262
+
263
+ if mase_arr.size == 0:
264
+ continue
265
+
266
+ num_series, num_windows, num_variates = mase_arr.shape
267
+
268
+ # Iterate over each (series, variate) combination
269
+ for series_idx in range(num_series):
270
+ series_name = series_names[series_idx] if series_names and series_idx < len(series_names) else f"item_{series_idx}"
271
+
272
+ for variate_idx in range(num_variates):
273
+ # For UTS: variate_name = series_name (since each series is its own variate)
274
+ if is_uts:
275
+ variate_name = series_name
276
+ else:
277
+ variate_name = variate_names[variate_idx] if variate_names and variate_idx < len(variate_names) else str(variate_idx)
278
+
279
+ # Aggregate only across windows
280
+ mase = np.nanmean(mase_arr[series_idx, :, variate_idx])
281
+ crps = np.nanmean(crps_arr[series_idx, :, variate_idx])
282
+ mae = np.nanmean(mae_arr[series_idx, :, variate_idx])
283
+ mse = np.nanmean(mse_arr[series_idx, :, variate_idx])
284
+
285
+ # Skip if all values are NaN
286
+ if np.isnan(mase) and np.isnan(crps):
287
+ continue
288
+
289
+ rows.append({
290
+ "dataset_id": dataset_id,
291
+ "series_name": series_name,
292
+ "variate_name": variate_name,
293
+ "is_uts": is_uts,
294
+ "model": model,
295
+ "horizon": horizon,
296
+ "MASE": mase,
297
+ "CRPS": crps,
298
+ "MAE": mae,
299
+ "MSE": mse,
300
+ })
301
+
302
+ if rows:
303
+ return pd.DataFrame(rows)
304
+ else:
305
+ return pd.DataFrame(columns=["dataset_id", "series_name", "variate_name", "is_uts", "model", "horizon", "MASE", "CRPS", "MAE", "MSE"])
306
+
307
+
308
+ def get_all_domains_and_freq(conf_dir="conf/data", datasets=None):
309
+ """
310
+ Scan YAML files and collect all unique domains.
311
+ """
312
+ domains, freqs = set(), set()
313
+
314
+ for ds in datasets:
315
+ yaml_path = os.path.join(conf_dir, f"{ds}.yaml")
316
+ if os.path.exists(yaml_path):
317
+ with open(yaml_path, "r") as f:
318
+ meta = yaml.safe_load(f)
319
+ domain = meta.get("domain")
320
+ freq = meta.get("freq")
321
+ if domain:
322
+ domains.add(domain)
323
+ if freq:
324
+ freqs.add(freq)
325
+ return sorted(list(domains)), sorted(list(freqs))
326
+
327
+
328
+ def get_dataset_choices(results_root="output/results") -> Tuple[List[str], dict, dict]:
329
+ """
330
+ Get list of available datasets from TIME results with smart display names.
331
+
332
+ For datasets with only one freq: display as "dataset" (e.g., "Australia_Solar")
333
+ For datasets with multiple freqs: display as "dataset/freq" (e.g., "Traffic/15T")
334
+
335
+ Args:
336
+ results_root: Path to the TIME results root directory
337
+
338
+ Returns:
339
+ Tuple of:
340
+ - display_names: Sorted list of display names for UI dropdown
341
+ - display_to_id: dict mapping display_name -> dataset_id
342
+ - id_to_display: dict mapping dataset_id -> display_name
343
+ """
344
+ if not os.path.exists(results_root):
345
+ return [], {}, {}
346
+
347
+ # Collect all dataset/freq combinations
348
+ dataset_freq_pairs = set() # Set of (dataset, freq) tuples
349
+
350
+ for model in os.listdir(results_root):
351
+ model_dir = os.path.join(results_root, model)
352
+ if not os.path.isdir(model_dir):
353
+ continue
354
+
355
+ for dataset in os.listdir(model_dir):
356
+ dataset_dir = os.path.join(model_dir, dataset)
357
+ if not os.path.isdir(dataset_dir):
358
+ continue
359
+
360
+ # Check directory structure
361
+ has_horizon_dirs = any(os.path.isdir(os.path.join(dataset_dir, h)) for h in ["short", "medium", "long"])
362
+
363
+ if has_horizon_dirs:
364
+ # Direct structure (legacy): treat as dataset with empty freq
365
+ # This shouldn't happen in the new structure but handle for safety
366
+ for horizon in ["short", "medium", "long"]:
367
+ config_path = os.path.join(dataset_dir, horizon, "config.json")
368
+ if os.path.exists(config_path):
369
+ dataset_freq_pairs.add((dataset, ""))
370
+ break
371
+ else:
372
+ # Nested structure: model/dataset/freq/horizon/
373
+ for freq_dir in os.listdir(dataset_dir):
374
+ freq_path = os.path.join(dataset_dir, freq_dir)
375
+ if not os.path.isdir(freq_path):
376
+ continue
377
+
378
+ for horizon in ["short", "medium", "long"]:
379
+ config_path = os.path.join(freq_path, horizon, "config.json")
380
+ if os.path.exists(config_path):
381
+ dataset_freq_pairs.add((dataset, freq_dir))
382
+ break
383
+
384
+ if not dataset_freq_pairs:
385
+ return [], {}, {}
386
+
387
+ # Count freqs per dataset
388
+ from collections import Counter
389
+ dataset_freq_count = Counter(ds for ds, _ in dataset_freq_pairs)
390
+
391
+ # Build mappings
392
+ id_to_display = {}
393
+ display_to_id = {}
394
+
395
+ for dataset, freq in dataset_freq_pairs:
396
+ if freq:
397
+ dataset_id = f"{dataset}/{freq}"
398
+ else:
399
+ dataset_id = dataset
400
+
401
+ if dataset_freq_count[dataset] > 1:
402
+ # Multiple freqs: display as dataset/freq
403
+ display_name = dataset_id
404
+ else:
405
+ # Single freq: display as dataset only
406
+ display_name = dataset
407
+
408
+ id_to_display[dataset_id] = display_name
409
+ display_to_id[display_name] = dataset_id
410
+
411
+ # Sort display names for UI
412
+ display_names = sorted(display_to_id.keys())
413
+
414
+ return display_names, display_to_id, id_to_display
415
+
416
+
417
+ def compute_ranks(df: pd.DataFrame, groupby_cols: str | List[str]) -> pd.DataFrame:
418
+ """
419
+ Compute ranks for models across datasets based on MASE and CRPS.
420
+
421
+ Args:
422
+ df (pd.DataFrame): Dataset-level results with columns
423
+ ["model", "dataset", "MASE", "CRPS"].
424
+
425
+ Returns:
426
+ pd.DataFrame: Dataframe with ["model", "MASE_rank", "CRPS_rank"].
427
+ """
428
+ if isinstance(groupby_cols, str):
429
+ groupby_cols = [groupby_cols]
430
+
431
+ if df.empty:
432
+ return pd.DataFrame(columns=["model", "MASE_rank", "CRPS_rank"])
433
+
434
+ df = df.copy()
435
+
436
+ df["MASE_rank"] = df.groupby(groupby_cols)["MASE"].rank(method="first", ascending=True)
437
+ df["CRPS_rank"] = df.groupby(groupby_cols)["CRPS"].rank(method="first", ascending=True)
438
+
439
+ return df
440
+
441
+
442
+ def normalize_by_seasonal_naive(
443
+ df: pd.DataFrame,
444
+ baseline_model: str = "seasonal_naive",
445
+ metrics: List[str] = None,
446
+ groupby_cols: List[str] = None,
447
+ ) -> pd.DataFrame:
448
+ """
449
+ Normalize metrics by Seasonal Naive baseline for each (dataset_id, horizon) group.
450
+
451
+ For each group, divides each model's metric values by Seasonal Naive's values.
452
+ This makes Seasonal Naive the baseline (=1.0) for comparison.
453
+
454
+ Args:
455
+ df (pd.DataFrame): Dataset-level results with columns including
456
+ ["model", "dataset_id", "horizon", "MASE", "CRPS", ...].
457
+ baseline_model (str): Name of the baseline model. Defaults to "seasonal_naive".
458
+ metrics (List[str]): List of metric columns to normalize. Defaults to ["MASE", "CRPS"].
459
+ groupby_cols (List[str]): Columns to group by for normalization.
460
+ Defaults to ["dataset_id", "horizon"].
461
+
462
+ Returns:
463
+ pd.DataFrame: DataFrame with normalized metric values.
464
+ - Configurations without baseline model results are excluded.
465
+ - NaN/inf values from division are handled.
466
+ """
467
+ if metrics is None:
468
+ metrics = ["MASE", "CRPS"]
469
+ if groupby_cols is None:
470
+ groupby_cols = ["dataset_id", "horizon"]
471
+
472
+ if df.empty:
473
+ return df.copy()
474
+
475
+ # Check if baseline model exists
476
+ if baseline_model not in df["model"].values:
477
+ print(f"[normalize_by_seasonal_naive] Warning: baseline model '{baseline_model}' not found in data")
478
+ return df.copy()
479
+
480
+ # Work on a copy
481
+ df_normalized = df.copy()
482
+
483
+ # Get baseline values for each group
484
+ baseline_df = df[df["model"] == baseline_model].copy()
485
+
486
+ # Create a mapping: (dataset_id, horizon) -> {metric: baseline_value}
487
+ baseline_values = {}
488
+ for _, row in baseline_df.iterrows():
489
+ key = tuple(row[col] for col in groupby_cols)
490
+ baseline_values[key] = {metric: row[metric] for metric in metrics}
491
+
492
+ # Normalize each row
493
+ rows_to_keep = []
494
+ for idx, row in df_normalized.iterrows():
495
+ key = tuple(row[col] for col in groupby_cols)
496
+
497
+ # Skip configurations without baseline results
498
+ if key not in baseline_values:
499
+ continue
500
+
501
+ rows_to_keep.append(idx)
502
+
503
+ # Normalize each metric
504
+ for metric in metrics:
505
+ baseline_val = baseline_values[key][metric]
506
+ if baseline_val is not None and baseline_val != 0 and not np.isnan(baseline_val):
507
+ df_normalized.at[idx, metric] = row[metric] / baseline_val
508
+ else:
509
+ # Handle division by zero or NaN baseline
510
+ df_normalized.at[idx, metric] = np.nan
511
+
512
+ # Keep only rows with valid baseline
513
+ df_normalized = df_normalized.loc[rows_to_keep].copy()
514
+
515
+ # Handle any remaining inf values
516
+ for metric in metrics:
517
+ df_normalized[metric] = df_normalized[metric].replace([np.inf, -np.inf], np.nan)
518
+
519
+ return df_normalized
520
+
521
+
522
+ def load_features(root_dir: str = "features", category: str = "public-benchmarks", split: str = "test") -> pd.DataFrame:
523
+ """
524
+ Load time series features for all datasets (legacy function).
525
+
526
+ Args:
527
+ root_dir (str): Path to features root directory.
528
+ category (str): Dataset category (e.g., "public-benchmarks").
529
+ split (str): Which split to load ("full" or "test").
530
+
531
+ Returns:
532
+ pd.DataFrame: Concatenated DataFrame with dataset column.
533
+ """
534
+ base_dir = os.path.join(root_dir, category)
535
+ all_data = []
536
+
537
+ for dataset in os.listdir(base_dir):
538
+ dataset_dir = os.path.join(base_dir, dataset)
539
+ csv_path = os.path.join(dataset_dir, f"{split}.csv")
540
+ if os.path.exists(csv_path):
541
+ df = pd.read_csv(csv_path)
542
+ df["dataset"] = dataset # add dataset name
543
+ cols = ["dataset"] + [c for c in df.columns if c != "dataset"] # 让 dataset 列放到第一列
544
+ df = df[cols]
545
+ all_data.append(df)
546
+
547
+ if all_data:
548
+ df = pd.concat(all_data, ignore_index=True)
549
+ if "unique_id" in df.columns:
550
+ df = df.rename(columns={"unique_id": "variate_name"})
551
+ return df
552
+ else:
553
+ return pd.DataFrame()
554
+
555
+
556
+ def load_all_features(features_root: str = "output/features", split: str = "test") -> pd.DataFrame:
557
+ """
558
+ Load time series features for all datasets from output/features directory.
559
+
560
+ Expected structure: features_root/{dataset}/{freq}/{split}.csv
561
+ Each CSV should have columns: dataset_id, series_name, variate_name, ...features...
562
+
563
+ Args:
564
+ features_root (str): Path to features root directory (e.g., "output/features").
565
+ split (str): Which split to load ("full" or "test").
566
+
567
+ Returns:
568
+ pd.DataFrame: Concatenated DataFrame with all variate features.
569
+ Columns: ["dataset_id", "series_name", "variate_name", "unique_id",
570
+ "is_random_walk", "has_spike_presence", "trend_strength", ...]
571
+ """
572
+ all_data = []
573
+
574
+ if not os.path.exists(features_root):
575
+ print(f"[load_all_features] features_root={features_root} does not exist")
576
+ return pd.DataFrame()
577
+
578
+ for dataset in os.listdir(features_root):
579
+ dataset_dir = os.path.join(features_root, dataset)
580
+ if not os.path.isdir(dataset_dir):
581
+ continue
582
+
583
+ for freq in os.listdir(dataset_dir):
584
+ freq_dir = os.path.join(dataset_dir, freq)
585
+ if not os.path.isdir(freq_dir):
586
+ continue
587
+
588
+ csv_path = os.path.join(freq_dir, f"{split}.csv")
589
+ if not os.path.exists(csv_path):
590
+ # Fallback: try full.csv if test.csv doesn't exist
591
+ csv_path = os.path.join(freq_dir, "full.csv")
592
+ if os.path.exists(csv_path):
593
+ try:
594
+ df = pd.read_csv(csv_path)
595
+ all_data.append(df)
596
+ except Exception as e:
597
+ print(f"[load_all_features] Error loading {csv_path}: {e}")
598
+
599
+ if all_data:
600
+ features_df = pd.concat(all_data, ignore_index=True)
601
+ print(f"[load_all_features] Loaded {len(features_df)} variate features from {len(all_data)} datasets")
602
+ return features_df
603
+ else:
604
+ print(f"[load_all_features] No features found in {features_root}")
605
+ return pd.DataFrame()
606
+
607
+
608
+
609
+ def binarize_features(df: pd.DataFrame, exclude: list) -> pd.DataFrame:
610
+ """
611
+ Binarize features in df based on their median values.
612
+ Columns in exclude will be skipped.
613
+
614
+ Args:
615
+ df (pd.DataFrame): Input dataframe with feature values.
616
+ exclude (list): Columns to exclude from binarization.
617
+
618
+ Returns:
619
+ pd.DataFrame: Model_A dataframe where selected feature columns are binarized (0/1).
620
+ """
621
+ # Select target feature columns
622
+ feature_cols = [col for col in df.columns if col not in exclude]
623
+
624
+ # Copy to avoid modifying original
625
+ df_binarized = df.copy()
626
+
627
+ # Compute medians
628
+ medians = df[feature_cols].median()
629
+
630
+ # Apply binarization
631
+ for col in feature_cols:
632
+ threshold = medians[col]
633
+ df_binarized[col] = (df[col] > threshold).astype(int)
634
+
635
+ return df_binarized