Spaces:
Running
Running
Commit ·
0b97f6a
0
Parent(s):
Initial release
Browse files- .gitattributes +35 -0
- .gitignore +18 -0
- .pre-commit-config.yaml +53 -0
- Dockerfile +37 -0
- Makefile +13 -0
- README.md +99 -0
- app.py +69 -0
- pyproject.toml +13 -0
- requirements.txt +27 -0
- src/about.py +119 -0
- src/display.egg-info/PKG-INFO +3 -0
- src/display.egg-info/SOURCES.txt +14 -0
- src/display.egg-info/dependency_links.txt +1 -0
- src/display.egg-info/top_level.txt +6 -0
- src/display/css_html_js.py +169 -0
- src/display/formatting.py +39 -0
- src/display/utils.py +169 -0
- src/hf_config.py +241 -0
- src/leaderboard.py +1085 -0
- src/tab.py +1370 -0
- src/utils.py +635 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
scale-hf-logo.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
auto_evals/
|
| 2 |
+
venv/
|
| 3 |
+
__pycache__/
|
| 4 |
+
.env
|
| 5 |
+
.ipynb_checkpoints
|
| 6 |
+
*ipynb
|
| 7 |
+
.idea
|
| 8 |
+
.vscode/
|
| 9 |
+
|
| 10 |
+
eval-queue/
|
| 11 |
+
eval-results/
|
| 12 |
+
eval-queue-bk/
|
| 13 |
+
eval-results-bk/
|
| 14 |
+
logs/
|
| 15 |
+
utils.py
|
| 16 |
+
css_html_js.py
|
| 17 |
+
formatting.py
|
| 18 |
+
run_local.sh
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
default_language_version:
|
| 16 |
+
python: python3
|
| 17 |
+
|
| 18 |
+
ci:
|
| 19 |
+
autofix_prs: true
|
| 20 |
+
autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
|
| 21 |
+
autoupdate_schedule: quarterly
|
| 22 |
+
|
| 23 |
+
repos:
|
| 24 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 25 |
+
rev: v4.3.0
|
| 26 |
+
hooks:
|
| 27 |
+
- id: check-yaml
|
| 28 |
+
- id: check-case-conflict
|
| 29 |
+
- id: detect-private-key
|
| 30 |
+
- id: check-added-large-files
|
| 31 |
+
args: ['--maxkb=1000']
|
| 32 |
+
- id: requirements-txt-fixer
|
| 33 |
+
- id: end-of-file-fixer
|
| 34 |
+
- id: trailing-whitespace
|
| 35 |
+
|
| 36 |
+
- repo: https://github.com/PyCQA/isort
|
| 37 |
+
rev: 5.12.0
|
| 38 |
+
hooks:
|
| 39 |
+
- id: isort
|
| 40 |
+
name: Format imports
|
| 41 |
+
|
| 42 |
+
- repo: https://github.com/psf/black
|
| 43 |
+
rev: 22.12.0
|
| 44 |
+
hooks:
|
| 45 |
+
- id: black
|
| 46 |
+
name: Format code
|
| 47 |
+
additional_dependencies: ['click==8.0.2']
|
| 48 |
+
|
| 49 |
+
- repo: https://github.com/charliermarsh/ruff-pre-commit
|
| 50 |
+
# Ruff version.
|
| 51 |
+
rev: 'v0.0.267'
|
| 52 |
+
hooks:
|
| 53 |
+
- id: ruff
|
Dockerfile
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install git for pip install from GitHub
|
| 6 |
+
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
|
| 7 |
+
|
| 8 |
+
# Copy requirements first for better caching
|
| 9 |
+
COPY requirements.txt .
|
| 10 |
+
|
| 11 |
+
# Install Python dependencies
|
| 12 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 13 |
+
pip install --no-cache-dir -r requirements.txt
|
| 14 |
+
|
| 15 |
+
# Copy application code
|
| 16 |
+
COPY . .
|
| 17 |
+
|
| 18 |
+
# Create startup script that installs timebench at runtime (when secrets are available)
|
| 19 |
+
RUN echo '#!/bin/bash\n\
|
| 20 |
+
if [ -n "$GITHUB_TOKEN" ]; then\n\
|
| 21 |
+
echo "Installing timebench from private GitHub repo..."\n\
|
| 22 |
+
pip install --no-cache-dir git+https://oauth2:${GITHUB_TOKEN}@github.com/zqiao11/TIME.git\n\
|
| 23 |
+
else\n\
|
| 24 |
+
echo "Installing timebench from public GitHub repo..."\n\
|
| 25 |
+
pip install --no-cache-dir git+https://github.com/zqiao11/TIME.git\n\
|
| 26 |
+
fi\n\
|
| 27 |
+
exec python app.py\n' > /app/start.sh && chmod +x /app/start.sh
|
| 28 |
+
|
| 29 |
+
# Expose Gradio default port
|
| 30 |
+
EXPOSE 7860
|
| 31 |
+
|
| 32 |
+
# Set environment variables
|
| 33 |
+
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
| 34 |
+
ENV GRADIO_SERVER_PORT="7860"
|
| 35 |
+
|
| 36 |
+
# Run the startup script
|
| 37 |
+
CMD ["/app/start.sh"]
|
Makefile
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: style format
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
style:
|
| 5 |
+
python -m black --line-length 119 .
|
| 6 |
+
python -m isort .
|
| 7 |
+
ruff check --fix .
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
quality:
|
| 11 |
+
python -m black --check --line-length 119 .
|
| 12 |
+
python -m isort --check-only .
|
| 13 |
+
ruff check .
|
README.md
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: TIME Benchmark Leaderboard
|
| 3 |
+
emoji: 🥇
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: true
|
| 8 |
+
license: apache-2.0
|
| 9 |
+
short_description: 'TIME: A Benchmark for Time Series Forecasting'
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# TIME Benchmark Leaderboard
|
| 13 |
+
|
| 14 |
+
A unified benchmark for time series probabilistic forecasting with multiple granularity evaluation.
|
| 15 |
+
|
| 16 |
+
## Features
|
| 17 |
+
|
| 18 |
+
- **Overall Performance**: Aggregated metrics across all datasets and horizons
|
| 19 |
+
- **Dataset-level Analysis**: Performance breakdown by individual datasets
|
| 20 |
+
- **Window-level Visualization**: Detailed test window analysis with prediction visualization
|
| 21 |
+
|
| 22 |
+
## Configuration
|
| 23 |
+
|
| 24 |
+
### Environment Variables
|
| 25 |
+
|
| 26 |
+
The app reads data from HuggingFace Hub. Configure the following environment variables:
|
| 27 |
+
|
| 28 |
+
| Variable | Description | Default |
|
| 29 |
+
|----------|-------------|---------|
|
| 30 |
+
| `HF_TOKEN` | HuggingFace API token (required for private datasets) | None |
|
| 31 |
+
| `HF_REPO_ID` | Dataset repository ID | `TIME-benchmark/TIME-1.0` |
|
| 32 |
+
| `USE_HF_HUB` | Use HF Hub (`true`) or local files (`false`) | `true` |
|
| 33 |
+
| `HF_CACHE_DIR` | Custom cache directory for downloads | `~/.cache/huggingface` |
|
| 34 |
+
|
| 35 |
+
### For HuggingFace Space Deployment
|
| 36 |
+
|
| 37 |
+
#### 快速部署(推荐)
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
# 1. 复制 timebench 模块到 leaderboard_app
|
| 41 |
+
cd /home/eee/qzz/TIME
|
| 42 |
+
cp -r src/timebench leaderboard_app/
|
| 43 |
+
|
| 44 |
+
# 2. 进入 leaderboard_app 目录
|
| 45 |
+
cd leaderboard_app
|
| 46 |
+
|
| 47 |
+
# 3. 运行部署脚本
|
| 48 |
+
chmod +x deploy.sh
|
| 49 |
+
./deploy.sh YOUR_USERNAME YOUR_SPACE_NAME
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
#### 手动部署
|
| 53 |
+
|
| 54 |
+
详细步骤请参考 [DEPLOY.md](DEPLOY.md)
|
| 55 |
+
|
| 56 |
+
**重要**: 部署前需要:
|
| 57 |
+
1. 创建 HuggingFace Space: https://huggingface.co/new-space
|
| 58 |
+
2. 在 Space Settings → Repository secrets 中添加 `HF_TOKEN`
|
| 59 |
+
3. 确保数据已上传到 `TIME-benchmark/TIME-1.0` Dataset
|
| 60 |
+
|
| 61 |
+
### For Local Development
|
| 62 |
+
|
| 63 |
+
Set `USE_HF_HUB=false` to use local data:
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
export USE_HF_HUB=false
|
| 67 |
+
python app.py
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## Installation
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
pip install -r requirements.txt
|
| 74 |
+
python app.py
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## Data Structure
|
| 78 |
+
|
| 79 |
+
The app expects the following data structure in the HuggingFace Dataset:
|
| 80 |
+
|
| 81 |
+
```
|
| 82 |
+
HF_REPO/
|
| 83 |
+
├── data/
|
| 84 |
+
│ └── hf_dataset/ # Time series datasets
|
| 85 |
+
│ ├── ECDC_COVID/
|
| 86 |
+
│ ├── Australia_Solar/
|
| 87 |
+
│ └── ...
|
| 88 |
+
├── output/
|
| 89 |
+
│ └── results/ # Model evaluation results
|
| 90 |
+
│ ├── moirai_small/
|
| 91 |
+
│ ├── chronos_base/
|
| 92 |
+
│ └── ...
|
| 93 |
+
└── config/
|
| 94 |
+
└── datasets.yaml # Dataset configurations
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
## Ethical Considerations
|
| 98 |
+
|
| 99 |
+
This release is for research purposes only in support of an academic paper. Our models, datasets, and code are not specifically designed or evaluated for all downstream purposes. We strongly recommend users evaluate and address potential concerns related to accuracy, safety, and fairness before deploying this model.
|
app.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
# Add project root and src directory to Python path to enable imports from timebench
|
| 5 |
+
# Get the directory containing this file (leaderboard_app/)
|
| 6 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 7 |
+
|
| 8 |
+
# Try multiple paths for timebench import:
|
| 9 |
+
# 1. Current directory (if timebench was copied to leaderboard_app/)
|
| 10 |
+
# 2. Parent directory's src (for local development: TIME/src/)
|
| 11 |
+
# 3. Parent's parent's src (if running from leaderboard_app/)
|
| 12 |
+
|
| 13 |
+
# Add current directory first (for Space deployment)
|
| 14 |
+
if current_dir not in sys.path:
|
| 15 |
+
sys.path.insert(0, current_dir)
|
| 16 |
+
|
| 17 |
+
# Add parent directory's src (for local development)
|
| 18 |
+
project_root = os.path.dirname(current_dir)
|
| 19 |
+
if project_root not in sys.path:
|
| 20 |
+
sys.path.insert(0, project_root)
|
| 21 |
+
|
| 22 |
+
src_dir = os.path.join(project_root, "src")
|
| 23 |
+
if src_dir not in sys.path and os.path.exists(src_dir):
|
| 24 |
+
sys.path.insert(0, src_dir)
|
| 25 |
+
|
| 26 |
+
import gradio as gr
|
| 27 |
+
from src.display.css_html_js import custom_css
|
| 28 |
+
from src.about import TITLE, INTRODUCTION_TEXT
|
| 29 |
+
from src.tab import init_overall_tab, init_per_window_tab, init_per_dataset_tab, init_per_pattern_tab
|
| 30 |
+
|
| 31 |
+
# Custom head content for responsive design
|
| 32 |
+
custom_head = """
|
| 33 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=5.0, user-scalable=yes">
|
| 34 |
+
<style>
|
| 35 |
+
/* 响应式设计:让页面自动适配不同屏幕尺寸 */
|
| 36 |
+
html {
|
| 37 |
+
width: 100%;
|
| 38 |
+
max-width: 100%;
|
| 39 |
+
}
|
| 40 |
+
body {
|
| 41 |
+
width: 100%;
|
| 42 |
+
max-width: 100%;
|
| 43 |
+
}
|
| 44 |
+
</style>
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
with gr.Blocks(css=custom_css, head=custom_head) as demo:
|
| 48 |
+
gr.HTML(TITLE)
|
| 49 |
+
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
|
| 50 |
+
|
| 51 |
+
with gr.Tabs(elem_id="custom-tabs") as tabs:
|
| 52 |
+
with gr.Tab("🏅 Overall Performance", id=0):
|
| 53 |
+
init_overall_tab()
|
| 54 |
+
|
| 55 |
+
with gr.Tab("🏅 Per Dataset", id=1):
|
| 56 |
+
init_per_dataset_tab(demo)
|
| 57 |
+
|
| 58 |
+
with gr.Tab("🏅 Per Test Window", id=3):
|
| 59 |
+
init_per_window_tab(demo)
|
| 60 |
+
|
| 61 |
+
with gr.Tab("🏅 Per Pattern", id=4):
|
| 62 |
+
init_per_pattern_tab(demo)
|
| 63 |
+
|
| 64 |
+
# with gr.Tab("📂 Archive", id=5):
|
| 65 |
+
# init_archive_tab(demo)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
demo.launch()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.ruff]
|
| 2 |
+
# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
|
| 3 |
+
select = ["E", "F"]
|
| 4 |
+
ignore = ["E501"] # line too long (black is taking care of this)
|
| 5 |
+
line-length = 119
|
| 6 |
+
fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
|
| 7 |
+
|
| 8 |
+
[tool.isort]
|
| 9 |
+
profile = "black"
|
| 10 |
+
line_length = 119
|
| 11 |
+
|
| 12 |
+
[tool.black]
|
| 13 |
+
line-length = 119
|
requirements.txt
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# TIME Leaderboard Dependencies
|
| 3 |
+
# =============================================================================
|
| 4 |
+
|
| 5 |
+
# Install timebench from public GitHub repo (always fetch latest from main)
|
| 6 |
+
timebench @ git+https://github.com/zqiao11/TIME.git@main
|
| 7 |
+
|
| 8 |
+
# Core dependencies - pinned to match local working environment
|
| 9 |
+
gradio==5.50.0
|
| 10 |
+
gradio_leaderboard==0.0.14
|
| 11 |
+
gradio_client==1.14.0
|
| 12 |
+
huggingface-hub==0.36.0
|
| 13 |
+
datasets==2.17.1
|
| 14 |
+
APScheduler
|
| 15 |
+
matplotlib
|
| 16 |
+
numpy==1.26.4
|
| 17 |
+
plotly==6.5.0
|
| 18 |
+
pandas==2.3.3
|
| 19 |
+
python-dateutil
|
| 20 |
+
python-dotenv
|
| 21 |
+
tqdm
|
| 22 |
+
pyarrow
|
| 23 |
+
pyyaml
|
| 24 |
+
scipy==1.11.4
|
| 25 |
+
|
| 26 |
+
# Note: gluonts and other timebench dependencies are automatically installed
|
| 27 |
+
# via the timebench package
|
src/about.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
+
# Import HuggingFace Hub configuration
|
| 7 |
+
from src.hf_config import get_results_root, get_config_root, get_features_root, initialize_data
|
| 8 |
+
|
| 9 |
+
from src.utils import (
|
| 10 |
+
get_all_datasets_results, get_all_domains_and_freq, get_all_variates_results,
|
| 11 |
+
get_dataset_choices, get_dataset_display_map, compute_ranks,
|
| 12 |
+
load_features, load_all_features, binarize_features
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# =============================================================================
|
| 17 |
+
# Initialize data from HuggingFace Hub (or local for development)
|
| 18 |
+
# =============================================================================
|
| 19 |
+
print("🚀 Starting TIME Leaderboard initialization...")
|
| 20 |
+
|
| 21 |
+
# Download/cache results and config from HuggingFace Hub
|
| 22 |
+
RESULTS_ROOT, CONFIG_ROOT = initialize_data()
|
| 23 |
+
|
| 24 |
+
# Get features root (local or HF)
|
| 25 |
+
FEATURES_ROOT = get_features_root()
|
| 26 |
+
|
| 27 |
+
# Get list of all models from results directory
|
| 28 |
+
ALL_MODELS = []
|
| 29 |
+
if RESULTS_ROOT.exists():
|
| 30 |
+
ALL_MODELS = [p.name for p in RESULTS_ROOT.iterdir() if p.is_dir()]
|
| 31 |
+
print(f"📊 Found {len(ALL_MODELS)} models: {ALL_MODELS}")
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------
|
| 34 |
+
# Get dataset choices from TIME results (with smart display names)
|
| 35 |
+
DATASET_CHOICES, DATASET_DISPLAY_TO_ID, DATASET_ID_TO_DISPLAY = get_dataset_choices(str(RESULTS_ROOT))
|
| 36 |
+
print(f"📁 Found {len(DATASET_CHOICES)} dataset configurations")
|
| 37 |
+
|
| 38 |
+
# === Load data once at startup ===
|
| 39 |
+
DATASETS_DF = get_all_datasets_results(root_dir=str(RESULTS_ROOT))
|
| 40 |
+
if not DATASETS_DF.empty:
|
| 41 |
+
# Use dataset_id (dataset/freq) for ranking to correctly handle multi-freq datasets
|
| 42 |
+
DATASETS_DF = compute_ranks(DATASETS_DF, groupby_cols=['dataset_id', "horizon"]) # Rows: 每一行是1个独立的实验 num_model x num_dataset_id x num_horizons
|
| 43 |
+
print(f"✅ Loaded {len(DATASETS_DF)} dataset results")
|
| 44 |
+
|
| 45 |
+
# === Load variate-level results for pattern-based leaderboard ===
|
| 46 |
+
print("📊 Loading variate-level results...")
|
| 47 |
+
VARIATES_DF = get_all_variates_results(root_dir=str(RESULTS_ROOT))
|
| 48 |
+
if not VARIATES_DF.empty:
|
| 49 |
+
# Compute ranks per (dataset_id, series_name, variate_name, horizon)
|
| 50 |
+
VARIATES_DF = compute_ranks(VARIATES_DF, groupby_cols=['dataset_id', 'series_name', 'variate_name', 'horizon'])
|
| 51 |
+
print(f"✅ Loaded {len(VARIATES_DF)} variate-level results")
|
| 52 |
+
else:
|
| 53 |
+
print("⚠️ No variate-level results found")
|
| 54 |
+
|
| 55 |
+
# === Load features for pattern-based filtering ===
|
| 56 |
+
print("📊 Loading features...")
|
| 57 |
+
FEATURES_DF = load_all_features(features_root=str(FEATURES_ROOT), split="test")
|
| 58 |
+
if not FEATURES_DF.empty:
|
| 59 |
+
print(f"✅ Loaded {len(FEATURES_DF)} variate features")
|
| 60 |
+
else:
|
| 61 |
+
print("⚠️ No features found")
|
| 62 |
+
|
| 63 |
+
# Columns to exclude from binarization
|
| 64 |
+
BINARIZE_EXCLUDE = [
|
| 65 |
+
'dataset_id', 'series_name', 'variate_name', 'unique_id',
|
| 66 |
+
'mean', 'std', 'length',
|
| 67 |
+
'period1', 'period2', 'period3',
|
| 68 |
+
'p_strength1', 'p_strength2', 'p_strength3',
|
| 69 |
+
'missing_rate',
|
| 70 |
+
# Meta features are already 0/1, handle separately
|
| 71 |
+
'is_random_walk', 'has_spike_presence',
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
# Binarize numeric features by median
|
| 75 |
+
FEATURES_BOOL_DF = pd.DataFrame()
|
| 76 |
+
if not FEATURES_DF.empty:
|
| 77 |
+
FEATURES_BOOL_DF = binarize_features(FEATURES_DF, exclude=BINARIZE_EXCLUDE)
|
| 78 |
+
print(f"✅ Binarized features for {len(FEATURES_BOOL_DF)} variates")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if not DATASETS_DF.empty:
|
| 82 |
+
OVERALL_TABLE_COLUMNS = ["model", "MASE", "CRPS", "MASE_rank", "CRPS_rank"]
|
| 83 |
+
else:
|
| 84 |
+
OVERALL_TABLE_COLUMNS = ["model", "MASE", "CRPS"]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
ALL_HORIZONS = ['short', 'medium', 'long']
|
| 88 |
+
|
| 89 |
+
# Pattern mapping: UI pattern name -> feature column name
|
| 90 |
+
PATTERN_MAP = {
|
| 91 |
+
# Trend patterns
|
| 92 |
+
"T_strength": "trend_strength",
|
| 93 |
+
"T_linearity": "linearity",
|
| 94 |
+
"T_curvature": "curvature",
|
| 95 |
+
# Seasonal patterns
|
| 96 |
+
"S_strength": "seasonal_strength",
|
| 97 |
+
"S_complexity": "seasonal_entropy",
|
| 98 |
+
"S_corr": "seasonal_corr",
|
| 99 |
+
# Residual patterns
|
| 100 |
+
"R_diff1_ACF1": "e_diff1_acf1",
|
| 101 |
+
"R_ACF1": "e_acf1",
|
| 102 |
+
# Meta patterns
|
| 103 |
+
"stationarity": "is_random_walk", # Note: stationarity = NOT is_random_walk
|
| 104 |
+
"outlier_presence": "has_spike_presence",
|
| 105 |
+
"complexity": "x_entropy", # High entropy = low predictability/high noise
|
| 106 |
+
}
|
| 107 |
+
# ---------------------------------------------------
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# Your leaderboard name
|
| 111 |
+
TITLE = """<h1 align="center" id="space-title"> It's TIME</h1>"""
|
| 112 |
+
|
| 113 |
+
# What does your leaderboard evaluate?
|
| 114 |
+
INTRODUCTION_TEXT = """
|
| 115 |
+
TIME introduces a unified benchmark for time series probabilistic forecasting that supports evaluation at **multiple granularities**, ranging from overall performance across datasets to dataset-level, variate-level, and even individual test windows (with visualization). Beyond conventional analysis, the benchmark enables **pattern-driven, cross-dataset benchmarking** by grouping variates with similar temporal features, where patterns are defined based on groups of tsfeatures that capture properties such as trend, seasonality, and stationarity, offering a more systematic understanding of model behavior. For data and results, please refer to 🤗 [dataset](https://huggingface.co/datasets/TIME-benchmark/TIME-1.0/tree/main).
|
| 116 |
+
"""
|
| 117 |
+
# An integrated archive further enriches the platform by providing structural tsfeatures and statistical descriptors of all variates,
|
| 118 |
+
# ensuring both comprehensive evaluation and transparent interpretability across diverse forecasting scenarios
|
| 119 |
+
print("✅ TIME Leaderboard initialization complete!")
|
src/display.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: display
|
| 3 |
+
Version: 0.0.0
|
src/display.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
src/about.py
|
| 4 |
+
src/hf_config.py
|
| 5 |
+
src/leaderboard.py
|
| 6 |
+
src/tab.py
|
| 7 |
+
src/utils.py
|
| 8 |
+
src/display/css_html_js.py
|
| 9 |
+
src/display/formatting.py
|
| 10 |
+
src/display/utils.py
|
| 11 |
+
src/display.egg-info/PKG-INFO
|
| 12 |
+
src/display.egg-info/SOURCES.txt
|
| 13 |
+
src/display.egg-info/dependency_links.txt
|
| 14 |
+
src/display.egg-info/top_level.txt
|
src/display.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
src/display.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
about
|
| 2 |
+
display
|
| 3 |
+
hf_config
|
| 4 |
+
leaderboard
|
| 5 |
+
tab
|
| 6 |
+
utils
|
src/display/css_html_js.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
custom_css = """
|
| 2 |
+
|
| 3 |
+
/* ========== 响应式布局 ========== */
|
| 4 |
+
/* 移除固定宽度,让Gradio自动适配不同屏幕尺寸 */
|
| 5 |
+
.gradio-container {
|
| 6 |
+
width: 100% !important;
|
| 7 |
+
max-width: 100% !important;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
/* 主体内容区域自适应 */
|
| 11 |
+
.main, .contain {
|
| 12 |
+
width: 100% !important;
|
| 13 |
+
max-width: 100% !important;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
/* Tab 内容区域自适应 */
|
| 17 |
+
.tabitem {
|
| 18 |
+
width: 100% !important;
|
| 19 |
+
max-width: 100% !important;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
/* Plot 组件自适应,但保持最小可读宽度 */
|
| 23 |
+
.js-plotly-plot, .plotly {
|
| 24 |
+
width: 100% !important;
|
| 25 |
+
max-width: 100% !important;
|
| 26 |
+
min-width: 300px !important; /* 保持最小可读宽度 */
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
/* ========== 原有样式 ========== */
|
| 30 |
+
.markdown-text {
|
| 31 |
+
font-size: 20px !important;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
/* 只影响 Tabs 按钮 */
|
| 35 |
+
#custom-tabs [role="tab"] {
|
| 36 |
+
font-size: 20px;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
/* ✅ 只影响表格 */
|
| 40 |
+
.custom-table table thead th {
|
| 41 |
+
font-size: 16px;
|
| 42 |
+
font-weight: 600; /* 想要普通就改成 400 */
|
| 43 |
+
text-align: center;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
.custom-table table tbody td {
|
| 47 |
+
font-size: 14px;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
/* 响应式表格布局 */
|
| 51 |
+
.custom-table table {
|
| 52 |
+
table-layout: auto; /* 使用自动布局,让表格自适应 */
|
| 53 |
+
width: 100%; /* 占满容器 */
|
| 54 |
+
min-width: 100%; /* 确保至少占满容器 */
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
/* 表格容器允许横向滚动(当内容过宽时) */
|
| 58 |
+
.custom-table {
|
| 59 |
+
overflow-x: auto; /* 当表格内容过宽时,允许横向滚动 */
|
| 60 |
+
width: 100%;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
/* 为不同列设置合适的宽度(使用相对单位,更灵活) */
|
| 64 |
+
.custom-table table th:nth-child(1),
|
| 65 |
+
.custom-table table td:nth-child(1) {
|
| 66 |
+
min-width: 150px; /* model 列最小宽度 */
|
| 67 |
+
max-width: 250px; /* 最大宽度限制 */
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
/* 指标列(MASE, CRPS, MAE, MSE) */
|
| 71 |
+
.custom-table table th:nth-child(2),
|
| 72 |
+
.custom-table table td:nth-child(2),
|
| 73 |
+
.custom-table table th:nth-child(3),
|
| 74 |
+
.custom-table table td:nth-child(3),
|
| 75 |
+
.custom-table table th:nth-child(4),
|
| 76 |
+
.custom-table table td:nth-child(4),
|
| 77 |
+
.custom-table table th:nth-child(5),
|
| 78 |
+
.custom-table table td:nth-child(5) {
|
| 79 |
+
min-width: 80px; /* 原始指标列最小宽度 */
|
| 80 |
+
max-width: 120px;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
/* 归一化指标列(MASE_norm, CRPS_norm, MAE_norm, MSE_norm) */
|
| 84 |
+
.custom-table table th:nth-child(6),
|
| 85 |
+
.custom-table table td:nth-child(6),
|
| 86 |
+
.custom-table table th:nth-child(7),
|
| 87 |
+
.custom-table table td:nth-child(7),
|
| 88 |
+
.custom-table table th:nth-child(8),
|
| 89 |
+
.custom-table table td:nth-child(8),
|
| 90 |
+
.custom-table table th:nth-child(9),
|
| 91 |
+
.custom-table table td:nth-child(9) {
|
| 92 |
+
min-width: 100px; /* 归一化指标列最小宽度 */
|
| 93 |
+
max-width: 150px;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
/* 排名列(MASE_rank, CRPS_rank) */
|
| 97 |
+
.custom-table table th:nth-child(10),
|
| 98 |
+
.custom-table table td:nth-child(10),
|
| 99 |
+
.custom-table table th:nth-child(11),
|
| 100 |
+
.custom-table table td:nth-child(11) {
|
| 101 |
+
min-width: 80px; /* 排名列最小宽度 */
|
| 102 |
+
max-width: 120px;
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
#archive-table table thead th { font-size: 14px; font-weight: 400}
|
| 107 |
+
#archive-table table {
|
| 108 |
+
table-layout: fixed; /* 强制固定布局 */
|
| 109 |
+
width: 100%; /* 占满容器 */
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
#archive-table table th:nth-child(1),
|
| 113 |
+
#archive-table table td:nth-child(1) {
|
| 114 |
+
width: 160px !important; /* dataset */
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
#archive-table table th:nth-child(2),
|
| 118 |
+
#archive-table table td:nth-child(2) {
|
| 119 |
+
width: 100px !important; /* variate_name */
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
#archive-table table th:nth-child(3),
|
| 123 |
+
#archive-table table td:nth-child(3) {
|
| 124 |
+
width: 60px !important; /* freq */
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
#archive-table table th:nth-child(4),
|
| 128 |
+
#archive-table table td:nth-child(4) {
|
| 129 |
+
width: 100px !important; /* domain */
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
/* 后面的特征列 */
|
| 133 |
+
#archive-table table th:nth-child(n+5),
|
| 134 |
+
#archive-table table td:nth-child(n+5) {
|
| 135 |
+
width: 120px !important;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
#citation-button span {
|
| 142 |
+
font-size: 14px !important;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
#citation-button textarea {
|
| 146 |
+
font-size: 16px !important;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
#citation-button > label > button {
|
| 150 |
+
margin: 6px;
|
| 151 |
+
transform: scale(1.3);
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
#search-bar-table-box > div:first-child {
|
| 155 |
+
background: none;
|
| 156 |
+
border: none;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
#search-bar {
|
| 160 |
+
padding: 0px;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
# ToDO: markdown-text不好使...
|
| 166 |
+
# archive-table table thead th { font-size: 14px; font-weight: 400}
|
| 167 |
+
|
| 168 |
+
# /* 让表格遵守列宽、并能横向滚动 */
|
| 169 |
+
# #
|
src/display/formatting.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def model_hyperlink(model_link, code_link, model_name):
|
| 2 |
+
if model_link == "":
|
| 3 |
+
return model_name
|
| 4 |
+
# return f'<a target="_blank">{model_name}</a>'
|
| 5 |
+
# return f'<a target="_blank" href="{link}" rel="noopener noreferrer">{model_name}</a>'
|
| 6 |
+
else:
|
| 7 |
+
model_url = f'<a target="_blank" href="{model_link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
| 8 |
+
if code_link == "":
|
| 9 |
+
return model_url
|
| 10 |
+
else:
|
| 11 |
+
code_url = f'<a target="_blank" href="{code_link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">code</a>'
|
| 12 |
+
return f"{model_url} ({code_url})"
|
| 13 |
+
# return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a> | ' \
|
| 14 |
+
# f'<a target="_blank" href="https://www.google.com" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}_link2</a>'
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def make_clickable_model(model_name):
|
| 18 |
+
link = f"https://huggingface.co/{model_name}"
|
| 19 |
+
return model_hyperlink(link, model_name)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def styled_error(error):
|
| 23 |
+
return f"<p style='color: red; font-size: 20px; text-align: center;'>{error}</p>"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def styled_warning(warn):
|
| 27 |
+
return f"<p style='color: orange; font-size: 20px; text-align: center;'>{warn}</p>"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def styled_message(message):
|
| 31 |
+
return f"<p style='color: green; font-size: 20px; text-align: center;'>{message}</p>"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def has_no_nan_values(df, columns):
|
| 35 |
+
return df[columns].notna().all(axis=1)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def has_nan_values(df, columns):
|
| 39 |
+
return df[columns].isna().any(axis=1)
|
src/display/utils.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, make_dataclass
|
| 2 |
+
from enum import Enum
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
+
from src.about import Tasks
|
| 7 |
+
|
| 8 |
+
def fields(raw_class):
|
| 9 |
+
return [v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# These classes are for user facing column names,
|
| 13 |
+
# to avoid having to change them all around the code
|
| 14 |
+
# when a modif is needed
|
| 15 |
+
@dataclass
|
| 16 |
+
class ColumnContent:
|
| 17 |
+
name: str
|
| 18 |
+
type: str
|
| 19 |
+
displayed_by_default: bool
|
| 20 |
+
hidden: bool = False
|
| 21 |
+
never_hidden: bool = False
|
| 22 |
+
|
| 23 |
+
## Leaderboard columns
|
| 24 |
+
archive_info_dict = []
|
| 25 |
+
|
| 26 |
+
archive_info_dict.append(["dataset", ColumnContent, ColumnContent("dataset", "markdown", True, never_hidden=True)])
|
| 27 |
+
archive_info_dict.append(["unique_id", ColumnContent, ColumnContent("unique_id", "str", True, never_hidden=True)])
|
| 28 |
+
archive_info_dict.append(["freq", ColumnContent, ColumnContent("freq", "str", True, never_hidden=True)])
|
| 29 |
+
archive_info_dict.append(["domain", ColumnContent, ColumnContent("domain", "str", True, never_hidden=True)])
|
| 30 |
+
# Raw features
|
| 31 |
+
archive_info_dict.append(["x_acf1", ColumnContent, ColumnContent("x_acf1", "number", False, False)])
|
| 32 |
+
archive_info_dict.append(["x_acf10", ColumnContent, ColumnContent("x_acf10", "number", False, False)])
|
| 33 |
+
archive_info_dict.append(["lumpiness", ColumnContent, ColumnContent("lumpiness", "number", False, False)])
|
| 34 |
+
archive_info_dict.append(["stability", ColumnContent, ColumnContent("stability", "number", False, False)])
|
| 35 |
+
archive_info_dict.append(["hurst", ColumnContent, ColumnContent("hurst", "number", False, False)])
|
| 36 |
+
archive_info_dict.append(["entropy", ColumnContent, ColumnContent("entropy", "number", False, False)])
|
| 37 |
+
# Trend features
|
| 38 |
+
archive_info_dict.append(["trend", ColumnContent, ColumnContent("trend_strength", "number", False, False)])
|
| 39 |
+
archive_info_dict.append(["trend_crossing_point_ratio", ColumnContent, ColumnContent("trend_xpoint_ratio", "number", False, False)])
|
| 40 |
+
archive_info_dict.append(["trend_stability", ColumnContent, ColumnContent("trend_stability", "number", False, False)])
|
| 41 |
+
archive_info_dict.append(["trend_lumpiness", ColumnContent, ColumnContent("trend_lumpiness", "number", False, False)])
|
| 42 |
+
archive_info_dict.append(["trend_hurst", ColumnContent, ColumnContent("trend_hurst", "number", False, False)])
|
| 43 |
+
archive_info_dict.append(["trend_entropy", ColumnContent, ColumnContent("trend_entropy", "number", False, False)])
|
| 44 |
+
# Seasonal features
|
| 45 |
+
archive_info_dict.append(["e_acf1", ColumnContent, ColumnContent("e_acf1", "number", False, False)])
|
| 46 |
+
archive_info_dict.append(["e_acf10", ColumnContent, ColumnContent("e_acf10", "number", False, False)])
|
| 47 |
+
archive_info_dict.append(["e_entropy", ColumnContent, ColumnContent("e_entropy", "number", False, False)])
|
| 48 |
+
archive_info_dict.append(["e_hurst", ColumnContent, ColumnContent("e_hurst", "number", False, False)])
|
| 49 |
+
archive_info_dict.append(["e_lumpiness", ColumnContent, ColumnContent("e_lumpiness", "number", False, False)])
|
| 50 |
+
archive_info_dict.append(["e_outlier_ratio", ColumnContent, ColumnContent("e_outlier_ratio", "number", False, False)])
|
| 51 |
+
# Remainder features
|
| 52 |
+
archive_info_dict.append(["seasonal_strength", ColumnContent, ColumnContent("seasonal_strength", "number", False, False)])
|
| 53 |
+
archive_info_dict.append(["seasonality_corr", ColumnContent, ColumnContent("seasonality_corr", "number", False, False)])
|
| 54 |
+
archive_info_dict.append(["seasonal_stability", ColumnContent, ColumnContent("seasonal_stability", "number", False, False)])
|
| 55 |
+
archive_info_dict.append(["seasonal_lumpiness", ColumnContent, ColumnContent("seasonal_lumpiness", "number", False, False)])
|
| 56 |
+
archive_info_dict.append(["seasonal_hurst", ColumnContent, ColumnContent("seasonal_hurst", "number", False, False)])
|
| 57 |
+
archive_info_dict.append(["seasonal_entropy", ColumnContent, ColumnContent("seasonal_entropy", "number", False, False)])
|
| 58 |
+
# Statistics
|
| 59 |
+
archive_info_dict.append(["mean", ColumnContent, ColumnContent("mean", "number", False, False)])
|
| 60 |
+
archive_info_dict.append(["std", ColumnContent, ColumnContent("std", "number", False, False)])
|
| 61 |
+
archive_info_dict.append(["missing_rate", ColumnContent, ColumnContent("missing_rate", "number", False, False)])
|
| 62 |
+
archive_info_dict.append(["length", ColumnContent, ColumnContent("length", "number", False, False)])
|
| 63 |
+
archive_info_dict.append(["period1", ColumnContent, ColumnContent("period1", "number", False, False)])
|
| 64 |
+
archive_info_dict.append(["period2", ColumnContent, ColumnContent("period2", "number", False, False)])
|
| 65 |
+
archive_info_dict.append(["period3", ColumnContent, ColumnContent("period3", "number", False, False)])
|
| 66 |
+
archive_info_dict.append(["p_strength1", ColumnContent, ColumnContent("p_strength1", "number", False, False)])
|
| 67 |
+
archive_info_dict.append(["p_strength2", ColumnContent, ColumnContent("p_strength2", "number", False, False)])
|
| 68 |
+
archive_info_dict.append(["p_strength3", ColumnContent, ColumnContent("p_strength3", "number", False, False)])
|
| 69 |
+
|
| 70 |
+
ArchiveInfoColumn = make_dataclass("ArchiveInfoColumn", archive_info_dict, frozen=True)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
model_info_dict = []
|
| 81 |
+
# Init column for the model properties
|
| 82 |
+
model_info_dict.append(["model_type_symbol", ColumnContent, ColumnContent("T", "str", True, never_hidden=True)])
|
| 83 |
+
model_info_dict.append(["model", ColumnContent, ColumnContent("model", "markdown", True, never_hidden=True)])
|
| 84 |
+
# Model information
|
| 85 |
+
model_info_dict.append(["model_type", ColumnContent, ColumnContent("Type", "str", False, True)])
|
| 86 |
+
model_info_dict.append(["precision", ColumnContent, ColumnContent("Precision", "str", False, True)])
|
| 87 |
+
model_info_dict.append(["license", ColumnContent, ColumnContent("Hub License", "str", False, True)])
|
| 88 |
+
model_info_dict.append(["params", ColumnContent, ColumnContent("#Params (B)", "number", False, True)])
|
| 89 |
+
model_info_dict.append(["likes", ColumnContent, ColumnContent("Hub ❤️", "number", False, True)])
|
| 90 |
+
model_info_dict.append(["still_on_hub", ColumnContent, ColumnContent("Available on the hub", "bool", False)])
|
| 91 |
+
model_info_dict.append(["org", ColumnContent, ColumnContent("Organization", "str", True, hidden=False)])
|
| 92 |
+
model_info_dict.append(["testdata_leakage", ColumnContent, ColumnContent("TestData Leakage", "str", True, hidden=False)])
|
| 93 |
+
|
| 94 |
+
# We use make dataclass to dynamically fill the scores from Tasks
|
| 95 |
+
ModelInfoColumn = make_dataclass("ModelInfoColumn", model_info_dict, frozen=True)
|
| 96 |
+
|
| 97 |
+
## For the queue columns in the submission tab
|
| 98 |
+
@dataclass(frozen=True)
|
| 99 |
+
class EvalQueueColumn: # Queue column
|
| 100 |
+
model = ColumnContent("model", "markdown", True)
|
| 101 |
+
revision = ColumnContent("revision", "str", True)
|
| 102 |
+
private = ColumnContent("private", "bool", True)
|
| 103 |
+
precision = ColumnContent("precision", "str", True)
|
| 104 |
+
weight_type = ColumnContent("weight_type", "str", "Original")
|
| 105 |
+
status = ColumnContent("status", "str", True)
|
| 106 |
+
|
| 107 |
+
## All the model information that we might need
|
| 108 |
+
@dataclass
|
| 109 |
+
class ModelDetails:
|
| 110 |
+
name: str
|
| 111 |
+
display_name: str = ""
|
| 112 |
+
symbol: str = "" # emoji
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ModelType(Enum):
|
| 116 |
+
PT = ModelDetails(name="🟢 pretrained", symbol="🟢")
|
| 117 |
+
ZT = ModelDetails(name="🔴 zero-shot", symbol="🔴")
|
| 118 |
+
FT = ModelDetails(name="🟣 fine-tuned", symbol="🟣")
|
| 119 |
+
AG = ModelDetails(name="🟡 agentic", symbol="🟡")
|
| 120 |
+
DL = ModelDetails(name="🔷 deep-learning", symbol="🔷")
|
| 121 |
+
ST = ModelDetails(name="🔶 statistical", symbol="🔶")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
Unknown = ModelDetails(name="", symbol="?")
|
| 125 |
+
|
| 126 |
+
def to_str(self, separator=" "):
|
| 127 |
+
return f"{self.value.symbol}{separator}{self.value.name}"
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def from_str(type):
|
| 131 |
+
if "fine-tuned" in type or "🔶" in type:
|
| 132 |
+
return ModelType.FT
|
| 133 |
+
if "pretrained" in type or "🟢" in type:
|
| 134 |
+
return ModelType.PT
|
| 135 |
+
if "zero-shot" in type or "🔴" in type:
|
| 136 |
+
return ModelType.ZT
|
| 137 |
+
if "agentic" in type or "🟡" in type:
|
| 138 |
+
return ModelType.AG
|
| 139 |
+
if "deep-learning" in type or "🟦" in type:
|
| 140 |
+
return ModelType.DL
|
| 141 |
+
if "statistical" in type or "🟣" in type:
|
| 142 |
+
return ModelType.ST
|
| 143 |
+
return ModelType.Unknown
|
| 144 |
+
|
| 145 |
+
class WeightType(Enum):
|
| 146 |
+
Adapter = ModelDetails("Adapter")
|
| 147 |
+
Original = ModelDetails("Original")
|
| 148 |
+
Delta = ModelDetails("Delta")
|
| 149 |
+
|
| 150 |
+
class Precision(Enum):
|
| 151 |
+
float16 = ModelDetails("float16")
|
| 152 |
+
bfloat16 = ModelDetails("bfloat16")
|
| 153 |
+
Unknown = ModelDetails("?")
|
| 154 |
+
|
| 155 |
+
def from_str(precision):
|
| 156 |
+
if precision in ["torch.float16", "float16"]:
|
| 157 |
+
return Precision.float16
|
| 158 |
+
if precision in ["torch.bfloat16", "bfloat16"]:
|
| 159 |
+
return Precision.bfloat16
|
| 160 |
+
return Precision.Unknown
|
| 161 |
+
|
| 162 |
+
# Column selection
|
| 163 |
+
MODEL_INFO_COLS = [c.name for c in fields(ModelInfoColumn) if not c.hidden]
|
| 164 |
+
|
| 165 |
+
EVAL_COLS = [c.name for c in fields(EvalQueueColumn)]
|
| 166 |
+
EVAL_TYPES = [c.type for c in fields(EvalQueueColumn)]
|
| 167 |
+
|
| 168 |
+
BENCHMARK_COLS = [t.value.col_name for t in Tasks]
|
| 169 |
+
|
src/hf_config.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HuggingFace Hub Configuration and Helper Functions
|
| 3 |
+
|
| 4 |
+
This module provides configuration and utilities for loading data from HuggingFace Hub.
|
| 5 |
+
The data is cached locally after first download, so subsequent accesses are fast.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
from functools import lru_cache
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from huggingface_hub import snapshot_download
|
| 13 |
+
|
| 14 |
+
# =============================================================================
|
| 15 |
+
# Configuration
|
| 16 |
+
# =============================================================================
|
| 17 |
+
|
| 18 |
+
# HuggingFace Dataset repository ID
|
| 19 |
+
HF_REPO_ID = os.environ.get("HF_REPO_ID", "TIME-benchmark/TIME-1.0")
|
| 20 |
+
|
| 21 |
+
# HuggingFace token (set via environment variable for security)
|
| 22 |
+
# In HuggingFace Space, set this in Settings -> Repository secrets
|
| 23 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 24 |
+
|
| 25 |
+
# Whether to use HuggingFace Hub (True) or local files (False)
|
| 26 |
+
# Set to False for local development with local data
|
| 27 |
+
USE_HF_HUB = os.environ.get("USE_HF_HUB", "true").lower() == "true"
|
| 28 |
+
|
| 29 |
+
# Local cache directory for HF Hub downloads
|
| 30 |
+
HF_CACHE_DIR = os.environ.get("HF_CACHE_DIR", None) # None uses default ~/.cache/huggingface
|
| 31 |
+
|
| 32 |
+
# Local data paths (used when USE_HF_HUB=false)
|
| 33 |
+
# Set these environment variables to specify custom local paths
|
| 34 |
+
LOCAL_RESULTS_PATH = os.environ.get("LOCAL_RESULTS_PATH", None) # Path to output/results
|
| 35 |
+
LOCAL_FEATURES_PATH = os.environ.get("LOCAL_FEATURES_PATH", None) # Path to output/features
|
| 36 |
+
LOCAL_CONFIG_PATH = os.environ.get("LOCAL_CONFIG_PATH", None) # Path to config directory
|
| 37 |
+
LOCAL_DATASETS_PATH = os.environ.get("LOCAL_DATASETS_PATH", None) # Path to data/hf_dataset
|
| 38 |
+
|
| 39 |
+
# =============================================================================
|
| 40 |
+
# Helper Functions
|
| 41 |
+
# =============================================================================
|
| 42 |
+
|
| 43 |
+
@lru_cache(maxsize=1)
|
| 44 |
+
def download_results_snapshot() -> Path:
|
| 45 |
+
"""
|
| 46 |
+
Download the results directory from HuggingFace Hub.
|
| 47 |
+
Uses caching - only downloads once, then returns cached path.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Path: Local path to the downloaded results directory
|
| 51 |
+
"""
|
| 52 |
+
if not USE_HF_HUB:
|
| 53 |
+
# Return local path for development
|
| 54 |
+
# Priority: 1) LOCAL_RESULTS_PATH env var, 2) ../output/results, 3) /home/eee/qzz/TIME/output/results
|
| 55 |
+
if LOCAL_RESULTS_PATH:
|
| 56 |
+
local_path = Path(LOCAL_RESULTS_PATH)
|
| 57 |
+
else:
|
| 58 |
+
local_path = Path("../output/results")
|
| 59 |
+
if not local_path.exists():
|
| 60 |
+
local_path = Path("/home/eee/qzz/TIME/output/results")
|
| 61 |
+
if not local_path.exists():
|
| 62 |
+
print(f"⚠️ Warning: Local results path does not exist: {local_path}")
|
| 63 |
+
return local_path
|
| 64 |
+
|
| 65 |
+
print(f"📥 Downloading results from HuggingFace Hub: {HF_REPO_ID}")
|
| 66 |
+
|
| 67 |
+
local_dir = snapshot_download(
|
| 68 |
+
repo_id=HF_REPO_ID,
|
| 69 |
+
repo_type="dataset",
|
| 70 |
+
token=HF_TOKEN,
|
| 71 |
+
allow_patterns=["output/results/**"],
|
| 72 |
+
cache_dir=HF_CACHE_DIR,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
results_path = Path(local_dir) / "output" / "results"
|
| 76 |
+
print(f"✅ Results cached at: {results_path}")
|
| 77 |
+
return results_path
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@lru_cache(maxsize=1)
|
| 81 |
+
def download_datasets_snapshot() -> Path:
|
| 82 |
+
"""
|
| 83 |
+
Download the hf_dataset directory from HuggingFace Hub.
|
| 84 |
+
Uses caching - only downloads once, then returns cached path.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Path: Local path to the downloaded hf_dataset directory
|
| 88 |
+
"""
|
| 89 |
+
if not USE_HF_HUB:
|
| 90 |
+
# Return local path for development
|
| 91 |
+
# Priority: 1) LOCAL_DATASETS_PATH env var, 2) ../data/hf_dataset, 3) /home/eee/qzz/TIME/data/hf_dataset
|
| 92 |
+
if LOCAL_DATASETS_PATH:
|
| 93 |
+
local_path = Path(LOCAL_DATASETS_PATH)
|
| 94 |
+
else:
|
| 95 |
+
local_path = Path("../data/hf_dataset")
|
| 96 |
+
if not local_path.exists():
|
| 97 |
+
local_path = Path("/home/eee/qzz/TIME/data/hf_dataset")
|
| 98 |
+
if not local_path.exists():
|
| 99 |
+
print(f"⚠️ Warning: Local datasets path does not exist: {local_path}")
|
| 100 |
+
return local_path
|
| 101 |
+
|
| 102 |
+
print(f"📥 Downloading datasets from HuggingFace Hub: {HF_REPO_ID}")
|
| 103 |
+
|
| 104 |
+
local_dir = snapshot_download(
|
| 105 |
+
repo_id=HF_REPO_ID,
|
| 106 |
+
repo_type="dataset",
|
| 107 |
+
token=HF_TOKEN,
|
| 108 |
+
allow_patterns=["data/hf_dataset/**"],
|
| 109 |
+
cache_dir=HF_CACHE_DIR,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
datasets_path = Path(local_dir) / "data" / "hf_dataset"
|
| 113 |
+
print(f"✅ Datasets cached at: {datasets_path}")
|
| 114 |
+
return datasets_path
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def download_config_snapshot() -> Path:
|
| 118 |
+
"""
|
| 119 |
+
Get the config directory from the installed timebench package.
|
| 120 |
+
|
| 121 |
+
The config (datasets.yaml) is bundled with the timebench package,
|
| 122 |
+
so no download is needed - we just use the installed package's config.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
Path: Local path to the config directory
|
| 126 |
+
"""
|
| 127 |
+
# Try to get config from installed timebench package
|
| 128 |
+
try:
|
| 129 |
+
from timebench.evaluation.data import DEFAULT_CONFIG_PATH
|
| 130 |
+
config_path = DEFAULT_CONFIG_PATH.parent # Get the config directory
|
| 131 |
+
if config_path.exists():
|
| 132 |
+
# print(f"📁 Using config from timebench package: {config_path}")
|
| 133 |
+
return config_path
|
| 134 |
+
except ImportError:
|
| 135 |
+
print(f"❌ ImportError: {ImportError}, using local config")
|
| 136 |
+
pass
|
| 137 |
+
|
| 138 |
+
# Fallback: Local development path
|
| 139 |
+
# Priority: 1) LOCAL_CONFIG_PATH env var, 2) ../config, 3) /home/eee/qzz/TIME/config
|
| 140 |
+
if LOCAL_CONFIG_PATH:
|
| 141 |
+
local_path = Path(LOCAL_CONFIG_PATH)
|
| 142 |
+
else:
|
| 143 |
+
local_path = Path("../config")
|
| 144 |
+
if not local_path.exists():
|
| 145 |
+
local_path = Path("/home/eee/qzz/TIME/config")
|
| 146 |
+
|
| 147 |
+
if local_path.exists():
|
| 148 |
+
print(f"📁 Using local config: {local_path}")
|
| 149 |
+
return local_path
|
| 150 |
+
|
| 151 |
+
raise FileNotFoundError(
|
| 152 |
+
"Config directory not found. Please ensure timebench is installed, "
|
| 153 |
+
"set USE_HF_HUB=false for local development, "
|
| 154 |
+
"or set LOCAL_CONFIG_PATH environment variable to point to your config directory."
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_results_root() -> Path:
|
| 159 |
+
"""Get the root path for results (handles both HF Hub and local)."""
|
| 160 |
+
return download_results_snapshot()
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_datasets_root() -> Path:
|
| 164 |
+
"""Get the root path for hf_dataset (handles both HF Hub and local)."""
|
| 165 |
+
return download_datasets_snapshot()
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def get_config_root() -> Path:
|
| 169 |
+
"""Get the root path for config (handles both HF Hub and local)."""
|
| 170 |
+
return download_config_snapshot()
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def get_features_root() -> Path:
|
| 174 |
+
"""
|
| 175 |
+
Get the root path for features (handles both HF Hub and local).
|
| 176 |
+
|
| 177 |
+
Features are stored at output/features/{dataset}/{freq}/test.csv
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Path: Local path to the features directory
|
| 181 |
+
"""
|
| 182 |
+
if not USE_HF_HUB:
|
| 183 |
+
# Return local path for development
|
| 184 |
+
# Priority: 1) LOCAL_FEATURES_PATH env var, 2) ../output/features, 3) /home/eee/qzz/TIME/output/features
|
| 185 |
+
if LOCAL_FEATURES_PATH:
|
| 186 |
+
local_path = Path(LOCAL_FEATURES_PATH)
|
| 187 |
+
else:
|
| 188 |
+
local_path = Path("../output/features")
|
| 189 |
+
if not local_path.exists():
|
| 190 |
+
local_path = Path("/home/eee/qzz/TIME/output/features")
|
| 191 |
+
if not local_path.exists():
|
| 192 |
+
print(f"⚠️ Warning: Local features path does not exist: {local_path}")
|
| 193 |
+
return local_path
|
| 194 |
+
|
| 195 |
+
# For HF Hub, features are in the same repo as results
|
| 196 |
+
print(f"📥 Downloading features from HuggingFace Hub: {HF_REPO_ID}")
|
| 197 |
+
|
| 198 |
+
local_dir = snapshot_download(
|
| 199 |
+
repo_id=HF_REPO_ID,
|
| 200 |
+
repo_type="dataset",
|
| 201 |
+
token=HF_TOKEN,
|
| 202 |
+
allow_patterns=["output/features/**"],
|
| 203 |
+
cache_dir=HF_CACHE_DIR,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
features_path = Path(local_dir) / "output" / "features"
|
| 207 |
+
print(f"✅ Features cached at: {features_path}")
|
| 208 |
+
return features_path
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def clear_cache():
|
| 212 |
+
"""Clear the LRU cache to force re-download on next access."""
|
| 213 |
+
download_results_snapshot.cache_clear()
|
| 214 |
+
download_datasets_snapshot.cache_clear()
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# =============================================================================
|
| 218 |
+
# Initialization - Download data at module import
|
| 219 |
+
# =============================================================================
|
| 220 |
+
|
| 221 |
+
def initialize_data():
|
| 222 |
+
"""
|
| 223 |
+
Initialize by downloading all necessary data.
|
| 224 |
+
Call this at app startup to pre-download data.
|
| 225 |
+
"""
|
| 226 |
+
print("🚀 Initializing TIME Leaderboard data...")
|
| 227 |
+
|
| 228 |
+
# Download results (required for leaderboard)
|
| 229 |
+
results_root = get_results_root()
|
| 230 |
+
print(f" Results: {results_root}")
|
| 231 |
+
|
| 232 |
+
# Download config (required for dataset settings)
|
| 233 |
+
config_root = get_config_root()
|
| 234 |
+
print(f" Config: {config_root}")
|
| 235 |
+
|
| 236 |
+
# Note: Datasets are downloaded on-demand when visualization is needed
|
| 237 |
+
# to reduce initial load time
|
| 238 |
+
|
| 239 |
+
print("✅ Initialization complete!")
|
| 240 |
+
return results_root, config_root
|
| 241 |
+
|
src/leaderboard.py
ADDED
|
@@ -0,0 +1,1085 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import json
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from typing import List, Tuple, Optional
|
| 8 |
+
from scipy import stats
|
| 9 |
+
from src.about import DATASETS_DF, OVERALL_TABLE_COLUMNS, ALL_MODELS, RESULTS_ROOT, DATASET_DISPLAY_TO_ID
|
| 10 |
+
from src.hf_config import get_datasets_root, get_config_root
|
| 11 |
+
from src.utils import normalize_by_seasonal_naive
|
| 12 |
+
# FEATURES_DF, FEATURES_BOOL_DF, VARIATES_DF VARIATE_COLUMNS
|
| 13 |
+
import ast
|
| 14 |
+
from timebench.evaluation.data import Dataset, get_dataset_settings, load_dataset_config
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
def resolve_dataset_id(display_name: str) -> str:
|
| 18 |
+
"""
|
| 19 |
+
Convert a display name to dataset_id.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
display_name: Either a display_name from UI or a dataset_id directly
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
dataset_id in format "dataset/freq"
|
| 26 |
+
"""
|
| 27 |
+
# If it's in the mapping, use the mapping
|
| 28 |
+
if display_name in DATASET_DISPLAY_TO_ID:
|
| 29 |
+
return DATASET_DISPLAY_TO_ID[display_name]
|
| 30 |
+
# Otherwise assume it's already a dataset_id
|
| 31 |
+
return display_name
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def find_dataset_term_path(results_root, model_name, display_name):
|
| 35 |
+
"""
|
| 36 |
+
Find the dataset_term path that matches the display_name or dataset_id.
|
| 37 |
+
Returns the path string (e.g., "Traffic/15T") which is dataset/freq, or None.
|
| 38 |
+
Path structure: results/{model_name}/{dataset}/{freq}/{horizon}/
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
results_root: Root directory for results
|
| 42 |
+
model_name: Model name
|
| 43 |
+
display_name: Display name from UI (could be "Traffic" or "Traffic/15T")
|
| 44 |
+
or dataset_id directly
|
| 45 |
+
"""
|
| 46 |
+
model_dir = os.path.join(results_root, model_name)
|
| 47 |
+
if not os.path.exists(model_dir):
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
# Resolve display_name to dataset_id
|
| 51 |
+
dataset_id = resolve_dataset_id(display_name)
|
| 52 |
+
|
| 53 |
+
# Check if dataset_id is in format "dataset/freq"
|
| 54 |
+
if "/" in dataset_id:
|
| 55 |
+
# Direct lookup: dataset_id is already dataset/freq
|
| 56 |
+
dataset_name, freq = dataset_id.split("/", 1)
|
| 57 |
+
freq_path = os.path.join(model_dir, dataset_name, freq)
|
| 58 |
+
if os.path.isdir(freq_path):
|
| 59 |
+
# Verify it has horizon directories
|
| 60 |
+
for horizon in ["short", "medium", "long"]:
|
| 61 |
+
config_path = os.path.join(freq_path, horizon, "config.json")
|
| 62 |
+
if os.path.exists(config_path):
|
| 63 |
+
return dataset_id
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
# Legacy fallback: dataset_name only (find first freq)
|
| 67 |
+
dataset_name = dataset_id
|
| 68 |
+
for dataset_dir_name in os.listdir(model_dir):
|
| 69 |
+
dataset_dir = os.path.join(model_dir, dataset_dir_name)
|
| 70 |
+
if not os.path.isdir(dataset_dir):
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
if dataset_dir_name == dataset_name:
|
| 74 |
+
# Check freq subdirectories
|
| 75 |
+
for freq_dir in os.listdir(dataset_dir):
|
| 76 |
+
freq_path = os.path.join(dataset_dir, freq_dir)
|
| 77 |
+
if not os.path.isdir(freq_path):
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
for horizon in ["short", "medium", "long"]:
|
| 81 |
+
config_path = os.path.join(freq_path, horizon, "config.json")
|
| 82 |
+
if os.path.exists(config_path):
|
| 83 |
+
return f"{dataset_dir_name}/{freq_dir}"
|
| 84 |
+
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
def load_test_windows(display_name, horizon, model_name="moirai_small", series=None, variate=None, window_id=None, parse_series=False):
|
| 88 |
+
"""
|
| 89 |
+
Load test window results from TIME NPZ files.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
display_name: Dataset display name from UI (will be converted to dataset_id)
|
| 93 |
+
horizon: Horizon name (short, medium, long)
|
| 94 |
+
model_name: Model name
|
| 95 |
+
series: Optional series name (string) or index (int/string) to filter
|
| 96 |
+
variate: Optional variate name (string) or index (int/string) to filter
|
| 97 |
+
window_id: Optional window_id to filter
|
| 98 |
+
parse_series: If True, include label and quantile predictions as lists
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
pd.DataFrame with columns: series_name, variate_name, window_id, MASE, CRPS, MAE, MSE,
|
| 102 |
+
and optionally label, quantile[...] if parse_series=True
|
| 103 |
+
"""
|
| 104 |
+
results_root = str(RESULTS_ROOT)
|
| 105 |
+
|
| 106 |
+
# Find the dataset_term directory (handles display_name -> dataset_id conversion)
|
| 107 |
+
dataset_term = find_dataset_term_path(results_root, model_name, display_name)
|
| 108 |
+
if dataset_term is None:
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
horizon_dir = os.path.join(results_root, model_name, dataset_term, horizon)
|
| 112 |
+
metrics_path = os.path.join(horizon_dir, "metrics.npz")
|
| 113 |
+
|
| 114 |
+
# Load data
|
| 115 |
+
metrics = np.load(metrics_path)
|
| 116 |
+
|
| 117 |
+
# Get array shapes
|
| 118 |
+
mase_arr = metrics["MASE"] # (num_series, num_windows, num_variates)
|
| 119 |
+
num_series, num_windows, num_variates = mase_arr.shape
|
| 120 |
+
|
| 121 |
+
# Load Dataset to get actual names
|
| 122 |
+
series_names = None
|
| 123 |
+
variate_names = None
|
| 124 |
+
try:
|
| 125 |
+
# Use HF config to get dataset root (handles both local and HF Hub)
|
| 126 |
+
hf_dataset_root = str(get_datasets_root())
|
| 127 |
+
|
| 128 |
+
if os.path.exists(hf_dataset_root):
|
| 129 |
+
# Use HF config to get config root
|
| 130 |
+
config_root = get_config_root()
|
| 131 |
+
config_path = config_root / "datasets.yaml"
|
| 132 |
+
|
| 133 |
+
config = load_dataset_config(config_path) if config_path.exists() else {}
|
| 134 |
+
settings = get_dataset_settings(dataset_term, horizon, config)
|
| 135 |
+
|
| 136 |
+
prediction_length = settings.get("prediction_length")
|
| 137 |
+
test_length = settings.get("test_length")
|
| 138 |
+
|
| 139 |
+
# Load dataset with storage_path parameter
|
| 140 |
+
dataset_obj = Dataset(
|
| 141 |
+
name=dataset_term,
|
| 142 |
+
term=horizon,
|
| 143 |
+
prediction_length=prediction_length,
|
| 144 |
+
test_length=test_length,
|
| 145 |
+
storage_path=hf_dataset_root,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Get series names (item_id)
|
| 149 |
+
if "item_id" in dataset_obj.hf_dataset.column_names:
|
| 150 |
+
series_names = dataset_obj.hf_dataset["item_id"]
|
| 151 |
+
else:
|
| 152 |
+
series_names = [dataset_obj.hf_dataset[i].get("item_id", f"item_{i}")
|
| 153 |
+
for i in range(len(dataset_obj.hf_dataset))]
|
| 154 |
+
|
| 155 |
+
# Get variate names
|
| 156 |
+
variate_names = dataset_obj.get_variate_names()
|
| 157 |
+
if variate_names is None:
|
| 158 |
+
# Univariate mode: variate names are same as series names
|
| 159 |
+
variate_names = series_names
|
| 160 |
+
except Exception as e:
|
| 161 |
+
print(f"Error loading Dataset for names: {e}")
|
| 162 |
+
|
| 163 |
+
# Create name to index mappings
|
| 164 |
+
series_name_to_idx = {}
|
| 165 |
+
variate_name_to_idx = {}
|
| 166 |
+
if series_names is not None:
|
| 167 |
+
series_name_to_idx = {name: idx for idx, name in enumerate(series_names)}
|
| 168 |
+
if variate_names is not None:
|
| 169 |
+
variate_name_to_idx = {name: idx for idx, name in enumerate(variate_names)}
|
| 170 |
+
|
| 171 |
+
# Convert series and variate names to indices if they are names
|
| 172 |
+
series_idx_filter = None
|
| 173 |
+
if series is not None:
|
| 174 |
+
if series in series_name_to_idx:
|
| 175 |
+
series_idx_filter = series_name_to_idx[series]
|
| 176 |
+
else:
|
| 177 |
+
# Try as index
|
| 178 |
+
try:
|
| 179 |
+
series_idx_filter = int(series)
|
| 180 |
+
except ValueError:
|
| 181 |
+
pass
|
| 182 |
+
|
| 183 |
+
variate_idx_filter = None
|
| 184 |
+
if variate is not None:
|
| 185 |
+
if variate in variate_name_to_idx:
|
| 186 |
+
variate_idx_filter = variate_name_to_idx[variate]
|
| 187 |
+
else:
|
| 188 |
+
# Try as index
|
| 189 |
+
try:
|
| 190 |
+
variate_idx_filter = int(variate)
|
| 191 |
+
except ValueError:
|
| 192 |
+
pass
|
| 193 |
+
|
| 194 |
+
# Build DataFrame row by row
|
| 195 |
+
rows = []
|
| 196 |
+
for series_idx in range(num_series):
|
| 197 |
+
# Filter by series if specified
|
| 198 |
+
if series_idx_filter is not None:
|
| 199 |
+
if series_idx != series_idx_filter:
|
| 200 |
+
continue
|
| 201 |
+
|
| 202 |
+
series_name = series_names[series_idx] if series_names is not None else str(series_idx)
|
| 203 |
+
|
| 204 |
+
for window_idx in range(num_windows):
|
| 205 |
+
# Filter by window_id if specified
|
| 206 |
+
if window_id is not None:
|
| 207 |
+
if window_idx != int(window_id):
|
| 208 |
+
continue
|
| 209 |
+
|
| 210 |
+
for variate_idx in range(num_variates):
|
| 211 |
+
# Filter by variate if specified
|
| 212 |
+
if variate_idx_filter is not None:
|
| 213 |
+
if variate_idx != variate_idx_filter:
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
variate_name = variate_names[variate_idx] if variate_names is not None else str(variate_idx)
|
| 217 |
+
|
| 218 |
+
row = {
|
| 219 |
+
"series_name": series_name,
|
| 220 |
+
"variate_name": variate_name,
|
| 221 |
+
"window_id": window_idx,
|
| 222 |
+
"MASE": float(mase_arr[series_idx, window_idx, variate_idx]),
|
| 223 |
+
"CRPS": float(metrics["CRPS"][series_idx, window_idx, variate_idx]),
|
| 224 |
+
"MAE": float(metrics["MAE"][series_idx, window_idx, variate_idx]),
|
| 225 |
+
"MSE": float(metrics["MSE"][series_idx, window_idx, variate_idx]),
|
| 226 |
+
"model": model_name,
|
| 227 |
+
"series_idx": series_idx, # Keep for reference
|
| 228 |
+
"variate_idx": variate_idx, # Keep for reference
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
# Add label and quantiles if requested
|
| 232 |
+
if parse_series:
|
| 233 |
+
# Get ground truth from predictions (we need to load it separately)
|
| 234 |
+
# For now, we'll need to compute it or load from a separate source
|
| 235 |
+
# TIME doesn't save ground_truth in predictions.npz, so we'll skip for now
|
| 236 |
+
# TODO: Need to handle ground truth loading
|
| 237 |
+
pass
|
| 238 |
+
|
| 239 |
+
rows.append(row)
|
| 240 |
+
|
| 241 |
+
if not rows:
|
| 242 |
+
return None
|
| 243 |
+
|
| 244 |
+
df = pd.DataFrame(rows)
|
| 245 |
+
|
| 246 |
+
# If parse_series, we would need ground truth data
|
| 247 |
+
# For now, return without series data
|
| 248 |
+
return df
|
| 249 |
+
|
| 250 |
+
def get_overall_leaderboard(df_datasets: pd.DataFrame, metric: str = "MASE") -> pd.DataFrame:
|
| 251 |
+
"""
|
| 252 |
+
Compute overall leaderboard across datasets by normalizing metrics by Seasonal Naive
|
| 253 |
+
and aggregating with geometric mean.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
df_datasets (pd.DataFrame): Dataset-level results, must include
|
| 257 |
+
["model", "dataset_id", "horizon", "MASE", "CRPS", "MASE_rank", "CRPS_rank"].
|
| 258 |
+
metric (str): Metric to use for sorting. Defaults to "MASE".
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
pd.DataFrame: Leaderboard with:
|
| 262 |
+
- MASE, CRPS: Geometric mean of Seasonal Naive-normalized values
|
| 263 |
+
- MASE_rank, CRPS_rank: Average rank across configurations (from original data)
|
| 264 |
+
Sorted by the chosen metric.
|
| 265 |
+
"""
|
| 266 |
+
if df_datasets.empty:
|
| 267 |
+
return pd.DataFrame(columns=OVERALL_TABLE_COLUMNS)
|
| 268 |
+
|
| 269 |
+
if metric not in df_datasets.columns:
|
| 270 |
+
return pd.DataFrame(columns=OVERALL_TABLE_COLUMNS)
|
| 271 |
+
|
| 272 |
+
# Step 1: Normalize MASE and CRPS by Seasonal Naive per (dataset_id, horizon)
|
| 273 |
+
df_normalized = normalize_by_seasonal_naive(
|
| 274 |
+
df_datasets,
|
| 275 |
+
baseline_model="seasonal_naive",
|
| 276 |
+
metrics=["MASE", "CRPS"],
|
| 277 |
+
groupby_cols=["dataset_id", "horizon"],
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
if df_normalized.empty:
|
| 281 |
+
# Fall back to original behavior if normalization fails
|
| 282 |
+
print("[get_overall_leaderboard] Warning: normalization failed, using arithmetic mean")
|
| 283 |
+
leaderboard = (
|
| 284 |
+
df_datasets.groupby(["model"])
|
| 285 |
+
.mean(numeric_only=True)
|
| 286 |
+
.reset_index()
|
| 287 |
+
)
|
| 288 |
+
# Rename columns: MASE -> MASE (norm.), CRPS -> CRPS (norm.)
|
| 289 |
+
if "MASE" in leaderboard.columns:
|
| 290 |
+
leaderboard = leaderboard.rename(columns={"MASE": "MASE (norm.)"})
|
| 291 |
+
if "CRPS" in leaderboard.columns:
|
| 292 |
+
leaderboard = leaderboard.rename(columns={"CRPS": "CRPS (norm.)"})
|
| 293 |
+
|
| 294 |
+
# Adjust metric name for sorting
|
| 295 |
+
sort_metric = metric
|
| 296 |
+
if metric == "MASE":
|
| 297 |
+
sort_metric = "MASE (norm.)"
|
| 298 |
+
elif metric == "CRPS":
|
| 299 |
+
sort_metric = "CRPS (norm.)"
|
| 300 |
+
|
| 301 |
+
if sort_metric in leaderboard.columns:
|
| 302 |
+
leaderboard = leaderboard.sort_values(by=sort_metric, ascending=True).reset_index(drop=True)
|
| 303 |
+
else:
|
| 304 |
+
leaderboard = leaderboard.sort_values(by=metric, ascending=True).reset_index(drop=True)
|
| 305 |
+
|
| 306 |
+
# Define column order
|
| 307 |
+
col_order = ["model", "MASE (norm.)", "CRPS (norm.)", "MASE_rank", "CRPS_rank"]
|
| 308 |
+
col_order = [col for col in col_order if col in leaderboard.columns]
|
| 309 |
+
leaderboard = leaderboard[col_order]
|
| 310 |
+
leaderboard = leaderboard.round(3)
|
| 311 |
+
return leaderboard
|
| 312 |
+
|
| 313 |
+
# Step 2: Aggregate normalized MASE and CRPS with geometric mean
|
| 314 |
+
# Filter out NaN values for geometric mean computation
|
| 315 |
+
def gmean_with_nan(x):
|
| 316 |
+
"""Compute geometric mean, ignoring NaN values."""
|
| 317 |
+
valid = x.dropna()
|
| 318 |
+
if len(valid) == 0:
|
| 319 |
+
return np.nan
|
| 320 |
+
return stats.gmean(valid)
|
| 321 |
+
|
| 322 |
+
normalized_metrics = (
|
| 323 |
+
df_normalized.groupby("model")[["MASE", "CRPS"]]
|
| 324 |
+
.agg(gmean_with_nan)
|
| 325 |
+
.reset_index()
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# Rename columns: MASE -> MASE (norm.), CRPS -> CRPS (norm.)
|
| 329 |
+
normalized_metrics = normalized_metrics.rename(columns={
|
| 330 |
+
"MASE": "MASE (norm.)",
|
| 331 |
+
"CRPS": "CRPS (norm.)"
|
| 332 |
+
})
|
| 333 |
+
|
| 334 |
+
# Step 3: Compute average ranks from original data (pre-normalized)
|
| 335 |
+
# Ranks should be computed on original metrics, which is already done in about.py
|
| 336 |
+
if "MASE_rank" in df_datasets.columns and "CRPS_rank" in df_datasets.columns:
|
| 337 |
+
# Use the same configurations that were used in normalization
|
| 338 |
+
# (only those with Seasonal Naive baseline)
|
| 339 |
+
df_with_baseline = df_datasets[
|
| 340 |
+
df_datasets.set_index(["dataset_id", "horizon"]).index.isin(
|
| 341 |
+
df_normalized.set_index(["dataset_id", "horizon"]).index.unique()
|
| 342 |
+
)
|
| 343 |
+
]
|
| 344 |
+
avg_ranks = (
|
| 345 |
+
df_with_baseline.groupby("model")[["MASE_rank", "CRPS_rank"]]
|
| 346 |
+
.mean()
|
| 347 |
+
.reset_index()
|
| 348 |
+
)
|
| 349 |
+
# Merge normalized metrics with average ranks
|
| 350 |
+
leaderboard = normalized_metrics.merge(avg_ranks, on="model", how="left")
|
| 351 |
+
else:
|
| 352 |
+
leaderboard = normalized_metrics
|
| 353 |
+
|
| 354 |
+
# Step 4: Sort by chosen metric (adjust metric name if needed)
|
| 355 |
+
sort_metric = metric
|
| 356 |
+
if metric == "MASE":
|
| 357 |
+
sort_metric = "MASE (norm.)"
|
| 358 |
+
elif metric == "CRPS":
|
| 359 |
+
sort_metric = "CRPS (norm.)"
|
| 360 |
+
|
| 361 |
+
if sort_metric in leaderboard.columns:
|
| 362 |
+
leaderboard = leaderboard.sort_values(by=sort_metric, ascending=True).reset_index(drop=True)
|
| 363 |
+
else:
|
| 364 |
+
# Fallback: sort by first available metric
|
| 365 |
+
leaderboard = leaderboard.sort_values(by=leaderboard.columns[1], ascending=True).reset_index(drop=True)
|
| 366 |
+
|
| 367 |
+
# Step 5: Select and order columns
|
| 368 |
+
col_order = ["model", "MASE (norm.)", "CRPS (norm.)", "MASE_rank", "CRPS_rank"]
|
| 369 |
+
col_order = [col for col in col_order if col in leaderboard.columns]
|
| 370 |
+
leaderboard = leaderboard[col_order]
|
| 371 |
+
leaderboard = leaderboard.round(3)
|
| 372 |
+
|
| 373 |
+
return leaderboard
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def get_dataset_leaderboard(
|
| 377 |
+
display_name: str,
|
| 378 |
+
horizons: List[str],
|
| 379 |
+
metric: str = "MASE"
|
| 380 |
+
) -> Tuple[str, pd.DataFrame]:
|
| 381 |
+
"""
|
| 382 |
+
Return leaderboard for a specific dataset, averaged over the specified horizons.
|
| 383 |
+
|
| 384 |
+
Returns both original metrics and Seasonal Naive-normalized metrics in a single table.
|
| 385 |
+
|
| 386 |
+
Args:
|
| 387 |
+
display_name (str): The dataset display name selected by the user (from UI dropdown).
|
| 388 |
+
Will be converted to dataset_id for filtering.
|
| 389 |
+
horizons (List[str]): List of horizons to include (e.g., ["short", "medium"]).
|
| 390 |
+
If None, all horizons are used.
|
| 391 |
+
metric (str): The metric used for sorting. Defaults to "MASE".
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
tuple:
|
| 395 |
+
str: A message string to display in the UI ("" if no error).
|
| 396 |
+
pd.DataFrame: Dataframe containing leaderboard with columns:
|
| 397 |
+
- model
|
| 398 |
+
- MASE, CRPS, MAE, MSE (original, arithmetic mean)
|
| 399 |
+
- MASE_norm, CRPS_norm, MAE_norm, MSE_norm (normalized, geometric mean)
|
| 400 |
+
- MASE_rank, CRPS_rank (average of per-task ranks)
|
| 401 |
+
"""
|
| 402 |
+
if DATASETS_DF.empty:
|
| 403 |
+
return "No dataset results are available. Please check your results folder.", pd.DataFrame(columns=["model"])
|
| 404 |
+
|
| 405 |
+
# Convert display_name to dataset_id for filtering
|
| 406 |
+
dataset_id = resolve_dataset_id(display_name)
|
| 407 |
+
|
| 408 |
+
# Filter by dataset_id
|
| 409 |
+
df_filtered = DATASETS_DF[DATASETS_DF["dataset_id"] == dataset_id].copy()
|
| 410 |
+
if df_filtered.empty:
|
| 411 |
+
return f"No results found for dataset '{display_name}'.", pd.DataFrame(columns=["model"])
|
| 412 |
+
|
| 413 |
+
# Filter by horizon
|
| 414 |
+
if horizons is None or len(horizons) == 0:
|
| 415 |
+
horizons = df_filtered["horizon"].unique().tolist()
|
| 416 |
+
df_filtered = df_filtered[df_filtered["horizon"].isin(horizons)]
|
| 417 |
+
if df_filtered.empty:
|
| 418 |
+
return f"No results found for dataset '{display_name}' with horizons {horizons}.", pd.DataFrame(columns=["model"])
|
| 419 |
+
|
| 420 |
+
# Get dataset information (series count, variate count, freq)
|
| 421 |
+
dataset_info_msg = ""
|
| 422 |
+
try:
|
| 423 |
+
# Parse dataset_id to get freq
|
| 424 |
+
if "/" in dataset_id:
|
| 425 |
+
_, freq = dataset_id.split("/", 1)
|
| 426 |
+
else:
|
| 427 |
+
freq = "unknown"
|
| 428 |
+
|
| 429 |
+
# Load Dataset to get series and variate counts
|
| 430 |
+
hf_dataset_root = str(get_datasets_root())
|
| 431 |
+
config_root = get_config_root()
|
| 432 |
+
config_path = config_root / "datasets.yaml"
|
| 433 |
+
|
| 434 |
+
if os.path.exists(hf_dataset_root) and config_path.exists():
|
| 435 |
+
config = load_dataset_config(config_path)
|
| 436 |
+
settings = get_dataset_settings(dataset_id, horizons[0] if horizons else "short", config)
|
| 437 |
+
|
| 438 |
+
dataset_obj = Dataset(
|
| 439 |
+
name=dataset_id,
|
| 440 |
+
term=horizons[0] if horizons else "short",
|
| 441 |
+
prediction_length=settings.get("prediction_length"),
|
| 442 |
+
test_length=settings.get("test_length"),
|
| 443 |
+
storage_path=hf_dataset_root,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Get series count
|
| 447 |
+
if "item_id" in dataset_obj.hf_dataset.column_names:
|
| 448 |
+
series_names = dataset_obj.hf_dataset["item_id"]
|
| 449 |
+
# Convert to list if it's an array/Series to avoid ambiguity in boolean check
|
| 450 |
+
if isinstance(series_names, (np.ndarray, pd.Series)):
|
| 451 |
+
series_names = list(series_names)
|
| 452 |
+
elif not isinstance(series_names, list):
|
| 453 |
+
series_names = list(series_names) if hasattr(series_names, '__iter__') else [series_names]
|
| 454 |
+
# Use len() check instead of boolean check to avoid ambiguity
|
| 455 |
+
if len(series_names) > 0:
|
| 456 |
+
num_series = len(set(series_names))
|
| 457 |
+
else:
|
| 458 |
+
num_series = len(dataset_obj.hf_dataset)
|
| 459 |
+
else:
|
| 460 |
+
num_series = len(dataset_obj.hf_dataset)
|
| 461 |
+
|
| 462 |
+
# Get variate count
|
| 463 |
+
variate_names = dataset_obj.get_variate_names()
|
| 464 |
+
if variate_names is not None:
|
| 465 |
+
num_variates = len(variate_names)
|
| 466 |
+
else:
|
| 467 |
+
# UTS: each series is one variate
|
| 468 |
+
num_variates = 1
|
| 469 |
+
|
| 470 |
+
dataset_info_msg = f"📊 Dataset Info: {num_series} series, {num_variates} variates, freq={freq}"
|
| 471 |
+
except Exception as e:
|
| 472 |
+
print(f"Error getting dataset info: {e}")
|
| 473 |
+
# If we can't get info, try to extract freq from dataset_id
|
| 474 |
+
if "/" in dataset_id:
|
| 475 |
+
_, freq = dataset_id.split("/", 1)
|
| 476 |
+
dataset_info_msg = f"📊 Dataset Info: freq={freq}"
|
| 477 |
+
|
| 478 |
+
metrics_list = ["MASE", "CRPS", "MAE", "MSE"]
|
| 479 |
+
|
| 480 |
+
# === Step 1: Compute original metrics (arithmetic mean) ===
|
| 481 |
+
original_df = (
|
| 482 |
+
df_filtered.groupby("model")[metrics_list]
|
| 483 |
+
.mean()
|
| 484 |
+
.reset_index()
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
# === Step 2: Compute normalized metrics (geometric mean of Seasonal Naive-normalized) ===
|
| 488 |
+
df_normalized = normalize_by_seasonal_naive(
|
| 489 |
+
df_filtered,
|
| 490 |
+
baseline_model="seasonal_naive",
|
| 491 |
+
metrics=metrics_list,
|
| 492 |
+
groupby_cols=["dataset_id", "horizon"],
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
# Helper function for geometric mean with NaN handling
|
| 496 |
+
def gmean_with_nan(x):
|
| 497 |
+
valid = x.dropna()
|
| 498 |
+
if len(valid) == 0:
|
| 499 |
+
return np.nan
|
| 500 |
+
return stats.gmean(valid)
|
| 501 |
+
|
| 502 |
+
if not df_normalized.empty:
|
| 503 |
+
normalized_df = (
|
| 504 |
+
df_normalized.groupby("model")[metrics_list]
|
| 505 |
+
.agg(gmean_with_nan)
|
| 506 |
+
.reset_index()
|
| 507 |
+
)
|
| 508 |
+
# Rename columns to * (norm.)
|
| 509 |
+
normalized_df = normalized_df.rename(columns={
|
| 510 |
+
"MASE": "MASE (norm.)",
|
| 511 |
+
"CRPS": "CRPS (norm.)",
|
| 512 |
+
"MAE": "MAE (norm.)",
|
| 513 |
+
"MSE": "MSE (norm.)",
|
| 514 |
+
})
|
| 515 |
+
else:
|
| 516 |
+
# If normalization fails, create empty normalized columns
|
| 517 |
+
normalized_df = original_df[["model"]].copy()
|
| 518 |
+
for col in ["MASE (norm.)", "CRPS (norm.)", "MAE (norm.)", "MSE (norm.)"]:
|
| 519 |
+
normalized_df[col] = np.nan
|
| 520 |
+
|
| 521 |
+
# Rename original columns to * (raw)
|
| 522 |
+
original_df = original_df.rename(columns={
|
| 523 |
+
"MASE": "MASE (raw)",
|
| 524 |
+
"CRPS": "CRPS (raw)",
|
| 525 |
+
"MAE": "MAE (raw)",
|
| 526 |
+
"MSE": "MSE (raw)",
|
| 527 |
+
})
|
| 528 |
+
|
| 529 |
+
# === Step 3: Compute average ranks from pre-computed per-task ranks ===
|
| 530 |
+
if "MASE_rank" in df_filtered.columns and "CRPS_rank" in df_filtered.columns:
|
| 531 |
+
# Use only configurations that have Seasonal Naive baseline (for consistency)
|
| 532 |
+
if not df_normalized.empty:
|
| 533 |
+
df_with_baseline = df_filtered[
|
| 534 |
+
df_filtered.set_index(["dataset_id", "horizon"]).index.isin(
|
| 535 |
+
df_normalized.set_index(["dataset_id", "horizon"]).index.unique()
|
| 536 |
+
)
|
| 537 |
+
]
|
| 538 |
+
else:
|
| 539 |
+
df_with_baseline = df_filtered
|
| 540 |
+
|
| 541 |
+
ranks_df = (
|
| 542 |
+
df_with_baseline.groupby("model")[["MASE_rank", "CRPS_rank"]]
|
| 543 |
+
.mean()
|
| 544 |
+
.reset_index()
|
| 545 |
+
)
|
| 546 |
+
else:
|
| 547 |
+
ranks_df = original_df[["model"]].copy()
|
| 548 |
+
ranks_df["MASE_rank"] = np.nan
|
| 549 |
+
ranks_df["CRPS_rank"] = np.nan
|
| 550 |
+
|
| 551 |
+
# === Step 4: Combine all into one DataFrame ===
|
| 552 |
+
agg_df = original_df.merge(normalized_df, on="model", how="left")
|
| 553 |
+
agg_df = agg_df.merge(ranks_df, on="model", how="left")
|
| 554 |
+
|
| 555 |
+
# Sort by MASE (norm.) as requested
|
| 556 |
+
if "MASE (norm.)" in agg_df.columns:
|
| 557 |
+
agg_df = agg_df.sort_values(by="MASE (norm.)", ascending=True).reset_index(drop=True)
|
| 558 |
+
elif "MASE (raw)" in agg_df.columns:
|
| 559 |
+
# Fallback to MASE (raw) if normalized version not available
|
| 560 |
+
agg_df = agg_df.sort_values(by="MASE (raw)", ascending=True).reset_index(drop=True)
|
| 561 |
+
else:
|
| 562 |
+
# Final fallback: sort by first available metric column
|
| 563 |
+
if len(agg_df.columns) > 1:
|
| 564 |
+
agg_df = agg_df.sort_values(by=agg_df.columns[1], ascending=True).reset_index(drop=True)
|
| 565 |
+
|
| 566 |
+
# Define column order: model, * (norm.), * (raw), *_rank
|
| 567 |
+
cols_order = ["model",
|
| 568 |
+
"MASE (norm.)", "CRPS (norm.)", "MAE (norm.)", "MSE (norm.)",
|
| 569 |
+
"MASE (raw)", "CRPS (raw)", "MAE (raw)", "MSE (raw)",
|
| 570 |
+
"MASE_rank", "CRPS_rank"]
|
| 571 |
+
cols_to_return = [col for col in cols_order if col in agg_df.columns]
|
| 572 |
+
|
| 573 |
+
agg_df = agg_df[cols_to_return].round(3)
|
| 574 |
+
|
| 575 |
+
return dataset_info_msg, agg_df
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def get_dataset_multilevel_leaderboard(display_name, series, variate, horizons, metric: str = "MASE"):
|
| 579 |
+
"""
|
| 580 |
+
Get leaderboard based on dataset, series, and variate selections.
|
| 581 |
+
|
| 582 |
+
Logic:
|
| 583 |
+
0. If only dataset selected (series="---", variate="---"): return dataset-level results
|
| 584 |
+
with both original and normalized metrics
|
| 585 |
+
1. If series/variate selected: return only original metrics (MASE, CRPS, MAE, MSE)
|
| 586 |
+
|
| 587 |
+
Args:
|
| 588 |
+
display_name: Dataset display name from UI (will be converted to dataset_id)
|
| 589 |
+
series: Series name or "---" if not selected
|
| 590 |
+
variate: Variate name or "---" if not selected
|
| 591 |
+
horizons: List of horizons to include
|
| 592 |
+
metric: Metric for sorting
|
| 593 |
+
|
| 594 |
+
Returns:
|
| 595 |
+
tuple: (message, DataFrame)
|
| 596 |
+
"""
|
| 597 |
+
# Case 0: Only dataset selected - return both original and normalized metrics
|
| 598 |
+
if (series is None or series == "---" or series == "") and (variate is None or variate == "---" or variate == ""):
|
| 599 |
+
return get_dataset_leaderboard(display_name, horizons, metric)
|
| 600 |
+
|
| 601 |
+
# Case 1: Series/Variate selected - return only original metrics
|
| 602 |
+
# Determine if dataset is UTS or MTS by checking if variate dropdown is enabled
|
| 603 |
+
results_root = str(RESULTS_ROOT)
|
| 604 |
+
model_name = ALL_MODELS[0]
|
| 605 |
+
|
| 606 |
+
dataset_term = find_dataset_term_path(results_root, model_name, display_name)
|
| 607 |
+
if dataset_term is None:
|
| 608 |
+
return f"Dataset '{display_name}' not found.", pd.DataFrame(columns=["model", "MASE", "CRPS", "MAE", "MSE"])
|
| 609 |
+
|
| 610 |
+
# Check if dataset is UTS or MTS
|
| 611 |
+
is_uts = False
|
| 612 |
+
try:
|
| 613 |
+
hf_dataset_root = str(get_datasets_root())
|
| 614 |
+
|
| 615 |
+
if os.path.exists(hf_dataset_root):
|
| 616 |
+
config_root = get_config_root()
|
| 617 |
+
config_path = config_root / "datasets.yaml"
|
| 618 |
+
|
| 619 |
+
config = load_dataset_config(config_path) if config_path.exists() else {}
|
| 620 |
+
settings = get_dataset_settings(dataset_term, horizons[0] if horizons else "short", config)
|
| 621 |
+
|
| 622 |
+
dataset_obj = Dataset(
|
| 623 |
+
name=dataset_term,
|
| 624 |
+
term=horizons[0] if horizons else "short",
|
| 625 |
+
prediction_length=settings.get("prediction_length"),
|
| 626 |
+
test_length=settings.get("test_length"),
|
| 627 |
+
storage_path=hf_dataset_root,
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
variate_names = dataset_obj.get_variate_names()
|
| 631 |
+
is_uts = (variate_names is None)
|
| 632 |
+
except Exception as e:
|
| 633 |
+
print(f"Error checking UTS/MTS: {e}")
|
| 634 |
+
|
| 635 |
+
# Collect data from all models and horizons
|
| 636 |
+
df_all = []
|
| 637 |
+
for model in ALL_MODELS:
|
| 638 |
+
for horizon in horizons:
|
| 639 |
+
series_filter = None if (series == "---" or series == "") else series
|
| 640 |
+
variate_filter = None if (variate == "---" or variate == "") else variate
|
| 641 |
+
|
| 642 |
+
if is_uts:
|
| 643 |
+
variate_filter = None
|
| 644 |
+
|
| 645 |
+
df_model = load_test_windows(
|
| 646 |
+
display_name, horizon, model,
|
| 647 |
+
series=series_filter,
|
| 648 |
+
variate=variate_filter,
|
| 649 |
+
window_id=None
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
if df_model is not None and not df_model.empty:
|
| 653 |
+
df_all.append(df_model)
|
| 654 |
+
|
| 655 |
+
if not df_all:
|
| 656 |
+
return f"⚠️ No results found for the selected filters.", pd.DataFrame(columns=["model", "MASE", "CRPS", "MAE", "MSE"])
|
| 657 |
+
|
| 658 |
+
# Combine all data
|
| 659 |
+
df_combined = pd.concat(df_all, ignore_index=True)
|
| 660 |
+
|
| 661 |
+
metrics_list = ["MASE", "CRPS", "MAE", "MSE"]
|
| 662 |
+
|
| 663 |
+
# Simple arithmetic mean across all windows (no normalization for series/variate level)
|
| 664 |
+
leaderboard = (
|
| 665 |
+
df_combined.groupby("model")[metrics_list]
|
| 666 |
+
.mean()
|
| 667 |
+
.reset_index()
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
leaderboard = leaderboard.round(3)
|
| 671 |
+
|
| 672 |
+
if metric not in leaderboard.columns:
|
| 673 |
+
return f"Metric '{metric}' not found.", pd.DataFrame(columns=["model"])
|
| 674 |
+
|
| 675 |
+
# Sort by metric
|
| 676 |
+
leaderboard = leaderboard.sort_values(by=metric, ascending=True).reset_index(drop=True)
|
| 677 |
+
|
| 678 |
+
return "", leaderboard
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def get_window_leaderboard(display_name, series, variate, window_id, horizon, metric: str = "MASE"):
|
| 683 |
+
"""
|
| 684 |
+
Get leaderboard for a specific test window.
|
| 685 |
+
|
| 686 |
+
Args:
|
| 687 |
+
display_name: Dataset display name from UI (will be converted to dataset_id)
|
| 688 |
+
series: Series name or index
|
| 689 |
+
variate: Variate name or index
|
| 690 |
+
window_id: Window index
|
| 691 |
+
horizon: Horizon name
|
| 692 |
+
metric: Metric for sorting
|
| 693 |
+
"""
|
| 694 |
+
df_all = []
|
| 695 |
+
for model in ALL_MODELS:
|
| 696 |
+
df_model = load_test_windows(display_name, horizon, model, series=series, variate=variate, window_id=window_id)
|
| 697 |
+
if df_model is not None and not df_model.empty:
|
| 698 |
+
df_all.append(df_model)
|
| 699 |
+
|
| 700 |
+
if not df_all:
|
| 701 |
+
# Return empty DataFrame with expected columns if no data found
|
| 702 |
+
return pd.DataFrame(columns=["model", "MASE", "CRPS", "MAE", "MSE"])
|
| 703 |
+
|
| 704 |
+
df_all = pd.concat(df_all, ignore_index=True)
|
| 705 |
+
|
| 706 |
+
# metrics DataFrame
|
| 707 |
+
metrics_cols = ["model", "MASE", "CRPS", "MAE", "MSE"]
|
| 708 |
+
leaderboard = df_all[metrics_cols].reset_index(drop=True)
|
| 709 |
+
|
| 710 |
+
if metric not in leaderboard.columns:
|
| 711 |
+
return pd.DataFrame(columns=["model"])
|
| 712 |
+
|
| 713 |
+
# Round numeric columns to 3 decimal places
|
| 714 |
+
numeric_cols = leaderboard.select_dtypes(include=[np.number]).columns
|
| 715 |
+
for col in numeric_cols:
|
| 716 |
+
leaderboard[col] = leaderboard[col].round(3)
|
| 717 |
+
|
| 718 |
+
return leaderboard
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
def get_pattern_leaderboard(
|
| 722 |
+
pattern_filters: dict[str, int],
|
| 723 |
+
selected_horizons: list[str],
|
| 724 |
+
) -> tuple[str, pd.DataFrame]:
|
| 725 |
+
"""
|
| 726 |
+
Filter variates by selected patterns and compute average metrics per model.
|
| 727 |
+
|
| 728 |
+
Uses FEATURES_BOOL_DF (binarized features) to filter variates,
|
| 729 |
+
then joins with VARIATES_DF to get metrics.
|
| 730 |
+
|
| 731 |
+
Args:
|
| 732 |
+
pattern_filters: Dict mapping pattern names to required values.
|
| 733 |
+
- {feature_name: required_value} where required_value is 0 or 1.
|
| 734 |
+
- Features with "Any" selection are not included in the dict.
|
| 735 |
+
- Example: {"T_strength": 1, "S_strength": 0} means:
|
| 736 |
+
- T_strength must be 1 (has the feature)
|
| 737 |
+
- S_strength must be 0 (does not have the feature)
|
| 738 |
+
selected_horizons: List of horizons to include (e.g., ["short", "medium"])
|
| 739 |
+
|
| 740 |
+
Returns:
|
| 741 |
+
tuple: (message, leaderboard_df)
|
| 742 |
+
- message: Status message with matching count
|
| 743 |
+
- leaderboard_df: DataFrame with model metrics, sorted by MASE
|
| 744 |
+
"""
|
| 745 |
+
from src.about import VARIATES_DF, FEATURES_DF, FEATURES_BOOL_DF, PATTERN_MAP
|
| 746 |
+
|
| 747 |
+
# Check if data is available
|
| 748 |
+
if VARIATES_DF.empty:
|
| 749 |
+
return "⚠️ No variate-level results available.", pd.DataFrame(columns=["model", "MASE", "CRPS"])
|
| 750 |
+
|
| 751 |
+
if FEATURES_DF.empty or FEATURES_BOOL_DF.empty:
|
| 752 |
+
return "⚠️ No features data available.", pd.DataFrame(columns=["model", "MASE", "CRPS"])
|
| 753 |
+
|
| 754 |
+
if not selected_horizons:
|
| 755 |
+
return "ℹ️ Please select at least one horizon.", pd.DataFrame(columns=["model", "MASE", "CRPS"])
|
| 756 |
+
|
| 757 |
+
# === Step 1. Apply pattern filters ===
|
| 758 |
+
# Start with all variates
|
| 759 |
+
mask = pd.Series(True, index=FEATURES_BOOL_DF.index)
|
| 760 |
+
|
| 761 |
+
# If no pattern filters (all "Any"), use all variates
|
| 762 |
+
if pattern_filters:
|
| 763 |
+
for pattern, required_value in pattern_filters.items():
|
| 764 |
+
# Map UI pattern name to feature column name
|
| 765 |
+
feature_col = PATTERN_MAP.get(pattern, pattern)
|
| 766 |
+
|
| 767 |
+
if feature_col not in FEATURES_BOOL_DF.columns:
|
| 768 |
+
return f"⚠️ Pattern '{pattern}' (column '{feature_col}') not found in features.", pd.DataFrame(columns=["model", "MASE", "CRPS"])
|
| 769 |
+
|
| 770 |
+
# Special handling for "stationarity" pattern
|
| 771 |
+
# stationarity = NOT is_random_walk
|
| 772 |
+
# When user selects "Has stationarity" (required_value=1), we want is_random_walk == 0
|
| 773 |
+
# When user selects "Not stationarity" (required_value=0), we want is_random_walk == 1
|
| 774 |
+
if pattern == "stationarity":
|
| 775 |
+
mask &= (FEATURES_BOOL_DF[feature_col] == (1 - required_value))
|
| 776 |
+
else:
|
| 777 |
+
mask &= (FEATURES_BOOL_DF[feature_col] == required_value)
|
| 778 |
+
|
| 779 |
+
# Get matching variates
|
| 780 |
+
matched_features = FEATURES_DF[mask].copy()
|
| 781 |
+
|
| 782 |
+
if matched_features.empty:
|
| 783 |
+
# Build debug info for empty results
|
| 784 |
+
debug_info = []
|
| 785 |
+
for pattern, required_value in pattern_filters.items():
|
| 786 |
+
feature_col = PATTERN_MAP.get(pattern, pattern)
|
| 787 |
+
if feature_col in FEATURES_BOOL_DF.columns:
|
| 788 |
+
value_counts = FEATURES_BOOL_DF[feature_col].value_counts().to_dict()
|
| 789 |
+
debug_info.append(f"{pattern} ({feature_col}): {value_counts}")
|
| 790 |
+
debug_msg = "; ".join(debug_info) if debug_info else "No debug info available"
|
| 791 |
+
return f"⚠️ No variates match the selected patterns.\n📊 Feature distribution: {debug_msg}", pd.DataFrame(columns=["model", "MASE", "CRPS"])
|
| 792 |
+
|
| 793 |
+
# === Step 2. Join with VARIATES_DF to get metrics ===
|
| 794 |
+
# Join strategy:
|
| 795 |
+
# - For multivariate (is_uts=False): use full join on (dataset_id, series_name, variate_name)
|
| 796 |
+
# - For univariate (is_uts=True): determine which FEATURES_DF field matches VARIATES_DF series_name
|
| 797 |
+
|
| 798 |
+
# Check if join columns exist
|
| 799 |
+
base_join_cols = ["dataset_id", "series_name", "variate_name"]
|
| 800 |
+
for col in base_join_cols:
|
| 801 |
+
if col not in matched_features.columns:
|
| 802 |
+
return f"⚠️ Column '{col}' not found in features.", pd.DataFrame(columns=["model", "MASE", "CRPS"])
|
| 803 |
+
if col not in VARIATES_DF.columns:
|
| 804 |
+
return f"⚠️ Column '{col}' not found in variates results.", pd.DataFrame(columns=["model", "MASE", "CRPS"])
|
| 805 |
+
|
| 806 |
+
# Select only join columns from features (to avoid column conflicts)
|
| 807 |
+
features_keys = matched_features[base_join_cols].drop_duplicates()
|
| 808 |
+
|
| 809 |
+
# Group by dataset_id and is_uts, then perform appropriate join
|
| 810 |
+
merged_list = []
|
| 811 |
+
|
| 812 |
+
for dataset_id in features_keys["dataset_id"].unique():
|
| 813 |
+
# Get dataset-specific data
|
| 814 |
+
dataset_features = features_keys[features_keys["dataset_id"] == dataset_id]
|
| 815 |
+
dataset_variates = VARIATES_DF[VARIATES_DF["dataset_id"] == dataset_id]
|
| 816 |
+
|
| 817 |
+
if dataset_variates.empty or dataset_features.empty:
|
| 818 |
+
continue
|
| 819 |
+
|
| 820 |
+
# Check is_uts for this dataset (should be consistent across all rows)
|
| 821 |
+
is_uts_values = dataset_variates["is_uts"].unique()
|
| 822 |
+
if len(is_uts_values) > 1:
|
| 823 |
+
print(f"⚠️ Warning: {dataset_id} has inconsistent is_uts values: {is_uts_values}")
|
| 824 |
+
is_uts = is_uts_values[0] if len(is_uts_values) > 0 else False
|
| 825 |
+
|
| 826 |
+
# Initialize dataset_merged to ensure it's always defined
|
| 827 |
+
dataset_merged = pd.DataFrame()
|
| 828 |
+
|
| 829 |
+
if not is_uts:
|
| 830 |
+
# Multivariate: use full join on (dataset_id, series_name, variate_name)
|
| 831 |
+
join_cols = ["dataset_id", "series_name", "variate_name"]
|
| 832 |
+
dataset_features_keys = dataset_features[join_cols].drop_duplicates()
|
| 833 |
+
dataset_merged = dataset_variates.merge(
|
| 834 |
+
dataset_features_keys,
|
| 835 |
+
on=join_cols,
|
| 836 |
+
how="inner"
|
| 837 |
+
)
|
| 838 |
+
else:
|
| 839 |
+
# Univariate: determine which FEATURES_DF field matches VARIATES_DF series_name
|
| 840 |
+
# Get unique series_name values from VARIATES_DF for this dataset
|
| 841 |
+
variates_series_names = set(dataset_variates["series_name"].unique())
|
| 842 |
+
|
| 843 |
+
# Check which FEATURES_DF field matches VARIATES_DF series_name
|
| 844 |
+
features_series_names = set(dataset_features["series_name"].unique())
|
| 845 |
+
features_variate_names = set(dataset_features["variate_name"].unique())
|
| 846 |
+
|
| 847 |
+
features_series_match = len(variates_series_names & features_series_names)
|
| 848 |
+
features_variate_match = len(variates_series_names & features_variate_names)
|
| 849 |
+
|
| 850 |
+
if features_series_match > features_variate_match:
|
| 851 |
+
# FEATURES_DF series_name matches VARIATES_DF series_name
|
| 852 |
+
join_cols = ["dataset_id", "series_name"]
|
| 853 |
+
dataset_features_keys = dataset_features[join_cols].drop_duplicates()
|
| 854 |
+
dataset_merged = dataset_variates.merge(
|
| 855 |
+
dataset_features_keys,
|
| 856 |
+
on=join_cols,
|
| 857 |
+
how="inner"
|
| 858 |
+
)
|
| 859 |
+
elif features_variate_match > features_series_match:
|
| 860 |
+
# FEATURES_DF variate_name matches VARIATES_DF series_name
|
| 861 |
+
# Create mapping: use FEATURES_DF variate_name to match VARIATES_DF series_name
|
| 862 |
+
dataset_features_keys = dataset_features[["dataset_id", "variate_name"]].drop_duplicates()
|
| 863 |
+
# Rename variate_name to series_name for join
|
| 864 |
+
dataset_features_keys = dataset_features_keys.rename(columns={"variate_name": "series_name"})
|
| 865 |
+
dataset_merged = dataset_variates.merge(
|
| 866 |
+
dataset_features_keys,
|
| 867 |
+
on=["dataset_id", "series_name"],
|
| 868 |
+
how="inner"
|
| 869 |
+
)
|
| 870 |
+
else:
|
| 871 |
+
# Both match equally or neither matches - try series_name first
|
| 872 |
+
if features_series_match > 0:
|
| 873 |
+
join_cols = ["dataset_id", "series_name"]
|
| 874 |
+
dataset_features_keys = dataset_features[join_cols].drop_duplicates()
|
| 875 |
+
dataset_merged = dataset_variates.merge(
|
| 876 |
+
dataset_features_keys,
|
| 877 |
+
on=join_cols,
|
| 878 |
+
how="inner"
|
| 879 |
+
)
|
| 880 |
+
else:
|
| 881 |
+
# No match found, skip this dataset
|
| 882 |
+
print(f"⚠️ Warning: {dataset_id} (UTS) - no matching field found between FEATURES_DF and VARIATES_DF series_name")
|
| 883 |
+
print(f" VARIATES_DF series_names: {sorted(list(variates_series_names))[:5]}")
|
| 884 |
+
print(f" FEATURES_DF series_names: {sorted(list(features_series_names))[:5]}")
|
| 885 |
+
print(f" FEATURES_DF variate_names: {sorted(list(features_variate_names))[:5]}")
|
| 886 |
+
continue
|
| 887 |
+
|
| 888 |
+
if not dataset_merged.empty:
|
| 889 |
+
merged_list.append(dataset_merged)
|
| 890 |
+
|
| 891 |
+
# Combine all merged results
|
| 892 |
+
if merged_list:
|
| 893 |
+
merged = pd.concat(merged_list, ignore_index=True)
|
| 894 |
+
else:
|
| 895 |
+
merged = pd.DataFrame(columns=VARIATES_DF.columns)
|
| 896 |
+
|
| 897 |
+
# === Step 3. Apply horizon filter ===
|
| 898 |
+
merged = merged[merged["horizon"].isin(selected_horizons)]
|
| 899 |
+
|
| 900 |
+
if merged.empty:
|
| 901 |
+
return f"⚠️ No results for selected horizons: {selected_horizons}", pd.DataFrame(columns=["model", "MASE", "CRPS"])
|
| 902 |
+
|
| 903 |
+
# === Step 4. Aggregate by model ===
|
| 904 |
+
metric_cols = ["MASE", "CRPS"]
|
| 905 |
+
available_metrics = [col for col in metric_cols if col in merged.columns]
|
| 906 |
+
|
| 907 |
+
# 4a. Original metrics: arithmetic mean across all matching variates and horizons
|
| 908 |
+
original_df = (
|
| 909 |
+
merged.groupby("model")[available_metrics]
|
| 910 |
+
.mean()
|
| 911 |
+
.reset_index()
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
# 4b. Normalized metrics: normalize by Seasonal Naive at (dataset_id, series_name, variate_name, horizon) level
|
| 915 |
+
# then aggregate with geometric mean
|
| 916 |
+
df_normalized = normalize_by_seasonal_naive(
|
| 917 |
+
merged,
|
| 918 |
+
baseline_model="seasonal_naive",
|
| 919 |
+
metrics=available_metrics,
|
| 920 |
+
groupby_cols=["dataset_id", "series_name", "variate_name", "horizon"],
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
# Helper function for geometric mean with NaN handling
|
| 924 |
+
def gmean_with_nan(x):
|
| 925 |
+
valid = x.dropna()
|
| 926 |
+
if len(valid) == 0:
|
| 927 |
+
return np.nan
|
| 928 |
+
return stats.gmean(valid)
|
| 929 |
+
|
| 930 |
+
if not df_normalized.empty:
|
| 931 |
+
normalized_df = (
|
| 932 |
+
df_normalized.groupby("model")[available_metrics]
|
| 933 |
+
.agg(gmean_with_nan)
|
| 934 |
+
.reset_index()
|
| 935 |
+
)
|
| 936 |
+
# Rename columns to *_norm
|
| 937 |
+
rename_map = {col: f"{col}_norm" for col in available_metrics}
|
| 938 |
+
normalized_df = normalized_df.rename(columns=rename_map)
|
| 939 |
+
else:
|
| 940 |
+
# If normalization fails, create empty normalized columns
|
| 941 |
+
normalized_df = original_df[["model"]].copy()
|
| 942 |
+
for col in available_metrics:
|
| 943 |
+
normalized_df[f"{col}_norm"] = np.nan
|
| 944 |
+
|
| 945 |
+
# Combine original and normalized metrics
|
| 946 |
+
leaderboard = original_df.merge(normalized_df, on="model", how="left")
|
| 947 |
+
|
| 948 |
+
# Rename columns for better clarity
|
| 949 |
+
rename_map = {}
|
| 950 |
+
if "MASE" in leaderboard.columns:
|
| 951 |
+
rename_map["MASE"] = "MASE (raw)"
|
| 952 |
+
if "CRPS" in leaderboard.columns:
|
| 953 |
+
rename_map["CRPS"] = "CRPS (raw)"
|
| 954 |
+
if "MASE_norm" in leaderboard.columns:
|
| 955 |
+
rename_map["MASE_norm"] = "MASE (norm.)"
|
| 956 |
+
if "CRPS_norm" in leaderboard.columns:
|
| 957 |
+
rename_map["CRPS_norm"] = "CRPS (norm.)"
|
| 958 |
+
|
| 959 |
+
if rename_map:
|
| 960 |
+
leaderboard = leaderboard.rename(columns=rename_map)
|
| 961 |
+
|
| 962 |
+
# Sort by MASE (norm.) if available, otherwise by MASE (raw)
|
| 963 |
+
if "MASE (norm.)" in leaderboard.columns:
|
| 964 |
+
leaderboard = leaderboard.sort_values(by="MASE (norm.)", ascending=True).reset_index(drop=True)
|
| 965 |
+
elif "MASE (raw)" in leaderboard.columns:
|
| 966 |
+
leaderboard = leaderboard.sort_values(by="MASE (raw)", ascending=True).reset_index(drop=True)
|
| 967 |
+
|
| 968 |
+
# Round numeric columns to 3 decimal places
|
| 969 |
+
numeric_cols = leaderboard.select_dtypes(include=[np.number]).columns
|
| 970 |
+
for col in numeric_cols:
|
| 971 |
+
# Round to 3 decimal places and ensure proper formatting
|
| 972 |
+
leaderboard[col] = leaderboard[col].round(3)
|
| 973 |
+
# Convert to float64 to ensure consistent display
|
| 974 |
+
leaderboard[col] = leaderboard[col].astype('float64')
|
| 975 |
+
|
| 976 |
+
# Reorder columns: model, MASE (norm.), CRPS (norm.), MASE (raw), CRPS (raw)
|
| 977 |
+
col_order = ["model", "MASE (norm.)", "CRPS (norm.)", "MASE (raw)", "CRPS (raw)"]
|
| 978 |
+
col_order = [col for col in col_order if col in leaderboard.columns]
|
| 979 |
+
leaderboard = leaderboard[col_order]
|
| 980 |
+
|
| 981 |
+
# === Step 5. Build message ===
|
| 982 |
+
num_variates = len(features_keys)
|
| 983 |
+
num_results = len(merged)
|
| 984 |
+
num_models = leaderboard["model"].nunique()
|
| 985 |
+
|
| 986 |
+
# Count by dataset
|
| 987 |
+
dataset_counts = features_keys["dataset_id"].value_counts().to_dict()
|
| 988 |
+
dataset_msg = ", ".join([f"{ds}: {cnt}" for ds, cnt in list(dataset_counts.items())[:5]])
|
| 989 |
+
if len(dataset_counts) > 5:
|
| 990 |
+
dataset_msg += f", ... ({len(dataset_counts)} datasets total)"
|
| 991 |
+
|
| 992 |
+
# Build pattern description
|
| 993 |
+
if pattern_filters:
|
| 994 |
+
pattern_desc = ", ".join([
|
| 995 |
+
f"{p}={v}" for p, v in pattern_filters.items()
|
| 996 |
+
])
|
| 997 |
+
filter_msg = f"🔍 Filters: {pattern_desc}"
|
| 998 |
+
else:
|
| 999 |
+
filter_msg = "🔍 Filters: All N/A (no filtering, all variates included)"
|
| 1000 |
+
|
| 1001 |
+
msg = (
|
| 1002 |
+
f"✨ {num_variates} variates matched across {len(dataset_counts)} datasets.\n"
|
| 1003 |
+
f"📊 {num_results} results from {num_models} models.\n"
|
| 1004 |
+
f"{filter_msg}\n"
|
| 1005 |
+
f"📁 Datasets: {dataset_msg}"
|
| 1006 |
+
)
|
| 1007 |
+
|
| 1008 |
+
return msg, leaderboard
|
| 1009 |
+
|
| 1010 |
+
|
| 1011 |
+
# def get_archive_results(dataset_name: str, selected_patterns: list[str], variate_name: str):
|
| 1012 |
+
# """
|
| 1013 |
+
# Return variates filtered by dataset, patterns, and variate name.
|
| 1014 |
+
# """
|
| 1015 |
+
# if FEATURES_BOOL_DF.empty:
|
| 1016 |
+
# return pd.DataFrame(columns=["dataset", "variate_name"])
|
| 1017 |
+
|
| 1018 |
+
# df = FEATURES_DF.copy()
|
| 1019 |
+
# df_bool = FEATURES_BOOL_DF.copy()
|
| 1020 |
+
|
| 1021 |
+
# # Dataset filter
|
| 1022 |
+
# if dataset_name and dataset_name != "All":
|
| 1023 |
+
# df = df[df["dataset"] == dataset_name]
|
| 1024 |
+
|
| 1025 |
+
# # Pattern filter
|
| 1026 |
+
# if selected_patterns:
|
| 1027 |
+
# mask = pd.Series(True, index=df_bool.index)
|
| 1028 |
+
# for pattern in selected_patterns:
|
| 1029 |
+
# if pattern in df_bool.columns:
|
| 1030 |
+
# mask &= df_bool[pattern] == 1
|
| 1031 |
+
# df = df[mask]
|
| 1032 |
+
|
| 1033 |
+
# # Variate filter
|
| 1034 |
+
# if variate_name and variate_name != "All":
|
| 1035 |
+
# df = df[df["variate_name"] == variate_name]
|
| 1036 |
+
|
| 1037 |
+
# if df.empty:
|
| 1038 |
+
# return pd.DataFrame(columns=FEATURES_BOOL_DF.columns) # ToDO: columns换成完整的
|
| 1039 |
+
|
| 1040 |
+
# # --- Add freq & domain from YAML ---
|
| 1041 |
+
# freq_map, domain_map = {}, {}
|
| 1042 |
+
# for ds in df["dataset"].unique():
|
| 1043 |
+
# yaml_path = os.path.join("conf", "data", f"{ds}.yaml")
|
| 1044 |
+
# if os.path.exists(yaml_path):
|
| 1045 |
+
# with open(yaml_path, "r") as f:
|
| 1046 |
+
# meta = yaml.safe_load(f)
|
| 1047 |
+
# freq_map[ds] = meta.get("freq", None)
|
| 1048 |
+
# domain_map[ds] = meta.get("domain", None)
|
| 1049 |
+
# else:
|
| 1050 |
+
# freq_map[ds] = None
|
| 1051 |
+
# domain_map[ds] = None
|
| 1052 |
+
|
| 1053 |
+
# df["freq"] = df["dataset"].map(freq_map)
|
| 1054 |
+
# df["domain"] = df["dataset"].map(domain_map)
|
| 1055 |
+
|
| 1056 |
+
# # Select useful columns
|
| 1057 |
+
# base_cols = ["dataset", "variate_name", "freq", "domain"]
|
| 1058 |
+
# tsfeature_cols = [c for c in df.columns if c not in base_cols+['unitroot_pp', 'unitroot_kpss']]
|
| 1059 |
+
|
| 1060 |
+
# # === Rename feature columns ===
|
| 1061 |
+
# renamed_cols = {}
|
| 1062 |
+
# for col in tsfeature_cols:
|
| 1063 |
+
# if col == 'trend':
|
| 1064 |
+
# renamed_cols['trend'] = "T_strength"
|
| 1065 |
+
# elif col.startswith("trend"):
|
| 1066 |
+
# renamed_cols[col] = col.replace("trend", "T", 1)
|
| 1067 |
+
# elif col.startswith("seasonality"):
|
| 1068 |
+
# renamed_cols[col] = col.replace("seasonality", "S", 1)
|
| 1069 |
+
# elif col.startswith("seasonal"):
|
| 1070 |
+
# renamed_cols[col] = col.replace("seasonal", "S", 1)
|
| 1071 |
+
# df = df.rename(columns=renamed_cols)
|
| 1072 |
+
|
| 1073 |
+
# # === Apply new column names ===
|
| 1074 |
+
# tsfeature_cols = [renamed_cols.get(c, c) for c in tsfeature_cols]
|
| 1075 |
+
|
| 1076 |
+
# global_cols = ["x_acf1", "x_acf10", "lumpiness", "stability", "hurst", "entropy"]
|
| 1077 |
+
# t_cols = [c for c in tsfeature_cols if c.startswith("T_")]
|
| 1078 |
+
# s_cols = [c for c in tsfeature_cols if c.startswith("S_")]
|
| 1079 |
+
# e_cols = [c for c in tsfeature_cols if c.startswith("e_")]
|
| 1080 |
+
# stats_cols = [c for c in tsfeature_cols if c not in global_cols+ t_cols + s_cols + e_cols]
|
| 1081 |
+
|
| 1082 |
+
# ordered_cols = base_cols + global_cols + t_cols + s_cols + e_cols + stats_cols
|
| 1083 |
+
|
| 1084 |
+
|
| 1085 |
+
# return df[ordered_cols].round(3)
|
src/tab.py
ADDED
|
@@ -0,0 +1,1370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import plotly.graph_objects as go
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Add project root and src directory to Python path to enable imports from timebench
|
| 8 |
+
# Get the directory containing this file (leaderboard_app/src/)
|
| 9 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 10 |
+
# Get leaderboard_app directory
|
| 11 |
+
leaderboard_app_dir = os.path.dirname(current_dir)
|
| 12 |
+
|
| 13 |
+
# Try multiple paths for timebench import:
|
| 14 |
+
# 1. Current leaderboard_app directory (if timebench was copied to leaderboard_app/)
|
| 15 |
+
# 2. Parent directory's src (for local development: TIME/src/)
|
| 16 |
+
|
| 17 |
+
# Add current leaderboard_app directory first (for Space deployment)
|
| 18 |
+
if leaderboard_app_dir not in sys.path:
|
| 19 |
+
sys.path.insert(0, leaderboard_app_dir)
|
| 20 |
+
|
| 21 |
+
# Get project root directory (TIME/) - for local development
|
| 22 |
+
project_root = os.path.dirname(leaderboard_app_dir)
|
| 23 |
+
if project_root not in sys.path:
|
| 24 |
+
sys.path.insert(0, project_root)
|
| 25 |
+
|
| 26 |
+
src_dir = os.path.join(project_root, "src")
|
| 27 |
+
if src_dir not in sys.path and os.path.exists(src_dir):
|
| 28 |
+
sys.path.insert(0, src_dir)
|
| 29 |
+
|
| 30 |
+
import json
|
| 31 |
+
import gradio as gr
|
| 32 |
+
from src.about import DATASET_CHOICES, ALL_MODELS, RESULTS_ROOT, FEATURES_DF, FEATURES_BOOL_DF, PATTERN_MAP
|
| 33 |
+
from src.leaderboard import (get_overall_leaderboard, get_dataset_multilevel_leaderboard,
|
| 34 |
+
get_window_leaderboard, get_pattern_leaderboard, resolve_dataset_id)
|
| 35 |
+
from src.about import DATASETS_DF, ALL_HORIZONS
|
| 36 |
+
from src.hf_config import get_datasets_root, get_config_root
|
| 37 |
+
import numpy as np
|
| 38 |
+
import pandas as pd
|
| 39 |
+
from pathlib import Path
|
| 40 |
+
import ast
|
| 41 |
+
import matplotlib
|
| 42 |
+
matplotlib.use('Agg') # Use non-interactive backend for Gradio
|
| 43 |
+
import yaml
|
| 44 |
+
import tempfile
|
| 45 |
+
|
| 46 |
+
from timebench.evaluation.data import Dataset, get_dataset_settings, load_dataset_config
|
| 47 |
+
from src.leaderboard import find_dataset_term_path
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def export_dataframe_to_csv(df, filename_prefix="leaderboard"):
|
| 51 |
+
"""Export a DataFrame to a temporary CSV file and return the path for download.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
df: pandas DataFrame to export
|
| 55 |
+
filename_prefix: prefix for the temporary file name
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
str: path to the temporary CSV file, or None if df is empty
|
| 59 |
+
"""
|
| 60 |
+
if df is None or (hasattr(df, 'empty') and df.empty):
|
| 61 |
+
return None
|
| 62 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, prefix=f"{filename_prefix}_") as f:
|
| 63 |
+
df.to_csv(f, index=False)
|
| 64 |
+
return f.name
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# def update_variate_choices(dataset_name: str, selected_patterns: list[str]):
|
| 68 |
+
# """
|
| 69 |
+
# Dynamically update the variate dropdown choices based on dataset + patterns.
|
| 70 |
+
# """
|
| 71 |
+
# if dataset_name == "All":
|
| 72 |
+
# return gr.Dropdown(choices=["All"], value="All", interactive=False)
|
| 73 |
+
|
| 74 |
+
# # Filter features by dataset
|
| 75 |
+
# df = FEATURES_BOOL_DF[FEATURES_BOOL_DF["dataset"] == dataset_name]
|
| 76 |
+
|
| 77 |
+
# # Apply pattern filters if provided
|
| 78 |
+
# if selected_patterns:
|
| 79 |
+
# mask = pd.Series(True, index=df.index)
|
| 80 |
+
# for pattern in selected_patterns:
|
| 81 |
+
# if pattern in df.columns:
|
| 82 |
+
# mask &= df[pattern] == 1
|
| 83 |
+
# df = df[mask]
|
| 84 |
+
|
| 85 |
+
# variates = sorted(df["variate_name"].unique().tolist())
|
| 86 |
+
# if not variates:
|
| 87 |
+
# return gr.Dropdown(choices=["All"], value="All", interactive=False)
|
| 88 |
+
|
| 89 |
+
# return gr.Dropdown(choices=["All"] + variates, value="All", interactive=True)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# # 更新 Variate 选择框
|
| 93 |
+
# def update_variate_choices_groups(dataset_name, t, s, r, g):
|
| 94 |
+
# selected_patterns = (t or []) + (s or []) + (r or []) + (g or [])
|
| 95 |
+
# return update_variate_choices(dataset_name, selected_patterns)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
########################## Dataset Tab ##########################
|
| 99 |
+
def update_series_and_variate(display_name):
|
| 100 |
+
"""
|
| 101 |
+
根据 dataset display_name 更新 series 和 variate 的下拉选项
|
| 102 |
+
用于合并后的 Dataset tab
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
display_name: Dataset display name from UI dropdown (will be resolved to dataset_id)
|
| 106 |
+
"""
|
| 107 |
+
# Use first available model to get data
|
| 108 |
+
model_name = ALL_MODELS[0]
|
| 109 |
+
# Find dataset_term (handles display_name -> dataset_id conversion)
|
| 110 |
+
results_root = str(RESULTS_ROOT)
|
| 111 |
+
dataset_term = find_dataset_term_path(results_root, model_name, display_name)
|
| 112 |
+
|
| 113 |
+
if dataset_term is None:
|
| 114 |
+
print(f"Error: dataset_term is None for display_name={display_name}, model_name={model_name}")
|
| 115 |
+
return (
|
| 116 |
+
gr.Dropdown(choices=["---"], value="---", label="Select Series", interactive=True),
|
| 117 |
+
gr.Dropdown(choices=["---"], value="---", label="Select Variate", interactive=True),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Load Dataset to get actual series and variate names
|
| 121 |
+
# Use HF config to get dataset root (handles both local and HF Hub)
|
| 122 |
+
hf_dataset_root = str(get_datasets_root())
|
| 123 |
+
|
| 124 |
+
# Use HF config to get config root
|
| 125 |
+
config_root = get_config_root()
|
| 126 |
+
config_path = config_root / "datasets.yaml"
|
| 127 |
+
|
| 128 |
+
# horizon不影响series和variate的值,因此直接用short
|
| 129 |
+
config = load_dataset_config(config_path)
|
| 130 |
+
settings = get_dataset_settings(dataset_term, "short", config)
|
| 131 |
+
|
| 132 |
+
prediction_length = settings.get("prediction_length")
|
| 133 |
+
test_length = settings.get("test_length")
|
| 134 |
+
|
| 135 |
+
dataset_obj = Dataset(
|
| 136 |
+
name=dataset_term,
|
| 137 |
+
term="short",
|
| 138 |
+
prediction_length=prediction_length,
|
| 139 |
+
test_length=test_length,
|
| 140 |
+
storage_path=hf_dataset_root, # Pass storage path directly
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Get series names
|
| 144 |
+
if "item_id" in dataset_obj.hf_dataset.column_names:
|
| 145 |
+
series_names = dataset_obj.hf_dataset["item_id"]
|
| 146 |
+
else:
|
| 147 |
+
series_names = [dataset_obj.hf_dataset[i].get("item_id", f"item_{i}")
|
| 148 |
+
for i in range(len(dataset_obj.hf_dataset))]
|
| 149 |
+
|
| 150 |
+
series_list = ["---"] + [str(name) for name in series_names]
|
| 151 |
+
|
| 152 |
+
# Get variate names
|
| 153 |
+
variate_names = dataset_obj.get_variate_names()
|
| 154 |
+
|
| 155 |
+
if variate_names is None:
|
| 156 |
+
# UTS mode: variate dropdown should be disabled
|
| 157 |
+
return (
|
| 158 |
+
gr.Dropdown(choices=series_list, value="---", label="Select Series", interactive=True),
|
| 159 |
+
gr.Dropdown(choices=["---"], value="---", label="Select Variate", interactive=False),
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
# MTS mode: both dropdowns are enabled
|
| 163 |
+
variates_list = ["---"] + [str(name) for name in variate_names]
|
| 164 |
+
return (
|
| 165 |
+
gr.Dropdown(choices=series_list, value="---", label="Select Series", interactive=True),
|
| 166 |
+
gr.Dropdown(choices=variates_list, value="---", label="Select Variate", interactive=True),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
########################## Window Tab ##########################
|
| 171 |
+
def get_available_horizons(display_name):
|
| 172 |
+
"""
|
| 173 |
+
获取数据集可用的horizons
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
display_name: Dataset display name from UI dropdown
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
list: 可用的horizon列表,例如 ["short", "medium", "long"] 或 ["short"]
|
| 180 |
+
"""
|
| 181 |
+
if DATASETS_DF.empty:
|
| 182 |
+
return ALL_HORIZONS
|
| 183 |
+
|
| 184 |
+
# Resolve display_name to dataset_id
|
| 185 |
+
dataset_id = resolve_dataset_id(display_name)
|
| 186 |
+
|
| 187 |
+
# Filter by dataset_id
|
| 188 |
+
df_filtered = DATASETS_DF[DATASETS_DF["dataset_id"] == dataset_id]
|
| 189 |
+
|
| 190 |
+
if df_filtered.empty:
|
| 191 |
+
# If not found, return all horizons as fallback
|
| 192 |
+
return ALL_HORIZONS
|
| 193 |
+
|
| 194 |
+
# Get unique horizons for this dataset
|
| 195 |
+
available_horizons = df_filtered["horizon"].unique().tolist()
|
| 196 |
+
|
| 197 |
+
# Sort to maintain order: short, medium, long
|
| 198 |
+
available_horizons = [h for h in ALL_HORIZONS if h in available_horizons]
|
| 199 |
+
|
| 200 |
+
return available_horizons if available_horizons else ["short"]
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def update_horizon_choices(display_name):
|
| 204 |
+
"""
|
| 205 |
+
根据数据集更新horizon Radio组件的choices和value
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
display_name: Dataset display name from UI dropdown
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
tuple: (choices, value) 用于更新Radio组件
|
| 212 |
+
"""
|
| 213 |
+
available_horizons = get_available_horizons(display_name)
|
| 214 |
+
|
| 215 |
+
# 如果当前选择的horizon不在可用列表中,则选择第一个可用的
|
| 216 |
+
current_value = "short" if "short" in available_horizons else (available_horizons[0] if available_horizons else "short")
|
| 217 |
+
|
| 218 |
+
# 创建choices列表,只包含可用的horizons
|
| 219 |
+
choices = [h for h in ALL_HORIZONS if h in available_horizons]
|
| 220 |
+
|
| 221 |
+
return gr.Radio(choices=choices, value=current_value)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def update_series_variate_and_window(display_name, horizon):
|
| 225 |
+
"""
|
| 226 |
+
根据 dataset display_name 和 horizon 更新 series, variate, window 的下拉选项
|
| 227 |
+
使用 Dataset 加载实际的 series 和 variate 名称
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
display_name: Dataset display name from UI dropdown (will be resolved to dataset_id)
|
| 231 |
+
horizon: Horizon name (short, medium, long)
|
| 232 |
+
"""
|
| 233 |
+
# Use first available model to get data
|
| 234 |
+
model_name = ALL_MODELS[0]
|
| 235 |
+
|
| 236 |
+
# Find dataset_term (handles display_name -> dataset_id conversion)
|
| 237 |
+
results_root = str(RESULTS_ROOT)
|
| 238 |
+
dataset_term = find_dataset_term_path(results_root, model_name, display_name)
|
| 239 |
+
|
| 240 |
+
if dataset_term is None:
|
| 241 |
+
print(f"Error: dataset_term is None for display_name={display_name}, horizon={horizon}, model_name={model_name}")
|
| 242 |
+
return (
|
| 243 |
+
gr.Dropdown(choices=[], value=None, label="Select Series", interactive=False),
|
| 244 |
+
gr.Dropdown(choices=[], value=None, label="Select Variate", interactive=False),
|
| 245 |
+
gr.Dropdown(choices=[], value=None, label="Select Testing Window", interactive=False),
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Parse dataset_name and freq from dataset_term (format: "dataset_name/freq")
|
| 249 |
+
dataset_name, freq = dataset_term.split("/", 1)
|
| 250 |
+
|
| 251 |
+
# Load Dataset to get actual series and variate names
|
| 252 |
+
# Use HF config to get dataset root (handles both local and HF Hub)
|
| 253 |
+
hf_dataset_root = str(get_datasets_root())
|
| 254 |
+
|
| 255 |
+
# Use HF config to get config root
|
| 256 |
+
config_root = get_config_root()
|
| 257 |
+
config_path = config_root / "datasets.yaml"
|
| 258 |
+
|
| 259 |
+
config = load_dataset_config(config_path) if config_path.exists() else {}
|
| 260 |
+
settings = get_dataset_settings(dataset_term, horizon, config)
|
| 261 |
+
|
| 262 |
+
prediction_length = settings.get("prediction_length")
|
| 263 |
+
test_length = settings.get("test_length")
|
| 264 |
+
|
| 265 |
+
# Load dataset
|
| 266 |
+
dataset_obj = Dataset(
|
| 267 |
+
name=dataset_term,
|
| 268 |
+
term=horizon,
|
| 269 |
+
prediction_length=prediction_length,
|
| 270 |
+
test_length=test_length,
|
| 271 |
+
storage_path=hf_dataset_root, # Pass storage path directly
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Get series names (item_id) from hf_dataset
|
| 275 |
+
if "item_id" in dataset_obj.hf_dataset.column_names:
|
| 276 |
+
series_names = dataset_obj.hf_dataset["item_id"]
|
| 277 |
+
else:
|
| 278 |
+
# Fallback: get from iterating
|
| 279 |
+
series_names = [dataset_obj.hf_dataset[i].get("item_id", f"item_{i}")
|
| 280 |
+
for i in range(len(dataset_obj.hf_dataset))]
|
| 281 |
+
|
| 282 |
+
# Get variate names
|
| 283 |
+
variate_names = dataset_obj.get_variate_names()
|
| 284 |
+
|
| 285 |
+
# Get window count
|
| 286 |
+
num_windows = dataset_obj.windows
|
| 287 |
+
windows = [str(i) for i in range(num_windows)]
|
| 288 |
+
|
| 289 |
+
# Convert to lists and maintain order (no sorting)
|
| 290 |
+
series_list = [str(name) for name in series_names]
|
| 291 |
+
|
| 292 |
+
# Handle UTS (Univariate Time Series) vs MTS (Multivariate Time Series)
|
| 293 |
+
if variate_names is None:
|
| 294 |
+
# UTS mode: each series is a single variate, so variate is always 0
|
| 295 |
+
return (
|
| 296 |
+
gr.Dropdown(choices=series_list, value=series_list[0], label="Select Series", interactive=True),
|
| 297 |
+
gr.Dropdown(choices=["0"], value="0", label="Select Variate", interactive=False),
|
| 298 |
+
gr.Dropdown(choices=windows, value=windows[0], label="Select Testing Window", interactive=True),
|
| 299 |
+
)
|
| 300 |
+
else:
|
| 301 |
+
# MTS mode: multiple variates per series
|
| 302 |
+
variates_list = [str(name) for name in variate_names]
|
| 303 |
+
return (
|
| 304 |
+
gr.Dropdown(choices=series_list, value=series_list[0], label="Select Series", interactive=True),
|
| 305 |
+
gr.Dropdown(choices=variates_list, value=variates_list[0], label="Select Variate", interactive=True),
|
| 306 |
+
gr.Dropdown(choices=windows, value=windows[0], label="Select Testing Window", interactive=True),
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def plot_window_series(display_name, series, variate, window_id, horizon, selected_quantiles, model):
|
| 311 |
+
"""
|
| 312 |
+
Plot time series predictions for a specific window using Plotly for interactive visualization.
|
| 313 |
+
Now includes full time series visualization with test window highlighted.
|
| 314 |
+
Accepts series and variate names (strings) and converts them to indices.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
display_name: Dataset display name from UI dropdown (will be resolved to dataset_id)
|
| 318 |
+
series: Series name
|
| 319 |
+
variate: Variate name
|
| 320 |
+
window_id: Window index
|
| 321 |
+
horizon: Horizon name
|
| 322 |
+
selected_quantiles: List of quantile strings to plot
|
| 323 |
+
model: Model name
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
tuple: (fig, info_message) where fig is Plotly figure and info_message contains prediction details
|
| 327 |
+
"""
|
| 328 |
+
print(f"🔍 plot_window_series called: display_name={display_name}, series={series}, variate={variate}, window_id={window_id}, horizon={horizon}, model={model}")
|
| 329 |
+
|
| 330 |
+
if display_name is None or series is None or variate is None or window_id is None:
|
| 331 |
+
print("❌ Missing parameters")
|
| 332 |
+
fig = go.Figure()
|
| 333 |
+
fig.update_layout(title="Please select all parameters")
|
| 334 |
+
return fig, ""
|
| 335 |
+
|
| 336 |
+
results_root = str(RESULTS_ROOT)
|
| 337 |
+
print(f"📁 results_root: {results_root}")
|
| 338 |
+
dataset_term = find_dataset_term_path(results_root, model, display_name)
|
| 339 |
+
print(f"📁 dataset_term: {dataset_term}")
|
| 340 |
+
if dataset_term is None:
|
| 341 |
+
print("❌ Dataset not found")
|
| 342 |
+
fig = go.Figure()
|
| 343 |
+
fig.update_layout(title="Dataset not found")
|
| 344 |
+
return fig, ""
|
| 345 |
+
|
| 346 |
+
predictions_path = os.path.join(results_root, model, dataset_term, horizon, "predictions.npz")
|
| 347 |
+
print(f"📁 predictions_path: {predictions_path}, exists: {os.path.exists(predictions_path)}")
|
| 348 |
+
|
| 349 |
+
if not os.path.exists(predictions_path):
|
| 350 |
+
print("❌ Predictions file not found")
|
| 351 |
+
fig = go.Figure()
|
| 352 |
+
fig.update_layout(title="Predictions file not found")
|
| 353 |
+
return fig, ""
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
predictions = np.load(predictions_path)
|
| 357 |
+
# Load pre-computed quantiles (new format only)
|
| 358 |
+
predictions_quantiles = predictions["predictions_quantiles"] # (num_series, num_windows, 9, num_variates, prediction_length)
|
| 359 |
+
quantile_levels = predictions["quantile_levels"] # [0.1, 0.2, ..., 0.9]
|
| 360 |
+
|
| 361 |
+
# Load prediction scale factor from config.json (for float16 overflow prevention)
|
| 362 |
+
model_config_path = os.path.join(results_root, model, dataset_term, horizon, "config.json")
|
| 363 |
+
prediction_scale_factor = 1.0
|
| 364 |
+
if os.path.exists(model_config_path):
|
| 365 |
+
with open(model_config_path, "r") as f:
|
| 366 |
+
model_config = json.load(f)
|
| 367 |
+
prediction_scale_factor = model_config.get("prediction_scale_factor", 1.0)
|
| 368 |
+
if prediction_scale_factor != 1.0:
|
| 369 |
+
print(f"📊 Applying inverse scale factor: {prediction_scale_factor}")
|
| 370 |
+
predictions_quantiles = predictions_quantiles.astype(np.float32) * prediction_scale_factor
|
| 371 |
+
|
| 372 |
+
# Convert series and variate names to indices
|
| 373 |
+
series_idx = None
|
| 374 |
+
variate_idx = None
|
| 375 |
+
dataset_obj = None
|
| 376 |
+
|
| 377 |
+
# Load Dataset to get name-to-index mappings and full time series
|
| 378 |
+
# Use HF config to get dataset root (handles both local and HF Hub)
|
| 379 |
+
hf_dataset_root = str(get_datasets_root())
|
| 380 |
+
print(f"📁 hf_dataset_root: {hf_dataset_root}, exists: {os.path.exists(hf_dataset_root)}")
|
| 381 |
+
|
| 382 |
+
# Use HF config to get config root
|
| 383 |
+
config_root = get_config_root()
|
| 384 |
+
config_path_yaml = config_root / "datasets.yaml"
|
| 385 |
+
print(f"📁 config_path_yaml: {config_path_yaml}, exists: {config_path_yaml.exists()}")
|
| 386 |
+
|
| 387 |
+
config = load_dataset_config(config_path_yaml) if config_path_yaml.exists() else {}
|
| 388 |
+
settings = get_dataset_settings(dataset_term, horizon, config)
|
| 389 |
+
print(f"⚙️ settings: {settings}")
|
| 390 |
+
|
| 391 |
+
prediction_length = settings.get("prediction_length")
|
| 392 |
+
test_length = settings.get("test_length")
|
| 393 |
+
|
| 394 |
+
print(f"📥 Loading Dataset: name={dataset_term}, term={horizon}, storage_path={hf_dataset_root}")
|
| 395 |
+
dataset_obj = Dataset(
|
| 396 |
+
name=dataset_term,
|
| 397 |
+
term=horizon,
|
| 398 |
+
prediction_length=prediction_length,
|
| 399 |
+
test_length=test_length,
|
| 400 |
+
storage_path=hf_dataset_root, # Pass storage path directly
|
| 401 |
+
)
|
| 402 |
+
print(f"✅ Dataset loaded: {len(dataset_obj.hf_dataset)} series")
|
| 403 |
+
|
| 404 |
+
# Get frequency from dataset
|
| 405 |
+
dataset_freq = dataset_obj.freq
|
| 406 |
+
print(f"📅 Dataset frequency: {dataset_freq}")
|
| 407 |
+
|
| 408 |
+
# Get series names and create mapping
|
| 409 |
+
if "item_id" in dataset_obj.hf_dataset.column_names:
|
| 410 |
+
series_names = dataset_obj.hf_dataset["item_id"]
|
| 411 |
+
else:
|
| 412 |
+
series_names = [dataset_obj.hf_dataset[i].get("item_id", f"item_{i}")
|
| 413 |
+
for i in range(len(dataset_obj.hf_dataset))]
|
| 414 |
+
|
| 415 |
+
print(f"📋 series_names: {list(series_names)}")
|
| 416 |
+
series_name_to_idx = {name: idx for idx, name in enumerate(series_names)}
|
| 417 |
+
if series in series_name_to_idx:
|
| 418 |
+
series_idx = series_name_to_idx[series]
|
| 419 |
+
print(f"✅ Found series '{series}' at index {series_idx}")
|
| 420 |
+
else:
|
| 421 |
+
series_idx = int(series)
|
| 422 |
+
print(f"⚠️ Series '{series}' not found in names, using int index {series_idx}")
|
| 423 |
+
|
| 424 |
+
# Get variate names and create mapping
|
| 425 |
+
variate_names = dataset_obj.get_variate_names()
|
| 426 |
+
print(f"📋 variate_names: {variate_names}")
|
| 427 |
+
if variate_names is not None:
|
| 428 |
+
# MTS mode: multiple variates per series
|
| 429 |
+
variate_name_to_idx = {name: idx for idx, name in enumerate(variate_names)}
|
| 430 |
+
if variate in variate_name_to_idx:
|
| 431 |
+
variate_idx = variate_name_to_idx[variate]
|
| 432 |
+
print(f"✅ Found variate '{variate}' at index {variate_idx}")
|
| 433 |
+
else:
|
| 434 |
+
variate_idx = int(variate)
|
| 435 |
+
print(f"⚠️ Variate '{variate}' not found in names, using int index {variate_idx}")
|
| 436 |
+
else:
|
| 437 |
+
# UTS mode: each series is a single variate, so variate_idx is always 0
|
| 438 |
+
variate_idx = 0
|
| 439 |
+
print(f"ℹ️ UTS mode, variate_idx=0")
|
| 440 |
+
|
| 441 |
+
if series_idx is None:
|
| 442 |
+
series_idx = int(series)
|
| 443 |
+
if variate_idx is None:
|
| 444 |
+
# For UTS mode, variate_idx should be 0
|
| 445 |
+
try:
|
| 446 |
+
variate_idx = int(variate) if variate is not None else 0
|
| 447 |
+
except (ValueError, TypeError):
|
| 448 |
+
variate_idx = 0
|
| 449 |
+
|
| 450 |
+
window_idx = int(window_id)
|
| 451 |
+
|
| 452 |
+
# Get pre-computed quantiles for this specific series, window, and variate
|
| 453 |
+
quantiles_data = predictions_quantiles[series_idx, window_idx, :, variate_idx, :] # (9, prediction_length)
|
| 454 |
+
prediction_length = quantiles_data.shape[1]
|
| 455 |
+
# Create mapping from quantile level string to index
|
| 456 |
+
quantile_level_to_idx = {f"{q:.1f}": i for i, q in enumerate(quantile_levels)}
|
| 457 |
+
|
| 458 |
+
# Load full time series data
|
| 459 |
+
full_series = None
|
| 460 |
+
train_end_idx = None
|
| 461 |
+
test_window_start_idx = None
|
| 462 |
+
test_window_end_idx = None
|
| 463 |
+
|
| 464 |
+
# Get full target time series for this series
|
| 465 |
+
print(f"📊 Getting target for series_idx={series_idx}, variate_idx={variate_idx}")
|
| 466 |
+
full_target = dataset_obj.hf_dataset[series_idx]["target"]
|
| 467 |
+
print(f"📊 full_target shape: {full_target.shape}, dtype: {full_target.dtype}")
|
| 468 |
+
print(f"📊 full_target first 10 values (all variates): {full_target[:, :10] if full_target.ndim > 1 else full_target[:10]}")
|
| 469 |
+
|
| 470 |
+
# Get start timestamp for this series and create timestamp array
|
| 471 |
+
series_start = dataset_obj.hf_dataset[series_idx]["start"]
|
| 472 |
+
print(f"📅 Series start timestamp: {series_start}, type: {type(series_start)}")
|
| 473 |
+
|
| 474 |
+
# Handle numpy array containing datetime64 (common when reading from HF dataset)
|
| 475 |
+
if isinstance(series_start, np.ndarray):
|
| 476 |
+
# Extract scalar from array
|
| 477 |
+
series_start = series_start.item() if series_start.ndim == 0 else series_start[0]
|
| 478 |
+
print(f"📅 Extracted scalar: {series_start}, type: {type(series_start)}")
|
| 479 |
+
|
| 480 |
+
# Convert numpy datetime64 to pandas Timestamp
|
| 481 |
+
if isinstance(series_start, (np.datetime64, str)):
|
| 482 |
+
series_start = pd.Timestamp(series_start)
|
| 483 |
+
|
| 484 |
+
# Calculate series length for timestamp creation
|
| 485 |
+
if full_target.ndim > 1:
|
| 486 |
+
ts_length = full_target.shape[1]
|
| 487 |
+
else:
|
| 488 |
+
ts_length = len(full_target)
|
| 489 |
+
|
| 490 |
+
# Create timestamp array for the entire series
|
| 491 |
+
try:
|
| 492 |
+
timestamps = pd.date_range(start=series_start, periods=ts_length, freq=dataset_freq)
|
| 493 |
+
print(f"📅 Created timestamp array: {timestamps[0]} to {timestamps[-1]}")
|
| 494 |
+
except Exception as e:
|
| 495 |
+
print(f"⚠️ Failed to create timestamps: {e}, falling back to indices")
|
| 496 |
+
timestamps = None
|
| 497 |
+
|
| 498 |
+
# Handle multivariate case: extract specific variate
|
| 499 |
+
if full_target.ndim > 1:
|
| 500 |
+
full_series = full_target[variate_idx, :] # Shape: (series_length,)
|
| 501 |
+
else:
|
| 502 |
+
full_series = full_target # Shape: (series_length,)
|
| 503 |
+
print(f"📊 full_series shape: {full_series.shape}, min: {full_series.min()}, max: {full_series.max()}, has_nan: {np.isnan(full_series).any()}")
|
| 504 |
+
|
| 505 |
+
# Calculate train/test split point
|
| 506 |
+
# Test data starts at: series_length - test_length
|
| 507 |
+
series_length = len(full_series)
|
| 508 |
+
train_end_idx = series_length - test_length
|
| 509 |
+
|
| 510 |
+
# Calculate current test window position
|
| 511 |
+
test_window_start_idx = train_end_idx + window_idx * prediction_length
|
| 512 |
+
test_window_end_idx = test_window_start_idx + prediction_length
|
| 513 |
+
|
| 514 |
+
# Create Plotly figure
|
| 515 |
+
fig = go.Figure()
|
| 516 |
+
|
| 517 |
+
# Quantile colors - from light to dark
|
| 518 |
+
quantile_colors = {
|
| 519 |
+
"0.1": "#c6dbef", "0.9": "#c6dbef", # lightest
|
| 520 |
+
"0.2": "#6baed6", "0.8": "#6baed6", # light
|
| 521 |
+
"0.3": "#4292c6", "0.7": "#4292c6", # medium
|
| 522 |
+
"0.4": "#2171b5", "0.6": "#2171b5", # dark
|
| 523 |
+
"0.5": "#08306b", # darkest (median)
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
# Calculate prediction time steps (overlay on the test window)
|
| 527 |
+
if test_window_start_idx is not None:
|
| 528 |
+
pred_time_steps = np.arange(test_window_start_idx, test_window_end_idx)
|
| 529 |
+
else:
|
| 530 |
+
pred_time_steps = np.arange(prediction_length)
|
| 531 |
+
|
| 532 |
+
# Plot full time series if available
|
| 533 |
+
time_steps = np.arange(len(full_series))
|
| 534 |
+
|
| 535 |
+
# Use timestamps for x-axis if available
|
| 536 |
+
if timestamps is not None:
|
| 537 |
+
x_full = timestamps
|
| 538 |
+
x_pred = timestamps[pred_time_steps] if test_window_start_idx is not None else timestamps[:prediction_length]
|
| 539 |
+
x_window = timestamps[test_window_start_idx:test_window_end_idx] if test_window_start_idx is not None else None
|
| 540 |
+
else:
|
| 541 |
+
x_full = time_steps
|
| 542 |
+
x_pred = pred_time_steps
|
| 543 |
+
x_window = np.arange(test_window_start_idx, test_window_end_idx) if test_window_start_idx is not None else None
|
| 544 |
+
|
| 545 |
+
# Plot full series in light gray
|
| 546 |
+
fig.add_trace(go.Scatter(
|
| 547 |
+
x=x_full,
|
| 548 |
+
y=full_series,
|
| 549 |
+
mode='lines',
|
| 550 |
+
name='Full Time Series',
|
| 551 |
+
line=dict(color='gray', width=1),
|
| 552 |
+
opacity=0.6,
|
| 553 |
+
hovertemplate='Time: %{x}<br>Value: %{y:.4f}<extra></extra>'
|
| 554 |
+
))
|
| 555 |
+
|
| 556 |
+
# Add shapes for regions (training, test, current window)
|
| 557 |
+
if train_end_idx is not None:
|
| 558 |
+
# Training region - use timestamps if available
|
| 559 |
+
x0_train = timestamps[0] if timestamps is not None else 0
|
| 560 |
+
x1_train = timestamps[train_end_idx] if timestamps is not None else train_end_idx
|
| 561 |
+
fig.add_shape(
|
| 562 |
+
type="rect",
|
| 563 |
+
x0=x0_train, x1=x1_train,
|
| 564 |
+
y0=0, y1=1, yref="paper",
|
| 565 |
+
fillcolor="blue", opacity=0.1,
|
| 566 |
+
layer="below", line_width=0,
|
| 567 |
+
)
|
| 568 |
+
# Test region
|
| 569 |
+
test_region_end = len(full_series)
|
| 570 |
+
x0_test = timestamps[train_end_idx] if timestamps is not None else train_end_idx
|
| 571 |
+
x1_test = timestamps[test_region_end-1] if timestamps is not None else test_region_end-1
|
| 572 |
+
fig.add_shape(
|
| 573 |
+
type="rect",
|
| 574 |
+
x0=x0_test, x1=x1_test,
|
| 575 |
+
y0=0, y1=1, yref="paper",
|
| 576 |
+
fillcolor="orange", opacity=0.15,
|
| 577 |
+
layer="below", line_width=0,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
# Highlight current test window
|
| 581 |
+
if test_window_start_idx is not None and test_window_end_idx is not None:
|
| 582 |
+
# Use timestamps for window highlight if available
|
| 583 |
+
x0_window = timestamps[test_window_start_idx] if timestamps is not None else test_window_start_idx
|
| 584 |
+
x1_window = timestamps[test_window_end_idx-1] if timestamps is not None else test_window_end_idx-1
|
| 585 |
+
fig.add_shape(
|
| 586 |
+
type="rect",
|
| 587 |
+
x0=x0_window, x1=x1_window,
|
| 588 |
+
y0=0, y1=1, yref="paper",
|
| 589 |
+
fillcolor="red", opacity=0.2,
|
| 590 |
+
layer="below", line_width=0,
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
# Plot the test window portion of full series
|
| 594 |
+
window_series = full_series[test_window_start_idx:test_window_end_idx]
|
| 595 |
+
fig.add_trace(go.Scatter(
|
| 596 |
+
x=x_window,
|
| 597 |
+
y=window_series,
|
| 598 |
+
mode='lines',
|
| 599 |
+
name='Ground Truth (Window)',
|
| 600 |
+
line=dict(color='red', width=2),
|
| 601 |
+
opacity=0.8,
|
| 602 |
+
hovertemplate='Time: %{x}<br>Value: %{y:.4f}<extra></extra>'
|
| 603 |
+
))
|
| 604 |
+
|
| 605 |
+
# Quantile pairs mapping: UI selection -> (low, high) quantile values
|
| 606 |
+
quantile_pair_map = {
|
| 607 |
+
"0.1-0.9": ("0.1", "0.9"),
|
| 608 |
+
"0.2-0.8": ("0.2", "0.8"),
|
| 609 |
+
"0.3-0.7": ("0.3", "0.7"),
|
| 610 |
+
"0.4-0.6": ("0.4", "0.6"),
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
# Helper function to get pre-computed quantile values
|
| 614 |
+
def get_quantile_values(q_str):
|
| 615 |
+
return quantiles_data[quantile_level_to_idx[q_str], :]
|
| 616 |
+
|
| 617 |
+
# Plot quantile pairs with fill (based on paired selection)
|
| 618 |
+
for pair_str, (q_low_str, q_high_str) in quantile_pair_map.items():
|
| 619 |
+
if pair_str in selected_quantiles:
|
| 620 |
+
quantile_low = get_quantile_values(q_low_str)
|
| 621 |
+
quantile_high = get_quantile_values(q_high_str)
|
| 622 |
+
color = quantile_colors.get(q_low_str, "#2171b5")
|
| 623 |
+
|
| 624 |
+
# Add filled area between quantiles
|
| 625 |
+
fig.add_trace(go.Scatter(
|
| 626 |
+
x=list(x_pred) + list(x_pred[::-1]),
|
| 627 |
+
y=list(quantile_high) + list(quantile_low[::-1]),
|
| 628 |
+
fill='toself',
|
| 629 |
+
fillcolor=color,
|
| 630 |
+
line=dict(color='rgba(255,255,255,0)'),
|
| 631 |
+
hoverinfo="skip",
|
| 632 |
+
showlegend=True,
|
| 633 |
+
name=f'Q{q_low_str}-Q{q_high_str}',
|
| 634 |
+
opacity=0.3
|
| 635 |
+
))
|
| 636 |
+
|
| 637 |
+
# Add lower quantile line
|
| 638 |
+
fig.add_trace(go.Scatter(
|
| 639 |
+
x=x_pred,
|
| 640 |
+
y=quantile_low,
|
| 641 |
+
mode='lines',
|
| 642 |
+
name=f'Q{q_low_str}',
|
| 643 |
+
line=dict(color=color, width=1),
|
| 644 |
+
opacity=0.7,
|
| 645 |
+
showlegend=False,
|
| 646 |
+
hovertemplate=f'Time: %{{x}}<br>Q{q_low_str}: %{{y:.4f}}<extra></extra>'
|
| 647 |
+
))
|
| 648 |
+
|
| 649 |
+
# Add upper quantile line
|
| 650 |
+
fig.add_trace(go.Scatter(
|
| 651 |
+
x=x_pred,
|
| 652 |
+
y=quantile_high,
|
| 653 |
+
mode='lines',
|
| 654 |
+
name=f'Q{q_high_str}',
|
| 655 |
+
line=dict(color=color, width=1),
|
| 656 |
+
opacity=0.7,
|
| 657 |
+
showlegend=False,
|
| 658 |
+
hovertemplate=f'Time: %{{x}}<br>Q{q_high_str}: %{{y:.4f}}<extra></extra>'
|
| 659 |
+
))
|
| 660 |
+
|
| 661 |
+
# Plot median (0.5) if selected
|
| 662 |
+
if "0.5" in selected_quantiles:
|
| 663 |
+
quantile_values = get_quantile_values("0.5")
|
| 664 |
+
color = quantile_colors.get("0.5", "#08306b")
|
| 665 |
+
|
| 666 |
+
fig.add_trace(go.Scatter(
|
| 667 |
+
x=x_pred,
|
| 668 |
+
y=quantile_values,
|
| 669 |
+
mode='lines+markers',
|
| 670 |
+
name='Median (Q0.5)',
|
| 671 |
+
line=dict(color=color, width=3),
|
| 672 |
+
marker=dict(size=5, symbol='circle'),
|
| 673 |
+
opacity=0.8,
|
| 674 |
+
hovertemplate='Time: %{x}<br>Q0.5: %{y:.4f}<extra></extra>'
|
| 675 |
+
))
|
| 676 |
+
|
| 677 |
+
# Update layout - use autosize for responsive width
|
| 678 |
+
x_axis_title = "Timestamp" if timestamps is not None else "Time Step"
|
| 679 |
+
fig.update_layout(
|
| 680 |
+
title=None,
|
| 681 |
+
xaxis_title=x_axis_title,
|
| 682 |
+
yaxis_title="Value",
|
| 683 |
+
hovermode='x unified',
|
| 684 |
+
autosize=True, # 使用自动宽度,让图表响应容器大小
|
| 685 |
+
height=400,
|
| 686 |
+
margin=dict(l=60, r=40, t=60, b=60), # 设置合理的边距
|
| 687 |
+
legend=dict(
|
| 688 |
+
orientation="h",
|
| 689 |
+
yanchor="bottom",
|
| 690 |
+
y=1.02,
|
| 691 |
+
xanchor="right",
|
| 692 |
+
x=1,
|
| 693 |
+
font=dict(size=14)
|
| 694 |
+
),
|
| 695 |
+
plot_bgcolor='white',
|
| 696 |
+
xaxis=dict(showgrid=True, gridcolor='lightgray', gridwidth=1),
|
| 697 |
+
yaxis=dict(showgrid=True, gridcolor='lightgray', gridwidth=1)
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
# Create info message for prediction window
|
| 702 |
+
if timestamps is not None and test_window_start_idx is not None and test_window_end_idx is not None:
|
| 703 |
+
pred_start_ts = timestamps[test_window_start_idx]
|
| 704 |
+
pred_end_ts = timestamps[test_window_end_idx - 1] # -1 because end index is exclusive
|
| 705 |
+
# Format with weekday name
|
| 706 |
+
start_str = f"{pred_start_ts.strftime('%Y-%m-%d %H:%M:%S')} ({pred_start_ts.day_name()})"
|
| 707 |
+
end_str = f"{pred_end_ts.strftime('%Y-%m-%d %H:%M:%S')} ({pred_end_ts.day_name()})"
|
| 708 |
+
base_info = (
|
| 709 |
+
f"📊 Prediction Length: {prediction_length}\n"
|
| 710 |
+
f"📅 Prediction Range: {start_str} → {end_str}\n"
|
| 711 |
+
f"🔄 Dataset Frequency: {dataset_freq}"
|
| 712 |
+
)
|
| 713 |
+
else:
|
| 714 |
+
base_info = (
|
| 715 |
+
f"📊 Prediction Length: {prediction_length}\n"
|
| 716 |
+
f"📅 Prediction Range: index {test_window_start_idx} → {test_window_end_idx - 1}\n"
|
| 717 |
+
f"🔄 Dataset Frequency: {dataset_freq if 'dataset_freq' in dir() else 'N/A'}"
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
# Get features information for the selected variate
|
| 721 |
+
# Pattern names from init_per_pattern_tab
|
| 722 |
+
pattern_names = [
|
| 723 |
+
"T_strength", "T_linearity",
|
| 724 |
+
"S_strength", "S_corr",
|
| 725 |
+
"R_ACF1",
|
| 726 |
+
"stationarity", "complexity"
|
| 727 |
+
]
|
| 728 |
+
|
| 729 |
+
features_info = ""
|
| 730 |
+
if not FEATURES_DF.empty and not FEATURES_BOOL_DF.empty:
|
| 731 |
+
# Find matching row in features dataframes
|
| 732 |
+
# Try to match by dataset_id, series_name, variate_name
|
| 733 |
+
feature_row_orig = None
|
| 734 |
+
feature_row_bool = None
|
| 735 |
+
|
| 736 |
+
# Match by dataset_id first
|
| 737 |
+
features_subset_orig = FEATURES_DF[FEATURES_DF["dataset_id"] == dataset_term]
|
| 738 |
+
features_subset_bool = FEATURES_BOOL_DF[FEATURES_BOOL_DF["dataset_id"] == dataset_term]
|
| 739 |
+
|
| 740 |
+
print(f"🔍 Features lookup: dataset_term={dataset_term}, series={series}, variate={variate}")
|
| 741 |
+
print(f"🔍 Features subset size: orig={len(features_subset_orig)}, bool={len(features_subset_bool)}")
|
| 742 |
+
|
| 743 |
+
# Try matching by series_name and variate_name (for MTS)
|
| 744 |
+
if not features_subset_orig.empty:
|
| 745 |
+
# Check if series_name matches
|
| 746 |
+
if "series_name" in features_subset_orig.columns:
|
| 747 |
+
series_match_orig = features_subset_orig["series_name"] == series
|
| 748 |
+
if series_match_orig.any():
|
| 749 |
+
series_matched = features_subset_orig[series_match_orig]
|
| 750 |
+
print(f"🔍 Found {len(series_matched)} rows with series_name={series}")
|
| 751 |
+
# Check if variate_name matches
|
| 752 |
+
if "variate_name" in series_matched.columns:
|
| 753 |
+
# For UTS, variate might be "0" or 0, try both
|
| 754 |
+
variate_str = str(variate)
|
| 755 |
+
variate_match_orig = (series_matched["variate_name"] == variate_str) | (series_matched["variate_name"] == variate)
|
| 756 |
+
if variate_match_orig.any():
|
| 757 |
+
feature_row_orig = series_matched[variate_match_orig].iloc[0]
|
| 758 |
+
print(f"✅ Found feature row by series_name + variate_name")
|
| 759 |
+
# Find corresponding row in bool dataframe
|
| 760 |
+
if not features_subset_bool.empty and "series_name" in features_subset_bool.columns and "variate_name" in features_subset_bool.columns:
|
| 761 |
+
series_match_bool = features_subset_bool["series_name"] == series
|
| 762 |
+
variate_match_bool = (features_subset_bool["variate_name"] == variate_str) | (features_subset_bool["variate_name"] == variate)
|
| 763 |
+
bool_matched = features_subset_bool[series_match_bool & variate_match_bool]
|
| 764 |
+
if not bool_matched.empty:
|
| 765 |
+
feature_row_bool = bool_matched.iloc[0]
|
| 766 |
+
|
| 767 |
+
# If not found, try matching by series_name only (for UTS cases where variate_name might not match)
|
| 768 |
+
if feature_row_orig is None and not features_subset_orig.empty:
|
| 769 |
+
if "series_name" in features_subset_orig.columns:
|
| 770 |
+
series_match_orig = features_subset_orig["series_name"] == series
|
| 771 |
+
if series_match_orig.any():
|
| 772 |
+
# For UTS, there might be only one row per series
|
| 773 |
+
series_matched = features_subset_orig[series_match_orig]
|
| 774 |
+
if len(series_matched) == 1:
|
| 775 |
+
feature_row_orig = series_matched.iloc[0]
|
| 776 |
+
print(f"✅ Found feature row by series_name only (UTS)")
|
| 777 |
+
# Find corresponding row in bool dataframe
|
| 778 |
+
if not features_subset_bool.empty and "series_name" in features_subset_bool.columns:
|
| 779 |
+
series_match_bool = features_subset_bool["series_name"] == series
|
| 780 |
+
bool_matched = features_subset_bool[series_match_bool]
|
| 781 |
+
if len(bool_matched) == 1:
|
| 782 |
+
feature_row_bool = bool_matched.iloc[0]
|
| 783 |
+
|
| 784 |
+
# If still not found, try matching by variate_name only (for UTS cases where variate_name == series)
|
| 785 |
+
if feature_row_orig is None and not features_subset_orig.empty:
|
| 786 |
+
if "variate_name" in features_subset_orig.columns:
|
| 787 |
+
variate_match_orig = features_subset_orig["variate_name"] == series # For UTS, series might be the variate_name
|
| 788 |
+
if variate_match_orig.any():
|
| 789 |
+
feature_row_orig = features_subset_orig[variate_match_orig].iloc[0]
|
| 790 |
+
print(f"✅ Found feature row by variate_name (series as variate_name)")
|
| 791 |
+
# Find corresponding row in bool dataframe
|
| 792 |
+
if not features_subset_bool.empty and "variate_name" in features_subset_bool.columns:
|
| 793 |
+
variate_match_bool = features_subset_bool["variate_name"] == series
|
| 794 |
+
if variate_match_bool.any():
|
| 795 |
+
feature_row_bool = features_subset_bool[variate_match_bool].iloc[0]
|
| 796 |
+
|
| 797 |
+
if feature_row_orig is None:
|
| 798 |
+
print(f"⚠️ Could not find features for dataset_term={dataset_term}, series={series}, variate={variate}")
|
| 799 |
+
if not features_subset_orig.empty:
|
| 800 |
+
print(f" Available series_names: {features_subset_orig['series_name'].unique()[:10] if 'series_name' in features_subset_orig.columns else 'N/A'}")
|
| 801 |
+
print(f" Available variate_names: {features_subset_orig['variate_name'].unique()[:10] if 'variate_name' in features_subset_orig.columns else 'N/A'}")
|
| 802 |
+
|
| 803 |
+
if feature_row_orig is not None:
|
| 804 |
+
# Build features display
|
| 805 |
+
features_orig_items = []
|
| 806 |
+
features_bool_items = []
|
| 807 |
+
|
| 808 |
+
for pattern_name in pattern_names:
|
| 809 |
+
# Map pattern name to feature column name
|
| 810 |
+
feature_col = PATTERN_MAP.get(pattern_name, pattern_name)
|
| 811 |
+
|
| 812 |
+
# Get original value (skip stationarity as it's derived from is_random_walk)
|
| 813 |
+
if pattern_name != "stationarity":
|
| 814 |
+
if feature_col in feature_row_orig.index:
|
| 815 |
+
orig_value = feature_row_orig[feature_col]
|
| 816 |
+
if pd.notna(orig_value):
|
| 817 |
+
features_orig_items.append(f"{pattern_name}: {orig_value:.3f}")
|
| 818 |
+
|
| 819 |
+
# Get binary value
|
| 820 |
+
if feature_row_bool is not None and feature_col in feature_row_bool.index:
|
| 821 |
+
bool_value = feature_row_bool[feature_col]
|
| 822 |
+
if pd.notna(bool_value):
|
| 823 |
+
# Special handling for stationarity (it's inverted)
|
| 824 |
+
if pattern_name == "stationarity":
|
| 825 |
+
# stationarity = NOT is_random_walk, so display the inverted value
|
| 826 |
+
display_bool = 1 - int(bool_value)
|
| 827 |
+
else:
|
| 828 |
+
display_bool = int(bool_value)
|
| 829 |
+
features_bool_items.append(f"{pattern_name}: {display_bool}")
|
| 830 |
+
|
| 831 |
+
if features_orig_items or features_bool_items:
|
| 832 |
+
features_info = "\n\n 📝 Features of variate:\n"
|
| 833 |
+
if features_orig_items:
|
| 834 |
+
features_info += "- Original Values: " + ", ".join(features_orig_items) + "\n"
|
| 835 |
+
if features_bool_items:
|
| 836 |
+
features_info += "- Binary Values (0/1): " + ", ".join(features_bool_items)
|
| 837 |
+
|
| 838 |
+
info_message = base_info + features_info
|
| 839 |
+
|
| 840 |
+
print(f"📝 Info message: {info_message}")
|
| 841 |
+
return fig, info_message
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
def init_overall_tab():
|
| 845 |
+
gr.Markdown(
|
| 846 |
+
"""
|
| 847 |
+
This tab presents each model's overall performance aggregated across all tasks. A **task** is defined as a specific **(dataset, horizon)** pair. For each task, the result is obtained by averaging the metrics across all its variates.
|
| 848 |
+
- **MASE (norm.), CRPS (norm.)**: task-level results are normalized by Seasonal Naive and aggregated by geometric mean.
|
| 849 |
+
- **MASE_rank, CRPS_rank**: for each task, models are ranked by the metric; the average rank across all tasks is then reported.
|
| 850 |
+
""",
|
| 851 |
+
elem_classes="markdown-text"
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
overall_table = gr.DataFrame(
|
| 855 |
+
value=get_overall_leaderboard(DATASETS_DF, metric="MASE"),
|
| 856 |
+
elem_classes="custom-table",
|
| 857 |
+
interactive=False
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
# CSV Export
|
| 861 |
+
def export_overall_csv():
|
| 862 |
+
df = get_overall_leaderboard(DATASETS_DF, metric="MASE")
|
| 863 |
+
return export_dataframe_to_csv(df, filename_prefix="overall_leaderboard")
|
| 864 |
+
|
| 865 |
+
with gr.Row():
|
| 866 |
+
export_btn = gr.Button("📥 Export CSV", size="sm")
|
| 867 |
+
export_file = gr.File(label="Download CSV", visible=False)
|
| 868 |
+
|
| 869 |
+
export_btn.click(
|
| 870 |
+
fn=export_overall_csv,
|
| 871 |
+
inputs=[],
|
| 872 |
+
outputs=[export_file]
|
| 873 |
+
).then(
|
| 874 |
+
fn=lambda: gr.File(visible=True),
|
| 875 |
+
inputs=[],
|
| 876 |
+
outputs=[export_file]
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
|
| 880 |
+
def init_per_dataset_tab(demo):
|
| 881 |
+
gr.Markdown(
|
| 882 |
+
"""
|
| 883 |
+
This tab provides flexible analysis at dataset, series, and variate levels.
|
| 884 |
+
|
| 885 |
+
- **Dataset only**: Shows both Seasonal Naive-normalized metrics (task-level) and original non-normalized metrics, plus average ranks
|
| 886 |
+
- **Series/Variate selected**: Shows only original metrics.
|
| 887 |
+
- **Horizons**: Select one or more horizons to aggregate results
|
| 888 |
+
""",
|
| 889 |
+
elem_classes="markdown-text"
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
+
with gr.Row():
|
| 893 |
+
with gr.Column(scale=1):
|
| 894 |
+
horizons = gr.CheckboxGroup(
|
| 895 |
+
choices=ALL_HORIZONS,
|
| 896 |
+
value=ALL_HORIZONS,
|
| 897 |
+
label="Horizons"
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
dataset_dropdown = gr.Dropdown(
|
| 901 |
+
choices=DATASET_CHOICES,
|
| 902 |
+
value=DATASET_CHOICES[0],
|
| 903 |
+
label="Dataset",
|
| 904 |
+
interactive=True
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
# Initialize series and variate dropdowns
|
| 908 |
+
initial_dataset = DATASET_CHOICES[0]
|
| 909 |
+
series_dropdown, variate_dropdown = update_series_and_variate(
|
| 910 |
+
initial_dataset
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
msg = gr.Textbox(label="Message", interactive=False)
|
| 914 |
+
table = gr.DataFrame(elem_classes="custom-table", interactive=False)
|
| 915 |
+
|
| 916 |
+
# Update series and variate dropdowns when dataset changes
|
| 917 |
+
dataset_dropdown.change(
|
| 918 |
+
fn=update_series_and_variate,
|
| 919 |
+
inputs=[dataset_dropdown],
|
| 920 |
+
outputs=[series_dropdown, variate_dropdown],
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
# Update leaderboard when any selection changes
|
| 924 |
+
for comp in [dataset_dropdown, series_dropdown, variate_dropdown, horizons]:
|
| 925 |
+
comp.change(
|
| 926 |
+
fn=get_dataset_multilevel_leaderboard,
|
| 927 |
+
inputs=[dataset_dropdown, series_dropdown, variate_dropdown, horizons],
|
| 928 |
+
outputs=[msg, table]
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
# Load on startup
|
| 932 |
+
demo.load(
|
| 933 |
+
fn=get_dataset_multilevel_leaderboard,
|
| 934 |
+
inputs=[dataset_dropdown, series_dropdown, variate_dropdown, horizons],
|
| 935 |
+
outputs=[msg, table]
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
# CSV Export
|
| 939 |
+
def export_dataset_csv(dataset, series, variate, horizons_val):
|
| 940 |
+
_, df = get_dataset_multilevel_leaderboard(dataset, series, variate, horizons_val)
|
| 941 |
+
# Sanitize dataset name for filename (replace / with _)
|
| 942 |
+
safe_dataset_name = dataset.replace("/", "_") if dataset else "unknown"
|
| 943 |
+
return export_dataframe_to_csv(df, filename_prefix=f"dataset_{safe_dataset_name}")
|
| 944 |
+
|
| 945 |
+
with gr.Row():
|
| 946 |
+
export_btn = gr.Button("📥 Export CSV", size="sm")
|
| 947 |
+
export_file = gr.File(label="Download CSV", visible=False)
|
| 948 |
+
|
| 949 |
+
export_btn.click(
|
| 950 |
+
fn=export_dataset_csv,
|
| 951 |
+
inputs=[dataset_dropdown, series_dropdown, variate_dropdown, horizons],
|
| 952 |
+
outputs=[export_file]
|
| 953 |
+
).then(
|
| 954 |
+
fn=lambda: gr.File(visible=True),
|
| 955 |
+
inputs=[],
|
| 956 |
+
outputs=[export_file]
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
+
|
| 960 |
+
def init_per_window_tab(demo):
|
| 961 |
+
gr.Markdown(
|
| 962 |
+
"""
|
| 963 |
+
This tab enables detailed analysis of model performance at the level of individual testing windows. By selecting a dataset, variate, horizon, and test window, users can examine window-level metrics (MASE, CRPS, MAE, MSE) at fine granularity and visualize the predicted quantiles of a model along with the ground-truth.
|
| 964 |
+
- **Interactive Visualization**: Zoom, pan, autoscale and download the plot.
|
| 965 |
+
- 🟦 Train Split 🟨 Test Split 🟥 Prediction Window
|
| 966 |
+
"""
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
QUANTILE_PAIR_CHOICES = ["0.1-0.9", "0.2-0.8", "0.3-0.7", "0.4-0.6", "0.5"]
|
| 970 |
+
initial_quantiles = ["0.5"]
|
| 971 |
+
|
| 972 |
+
with gr.Row():
|
| 973 |
+
with gr.Column(scale=1):
|
| 974 |
+
# Initialize horizon choices based on first dataset
|
| 975 |
+
initial_dataset = DATASET_CHOICES[0] if DATASET_CHOICES else None
|
| 976 |
+
initial_horizons = get_available_horizons(initial_dataset) if initial_dataset else ALL_HORIZONS
|
| 977 |
+
horizons = gr.Radio(
|
| 978 |
+
choices=initial_horizons,
|
| 979 |
+
value="short" if "short" in initial_horizons else (initial_horizons[0] if initial_horizons else "short"),
|
| 980 |
+
label="Horizons"
|
| 981 |
+
)
|
| 982 |
+
|
| 983 |
+
# Dropdown for dataset selection
|
| 984 |
+
dataset_dropdown = gr.Dropdown(
|
| 985 |
+
choices=DATASET_CHOICES,
|
| 986 |
+
value=DATASET_CHOICES[0] if DATASET_CHOICES else None, # 默认选第一个
|
| 987 |
+
label="Dataset",
|
| 988 |
+
interactive=True
|
| 989 |
+
)
|
| 990 |
+
|
| 991 |
+
# Initialize series, variate, window dropdowns using function
|
| 992 |
+
series_dropdown, variate_dropdown, window_dropdown = update_series_variate_and_window(
|
| 993 |
+
dataset_dropdown.value, horizons.value
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
with gr.Column(scale=2):
|
| 997 |
+
with gr.Row():
|
| 998 |
+
with gr.Column(scale=2):
|
| 999 |
+
quantiles = gr.CheckboxGroup(
|
| 1000 |
+
choices=QUANTILE_PAIR_CHOICES,
|
| 1001 |
+
value=initial_quantiles,
|
| 1002 |
+
label="Select Quantiles for Visualization"
|
| 1003 |
+
)
|
| 1004 |
+
with gr.Column(scale=1):
|
| 1005 |
+
model = gr.Dropdown(
|
| 1006 |
+
choices=ALL_MODELS,
|
| 1007 |
+
value=ALL_MODELS[0],
|
| 1008 |
+
label="Select Model for Visualization",
|
| 1009 |
+
interactive=True
|
| 1010 |
+
)
|
| 1011 |
+
ts_visualization = gr.Plot()
|
| 1012 |
+
# Message box for prediction window info
|
| 1013 |
+
prediction_info = gr.Textbox(
|
| 1014 |
+
label="Info",
|
| 1015 |
+
interactive=False,
|
| 1016 |
+
lines=3
|
| 1017 |
+
)
|
| 1018 |
+
|
| 1019 |
+
table_window = gr.DataFrame(elem_classes="custom-table", interactive=False)
|
| 1020 |
+
|
| 1021 |
+
# When dataset changes: first update horizon choices, then update dropdowns
|
| 1022 |
+
dataset_dropdown.change(
|
| 1023 |
+
fn=update_horizon_choices,
|
| 1024 |
+
inputs=[dataset_dropdown],
|
| 1025 |
+
outputs=[horizons],
|
| 1026 |
+
).then(
|
| 1027 |
+
fn=update_series_variate_and_window,
|
| 1028 |
+
inputs=[dataset_dropdown, horizons],
|
| 1029 |
+
outputs=[series_dropdown, variate_dropdown, window_dropdown],
|
| 1030 |
+
).then(
|
| 1031 |
+
# After dropdowns are updated, refresh the visualization and table
|
| 1032 |
+
fn=plot_window_series,
|
| 1033 |
+
inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons, quantiles, model],
|
| 1034 |
+
outputs=[ts_visualization, prediction_info]
|
| 1035 |
+
).then(
|
| 1036 |
+
fn=get_window_leaderboard,
|
| 1037 |
+
inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons],
|
| 1038 |
+
outputs=table_window
|
| 1039 |
+
)
|
| 1040 |
+
|
| 1041 |
+
# When horizon changes: update dropdowns, then refresh visualization
|
| 1042 |
+
horizons.change(
|
| 1043 |
+
fn=update_series_variate_and_window,
|
| 1044 |
+
inputs=[dataset_dropdown, horizons],
|
| 1045 |
+
outputs=[series_dropdown, variate_dropdown, window_dropdown],
|
| 1046 |
+
).then(
|
| 1047 |
+
fn=plot_window_series,
|
| 1048 |
+
inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons, quantiles, model],
|
| 1049 |
+
outputs=[ts_visualization, prediction_info]
|
| 1050 |
+
).then(
|
| 1051 |
+
fn=get_window_leaderboard,
|
| 1052 |
+
inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons],
|
| 1053 |
+
outputs=table_window
|
| 1054 |
+
)
|
| 1055 |
+
|
| 1056 |
+
# For series, variate, window changes - update visualization and table
|
| 1057 |
+
for comp in [series_dropdown, variate_dropdown, window_dropdown]:
|
| 1058 |
+
comp.change(
|
| 1059 |
+
fn=get_window_leaderboard,
|
| 1060 |
+
inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons],
|
| 1061 |
+
outputs=table_window
|
| 1062 |
+
)
|
| 1063 |
+
comp.change(
|
| 1064 |
+
fn=plot_window_series,
|
| 1065 |
+
inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons, quantiles, model],
|
| 1066 |
+
outputs=[ts_visualization, prediction_info]
|
| 1067 |
+
)
|
| 1068 |
+
|
| 1069 |
+
# For quantiles and model changes - only update visualization (no table change needed)
|
| 1070 |
+
for comp in [quantiles, model]:
|
| 1071 |
+
comp.change(
|
| 1072 |
+
fn=plot_window_series,
|
| 1073 |
+
inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons, quantiles, model],
|
| 1074 |
+
outputs=[ts_visualization, prediction_info]
|
| 1075 |
+
)
|
| 1076 |
+
|
| 1077 |
+
# Load initial visualization and table on page load
|
| 1078 |
+
demo.load(
|
| 1079 |
+
fn=plot_window_series,
|
| 1080 |
+
inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons, quantiles, model],
|
| 1081 |
+
outputs=[ts_visualization, prediction_info]
|
| 1082 |
+
)
|
| 1083 |
+
demo.load(
|
| 1084 |
+
fn=get_window_leaderboard,
|
| 1085 |
+
inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons],
|
| 1086 |
+
outputs=table_window
|
| 1087 |
+
)
|
| 1088 |
+
|
| 1089 |
+
# CSV Export
|
| 1090 |
+
def export_window_csv(dataset, series, variate, window, horizon):
|
| 1091 |
+
df = get_window_leaderboard(dataset, series, variate, window, horizon)
|
| 1092 |
+
return export_dataframe_to_csv(df, filename_prefix="window_leaderboard")
|
| 1093 |
+
|
| 1094 |
+
with gr.Row():
|
| 1095 |
+
export_btn = gr.Button("📥 Export CSV", size="sm")
|
| 1096 |
+
export_file = gr.File(label="Download CSV", visible=False)
|
| 1097 |
+
|
| 1098 |
+
export_btn.click(
|
| 1099 |
+
fn=export_window_csv,
|
| 1100 |
+
inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons],
|
| 1101 |
+
outputs=[export_file]
|
| 1102 |
+
).then(
|
| 1103 |
+
fn=lambda: gr.File(visible=True),
|
| 1104 |
+
inputs=[],
|
| 1105 |
+
outputs=[export_file]
|
| 1106 |
+
)
|
| 1107 |
+
|
| 1108 |
+
|
| 1109 |
+
def init_per_pattern_tab(demo):
|
| 1110 |
+
gr.Markdown(
|
| 1111 |
+
"""
|
| 1112 |
+
This tab allows you to explore model performance based on **selected patterns**.
|
| 1113 |
+
|
| 1114 |
+
Select patterns to filter variates that exhibit those characteristics, then view aggregated model performance.
|
| 1115 |
+
Each pattern is a **boolean indicator** derived from time series features (binarized by **median** threshold for continuous features).
|
| 1116 |
+
|
| 1117 |
+
- **Patterns are intersected**: A variate must exhibit ALL selected patterns to be included.
|
| 1118 |
+
- **MASE (norm.), CRPS (norm.)**: variate-level results are normalized by Seasonal Naive and aggregated by geometric mean across all matching variates.
|
| 1119 |
+
- **MASE (raw), CRPS (raw)**: arithmetic mean across all matching variates.
|
| 1120 |
+
""",
|
| 1121 |
+
elem_classes="markdown-text"
|
| 1122 |
+
)
|
| 1123 |
+
|
| 1124 |
+
# Define pattern choices for Radio components
|
| 1125 |
+
PATTERN_CHOICES = ["N/A", "=1", "=0"]
|
| 1126 |
+
|
| 1127 |
+
with gr.Row(): # TSFeatures
|
| 1128 |
+
with gr.Column(scale=1):
|
| 1129 |
+
with gr.Group():
|
| 1130 |
+
gr.Markdown("### 📈 Trend Features")
|
| 1131 |
+
T_strength = gr.Radio(
|
| 1132 |
+
choices=PATTERN_CHOICES, value="N/A", label="T_strength"
|
| 1133 |
+
)
|
| 1134 |
+
T_linearity = gr.Radio(
|
| 1135 |
+
choices=PATTERN_CHOICES, value="N/A", label="T_linearity"
|
| 1136 |
+
)
|
| 1137 |
+
with gr.Column(scale=1):
|
| 1138 |
+
with gr.Group():
|
| 1139 |
+
gr.Markdown("### 🔄 Seasonal Features")
|
| 1140 |
+
S_strength = gr.Radio(
|
| 1141 |
+
choices=PATTERN_CHOICES, value="N/A", label="S_strength"
|
| 1142 |
+
)
|
| 1143 |
+
S_corr = gr.Radio(
|
| 1144 |
+
choices=PATTERN_CHOICES, value="N/A", label="S_corr"
|
| 1145 |
+
)
|
| 1146 |
+
with gr.Column(scale=1):
|
| 1147 |
+
with gr.Group():
|
| 1148 |
+
gr.Markdown("### 🎯 Residual Features")
|
| 1149 |
+
R_ACF1 = gr.Radio(
|
| 1150 |
+
choices=PATTERN_CHOICES, value="N/A", label="R_ACF1"
|
| 1151 |
+
)
|
| 1152 |
+
with gr.Column(scale=1):
|
| 1153 |
+
with gr.Group():
|
| 1154 |
+
gr.Markdown("### ⚙️ Global Features")
|
| 1155 |
+
stationarity = gr.Radio(
|
| 1156 |
+
choices=PATTERN_CHOICES, value="N/A", label="stationarity"
|
| 1157 |
+
)
|
| 1158 |
+
complexity = gr.Radio(
|
| 1159 |
+
choices=PATTERN_CHOICES, value="N/A", label="complexity"
|
| 1160 |
+
)
|
| 1161 |
+
|
| 1162 |
+
# List of all pattern Radio components and their names
|
| 1163 |
+
pattern_radios = [
|
| 1164 |
+
T_strength, T_linearity,
|
| 1165 |
+
S_strength, S_corr,
|
| 1166 |
+
R_ACF1,
|
| 1167 |
+
stationarity, complexity
|
| 1168 |
+
]
|
| 1169 |
+
pattern_names = [
|
| 1170 |
+
"T_strength", "T_linearity",
|
| 1171 |
+
"S_strength", "S_corr",
|
| 1172 |
+
"R_ACF1",
|
| 1173 |
+
"stationarity", "complexity"
|
| 1174 |
+
]
|
| 1175 |
+
|
| 1176 |
+
with gr.Row():
|
| 1177 |
+
with gr.Column(scale=1):
|
| 1178 |
+
horizons = gr.CheckboxGroup(
|
| 1179 |
+
choices=ALL_HORIZONS,
|
| 1180 |
+
value=ALL_HORIZONS,
|
| 1181 |
+
label="Horizons"
|
| 1182 |
+
)
|
| 1183 |
+
with gr.Column(scale=2):
|
| 1184 |
+
msg_pattern = gr.Textbox(label="Status", interactive=False, lines=4)
|
| 1185 |
+
|
| 1186 |
+
table_variates = gr.DataFrame(elem_classes="custom-table", interactive=False)
|
| 1187 |
+
|
| 1188 |
+
def merge_patterns(*radio_values):
|
| 1189 |
+
"""Convert Radio values to pattern filter dict.
|
| 1190 |
+
|
| 1191 |
+
Args:
|
| 1192 |
+
*radio_values: Values from all Radio components in order of pattern_names
|
| 1193 |
+
|
| 1194 |
+
Returns:
|
| 1195 |
+
dict: {feature_name: required_value} where required_value is 0 or 1.
|
| 1196 |
+
Features with "N/A" are not included in the dict.
|
| 1197 |
+
"""
|
| 1198 |
+
result = {}
|
| 1199 |
+
for name, value in zip(pattern_names, radio_values):
|
| 1200 |
+
if value == "=1":
|
| 1201 |
+
result[name] = 1
|
| 1202 |
+
elif value == "=0":
|
| 1203 |
+
result[name] = 0
|
| 1204 |
+
# "N/A" -> don't include in dict (no filter on this feature)
|
| 1205 |
+
return result
|
| 1206 |
+
|
| 1207 |
+
def update_leaderboard(*args):
|
| 1208 |
+
"""Callback to update the pattern leaderboard.
|
| 1209 |
+
|
| 1210 |
+
Args:
|
| 1211 |
+
*args: All Radio values followed by horizons (last argument)
|
| 1212 |
+
"""
|
| 1213 |
+
# Last argument is horizons, rest are pattern radio values
|
| 1214 |
+
horizons_val = args[-1]
|
| 1215 |
+
radio_values = args[:-1]
|
| 1216 |
+
pattern_filters = merge_patterns(*radio_values)
|
| 1217 |
+
return get_pattern_leaderboard(pattern_filters, horizons_val)
|
| 1218 |
+
|
| 1219 |
+
# Bind change events for all pattern radios and horizons
|
| 1220 |
+
all_inputs = pattern_radios + [horizons]
|
| 1221 |
+
for comp in all_inputs:
|
| 1222 |
+
comp.change(
|
| 1223 |
+
fn=update_leaderboard,
|
| 1224 |
+
inputs=all_inputs,
|
| 1225 |
+
outputs=[msg_pattern, table_variates]
|
| 1226 |
+
)
|
| 1227 |
+
|
| 1228 |
+
# Load initial state
|
| 1229 |
+
demo.load(
|
| 1230 |
+
fn=update_leaderboard,
|
| 1231 |
+
inputs=all_inputs,
|
| 1232 |
+
outputs=[msg_pattern, table_variates]
|
| 1233 |
+
)
|
| 1234 |
+
|
| 1235 |
+
# CSV Export
|
| 1236 |
+
def export_pattern_csv(*args):
|
| 1237 |
+
# Last argument is horizons, rest are pattern radio values
|
| 1238 |
+
horizons_val = args[-1]
|
| 1239 |
+
radio_values = args[:-1]
|
| 1240 |
+
pattern_filters = merge_patterns(*radio_values)
|
| 1241 |
+
_, df = get_pattern_leaderboard(pattern_filters, horizons_val)
|
| 1242 |
+
return export_dataframe_to_csv(df, filename_prefix="pattern_leaderboard")
|
| 1243 |
+
|
| 1244 |
+
with gr.Row():
|
| 1245 |
+
export_btn = gr.Button("📥 Export CSV", size="sm")
|
| 1246 |
+
export_file = gr.File(label="Download CSV", visible=False)
|
| 1247 |
+
|
| 1248 |
+
export_btn.click(
|
| 1249 |
+
fn=export_pattern_csv,
|
| 1250 |
+
inputs=all_inputs,
|
| 1251 |
+
outputs=[export_file]
|
| 1252 |
+
).then(
|
| 1253 |
+
fn=lambda: gr.File(visible=True),
|
| 1254 |
+
inputs=[],
|
| 1255 |
+
outputs=[export_file]
|
| 1256 |
+
)
|
| 1257 |
+
|
| 1258 |
+
|
| 1259 |
+
# # ToDO: Now the archive is using different features from the ones in per_pattern tab
|
| 1260 |
+
# def init_archive_tab(demo):
|
| 1261 |
+
# gr.Markdown(
|
| 1262 |
+
# """
|
| 1263 |
+
# This tab provides an interactive archive of the features of time series variates across datasets. You can explore the archive by specifying a dataset, domain, and frequency, and filter variates with the selected structural patterns. Each pattern is a **boolean indicator** showing whether a variate exhibits the pattern, with thresholds derived from the distribution of feature values across the entire dataset. Pattern filters are applied as an **intersection** (a variate must exhibit all selected patterns). Domain and frequency filters are applied as a **union** (a variate may belong to any selected category). The resulting table displays all variates that satisfy the chosen filters, together with their dataset, frequency, domain, and computed feature values. This view makes it possible to identify and group variates that share similar feature profiles.
|
| 1264 |
+
# """
|
| 1265 |
+
# )
|
| 1266 |
+
|
| 1267 |
+
# with gr.Row():
|
| 1268 |
+
# with gr.Column(scale=1):
|
| 1269 |
+
# dataset_dropdown = gr.Dropdown(
|
| 1270 |
+
# choices=["All"] + sorted(FEATURES_BOOL_DF["dataset"].unique().tolist()),
|
| 1271 |
+
# value="All",
|
| 1272 |
+
# label="Select Dataset"
|
| 1273 |
+
# )
|
| 1274 |
+
|
| 1275 |
+
# variate_dropdown = gr.Dropdown(
|
| 1276 |
+
# choices=["All"],
|
| 1277 |
+
# value="All",
|
| 1278 |
+
# label="Select Variate",
|
| 1279 |
+
# interactive=False
|
| 1280 |
+
# )
|
| 1281 |
+
|
| 1282 |
+
# domains = gr.CheckboxGroup(
|
| 1283 |
+
# choices=ALL_DOMAINS,
|
| 1284 |
+
# value=ALL_DOMAINS, # default all checked
|
| 1285 |
+
# label="Domains"
|
| 1286 |
+
# )
|
| 1287 |
+
|
| 1288 |
+
# freqs = gr.CheckboxGroup(
|
| 1289 |
+
# choices=ALL_FREQS,
|
| 1290 |
+
# value=ALL_FREQS, # 默认全选
|
| 1291 |
+
# label="Frequencies"
|
| 1292 |
+
# )
|
| 1293 |
+
|
| 1294 |
+
# with gr.Column(scale=2):
|
| 1295 |
+
# trend_group = gr.CheckboxGroup(
|
| 1296 |
+
# choices=["trend", "trend_stability", "trend_lumpiness", "trend_hurst", "trend_entropy"],
|
| 1297 |
+
# label="Trend Patterns"
|
| 1298 |
+
# )
|
| 1299 |
+
|
| 1300 |
+
# season_group = gr.CheckboxGroup(
|
| 1301 |
+
# choices=["seasonal_strength", "seasonality_corr", "seasonal_stability",
|
| 1302 |
+
# "seasonal_lumpiness", "seasonal_hurst", "seasonal_entropy"],
|
| 1303 |
+
# label="Seasonality Patterns"
|
| 1304 |
+
# )
|
| 1305 |
+
|
| 1306 |
+
# remainder_group = gr.CheckboxGroup(
|
| 1307 |
+
# choices=["e_acf1", "e_acf10",
|
| 1308 |
+
# "e_entropy", "e_hurst", "e_lumpiness", "e_outlier_ratio"],
|
| 1309 |
+
# label="Remainder Patterns"
|
| 1310 |
+
# )
|
| 1311 |
+
|
| 1312 |
+
# global_group = gr.CheckboxGroup(
|
| 1313 |
+
# choices=["x_acf1", "x_acf10", "lumpiness", "stability", "hurst", "entropy"],
|
| 1314 |
+
# label="Global Patterns"
|
| 1315 |
+
# )
|
| 1316 |
+
|
| 1317 |
+
# msg_box = gr.Textbox(
|
| 1318 |
+
# label="Message",
|
| 1319 |
+
# interactive=False
|
| 1320 |
+
# )
|
| 1321 |
+
|
| 1322 |
+
# archive_leaderboard = gr.DataFrame(
|
| 1323 |
+
# elem_classes="custom-table",
|
| 1324 |
+
# elem_id="archive-table",
|
| 1325 |
+
# max_height=600,
|
| 1326 |
+
# interactive=False
|
| 1327 |
+
# )
|
| 1328 |
+
|
| 1329 |
+
# # 绑定事件
|
| 1330 |
+
# domains.change(
|
| 1331 |
+
# fn=update_dataset_choices,
|
| 1332 |
+
# inputs=[domains, freqs],
|
| 1333 |
+
# outputs=dataset_dropdown
|
| 1334 |
+
# )
|
| 1335 |
+
|
| 1336 |
+
# freqs.change(
|
| 1337 |
+
# fn=update_dataset_choices,
|
| 1338 |
+
# inputs=[domains, freqs],
|
| 1339 |
+
# outputs=dataset_dropdown
|
| 1340 |
+
# )
|
| 1341 |
+
|
| 1342 |
+
# # Change DF
|
| 1343 |
+
# for comp in [dataset_dropdown, trend_group, season_group, remainder_group, global_group]:
|
| 1344 |
+
# comp.change(
|
| 1345 |
+
# fn=update_variate_choices_groups,
|
| 1346 |
+
# inputs=[dataset_dropdown, trend_group, season_group, remainder_group, global_group],
|
| 1347 |
+
# outputs=variate_dropdown
|
| 1348 |
+
# )
|
| 1349 |
+
# comp.change(
|
| 1350 |
+
# fn=collect_patterns,
|
| 1351 |
+
# inputs=[dataset_dropdown, trend_group, season_group, remainder_group, global_group,
|
| 1352 |
+
# variate_dropdown, domains, freqs],
|
| 1353 |
+
# outputs=[msg_box, archive_leaderboard]
|
| 1354 |
+
# )
|
| 1355 |
+
|
| 1356 |
+
# for comp in [variate_dropdown, domains, freqs]:
|
| 1357 |
+
# comp.change(
|
| 1358 |
+
# fn=collect_patterns,
|
| 1359 |
+
# inputs=[dataset_dropdown, trend_group, season_group, remainder_group, global_group,
|
| 1360 |
+
# variate_dropdown, domains, freqs],
|
| 1361 |
+
# outputs=[msg_box, archive_leaderboard]
|
| 1362 |
+
# )
|
| 1363 |
+
|
| 1364 |
+
# # Initial Load
|
| 1365 |
+
# demo.load(
|
| 1366 |
+
# fn=collect_patterns,
|
| 1367 |
+
# inputs=[dataset_dropdown, trend_group, season_group, remainder_group, global_group,
|
| 1368 |
+
# variate_dropdown, domains, freqs],
|
| 1369 |
+
# outputs=[msg_box, archive_leaderboard]
|
| 1370 |
+
# )
|
src/utils.py
ADDED
|
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import json
|
| 5 |
+
from typing import List, Tuple, Optional
|
| 6 |
+
import yaml
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from scipy import stats
|
| 9 |
+
|
| 10 |
+
from timebench.evaluation.data import Dataset, get_dataset_settings, load_dataset_config
|
| 11 |
+
from src.hf_config import get_datasets_root, get_config_root
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_time_results(root_dir, model_name, dataset_with_freq, horizon):
|
| 15 |
+
"""
|
| 16 |
+
Load TIME results from NPZ files for a specific model, dataset, and horizon.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
root_dir: Root directory containing TIME results (e.g., "output/results")
|
| 20 |
+
model_name: Model name (e.g., "moirai_small")
|
| 21 |
+
dataset_with_freq: Dataset and freq combined (e.g., "Water_Quality_Darwin/15T")
|
| 22 |
+
horizon: Horizon name (e.g., "short", "medium", "long")
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
tuple: (metrics_dict, predictions_dict, config_dict) or (None, None, None) if not found
|
| 26 |
+
"""
|
| 27 |
+
horizon_dir = os.path.join(root_dir, model_name, dataset_with_freq, horizon)
|
| 28 |
+
metrics_path = os.path.join(horizon_dir, "metrics.npz")
|
| 29 |
+
predictions_path = os.path.join(horizon_dir, "predictions.npz")
|
| 30 |
+
config_path = os.path.join(horizon_dir, "config.json")
|
| 31 |
+
|
| 32 |
+
if not os.path.exists(metrics_path) or not os.path.exists(predictions_path):
|
| 33 |
+
return None, None, None
|
| 34 |
+
|
| 35 |
+
metrics = np.load(metrics_path)
|
| 36 |
+
predictions = np.load(predictions_path)
|
| 37 |
+
|
| 38 |
+
metrics_dict = {k: metrics[k] for k in metrics.files}
|
| 39 |
+
predictions_dict = {k: predictions[k] for k in predictions.files}
|
| 40 |
+
|
| 41 |
+
config_dict = {}
|
| 42 |
+
if os.path.exists(config_path):
|
| 43 |
+
with open(config_path, "r") as f:
|
| 44 |
+
config_dict = json.load(f)
|
| 45 |
+
|
| 46 |
+
return metrics_dict, predictions_dict, config_dict
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_all_datasets_results(root_dir="output/results"):
|
| 50 |
+
"""
|
| 51 |
+
Load dataset-level leaderboard by reading TIME NPZ files and aggregating.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
root_dir (str): Path to the TIME results root directory (e.g., "output/results").
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
pd.DataFrame: DataFrame containing dataset-level results with columns
|
| 58 |
+
["model", "dataset", "freq", "dataset_id", "horizon", "MASE", "CRPS", "MAE", "MSE"].
|
| 59 |
+
- dataset: Original dataset name (e.g., "Traffic")
|
| 60 |
+
- freq: Frequency string (e.g., "15T", "1H")
|
| 61 |
+
- dataset_id: Unique identifier as "dataset/freq" (e.g., "Traffic/15T")
|
| 62 |
+
Number of Rows: num_model x num_dataset_freq_combinations x num_horizons
|
| 63 |
+
"""
|
| 64 |
+
rows = []
|
| 65 |
+
|
| 66 |
+
if not os.path.exists(root_dir):
|
| 67 |
+
print(f"Error: root_dir={root_dir} does not exist")
|
| 68 |
+
return pd.DataFrame(columns=["model", "dataset", "freq", "dataset_id", "horizon", "MASE", "CRPS", "MAE", "MSE"])
|
| 69 |
+
|
| 70 |
+
for model in os.listdir(root_dir):
|
| 71 |
+
model_dir = os.path.join(root_dir, model)
|
| 72 |
+
if not os.path.isdir(model_dir):
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
for dataset in os.listdir(model_dir):
|
| 76 |
+
dataset_dir = os.path.join(model_dir, dataset)
|
| 77 |
+
if not os.path.isdir(dataset_dir):
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
# Nested structure: model/dataset/freq/horizon/
|
| 81 |
+
for freq_dir in os.listdir(dataset_dir):
|
| 82 |
+
freq_path = os.path.join(dataset_dir, freq_dir)
|
| 83 |
+
if not os.path.isdir(freq_path):
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
for horizon in ["short", "medium", "long"]:
|
| 87 |
+
dataset_with_freq = f"{dataset}/{freq_dir}"
|
| 88 |
+
metrics_dict, _, config_dict = load_time_results(root_dir, model, dataset_with_freq, horizon)
|
| 89 |
+
if metrics_dict is None:
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
# Aggregate metrics
|
| 93 |
+
mase = np.nanmean(metrics_dict.get("MASE", np.array([])))
|
| 94 |
+
crps = np.nanmean(metrics_dict.get("CRPS", np.array([])))
|
| 95 |
+
mae = np.nanmean(metrics_dict.get("MAE", np.array([])))
|
| 96 |
+
mse = np.nanmean(metrics_dict.get("MSE", np.array([])))
|
| 97 |
+
|
| 98 |
+
rows.append({
|
| 99 |
+
"model": model,
|
| 100 |
+
"dataset": dataset,
|
| 101 |
+
"freq": freq_dir,
|
| 102 |
+
"dataset_id": dataset_with_freq, # Unique identifier: dataset/freq
|
| 103 |
+
"horizon": horizon,
|
| 104 |
+
"MASE": mase,
|
| 105 |
+
"CRPS": crps,
|
| 106 |
+
"MAE": mae,
|
| 107 |
+
"MSE": mse,
|
| 108 |
+
})
|
| 109 |
+
|
| 110 |
+
if rows:
|
| 111 |
+
return pd.DataFrame(rows)
|
| 112 |
+
else:
|
| 113 |
+
return pd.DataFrame(columns=["model", "dataset", "freq", "dataset_id", "horizon", "MASE", "CRPS", "MAE", "MSE"])
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_dataset_display_map(datasets_df: pd.DataFrame) -> Tuple[dict, dict]:
|
| 117 |
+
"""
|
| 118 |
+
Generate smart display name mapping for datasets.
|
| 119 |
+
|
| 120 |
+
For datasets with only one freq: display as "dataset" (e.g., "Australia_Solar")
|
| 121 |
+
For datasets with multiple freqs: display as "dataset/freq" (e.g., "Traffic/15T")
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
datasets_df: DataFrame with 'dataset', 'freq', 'dataset_id' columns
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
Tuple of:
|
| 128 |
+
- id_to_display: dict mapping dataset_id -> display_name
|
| 129 |
+
- display_to_id: dict mapping display_name -> dataset_id
|
| 130 |
+
"""
|
| 131 |
+
if datasets_df.empty:
|
| 132 |
+
return {}, {}
|
| 133 |
+
|
| 134 |
+
# Count unique freqs per dataset
|
| 135 |
+
freq_counts = datasets_df.groupby('dataset')['freq'].nunique()
|
| 136 |
+
|
| 137 |
+
# Build mappings
|
| 138 |
+
id_to_display = {}
|
| 139 |
+
display_to_id = {}
|
| 140 |
+
|
| 141 |
+
unique_configs = datasets_df[['dataset', 'freq', 'dataset_id']].drop_duplicates()
|
| 142 |
+
|
| 143 |
+
for _, row in unique_configs.iterrows():
|
| 144 |
+
dataset_id = row['dataset_id']
|
| 145 |
+
dataset_name = row['dataset']
|
| 146 |
+
|
| 147 |
+
if freq_counts[dataset_name] > 1:
|
| 148 |
+
# Multiple freqs: display as dataset/freq
|
| 149 |
+
display_name = dataset_id
|
| 150 |
+
else:
|
| 151 |
+
# Single freq: display as dataset only
|
| 152 |
+
display_name = dataset_name
|
| 153 |
+
|
| 154 |
+
id_to_display[dataset_id] = display_name
|
| 155 |
+
display_to_id[display_name] = dataset_id
|
| 156 |
+
|
| 157 |
+
return id_to_display, display_to_id
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def get_all_variates_results(root_dir: str = "output/results") -> pd.DataFrame:
|
| 161 |
+
"""
|
| 162 |
+
Collect all variate-individual-level results from TIME NPZ files.
|
| 163 |
+
|
| 164 |
+
Each (series, variate) combination is treated as an independent variate individual.
|
| 165 |
+
Metrics are aggregated only across windows (not across series).
|
| 166 |
+
Uses actual series_names and variate_names from Dataset objects.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
root_dir (str): Path to the TIME results root directory (e.g., "output/results").
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
pd.DataFrame: DataFrame with columns:
|
| 173 |
+
["dataset_id", "series_name", "variate_name", "is_uts", "model", "horizon", "MASE", "CRPS", "MAE", "MSE"]
|
| 174 |
+
Number of Rows: num_models x num_datasets x num_horizons x num_series x num_variates
|
| 175 |
+
"""
|
| 176 |
+
rows = []
|
| 177 |
+
|
| 178 |
+
if not os.path.exists(root_dir):
|
| 179 |
+
print(f"[get_all_variates_results] root_dir={root_dir} does not exist")
|
| 180 |
+
return pd.DataFrame(columns=["dataset_id", "series_name", "variate_name", "is_uts", "model", "horizon", "MASE", "CRPS", "MAE", "MSE"])
|
| 181 |
+
|
| 182 |
+
# Cache for dataset info (series_names, variate_names) to avoid repeated loading
|
| 183 |
+
dataset_info_cache = {}
|
| 184 |
+
|
| 185 |
+
for model in os.listdir(root_dir):
|
| 186 |
+
model_dir = os.path.join(root_dir, model)
|
| 187 |
+
if not os.path.isdir(model_dir):
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
for dataset in os.listdir(model_dir):
|
| 191 |
+
dataset_dir = os.path.join(model_dir, dataset)
|
| 192 |
+
if not os.path.isdir(dataset_dir):
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
# Nested structure: model/dataset/freq/horizon/
|
| 196 |
+
for freq_dir in os.listdir(dataset_dir):
|
| 197 |
+
freq_path = os.path.join(dataset_dir, freq_dir)
|
| 198 |
+
if not os.path.isdir(freq_path):
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
dataset_id = f"{dataset}/{freq_dir}"
|
| 202 |
+
|
| 203 |
+
# Get series_names and variate_names (use cache)
|
| 204 |
+
if dataset_id not in dataset_info_cache:
|
| 205 |
+
series_names = None
|
| 206 |
+
variate_names = None
|
| 207 |
+
is_uts = False
|
| 208 |
+
try:
|
| 209 |
+
hf_dataset_root = str(get_datasets_root())
|
| 210 |
+
if os.path.exists(hf_dataset_root):
|
| 211 |
+
config_root = get_config_root()
|
| 212 |
+
config_path = config_root / "datasets.yaml"
|
| 213 |
+
config = load_dataset_config(config_path) if config_path.exists() else {}
|
| 214 |
+
settings = get_dataset_settings(dataset_id, "short", config)
|
| 215 |
+
|
| 216 |
+
dataset_obj = Dataset(
|
| 217 |
+
name=dataset_id,
|
| 218 |
+
term="short",
|
| 219 |
+
prediction_length=settings.get("prediction_length"),
|
| 220 |
+
test_length=settings.get("test_length"),
|
| 221 |
+
storage_path=hf_dataset_root,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# Get series names
|
| 225 |
+
if "item_id" in dataset_obj.hf_dataset.column_names:
|
| 226 |
+
series_names = list(dataset_obj.hf_dataset["item_id"])
|
| 227 |
+
else:
|
| 228 |
+
series_names = [f"item_{i}" for i in range(len(dataset_obj.hf_dataset))]
|
| 229 |
+
|
| 230 |
+
# Get variate names
|
| 231 |
+
variate_names = dataset_obj.get_variate_names()
|
| 232 |
+
if variate_names is None:
|
| 233 |
+
# UTS mode: variate_names = series_names, and is_uts = True
|
| 234 |
+
is_uts = True
|
| 235 |
+
variate_names = series_names
|
| 236 |
+
else:
|
| 237 |
+
variate_names = list(variate_names)
|
| 238 |
+
except Exception as e:
|
| 239 |
+
print(f"[get_all_variates_results] Error loading Dataset info for {dataset_id}: {e}")
|
| 240 |
+
|
| 241 |
+
dataset_info_cache[dataset_id] = {
|
| 242 |
+
"series_names": series_names,
|
| 243 |
+
"variate_names": variate_names,
|
| 244 |
+
"is_uts": is_uts,
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
info = dataset_info_cache[dataset_id]
|
| 248 |
+
series_names = info["series_names"]
|
| 249 |
+
variate_names = info["variate_names"]
|
| 250 |
+
is_uts = info["is_uts"]
|
| 251 |
+
|
| 252 |
+
for horizon in ["short", "medium", "long"]:
|
| 253 |
+
metrics_dict, _, _ = load_time_results(root_dir, model, dataset_id, horizon)
|
| 254 |
+
if metrics_dict is None:
|
| 255 |
+
continue
|
| 256 |
+
|
| 257 |
+
# Get metrics arrays: shape = (num_series, num_windows, num_variates)
|
| 258 |
+
mase_arr = metrics_dict.get("MASE", np.array([]))
|
| 259 |
+
crps_arr = metrics_dict.get("CRPS", np.array([]))
|
| 260 |
+
mae_arr = metrics_dict.get("MAE", np.array([]))
|
| 261 |
+
mse_arr = metrics_dict.get("MSE", np.array([]))
|
| 262 |
+
|
| 263 |
+
if mase_arr.size == 0:
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
num_series, num_windows, num_variates = mase_arr.shape
|
| 267 |
+
|
| 268 |
+
# Iterate over each (series, variate) combination
|
| 269 |
+
for series_idx in range(num_series):
|
| 270 |
+
series_name = series_names[series_idx] if series_names and series_idx < len(series_names) else f"item_{series_idx}"
|
| 271 |
+
|
| 272 |
+
for variate_idx in range(num_variates):
|
| 273 |
+
# For UTS: variate_name = series_name (since each series is its own variate)
|
| 274 |
+
if is_uts:
|
| 275 |
+
variate_name = series_name
|
| 276 |
+
else:
|
| 277 |
+
variate_name = variate_names[variate_idx] if variate_names and variate_idx < len(variate_names) else str(variate_idx)
|
| 278 |
+
|
| 279 |
+
# Aggregate only across windows
|
| 280 |
+
mase = np.nanmean(mase_arr[series_idx, :, variate_idx])
|
| 281 |
+
crps = np.nanmean(crps_arr[series_idx, :, variate_idx])
|
| 282 |
+
mae = np.nanmean(mae_arr[series_idx, :, variate_idx])
|
| 283 |
+
mse = np.nanmean(mse_arr[series_idx, :, variate_idx])
|
| 284 |
+
|
| 285 |
+
# Skip if all values are NaN
|
| 286 |
+
if np.isnan(mase) and np.isnan(crps):
|
| 287 |
+
continue
|
| 288 |
+
|
| 289 |
+
rows.append({
|
| 290 |
+
"dataset_id": dataset_id,
|
| 291 |
+
"series_name": series_name,
|
| 292 |
+
"variate_name": variate_name,
|
| 293 |
+
"is_uts": is_uts,
|
| 294 |
+
"model": model,
|
| 295 |
+
"horizon": horizon,
|
| 296 |
+
"MASE": mase,
|
| 297 |
+
"CRPS": crps,
|
| 298 |
+
"MAE": mae,
|
| 299 |
+
"MSE": mse,
|
| 300 |
+
})
|
| 301 |
+
|
| 302 |
+
if rows:
|
| 303 |
+
return pd.DataFrame(rows)
|
| 304 |
+
else:
|
| 305 |
+
return pd.DataFrame(columns=["dataset_id", "series_name", "variate_name", "is_uts", "model", "horizon", "MASE", "CRPS", "MAE", "MSE"])
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def get_all_domains_and_freq(conf_dir="conf/data", datasets=None):
|
| 309 |
+
"""
|
| 310 |
+
Scan YAML files and collect all unique domains.
|
| 311 |
+
"""
|
| 312 |
+
domains, freqs = set(), set()
|
| 313 |
+
|
| 314 |
+
for ds in datasets:
|
| 315 |
+
yaml_path = os.path.join(conf_dir, f"{ds}.yaml")
|
| 316 |
+
if os.path.exists(yaml_path):
|
| 317 |
+
with open(yaml_path, "r") as f:
|
| 318 |
+
meta = yaml.safe_load(f)
|
| 319 |
+
domain = meta.get("domain")
|
| 320 |
+
freq = meta.get("freq")
|
| 321 |
+
if domain:
|
| 322 |
+
domains.add(domain)
|
| 323 |
+
if freq:
|
| 324 |
+
freqs.add(freq)
|
| 325 |
+
return sorted(list(domains)), sorted(list(freqs))
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def get_dataset_choices(results_root="output/results") -> Tuple[List[str], dict, dict]:
|
| 329 |
+
"""
|
| 330 |
+
Get list of available datasets from TIME results with smart display names.
|
| 331 |
+
|
| 332 |
+
For datasets with only one freq: display as "dataset" (e.g., "Australia_Solar")
|
| 333 |
+
For datasets with multiple freqs: display as "dataset/freq" (e.g., "Traffic/15T")
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
results_root: Path to the TIME results root directory
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
Tuple of:
|
| 340 |
+
- display_names: Sorted list of display names for UI dropdown
|
| 341 |
+
- display_to_id: dict mapping display_name -> dataset_id
|
| 342 |
+
- id_to_display: dict mapping dataset_id -> display_name
|
| 343 |
+
"""
|
| 344 |
+
if not os.path.exists(results_root):
|
| 345 |
+
return [], {}, {}
|
| 346 |
+
|
| 347 |
+
# Collect all dataset/freq combinations
|
| 348 |
+
dataset_freq_pairs = set() # Set of (dataset, freq) tuples
|
| 349 |
+
|
| 350 |
+
for model in os.listdir(results_root):
|
| 351 |
+
model_dir = os.path.join(results_root, model)
|
| 352 |
+
if not os.path.isdir(model_dir):
|
| 353 |
+
continue
|
| 354 |
+
|
| 355 |
+
for dataset in os.listdir(model_dir):
|
| 356 |
+
dataset_dir = os.path.join(model_dir, dataset)
|
| 357 |
+
if not os.path.isdir(dataset_dir):
|
| 358 |
+
continue
|
| 359 |
+
|
| 360 |
+
# Check directory structure
|
| 361 |
+
has_horizon_dirs = any(os.path.isdir(os.path.join(dataset_dir, h)) for h in ["short", "medium", "long"])
|
| 362 |
+
|
| 363 |
+
if has_horizon_dirs:
|
| 364 |
+
# Direct structure (legacy): treat as dataset with empty freq
|
| 365 |
+
# This shouldn't happen in the new structure but handle for safety
|
| 366 |
+
for horizon in ["short", "medium", "long"]:
|
| 367 |
+
config_path = os.path.join(dataset_dir, horizon, "config.json")
|
| 368 |
+
if os.path.exists(config_path):
|
| 369 |
+
dataset_freq_pairs.add((dataset, ""))
|
| 370 |
+
break
|
| 371 |
+
else:
|
| 372 |
+
# Nested structure: model/dataset/freq/horizon/
|
| 373 |
+
for freq_dir in os.listdir(dataset_dir):
|
| 374 |
+
freq_path = os.path.join(dataset_dir, freq_dir)
|
| 375 |
+
if not os.path.isdir(freq_path):
|
| 376 |
+
continue
|
| 377 |
+
|
| 378 |
+
for horizon in ["short", "medium", "long"]:
|
| 379 |
+
config_path = os.path.join(freq_path, horizon, "config.json")
|
| 380 |
+
if os.path.exists(config_path):
|
| 381 |
+
dataset_freq_pairs.add((dataset, freq_dir))
|
| 382 |
+
break
|
| 383 |
+
|
| 384 |
+
if not dataset_freq_pairs:
|
| 385 |
+
return [], {}, {}
|
| 386 |
+
|
| 387 |
+
# Count freqs per dataset
|
| 388 |
+
from collections import Counter
|
| 389 |
+
dataset_freq_count = Counter(ds for ds, _ in dataset_freq_pairs)
|
| 390 |
+
|
| 391 |
+
# Build mappings
|
| 392 |
+
id_to_display = {}
|
| 393 |
+
display_to_id = {}
|
| 394 |
+
|
| 395 |
+
for dataset, freq in dataset_freq_pairs:
|
| 396 |
+
if freq:
|
| 397 |
+
dataset_id = f"{dataset}/{freq}"
|
| 398 |
+
else:
|
| 399 |
+
dataset_id = dataset
|
| 400 |
+
|
| 401 |
+
if dataset_freq_count[dataset] > 1:
|
| 402 |
+
# Multiple freqs: display as dataset/freq
|
| 403 |
+
display_name = dataset_id
|
| 404 |
+
else:
|
| 405 |
+
# Single freq: display as dataset only
|
| 406 |
+
display_name = dataset
|
| 407 |
+
|
| 408 |
+
id_to_display[dataset_id] = display_name
|
| 409 |
+
display_to_id[display_name] = dataset_id
|
| 410 |
+
|
| 411 |
+
# Sort display names for UI
|
| 412 |
+
display_names = sorted(display_to_id.keys())
|
| 413 |
+
|
| 414 |
+
return display_names, display_to_id, id_to_display
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def compute_ranks(df: pd.DataFrame, groupby_cols: str | List[str]) -> pd.DataFrame:
|
| 418 |
+
"""
|
| 419 |
+
Compute ranks for models across datasets based on MASE and CRPS.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
df (pd.DataFrame): Dataset-level results with columns
|
| 423 |
+
["model", "dataset", "MASE", "CRPS"].
|
| 424 |
+
|
| 425 |
+
Returns:
|
| 426 |
+
pd.DataFrame: Dataframe with ["model", "MASE_rank", "CRPS_rank"].
|
| 427 |
+
"""
|
| 428 |
+
if isinstance(groupby_cols, str):
|
| 429 |
+
groupby_cols = [groupby_cols]
|
| 430 |
+
|
| 431 |
+
if df.empty:
|
| 432 |
+
return pd.DataFrame(columns=["model", "MASE_rank", "CRPS_rank"])
|
| 433 |
+
|
| 434 |
+
df = df.copy()
|
| 435 |
+
|
| 436 |
+
df["MASE_rank"] = df.groupby(groupby_cols)["MASE"].rank(method="first", ascending=True)
|
| 437 |
+
df["CRPS_rank"] = df.groupby(groupby_cols)["CRPS"].rank(method="first", ascending=True)
|
| 438 |
+
|
| 439 |
+
return df
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def normalize_by_seasonal_naive(
|
| 443 |
+
df: pd.DataFrame,
|
| 444 |
+
baseline_model: str = "seasonal_naive",
|
| 445 |
+
metrics: List[str] = None,
|
| 446 |
+
groupby_cols: List[str] = None,
|
| 447 |
+
) -> pd.DataFrame:
|
| 448 |
+
"""
|
| 449 |
+
Normalize metrics by Seasonal Naive baseline for each (dataset_id, horizon) group.
|
| 450 |
+
|
| 451 |
+
For each group, divides each model's metric values by Seasonal Naive's values.
|
| 452 |
+
This makes Seasonal Naive the baseline (=1.0) for comparison.
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
df (pd.DataFrame): Dataset-level results with columns including
|
| 456 |
+
["model", "dataset_id", "horizon", "MASE", "CRPS", ...].
|
| 457 |
+
baseline_model (str): Name of the baseline model. Defaults to "seasonal_naive".
|
| 458 |
+
metrics (List[str]): List of metric columns to normalize. Defaults to ["MASE", "CRPS"].
|
| 459 |
+
groupby_cols (List[str]): Columns to group by for normalization.
|
| 460 |
+
Defaults to ["dataset_id", "horizon"].
|
| 461 |
+
|
| 462 |
+
Returns:
|
| 463 |
+
pd.DataFrame: DataFrame with normalized metric values.
|
| 464 |
+
- Configurations without baseline model results are excluded.
|
| 465 |
+
- NaN/inf values from division are handled.
|
| 466 |
+
"""
|
| 467 |
+
if metrics is None:
|
| 468 |
+
metrics = ["MASE", "CRPS"]
|
| 469 |
+
if groupby_cols is None:
|
| 470 |
+
groupby_cols = ["dataset_id", "horizon"]
|
| 471 |
+
|
| 472 |
+
if df.empty:
|
| 473 |
+
return df.copy()
|
| 474 |
+
|
| 475 |
+
# Check if baseline model exists
|
| 476 |
+
if baseline_model not in df["model"].values:
|
| 477 |
+
print(f"[normalize_by_seasonal_naive] Warning: baseline model '{baseline_model}' not found in data")
|
| 478 |
+
return df.copy()
|
| 479 |
+
|
| 480 |
+
# Work on a copy
|
| 481 |
+
df_normalized = df.copy()
|
| 482 |
+
|
| 483 |
+
# Get baseline values for each group
|
| 484 |
+
baseline_df = df[df["model"] == baseline_model].copy()
|
| 485 |
+
|
| 486 |
+
# Create a mapping: (dataset_id, horizon) -> {metric: baseline_value}
|
| 487 |
+
baseline_values = {}
|
| 488 |
+
for _, row in baseline_df.iterrows():
|
| 489 |
+
key = tuple(row[col] for col in groupby_cols)
|
| 490 |
+
baseline_values[key] = {metric: row[metric] for metric in metrics}
|
| 491 |
+
|
| 492 |
+
# Normalize each row
|
| 493 |
+
rows_to_keep = []
|
| 494 |
+
for idx, row in df_normalized.iterrows():
|
| 495 |
+
key = tuple(row[col] for col in groupby_cols)
|
| 496 |
+
|
| 497 |
+
# Skip configurations without baseline results
|
| 498 |
+
if key not in baseline_values:
|
| 499 |
+
continue
|
| 500 |
+
|
| 501 |
+
rows_to_keep.append(idx)
|
| 502 |
+
|
| 503 |
+
# Normalize each metric
|
| 504 |
+
for metric in metrics:
|
| 505 |
+
baseline_val = baseline_values[key][metric]
|
| 506 |
+
if baseline_val is not None and baseline_val != 0 and not np.isnan(baseline_val):
|
| 507 |
+
df_normalized.at[idx, metric] = row[metric] / baseline_val
|
| 508 |
+
else:
|
| 509 |
+
# Handle division by zero or NaN baseline
|
| 510 |
+
df_normalized.at[idx, metric] = np.nan
|
| 511 |
+
|
| 512 |
+
# Keep only rows with valid baseline
|
| 513 |
+
df_normalized = df_normalized.loc[rows_to_keep].copy()
|
| 514 |
+
|
| 515 |
+
# Handle any remaining inf values
|
| 516 |
+
for metric in metrics:
|
| 517 |
+
df_normalized[metric] = df_normalized[metric].replace([np.inf, -np.inf], np.nan)
|
| 518 |
+
|
| 519 |
+
return df_normalized
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def load_features(root_dir: str = "features", category: str = "public-benchmarks", split: str = "test") -> pd.DataFrame:
|
| 523 |
+
"""
|
| 524 |
+
Load time series features for all datasets (legacy function).
|
| 525 |
+
|
| 526 |
+
Args:
|
| 527 |
+
root_dir (str): Path to features root directory.
|
| 528 |
+
category (str): Dataset category (e.g., "public-benchmarks").
|
| 529 |
+
split (str): Which split to load ("full" or "test").
|
| 530 |
+
|
| 531 |
+
Returns:
|
| 532 |
+
pd.DataFrame: Concatenated DataFrame with dataset column.
|
| 533 |
+
"""
|
| 534 |
+
base_dir = os.path.join(root_dir, category)
|
| 535 |
+
all_data = []
|
| 536 |
+
|
| 537 |
+
for dataset in os.listdir(base_dir):
|
| 538 |
+
dataset_dir = os.path.join(base_dir, dataset)
|
| 539 |
+
csv_path = os.path.join(dataset_dir, f"{split}.csv")
|
| 540 |
+
if os.path.exists(csv_path):
|
| 541 |
+
df = pd.read_csv(csv_path)
|
| 542 |
+
df["dataset"] = dataset # add dataset name
|
| 543 |
+
cols = ["dataset"] + [c for c in df.columns if c != "dataset"] # 让 dataset 列放到第一列
|
| 544 |
+
df = df[cols]
|
| 545 |
+
all_data.append(df)
|
| 546 |
+
|
| 547 |
+
if all_data:
|
| 548 |
+
df = pd.concat(all_data, ignore_index=True)
|
| 549 |
+
if "unique_id" in df.columns:
|
| 550 |
+
df = df.rename(columns={"unique_id": "variate_name"})
|
| 551 |
+
return df
|
| 552 |
+
else:
|
| 553 |
+
return pd.DataFrame()
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def load_all_features(features_root: str = "output/features", split: str = "test") -> pd.DataFrame:
|
| 557 |
+
"""
|
| 558 |
+
Load time series features for all datasets from output/features directory.
|
| 559 |
+
|
| 560 |
+
Expected structure: features_root/{dataset}/{freq}/{split}.csv
|
| 561 |
+
Each CSV should have columns: dataset_id, series_name, variate_name, ...features...
|
| 562 |
+
|
| 563 |
+
Args:
|
| 564 |
+
features_root (str): Path to features root directory (e.g., "output/features").
|
| 565 |
+
split (str): Which split to load ("full" or "test").
|
| 566 |
+
|
| 567 |
+
Returns:
|
| 568 |
+
pd.DataFrame: Concatenated DataFrame with all variate features.
|
| 569 |
+
Columns: ["dataset_id", "series_name", "variate_name", "unique_id",
|
| 570 |
+
"is_random_walk", "has_spike_presence", "trend_strength", ...]
|
| 571 |
+
"""
|
| 572 |
+
all_data = []
|
| 573 |
+
|
| 574 |
+
if not os.path.exists(features_root):
|
| 575 |
+
print(f"[load_all_features] features_root={features_root} does not exist")
|
| 576 |
+
return pd.DataFrame()
|
| 577 |
+
|
| 578 |
+
for dataset in os.listdir(features_root):
|
| 579 |
+
dataset_dir = os.path.join(features_root, dataset)
|
| 580 |
+
if not os.path.isdir(dataset_dir):
|
| 581 |
+
continue
|
| 582 |
+
|
| 583 |
+
for freq in os.listdir(dataset_dir):
|
| 584 |
+
freq_dir = os.path.join(dataset_dir, freq)
|
| 585 |
+
if not os.path.isdir(freq_dir):
|
| 586 |
+
continue
|
| 587 |
+
|
| 588 |
+
csv_path = os.path.join(freq_dir, f"{split}.csv")
|
| 589 |
+
if not os.path.exists(csv_path):
|
| 590 |
+
# Fallback: try full.csv if test.csv doesn't exist
|
| 591 |
+
csv_path = os.path.join(freq_dir, "full.csv")
|
| 592 |
+
if os.path.exists(csv_path):
|
| 593 |
+
try:
|
| 594 |
+
df = pd.read_csv(csv_path)
|
| 595 |
+
all_data.append(df)
|
| 596 |
+
except Exception as e:
|
| 597 |
+
print(f"[load_all_features] Error loading {csv_path}: {e}")
|
| 598 |
+
|
| 599 |
+
if all_data:
|
| 600 |
+
features_df = pd.concat(all_data, ignore_index=True)
|
| 601 |
+
print(f"[load_all_features] Loaded {len(features_df)} variate features from {len(all_data)} datasets")
|
| 602 |
+
return features_df
|
| 603 |
+
else:
|
| 604 |
+
print(f"[load_all_features] No features found in {features_root}")
|
| 605 |
+
return pd.DataFrame()
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def binarize_features(df: pd.DataFrame, exclude: list) -> pd.DataFrame:
|
| 610 |
+
"""
|
| 611 |
+
Binarize features in df based on their median values.
|
| 612 |
+
Columns in exclude will be skipped.
|
| 613 |
+
|
| 614 |
+
Args:
|
| 615 |
+
df (pd.DataFrame): Input dataframe with feature values.
|
| 616 |
+
exclude (list): Columns to exclude from binarization.
|
| 617 |
+
|
| 618 |
+
Returns:
|
| 619 |
+
pd.DataFrame: Model_A dataframe where selected feature columns are binarized (0/1).
|
| 620 |
+
"""
|
| 621 |
+
# Select target feature columns
|
| 622 |
+
feature_cols = [col for col in df.columns if col not in exclude]
|
| 623 |
+
|
| 624 |
+
# Copy to avoid modifying original
|
| 625 |
+
df_binarized = df.copy()
|
| 626 |
+
|
| 627 |
+
# Compute medians
|
| 628 |
+
medians = df[feature_cols].median()
|
| 629 |
+
|
| 630 |
+
# Apply binarization
|
| 631 |
+
for col in feature_cols:
|
| 632 |
+
threshold = medians[col]
|
| 633 |
+
df_binarized[col] = (df[col] > threshold).astype(int)
|
| 634 |
+
|
| 635 |
+
return df_binarized
|