xianqiu commited on
Commit
5145926
·
0 Parent(s):

Initial deployment: Kronos BTC Forecast API (xianqiu/qlang)

Browse files
Files changed (11) hide show
  1. .gitattributes +35 -0
  2. DEPLOYMENT.md +217 -0
  3. Dockerfile +40 -0
  4. README.md +33 -0
  5. app.py +776 -0
  6. client.py +410 -0
  7. model/__init__.py +17 -0
  8. model/kronos.py +589 -0
  9. model/module.py +580 -0
  10. models/predictor/README.md +10 -0
  11. models/predictor/config.json +13 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
DEPLOYMENT.md ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HuggingFace Space 部署指南
2
+
3
+ 本指南介绍如何将 Kronos BTC 预测 API 部署到 HuggingFace Spaces。
4
+
5
+ ## 准备工作
6
+
7
+ ### 1. 创建 HuggingFace 账户
8
+
9
+ 如果还没有账户,请访问 https://huggingface.co/join 注册。
10
+
11
+ ### 2. 安装 HuggingFace CLI
12
+
13
+ ```bash
14
+ pip install huggingface_hub
15
+ huggingface-cli login
16
+ ```
17
+
18
+ ## 方法一:通过 Git 部署 (推荐)
19
+
20
+ ### 1. 创建新 Space
21
+
22
+ 访问 https://huggingface.co/new-space 创建新 Space:
23
+
24
+ - **Space name**: `kronos-btc-predictor` (或任意名称)
25
+ - **License**: MIT
26
+ - **SDK**: Docker
27
+ - **Hardware**: CPU basic (免费)
28
+
29
+ ### 2. 克隆 Space 仓库
30
+
31
+ ```bash
32
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/kronos-btc-predictor
33
+ cd kronos-btc-predictor
34
+ ```
35
+
36
+ ### 3. 复制文件
37
+
38
+ ```bash
39
+ # 复制所有文件到 Space 仓库
40
+ cp -r /path/to/hf_space/* .
41
+
42
+ # 文件结构应该是:
43
+ # ├── app.py
44
+ # ├── requirements.txt
45
+ # ├── README.md
46
+ # ├── client.py
47
+ # ├── Dockerfile # 需要创建
48
+ # ├── model/
49
+ # │ ├── __init__.py
50
+ # │ ├── kronos.py
51
+ # │ └── module.py
52
+ # └── models/
53
+ # ├── tokenizer/
54
+ # │ ├── config.json
55
+ # │ └── model.safetensors
56
+ # └── predictor/
57
+ # ├── config.json
58
+ # └── model.safetensors
59
+ ```
60
+
61
+ ### 4. 创建 Dockerfile
62
+
63
+ ```dockerfile
64
+ FROM python:3.10-slim
65
+
66
+ WORKDIR /app
67
+
68
+ # 安装依赖
69
+ COPY requirements.txt .
70
+ RUN pip install --no-cache-dir -r requirements.txt
71
+
72
+ # 复制应用代码
73
+ COPY . .
74
+
75
+ # 暴露端口
76
+ EXPOSE 7860
77
+
78
+ # 启动服务
79
+ CMD ["python", "app.py"]
80
+ ```
81
+
82
+ ### 5. 推送到 HuggingFace
83
+
84
+ ```bash
85
+ git add .
86
+ git commit -m "Initial deployment"
87
+ git push
88
+ ```
89
+
90
+ ### 6. 等待构建
91
+
92
+ Space 会自动构建和部署。你可以在 Space 页面查看构建日志。
93
+
94
+ 构建完成后,API 将在以下地址可用:
95
+ ```
96
+ https://YOUR_USERNAME-kronos-btc-predictor.hf.space
97
+ ```
98
+
99
+ ## 方法二:通过 Web 界面上传
100
+
101
+ ### 1. 创建 Space
102
+
103
+ 访问 https://huggingface.co/new-space:
104
+ - SDK: Docker
105
+ - Hardware: CPU basic
106
+
107
+ ### 2. 上传文件
108
+
109
+ 在 Space 页面点击 "Files" 标签,然后 "Add file" -> "Upload files":
110
+
111
+ 逐个上传以下文件:
112
+ - `app.py`
113
+ - `requirements.txt`
114
+ - `Dockerfile`
115
+ - `model/__init__.py`
116
+ - `model/kronos.py`
117
+ - `model/module.py`
118
+ - `models/tokenizer/config.json`
119
+ - `models/tokenizer/model.safetensors`
120
+ - `models/predictor/config.json`
121
+ - `models/predictor/model.safetensors`
122
+
123
+ ## 验证部署
124
+
125
+ ### 1. 健康检查
126
+
127
+ ```bash
128
+ curl https://YOUR_USERNAME-kronos-btc-predictor.hf.space/health
129
+ ```
130
+
131
+ 预期响应:
132
+ ```json
133
+ {
134
+ "status": "healthy",
135
+ "model_loaded": true,
136
+ "model_version": "iter5 (converged)",
137
+ "device": "cpu"
138
+ }
139
+ ```
140
+
141
+ ### 2. API 文档
142
+
143
+ 访问 Swagger UI:
144
+ ```
145
+ https://YOUR_USERNAME-kronos-btc-predictor.hf.space/docs
146
+ ```
147
+
148
+ ### 3. 测试预测
149
+
150
+ ```python
151
+ from client import KronosClient
152
+
153
+ client = KronosClient("https://YOUR_USERNAME-kronos-btc-predictor.hf.space")
154
+ health = client.health()
155
+ print(f"Status: {health.status}")
156
+ ```
157
+
158
+ ## 配置自定义域名
159
+
160
+ 1. 在 Space 设置中找到 "Custom domain"
161
+ 2. 输入你的域名 (如 `api.yourdomain.com`)
162
+ 3. 配置 DNS CNAME 记录指向 HuggingFace
163
+
164
+ ## 注意事项
165
+
166
+ ### 免费版限制
167
+
168
+ - **CPU**: 2 vCPU
169
+ - **内存**: 16GB RAM
170
+ - **存储**: 50GB
171
+ - **请求**: 无硬性限制,但有速率控制
172
+ - **冷启动**: 不活动时会休眠,首次请求需等待约 30-60 秒
173
+
174
+ ### 性能优化
175
+
176
+ 1. **减少 n_paths**: 使用 10-20 个路径而不是 30-100
177
+ 2. **减少 pred_len**: 使用 12-24 而不是 72
178
+ 3. **预热**: 定期发送健康检查请求防止休眠
179
+
180
+ ### 安全建议
181
+
182
+ 1. 不要在代码中硬编码 API 密钥
183
+ 2. 使用 HuggingFace Secrets 存储敏感信息
184
+ 3. 考虑添加请求速率限制
185
+
186
+ ## 升级到 Pro
187
+
188
+ 如果需要更好的性能,可以升级到 HuggingFace Pro:
189
+
190
+ - **CPU upgrade**: 更快的 CPU
191
+ - **GPU**: T4 GPU (付费)
192
+ - **永不休眠**: 始终保持运行
193
+
194
+ 访问 https://huggingface.co/pricing 了解详情。
195
+
196
+ ## 故障排除
197
+
198
+ ### 构建失败
199
+
200
+ 1. 检查 `requirements.txt` 中的版本兼容性
201
+ 2. 确保所有文件都已上传
202
+ 3. 查看构建日志中的错误信息
203
+
204
+ ### 模型加载失败
205
+
206
+ 1. 确认 `models/` 目录结构正确
207
+ 2. 检查 `config.json` 和 `model.safetensors` 文件
208
+
209
+ ### 请求超时
210
+
211
+ 1. 减少 `n_paths` 和 `pred_len` 参数
212
+ 2. 检查输入数据大小
213
+ 3. 考虑升级到更好的硬件
214
+
215
+ ## 联系支持
216
+
217
+ 如有问题,请在项目仓库提交 Issue。
Dockerfile ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Kronos BTC Prediction API - Docker Image
2
+ # Optimized for HuggingFace Spaces
3
+
4
+ FROM python:3.10-slim
5
+
6
+ # Set working directory
7
+ WORKDIR /app
8
+
9
+ # Install system dependencies
10
+ RUN apt-get update && apt-get install -y --no-install-recommends \
11
+ build-essential \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Copy requirements first for better caching
15
+ COPY requirements.txt .
16
+
17
+ # Install Python dependencies
18
+ RUN pip install --no-cache-dir -r requirements.txt
19
+
20
+ # Copy application code
21
+ COPY . .
22
+
23
+ # Create non-root user for security
24
+ RUN useradd -m -u 1000 user
25
+ USER user
26
+
27
+ # Set environment variables
28
+ ENV HOME=/home/user \
29
+ PATH=/home/user/.local/bin:$PATH \
30
+ PYTHONUNBUFFERED=1
31
+
32
+ # Expose port (HuggingFace Spaces uses 7860)
33
+ EXPOSE 7860
34
+
35
+ # Health check
36
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
37
+ CMD python -c "import httpx; httpx.get('http://localhost:7860/health', timeout=5)" || exit 1
38
+
39
+ # Start the application
40
+ CMD ["python", "app.py"]
README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Kronos BTC Forecast
3
+ emoji: 📈
4
+ colorFrom: blue
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 5.9.1
8
+ python_version: "3.10"
9
+ app_file: app.py
10
+ pinned: false
11
+ license: mit
12
+ ---
13
+
14
+ # Kronos BTC/USDT Forecast API
15
+
16
+ Probabilistic BTC/USDT price forecasting using [Kronos](https://github.com/shiyu-coder/Kronos) foundation model.
17
+
18
+ ## API Usage
19
+
20
+ ```python
21
+ from gradio_client import Client
22
+
23
+ client = Client("xianqiu/qlang")
24
+
25
+ # Get BTC/USDT 24-hour forecast
26
+ plot, result = client.predict(api_name="/predict")
27
+ print(result)
28
+ ```
29
+
30
+ ## Model
31
+
32
+ - **Model:** Kronos-mini (4.1M params)
33
+ - **Paper:** [arXiv:2508.02739](https://arxiv.org/abs/2508.02739)
app.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Kronos API Server - Hugging Face Space
3
+
4
+ Provides API endpoints for BTC/USDT price forecasting using Kronos model.
5
+
6
+ API Usage:
7
+ from gradio_client import Client
8
+ client = Client("xianqiu/qlang")
9
+
10
+ # Fast API (no plot)
11
+ result = client.predict(align_to_hour=True, api_name="/predict_api")
12
+
13
+ # With plot
14
+ plot, result = client.predict(align_to_hour=True, api_name="/predict")
15
+ """
16
+
17
+ import os
18
+ import json
19
+ import time
20
+ from datetime import datetime, timezone, timedelta
21
+
22
+ import gradio as gr
23
+ import numpy as np
24
+ import pandas as pd
25
+ import torch
26
+ import matplotlib
27
+ matplotlib.use('Agg')
28
+ import matplotlib.pyplot as plt
29
+
30
+ from model import Kronos, KronosTokenizer, KronosPredictor
31
+
32
+ # === Configuration ===
33
+ CONFIG = {
34
+ "SYMBOL": "BTCUSDT",
35
+ "INTERVAL": "1h",
36
+ "HIST_POINTS": 360,
37
+ "PRED_HORIZON": 24,
38
+ "N_PREDICTIONS": 30,
39
+ "VOL_WINDOW": 24,
40
+ "TEMPERATURE": 1.0,
41
+ "TOP_P": 0.95,
42
+ }
43
+
44
+ # Global model instance
45
+ predictor = None
46
+
47
+
48
+ def load_model():
49
+ """Load Kronos model and tokenizer."""
50
+ global predictor
51
+ if predictor is not None:
52
+ return predictor
53
+
54
+ print("Loading Kronos model...")
55
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
56
+
57
+ tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-2k")
58
+ model = Kronos.from_pretrained("NeoQuasar/Kronos-mini")
59
+
60
+ tokenizer.eval()
61
+ model.eval()
62
+
63
+ predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
64
+ print(f"Model loaded on {device}")
65
+
66
+ return predictor
67
+
68
+
69
+ def fetch_binance_data():
70
+ """Fetch K-line data using Binance public REST API."""
71
+ import requests
72
+
73
+ symbol = "BTCUSDT"
74
+ interval = "1h"
75
+ limit = CONFIG["HIST_POINTS"] + CONFIG["VOL_WINDOW"]
76
+
77
+ # Try multiple Binance API endpoints
78
+ endpoints = [
79
+ "https://api.binance.com/api/v3/klines",
80
+ "https://api1.binance.com/api/v3/klines",
81
+ "https://api2.binance.com/api/v3/klines",
82
+ "https://api3.binance.com/api/v3/klines",
83
+ "https://data-api.binance.vision/api/v3/klines", # Data API endpoint
84
+ ]
85
+
86
+ ohlcv = None
87
+ last_error = None
88
+
89
+ for endpoint in endpoints:
90
+ try:
91
+ url = f"{endpoint}?symbol={symbol}&interval={interval}&limit={limit}"
92
+ response = requests.get(url, timeout=30)
93
+ response.raise_for_status()
94
+ ohlcv = response.json()
95
+ break
96
+ except Exception as e:
97
+ last_error = e
98
+ continue
99
+
100
+ if ohlcv is None:
101
+ # Fallback to ccxt with OKX
102
+ try:
103
+ import ccxt
104
+ exchange = ccxt.okx({'enableRateLimit': True})
105
+ raw_ohlcv = exchange.fetch_ohlcv("BTC/USDT", "1h", limit=limit)
106
+ # Convert ccxt format to binance format
107
+ ohlcv = [[d[0], d[1], d[2], d[3], d[4], d[5], d[0], 0, 0, 0, 0, 0] for d in raw_ohlcv]
108
+ except Exception as e:
109
+ raise Exception(f"Failed to fetch data from all sources. Last error: {last_error}, ccxt error: {e}")
110
+
111
+ # Parse Binance format: [open_time, open, high, low, close, volume, close_time, quote_volume, ...]
112
+ df = pd.DataFrame(ohlcv, columns=[
113
+ 'open_time', 'open', 'high', 'low', 'close', 'volume', 'close_time',
114
+ 'quote_asset_volume', 'number_of_trades', 'taker_buy_base_asset_volume',
115
+ 'taker_buy_quote_asset_volume', 'ignore'
116
+ ])
117
+
118
+ df['timestamps'] = pd.to_datetime(df['open_time'], unit='ms')
119
+ df['amount'] = pd.to_numeric(df['quote_asset_volume'])
120
+
121
+ for col in ['open', 'high', 'low', 'close', 'volume']:
122
+ df[col] = pd.to_numeric(df[col])
123
+
124
+ df = df[['timestamps', 'open', 'high', 'low', 'close', 'volume', 'amount']]
125
+
126
+ return df
127
+
128
+
129
+ def make_prediction(df, pred_model):
130
+ """Generate probabilistic forecasts."""
131
+ last_timestamp = df['timestamps'].max()
132
+ start_new_range = last_timestamp + pd.Timedelta(hours=1)
133
+ new_timestamps_index = pd.date_range(
134
+ start=start_new_range,
135
+ periods=CONFIG["PRED_HORIZON"],
136
+ freq='h'
137
+ )
138
+ y_timestamp = pd.Series(new_timestamps_index, name='y_timestamp')
139
+ x_timestamp = df['timestamps']
140
+ x_df = df[['open', 'high', 'low', 'close', 'volume', 'amount']]
141
+
142
+ with torch.no_grad():
143
+ close_preds, volume_preds = pred_model.predict(
144
+ df=x_df,
145
+ x_timestamp=x_timestamp,
146
+ y_timestamp=y_timestamp,
147
+ pred_len=CONFIG["PRED_HORIZON"],
148
+ T=CONFIG["TEMPERATURE"],
149
+ top_p=CONFIG["TOP_P"],
150
+ sample_count=CONFIG["N_PREDICTIONS"],
151
+ verbose=False
152
+ )
153
+
154
+ return close_preds, volume_preds, y_timestamp
155
+
156
+
157
+ def make_prediction_detail(df, pred_model):
158
+ """Generate probabilistic forecasts with full OHLCV output."""
159
+ last_timestamp = df['timestamps'].max()
160
+ start_new_range = last_timestamp + pd.Timedelta(hours=1)
161
+ new_timestamps_index = pd.date_range(
162
+ start=start_new_range,
163
+ periods=CONFIG["PRED_HORIZON"],
164
+ freq='h'
165
+ )
166
+ y_timestamp = pd.Series(new_timestamps_index, name='y_timestamp')
167
+ x_timestamp = df['timestamps']
168
+ x_df = df[['open', 'high', 'low', 'close', 'volume', 'amount']]
169
+
170
+ with torch.no_grad():
171
+ preds_dict = pred_model.predict_detail(
172
+ df=x_df,
173
+ x_timestamp=x_timestamp,
174
+ y_timestamp=y_timestamp,
175
+ pred_len=CONFIG["PRED_HORIZON"],
176
+ T=CONFIG["TEMPERATURE"],
177
+ top_p=CONFIG["TOP_P"],
178
+ sample_count=CONFIG["N_PREDICTIONS"],
179
+ verbose=False
180
+ )
181
+
182
+ return preds_dict, y_timestamp
183
+
184
+
185
+ def calculate_metrics(hist_df, close_preds_df):
186
+ """Calculate upside and volatility metrics."""
187
+ last_close = hist_df['close'].iloc[-1]
188
+
189
+ # Upside Probability
190
+ final_hour_preds = close_preds_df.iloc[-1]
191
+ upside_prob = float((final_hour_preds > last_close).mean())
192
+
193
+ # Volatility Amplification
194
+ hist_log_returns = np.log(hist_df['close'] / hist_df['close'].shift(1))
195
+ historical_vol = hist_log_returns.iloc[-CONFIG["VOL_WINDOW"]:].std()
196
+
197
+ amplification_count = 0
198
+ for col in close_preds_df.columns:
199
+ full_sequence = pd.concat([pd.Series([last_close]), close_preds_df[col]]).reset_index(drop=True)
200
+ pred_log_returns = np.log(full_sequence / full_sequence.shift(1))
201
+ predicted_vol = pred_log_returns.std()
202
+ if predicted_vol > historical_vol:
203
+ amplification_count += 1
204
+
205
+ vol_amp_prob = amplification_count / len(close_preds_df.columns)
206
+
207
+ return upside_prob, vol_amp_prob
208
+
209
+
210
+ def create_plot(hist_df, close_preds_df, volume_preds_df):
211
+ """Create forecast visualization."""
212
+ fig, (ax1, ax2) = plt.subplots(
213
+ 2, 1, figsize=(15, 10), sharex=True,
214
+ gridspec_kw={'height_ratios': [3, 1]}
215
+ )
216
+
217
+ hist_time = hist_df['timestamps']
218
+ last_hist_time = hist_time.iloc[-1]
219
+ pred_time = pd.to_datetime([last_hist_time + timedelta(hours=i + 1) for i in range(len(close_preds_df))])
220
+
221
+ ax1.plot(hist_time, hist_df['close'], color='royalblue', label='Historical Price', linewidth=1.5)
222
+ mean_preds = close_preds_df.mean(axis=1)
223
+ ax1.plot(pred_time, mean_preds, color='darkorange', linestyle='-', label='Mean Forecast', linewidth=2)
224
+ ax1.fill_between(pred_time, close_preds_df.min(axis=1), close_preds_df.max(axis=1),
225
+ color='darkorange', alpha=0.2, label='Forecast Range')
226
+ ax1.set_title(f'{CONFIG["SYMBOL"]} 24-Hour Price Forecast (Kronos)', fontsize=16, weight='bold')
227
+ ax1.set_ylabel('Price (USDT)')
228
+ ax1.legend()
229
+ ax1.grid(True, linestyle='--', alpha=0.7)
230
+
231
+ ax2.bar(hist_time, hist_df['volume'], color='skyblue', label='Historical Volume', width=0.03)
232
+ ax2.bar(pred_time, volume_preds_df.mean(axis=1), color='sandybrown', label='Forecast Volume', width=0.03)
233
+ ax2.set_ylabel('Volume')
234
+ ax2.set_xlabel('Time (UTC)')
235
+ ax2.legend()
236
+ ax2.grid(True, linestyle='--', alpha=0.7)
237
+
238
+ separator_time = hist_time.iloc[-1] + timedelta(minutes=30)
239
+ for ax in [ax1, ax2]:
240
+ ax.axvline(x=separator_time, color='red', linestyle='--', linewidth=1.5)
241
+ ax.tick_params(axis='x', rotation=30)
242
+
243
+ fig.tight_layout()
244
+ return fig
245
+
246
+
247
+ def predict_btc(align_to_hour: bool = True):
248
+ """
249
+ Main prediction function with plot (for UI).
250
+
251
+ Args:
252
+ align_to_hour: If True, use data up to the last completed hour (aligned with official demo).
253
+ If False, use all available data including the current incomplete hour.
254
+
255
+ Returns:
256
+ tuple: (plot_figure, result_dict)
257
+ """
258
+ fig, result = _do_prediction(align_to_hour=align_to_hour, include_plot=True)
259
+ return fig, result
260
+
261
+
262
+ def predict_btc_api(align_to_hour: bool = True):
263
+ """
264
+ API-only prediction (no plot, faster response).
265
+
266
+ Args:
267
+ align_to_hour: If True, use data up to the last completed hour (aligned with official demo).
268
+ If False, use all available data including the current incomplete hour.
269
+
270
+ Returns:
271
+ dict: Prediction result without plot
272
+ """
273
+ _, result = _do_prediction(align_to_hour=align_to_hour, include_plot=False)
274
+ return result
275
+
276
+
277
+ def predict_btc_detail(align_to_hour: bool = True):
278
+ """
279
+ Detailed prediction API returning all Monte Carlo sample paths.
280
+
281
+ Args:
282
+ align_to_hour: If True, use data up to the last completed hour (aligned with official demo).
283
+ If False, use all available data including the current incomplete hour.
284
+
285
+ Returns:
286
+ dict: Prediction result with all Monte Carlo sample paths
287
+ """
288
+ _, result = _do_prediction_detail(align_to_hour=align_to_hour)
289
+ return result
290
+
291
+
292
+ def _do_prediction(align_to_hour: bool = True, include_plot: bool = True):
293
+ """
294
+ Internal prediction function.
295
+
296
+ Args:
297
+ align_to_hour: If True, use data up to the last completed hour.
298
+ include_plot: If True, generate plot (slower). If False, skip plot (faster).
299
+
300
+ Returns:
301
+ tuple: (plot_figure or None, result_dict)
302
+ """
303
+ try:
304
+ sample_count = CONFIG["N_PREDICTIONS"]
305
+
306
+ print(f"[Predict] align_to_hour={align_to_hour}, include_plot={include_plot}")
307
+ start_time = time.time()
308
+
309
+ # Load model
310
+ pred_model = load_model()
311
+
312
+ # Fetch data
313
+ df_full = fetch_binance_data()
314
+
315
+ # Choose data based on alignment mode
316
+ if align_to_hour:
317
+ # Exclude the last (incomplete) bar - aligned with official demo
318
+ df_for_model = df_full.iloc[:-1]
319
+ data_mode = "hourly_aligned"
320
+ else:
321
+ # Use all data including current incomplete bar
322
+ df_for_model = df_full
323
+ data_mode = "realtime"
324
+
325
+ # Make predictions
326
+ close_preds, volume_preds, pred_timestamps = make_prediction(
327
+ df_for_model, pred_model
328
+ )
329
+
330
+ # Calculate metrics
331
+ hist_df_for_metrics = df_for_model.tail(CONFIG["VOL_WINDOW"])
332
+ upside_prob, vol_amp_prob = calculate_metrics(hist_df_for_metrics, close_preds)
333
+
334
+ # Create plot only if requested
335
+ fig = None
336
+ if include_plot:
337
+ hist_df_for_plot = df_for_model.tail(CONFIG["HIST_POINTS"])
338
+ fig = create_plot(hist_df_for_plot, close_preds, volume_preds)
339
+
340
+ # Prepare result
341
+ last_close = float(df_for_model['close'].iloc[-1])
342
+ last_timestamp = df_for_model['timestamps'].iloc[-1]
343
+ mean_preds = close_preds.mean(axis=1).tolist()
344
+ min_preds = close_preds.min(axis=1).tolist()
345
+ max_preds = close_preds.max(axis=1).tolist()
346
+
347
+ elapsed = time.time() - start_time
348
+
349
+ result = {
350
+ "timestamp": datetime.now(timezone.utc).isoformat(),
351
+ "symbol": CONFIG["SYMBOL"],
352
+ "last_close": last_close,
353
+ "last_data_timestamp": last_timestamp.isoformat(),
354
+ "data_mode": data_mode,
355
+ "upside_probability": round(upside_prob * 100, 1),
356
+ "volatility_amplification": round(vol_amp_prob * 100, 1),
357
+ "prediction_horizon_hours": CONFIG["PRED_HORIZON"],
358
+ "sample_count": sample_count,
359
+ "inference_time_seconds": round(elapsed, 1),
360
+ "predictions": {
361
+ "timestamps": [t.isoformat() for t in pred_timestamps],
362
+ "mean": mean_preds,
363
+ "min": min_preds,
364
+ "max": max_preds,
365
+ },
366
+ "model": {
367
+ "name": "Kronos-mini",
368
+ "tokenizer": "Kronos-Tokenizer-2k",
369
+ "temperature": CONFIG["TEMPERATURE"],
370
+ "top_p": CONFIG["TOP_P"],
371
+ }
372
+ }
373
+
374
+ print(f"[Done] Prediction completed in {elapsed:.1f}s (plot={include_plot})")
375
+ return fig, result
376
+
377
+ except Exception as e:
378
+ error_result = {
379
+ "error": str(e),
380
+ "timestamp": datetime.now(timezone.utc).isoformat()
381
+ }
382
+ return None, error_result
383
+
384
+
385
+ def _do_prediction_detail(align_to_hour: bool = True):
386
+ """
387
+ Internal prediction function that returns all Monte Carlo sample paths.
388
+
389
+ Args:
390
+ align_to_hour: If True, use data up to the last completed hour.
391
+
392
+ Returns:
393
+ tuple: (None, result_dict with all sample paths)
394
+ """
395
+ try:
396
+ sample_count = CONFIG["N_PREDICTIONS"]
397
+
398
+ print(f"[Predict Detail] align_to_hour={align_to_hour}")
399
+ start_time = time.time()
400
+
401
+ # Load model
402
+ pred_model = load_model()
403
+
404
+ # Fetch data
405
+ df_full = fetch_binance_data()
406
+
407
+ # Choose data based on alignment mode
408
+ if align_to_hour:
409
+ # Exclude the last (incomplete) bar - aligned with official demo
410
+ df_for_model = df_full.iloc[:-1]
411
+ data_mode = "hourly_aligned"
412
+ else:
413
+ # Use all data including current incomplete bar
414
+ df_for_model = df_full
415
+ data_mode = "realtime"
416
+
417
+ # Make predictions with full OHLCV output
418
+ preds_dict, pred_timestamps = make_prediction_detail(
419
+ df_for_model, pred_model
420
+ )
421
+
422
+ # Extract close predictions for metrics calculation
423
+ close_preds = preds_dict['close']
424
+
425
+ # Calculate metrics
426
+ hist_df_for_metrics = df_for_model.tail(CONFIG["VOL_WINDOW"])
427
+ upside_prob, vol_amp_prob = calculate_metrics(hist_df_for_metrics, close_preds)
428
+
429
+ # Prepare result
430
+ last_close = float(df_for_model['close'].iloc[-1])
431
+ last_timestamp = df_for_model['timestamps'].iloc[-1]
432
+
433
+ # Summary statistics for close price
434
+ mean_preds = close_preds.mean(axis=1).tolist()
435
+ min_preds = close_preds.min(axis=1).tolist()
436
+ max_preds = close_preds.max(axis=1).tolist()
437
+
438
+ # Prepare all sample paths for OHLCV (each column is a sample path)
439
+ all_samples = {}
440
+ for price_type in ['open', 'high', 'low', 'close', 'volume']:
441
+ price_df = preds_dict[price_type]
442
+ samples = {}
443
+ for col in price_df.columns:
444
+ samples[col] = price_df[col].tolist()
445
+ all_samples[price_type] = samples
446
+
447
+ elapsed = time.time() - start_time
448
+
449
+ result = {
450
+ "timestamp": datetime.now(timezone.utc).isoformat(),
451
+ "symbol": CONFIG["SYMBOL"],
452
+ "last_close": last_close,
453
+ "last_data_timestamp": last_timestamp.isoformat(),
454
+ "data_mode": data_mode,
455
+ "upside_probability": round(upside_prob * 100, 1),
456
+ "volatility_amplification": round(vol_amp_prob * 100, 1),
457
+ "prediction_horizon_hours": CONFIG["PRED_HORIZON"],
458
+ "sample_count": sample_count,
459
+ "inference_time_seconds": round(elapsed, 1),
460
+ "predictions": {
461
+ "timestamps": [t.isoformat() for t in pred_timestamps],
462
+ "mean": mean_preds,
463
+ "min": min_preds,
464
+ "max": max_preds,
465
+ },
466
+ "all_samples": all_samples,
467
+ "model": {
468
+ "name": "Kronos-mini",
469
+ "tokenizer": "Kronos-Tokenizer-2k",
470
+ "temperature": CONFIG["TEMPERATURE"],
471
+ "top_p": CONFIG["TOP_P"],
472
+ }
473
+ }
474
+
475
+ print(f"[Done] Detail prediction completed in {elapsed:.1f}s")
476
+ return None, result
477
+
478
+ except Exception as e:
479
+ error_result = {
480
+ "error": str(e),
481
+ "timestamp": datetime.now(timezone.utc).isoformat()
482
+ }
483
+ return None, error_result
484
+
485
+
486
+ def predict_custom(
487
+ hist_data_json: str,
488
+ pred_horizon: int = 24,
489
+ sample_count: int = 30,
490
+ temperature: float = 1.0,
491
+ top_p: float = 0.95
492
+ ):
493
+ """
494
+ Custom prediction with user-provided data.
495
+
496
+ Args:
497
+ hist_data_json: JSON string with format:
498
+ {
499
+ "timestamps": ["2024-01-01T00:00:00", ...],
500
+ "open": [100.0, ...],
501
+ "high": [101.0, ...],
502
+ "low": [99.0, ...],
503
+ "close": [100.5, ...],
504
+ "volume": [1000.0, ...], # optional
505
+ "amount": [100000.0, ...] # optional
506
+ }
507
+ pred_horizon: Number of hours to predict (1-48)
508
+ sample_count: Number of Monte Carlo samples (1-100)
509
+ temperature: Sampling temperature (0.1-2.0)
510
+ top_p: Nucleus sampling probability (0.1-1.0)
511
+
512
+ Returns:
513
+ JSON string with predictions
514
+ """
515
+ try:
516
+ pred_model = load_model()
517
+
518
+ # Parse input
519
+ data = json.loads(hist_data_json)
520
+ df = pd.DataFrame(data)
521
+ df['timestamps'] = pd.to_datetime(df['timestamps'])
522
+
523
+ # Ensure required columns
524
+ for col in ['open', 'high', 'low', 'close']:
525
+ if col not in df.columns:
526
+ raise ValueError(f"Missing required column: {col}")
527
+ df[col] = pd.to_numeric(df[col])
528
+
529
+ if 'volume' not in df.columns:
530
+ df['volume'] = 0.0
531
+ if 'amount' not in df.columns:
532
+ df['amount'] = 0.0
533
+
534
+ # Validate parameters
535
+ pred_horizon = max(1, min(48, pred_horizon))
536
+ sample_count = max(1, min(100, sample_count))
537
+ temperature = max(0.1, min(2.0, temperature))
538
+ top_p = max(0.1, min(1.0, top_p))
539
+
540
+ # Prepare timestamps
541
+ last_timestamp = df['timestamps'].max()
542
+ freq = pd.infer_freq(df['timestamps'])
543
+ if freq is None:
544
+ freq = 'h'
545
+
546
+ y_timestamp = pd.Series(
547
+ pd.date_range(start=last_timestamp + pd.Timedelta(hours=1), periods=pred_horizon, freq=freq)
548
+ )
549
+ x_timestamp = df['timestamps']
550
+ x_df = df[['open', 'high', 'low', 'close', 'volume', 'amount']]
551
+
552
+ # Predict
553
+ with torch.no_grad():
554
+ close_preds, volume_preds = pred_model.predict(
555
+ df=x_df,
556
+ x_timestamp=x_timestamp,
557
+ y_timestamp=y_timestamp,
558
+ pred_len=pred_horizon,
559
+ T=temperature,
560
+ top_p=top_p,
561
+ sample_count=sample_count,
562
+ verbose=False
563
+ )
564
+
565
+ # Calculate metrics
566
+ last_close = float(df['close'].iloc[-1])
567
+ final_hour_preds = close_preds.iloc[-1]
568
+ upside_prob = float((final_hour_preds > last_close).mean())
569
+
570
+ result = {
571
+ "timestamp": datetime.now(timezone.utc).isoformat(),
572
+ "last_close": last_close,
573
+ "upside_probability": round(upside_prob * 100, 1),
574
+ "prediction_horizon": pred_horizon,
575
+ "sample_count": sample_count,
576
+ "predictions": {
577
+ "timestamps": [t.isoformat() for t in y_timestamp],
578
+ "mean": close_preds.mean(axis=1).tolist(),
579
+ "min": close_preds.min(axis=1).tolist(),
580
+ "max": close_preds.max(axis=1).tolist(),
581
+ "volume_mean": volume_preds.mean(axis=1).tolist(),
582
+ },
583
+ "parameters": {
584
+ "temperature": temperature,
585
+ "top_p": top_p,
586
+ }
587
+ }
588
+
589
+ return json.dumps(result, indent=2)
590
+
591
+ except Exception as e:
592
+ return json.dumps({"error": str(e)}, indent=2)
593
+
594
+
595
+ # === Gradio Interface ===
596
+ with gr.Blocks(title="Kronos BTC Forecast API") as demo:
597
+ gr.Markdown("""
598
+ # Kronos: BTC/USDT Price Forecast API
599
+
600
+ This Space provides an API for probabilistic BTC/USDT price forecasting using the
601
+ [Kronos](https://github.com/shiyu-coder/Kronos) foundation model.
602
+
603
+ ## Quick Start (Python)
604
+
605
+ ```python
606
+ from gradio_client import Client
607
+
608
+ client = Client("xianqiu/qlang")
609
+
610
+ # Fast API call (no plot, recommended)
611
+ result = client.predict(align_to_hour=True, api_name="/predict_api")
612
+ print(result)
613
+
614
+ # With plot (slower)
615
+ plot, result = client.predict(align_to_hour=True, api_name="/predict")
616
+
617
+ # Detail API - returns all Monte Carlo sample paths with full OHLCV
618
+ result = client.predict(align_to_hour=True, api_name="/predict_all")
619
+ print(result["all_samples"]["open"]) # All 30 open price prediction paths
620
+ print(result["all_samples"]["high"]) # All 30 high price prediction paths
621
+ print(result["all_samples"]["low"]) # All 30 low price prediction paths
622
+ print(result["all_samples"]["close"]) # All 30 close price prediction paths
623
+ print(result["all_samples"]["volume"]) # All 30 volume prediction paths
624
+ ```
625
+
626
+ ## API Endpoints
627
+
628
+ - `/predict_api` - **Recommended**: JSON-only response (faster, no plot)
629
+ - `/predict` - With plot (for visualization)
630
+ - `/predict_all` - Returns all Monte Carlo sample paths with full OHLCV (for detailed analysis)
631
+ - `/predict_custom` - Custom OHLCV data prediction
632
+
633
+ ## Data Mode
634
+
635
+ - **Hourly Aligned (default)**: Uses data up to the last completed hour, matching the official Kronos demo
636
+ - **Realtime**: Uses all available data including the current incomplete hour
637
+ """)
638
+
639
+ with gr.Tab("BTC/USDT Forecast"):
640
+ gr.Markdown("""
641
+ Generate 24-hour BTC/USDT price forecast.
642
+
643
+ **Data Mode:**
644
+ - **Hourly Aligned**: Use data up to last completed hour (matches official demo for comparison)
645
+ - **Realtime**: Use all available data including current incomplete hour
646
+ """)
647
+
648
+ align_checkbox = gr.Checkbox(
649
+ label="Align to Hour (match official demo)",
650
+ value=True,
651
+ info="If checked, excludes current incomplete hour for consistency with official demo"
652
+ )
653
+ predict_btn = gr.Button("Generate Forecast", variant="primary")
654
+
655
+ with gr.Row():
656
+ plot_output = gr.Plot(label="Forecast Chart")
657
+
658
+ json_output = gr.JSON(label="Prediction Result")
659
+
660
+ # UI button - with plot
661
+ predict_btn.click(
662
+ fn=predict_btc,
663
+ inputs=[align_checkbox],
664
+ outputs=[plot_output, json_output],
665
+ api_name="predict"
666
+ )
667
+
668
+ with gr.Tab("API Only (Fast)"):
669
+ gr.Markdown("""
670
+ **Fast API endpoint** - Returns JSON only, no plot generation.
671
+
672
+ Use this for programmatic access when you don't need the chart.
673
+ """)
674
+
675
+ api_align_checkbox = gr.Checkbox(
676
+ label="Align to Hour (match official demo)",
677
+ value=True
678
+ )
679
+ api_btn = gr.Button("Get Prediction (API)", variant="primary")
680
+ api_json_output = gr.JSON(label="Prediction Result")
681
+
682
+ api_btn.click(
683
+ fn=predict_btc_api,
684
+ inputs=[api_align_checkbox],
685
+ outputs=[api_json_output],
686
+ api_name="predict_api"
687
+ )
688
+
689
+ with gr.Tab("Detail API (All Samples)"):
690
+ gr.Markdown("""
691
+ **Detail API endpoint** - Returns all Monte Carlo sample paths with full OHLCV data.
692
+
693
+ Use this for detailed analysis when you need all individual prediction paths, not just summary statistics (mean/min/max).
694
+
695
+ **Response includes:**
696
+ - `predictions`: Summary statistics for close price (mean, min, max)
697
+ - `all_samples.open`: All open price prediction paths (pred-1, pred-2, ..., pred-N)
698
+ - `all_samples.high`: All high price prediction paths
699
+ - `all_samples.low`: All low price prediction paths
700
+ - `all_samples.close`: All close price prediction paths
701
+ - `all_samples.volume`: All volume prediction paths
702
+ """)
703
+
704
+ detail_align_checkbox = gr.Checkbox(
705
+ label="Align to Hour (match official demo)",
706
+ value=True
707
+ )
708
+ detail_btn = gr.Button("Get Detail Prediction", variant="primary")
709
+ detail_json_output = gr.JSON(label="Detail Prediction Result")
710
+
711
+ detail_btn.click(
712
+ fn=predict_btc_detail,
713
+ inputs=[detail_align_checkbox],
714
+ outputs=[detail_json_output],
715
+ api_name="predict_all"
716
+ )
717
+
718
+ with gr.Tab("Custom Prediction"):
719
+ gr.Markdown("""
720
+ Provide your own OHLCV data for prediction.
721
+
722
+ **Input Format:**
723
+ ```json
724
+ {
725
+ "timestamps": ["2024-01-01T00:00:00", "2024-01-01T01:00:00", ...],
726
+ "open": [100.0, 101.0, ...],
727
+ "high": [101.0, 102.0, ...],
728
+ "low": [99.0, 100.0, ...],
729
+ "close": [100.5, 101.5, ...],
730
+ "volume": [1000.0, 1100.0, ...],
731
+ "amount": [100000.0, 110000.0, ...]
732
+ }
733
+ ```
734
+ """)
735
+
736
+ with gr.Row():
737
+ with gr.Column():
738
+ data_input = gr.Textbox(
739
+ label="Historical Data (JSON)",
740
+ placeholder='{"timestamps": [...], "open": [...], "high": [...], "low": [...], "close": [...]}',
741
+ lines=10
742
+ )
743
+
744
+ with gr.Row():
745
+ horizon_input = gr.Slider(1, 48, value=24, step=1, label="Prediction Horizon (hours)")
746
+ samples_input = gr.Slider(1, 100, value=30, step=1, label="Sample Count")
747
+
748
+ with gr.Row():
749
+ temp_input = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature")
750
+ topp_input = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
751
+
752
+ custom_btn = gr.Button("Predict", variant="primary")
753
+
754
+ with gr.Column():
755
+ custom_output = gr.JSON(label="Prediction Result")
756
+
757
+ custom_btn.click(
758
+ fn=predict_custom,
759
+ inputs=[data_input, horizon_input, samples_input, temp_input, topp_input],
760
+ outputs=custom_output,
761
+ api_name="predict_custom"
762
+ )
763
+
764
+ gr.Markdown("""
765
+ ---
766
+ **Model:** Kronos-mini (4.1M params) | **Paper:** [arXiv:2508.02739](https://arxiv.org/abs/2508.02739)
767
+ """)
768
+
769
+
770
+ # Pre-load model on startup
771
+ print("Pre-loading model...")
772
+ load_model()
773
+ print("Model ready!")
774
+
775
+ if __name__ == "__main__":
776
+ demo.launch()
client.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Kronos BTC 预测 API 测试客户端
4
+
5
+ 可直接运行来验证 HuggingFace Space API 是否正常工作。
6
+
7
+ 使用方法:
8
+ # 测试健康检查
9
+ python client.py health
10
+
11
+ # 测试预测 API
12
+ python client.py predict
13
+
14
+ # 测试交易信号 API
15
+ python client.py signal
16
+
17
+ # 运行所有测试
18
+ python client.py all
19
+
20
+ # 使用自定义 URL
21
+ python client.py all --url https://your-space.hf.space
22
+ """
23
+
24
+ import argparse
25
+ import json
26
+ import sys
27
+ import time
28
+ from datetime import datetime, timedelta
29
+ from typing import List, Dict, Any, Optional
30
+
31
+ import requests
32
+
33
+
34
+ # ==================== 配置 ====================
35
+
36
+ DEFAULT_API_URL = "https://xianqiu-qlang.hf.space"
37
+
38
+ # 币安 API
39
+ BINANCE_API = "https://api.binance.com/api/v3/klines"
40
+
41
+
42
+ # ==================== 辅助函数 ====================
43
+
44
+ def fetch_btc_data(symbol: str = "BTCUSDT", interval: str = "1h", limit: int = 200) -> List[Dict]:
45
+ """
46
+ 从币安获取 BTC K线数据
47
+
48
+ Args:
49
+ symbol: 交易对
50
+ interval: K线周期 (1h, 4h, 1d 等)
51
+ limit: 获取条数 (最大 1000)
52
+
53
+ Returns:
54
+ OHLCV 数据列表
55
+ """
56
+ print(f"[Binance] 获取 {symbol} {interval} K线数据 (最近 {limit} 条)...")
57
+
58
+ params = {
59
+ "symbol": symbol,
60
+ "interval": interval,
61
+ "limit": limit
62
+ }
63
+
64
+ try:
65
+ response = requests.get(BINANCE_API, params=params, timeout=10)
66
+ response.raise_for_status()
67
+ data = response.json()
68
+ except requests.exceptions.RequestException as e:
69
+ print(f"[Error] 无法连接币安 API: {e}")
70
+ print("[Info] 使用模拟数据...")
71
+ return generate_mock_data(limit)
72
+
73
+ ohlcv_list = []
74
+ for item in data:
75
+ ohlcv_list.append({
76
+ "timestamp": datetime.fromtimestamp(item[0] / 1000).isoformat(),
77
+ "open": float(item[1]),
78
+ "high": float(item[2]),
79
+ "low": float(item[3]),
80
+ "close": float(item[4]),
81
+ "volume": float(item[5]),
82
+ "amount": float(item[7]) # Quote asset volume
83
+ })
84
+
85
+ print(f"[OK] 获取到 {len(ohlcv_list)} 条数据")
86
+ print(f" 时间范围: {ohlcv_list[0]['timestamp']} ~ {ohlcv_list[-1]['timestamp']}")
87
+ print(f" 当前价格: ${ohlcv_list[-1]['close']:,.2f}")
88
+
89
+ return ohlcv_list
90
+
91
+
92
+ def generate_mock_data(n: int = 200) -> List[Dict]:
93
+ """生成模拟 K线数据 (当币安 API 不可用时使用)"""
94
+ import random
95
+
96
+ base_price = 100000.0
97
+ data = []
98
+ current_time = datetime.utcnow() - timedelta(hours=n)
99
+
100
+ for i in range(n):
101
+ change = random.gauss(0, 0.01) # 1% 标准差
102
+ base_price *= (1 + change)
103
+
104
+ high = base_price * (1 + random.random() * 0.005)
105
+ low = base_price * (1 - random.random() * 0.005)
106
+ close = random.uniform(low, high)
107
+
108
+ data.append({
109
+ "timestamp": current_time.isoformat(),
110
+ "open": round(base_price, 2),
111
+ "high": round(high, 2),
112
+ "low": round(low, 2),
113
+ "close": round(close, 2),
114
+ "volume": round(random.uniform(100, 1000), 2),
115
+ "amount": round(random.uniform(1000000, 10000000), 2)
116
+ })
117
+
118
+ current_time += timedelta(hours=1)
119
+
120
+ return data
121
+
122
+
123
+ def print_json(data: Any, title: str = None):
124
+ """美化打印 JSON"""
125
+ if title:
126
+ print(f"\n{'='*60}")
127
+ print(f" {title}")
128
+ print(f"{'='*60}")
129
+ print(json.dumps(data, indent=2, ensure_ascii=False))
130
+
131
+
132
+ # ==================== API 测试函数 ====================
133
+
134
+ def test_health(base_url: str) -> bool:
135
+ """测试健康检查 API"""
136
+ print("\n" + "="*60)
137
+ print(" TEST: /health")
138
+ print("="*60)
139
+
140
+ url = f"{base_url}/health"
141
+ print(f"[Request] GET {url}")
142
+
143
+ try:
144
+ start = time.time()
145
+ response = requests.get(url, timeout=30)
146
+ elapsed = time.time() - start
147
+
148
+ print(f"[Response] Status: {response.status_code} ({elapsed:.2f}s)")
149
+
150
+ if response.status_code == 200:
151
+ data = response.json()
152
+ print(f"\n[Result]")
153
+ print(f" Status: {data.get('status', 'N/A')}")
154
+ print(f" Model Loaded: {data.get('model_loaded', 'N/A')}")
155
+ print(f" Model Version: {data.get('model_version', 'N/A')}")
156
+ print(f" Device: {data.get('device', 'N/A')}")
157
+ print(f" Timestamp: {data.get('timestamp', 'N/A')}")
158
+ return True
159
+ else:
160
+ print(f"[Error] {response.text}")
161
+ return False
162
+
163
+ except requests.exceptions.RequestException as e:
164
+ print(f"[Error] 请求失败: {e}")
165
+ return False
166
+
167
+
168
+ def test_predict(base_url: str, data: List[Dict] = None) -> bool:
169
+ """测试预测 API"""
170
+ print("\n" + "="*60)
171
+ print(" TEST: /predict")
172
+ print("="*60)
173
+
174
+ # 获取数据
175
+ if data is None:
176
+ data = fetch_btc_data(limit=200)
177
+
178
+ url = f"{base_url}/predict"
179
+ payload = {
180
+ "data": data,
181
+ "pred_len": 24,
182
+ "n_paths": 30,
183
+ "temperature": 1.0,
184
+ "top_p": 0.9
185
+ }
186
+
187
+ print(f"\n[Request] POST {url}")
188
+ print(f" 数据点数: {len(data)}")
189
+ print(f" 预测长度: {payload['pred_len']} 小时")
190
+ print(f" Monte Carlo: {payload['n_paths']} 路径")
191
+
192
+ try:
193
+ start = time.time()
194
+ response = requests.post(url, json=payload, timeout=120)
195
+ elapsed = time.time() - start
196
+
197
+ print(f"\n[Response] Status: {response.status_code} ({elapsed:.2f}s)")
198
+
199
+ if response.status_code == 200:
200
+ result = response.json()
201
+ print(f"\n[Result]")
202
+ print(f" 当前价格: ${result.get('current_price', 0):,.2f}")
203
+ print(f" 预测均值: ${result.get('mean_forecast', 0):,.2f}")
204
+ print(f" 预测范围: ${result.get('min_forecast', 0):,.2f} ~ ${result.get('max_forecast', 0):,.2f}")
205
+ print(f" 上涨概率: {result.get('upside_probability', 0)*100:.1f}%")
206
+ print(f" 预期收益: {result.get('expected_return', 0)*100:.2f}%")
207
+ print(f" 波动放大: {result.get('volatility_amplification', 0):.2f}x")
208
+ print(f" 置信度: {result.get('confidence', 0)*100:.1f}%")
209
+ print(f" 预测点数: {len(result.get('forecast_prices', []))} 个")
210
+
211
+ # 显示部分预测价格
212
+ prices = result.get('forecast_prices', [])
213
+ if prices:
214
+ print(f"\n 预测价格趋势 (每6小时):")
215
+ for i in range(0, len(prices), 6):
216
+ print(f" +{i}h: ${prices[i]:,.2f}")
217
+
218
+ return True
219
+ elif response.status_code == 503:
220
+ print(f"[Warning] 模型未加载,请稍后重试")
221
+ print(f" Response: {response.text}")
222
+ return False
223
+ else:
224
+ print(f"[Error] {response.text}")
225
+ return False
226
+
227
+ except requests.exceptions.Timeout:
228
+ print(f"[Error] 请求超时 (>120s)")
229
+ return False
230
+ except requests.exceptions.RequestException as e:
231
+ print(f"[Error] 请求失败: {e}")
232
+ return False
233
+
234
+
235
+ def test_signal(base_url: str, data: List[Dict] = None) -> bool:
236
+ """测试交易信号 API"""
237
+ print("\n" + "="*60)
238
+ print(" TEST: /signal")
239
+ print("="*60)
240
+
241
+ # 获取数据
242
+ if data is None:
243
+ data = fetch_btc_data(limit=200)
244
+
245
+ url = f"{base_url}/signal"
246
+ payload = {
247
+ "data": data,
248
+ "buy_threshold": 0.58,
249
+ "sell_threshold": 0.42,
250
+ "stop_loss": 0.03,
251
+ "take_profit": 0.08,
252
+ "n_paths": 30
253
+ }
254
+
255
+ print(f"\n[Request] POST {url}")
256
+ print(f" 数据点数: {len(data)}")
257
+ print(f" 买入阈值: {payload['buy_threshold']}")
258
+ print(f" 卖出阈值: {payload['sell_threshold']}")
259
+ print(f" 止损比例: {payload['stop_loss']*100:.1f}%")
260
+ print(f" 止盈比例: {payload['take_profit']*100:.1f}%")
261
+
262
+ try:
263
+ start = time.time()
264
+ response = requests.post(url, json=payload, timeout=120)
265
+ elapsed = time.time() - start
266
+
267
+ print(f"\n[Response] Status: {response.status_code} ({elapsed:.2f}s)")
268
+
269
+ if response.status_code == 200:
270
+ result = response.json()
271
+ signal = result.get('signal', 'N/A')
272
+
273
+ # 信号颜色
274
+ signal_icons = {
275
+ 'STRONG_BUY': '[++]',
276
+ 'BUY': '[+]',
277
+ 'HOLD': '[=]',
278
+ 'SELL': '[-]',
279
+ 'STRONG_SELL': '[--]'
280
+ }
281
+
282
+ print(f"\n[Result]")
283
+ print(f" 信号: {signal_icons.get(signal, '')} {signal}")
284
+ print(f" 置信度: {result.get('confidence', 0)*100:.1f}%")
285
+ print(f" 当前价格: ${result.get('current_price', 0):,.2f}")
286
+ print(f" 目标价格: ${result.get('target_price', 0):,.2f}")
287
+ print(f" 止损价格: ${result.get('stop_loss_price', 0):,.2f}")
288
+ print(f" 止盈价格: ${result.get('take_profit_price', 0):,.2f}")
289
+ print(f" 上涨概率: {result.get('upside_probability', 0)*100:.1f}%")
290
+ print(f" 预期收益: {result.get('expected_return', 0)*100:.2f}%")
291
+ print(f" 建议仓位: {result.get('suggested_position_size', 0)*100:.1f}%")
292
+ print(f" 原因: {result.get('reason', 'N/A')}")
293
+
294
+ return True
295
+ elif response.status_code == 503:
296
+ print(f"[Warning] 模型未加载,请稍后重试")
297
+ print(f" Response: {response.text}")
298
+ return False
299
+ else:
300
+ print(f"[Error] {response.text}")
301
+ return False
302
+
303
+ except requests.exceptions.Timeout:
304
+ print(f"[Error] 请求超时 (>120s)")
305
+ return False
306
+ except requests.exceptions.RequestException as e:
307
+ print(f"[Error] 请求失败: {e}")
308
+ return False
309
+
310
+
311
+ def run_all_tests(base_url: str) -> bool:
312
+ """运行所有测试"""
313
+ print("\n" + "#"*60)
314
+ print("#")
315
+ print("# Kronos BTC 预测 API 测试")
316
+ print(f"# URL: {base_url}")
317
+ print(f"# 时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
318
+ print("#")
319
+ print("#"*60)
320
+
321
+ results = {}
322
+
323
+ # 1. 健康检查
324
+ results['health'] = test_health(base_url)
325
+
326
+ if not results['health']:
327
+ print("\n[Warning] 健康检查失败,API 可能未启动")
328
+ print(" 请检查 HuggingFace Space 是否正在运行")
329
+ return False
330
+
331
+ # 2. 获取数据 (只获取一次,两个测试共用)
332
+ print("\n" + "-"*60)
333
+ data = fetch_btc_data(limit=200)
334
+
335
+ # 3. 预测测试
336
+ results['predict'] = test_predict(base_url, data)
337
+
338
+ # 4. 信号测试
339
+ results['signal'] = test_signal(base_url, data)
340
+
341
+ # 汇总
342
+ print("\n" + "="*60)
343
+ print(" 测试结果汇总")
344
+ print("="*60)
345
+
346
+ for test_name, passed in results.items():
347
+ status = "PASS" if passed else "FAIL"
348
+ icon = "[OK]" if passed else "[X]"
349
+ print(f" {icon} {test_name}: {status}")
350
+
351
+ all_passed = all(results.values())
352
+
353
+ print("\n" + "-"*60)
354
+ if all_passed:
355
+ print(" 所有测试通过!")
356
+ else:
357
+ print(" 部分测试失败,请检查 API 状态")
358
+ print("-"*60)
359
+
360
+ return all_passed
361
+
362
+
363
+ # ==================== 主函数 ====================
364
+
365
+ def main():
366
+ parser = argparse.ArgumentParser(
367
+ description="Kronos BTC 预测 API 测试客户端",
368
+ formatter_class=argparse.RawDescriptionHelpFormatter,
369
+ epilog="""
370
+ 示例:
371
+ python client.py health # 测试健康检查
372
+ python client.py predict # 测试预测 API
373
+ python client.py signal # 测试交易信号 API
374
+ python client.py all # 运行所有测试
375
+ python client.py all --url http://localhost:7860 # 测试本地服务
376
+ """
377
+ )
378
+
379
+ parser.add_argument(
380
+ "command",
381
+ choices=["health", "predict", "signal", "all"],
382
+ help="要执行的测试命令"
383
+ )
384
+
385
+ parser.add_argument(
386
+ "--url",
387
+ default=DEFAULT_API_URL,
388
+ help=f"API 地址 (默认: {DEFAULT_API_URL})"
389
+ )
390
+
391
+ args = parser.parse_args()
392
+
393
+ # 执行测试
394
+ if args.command == "health":
395
+ success = test_health(args.url)
396
+ elif args.command == "predict":
397
+ success = test_predict(args.url)
398
+ elif args.command == "signal":
399
+ success = test_signal(args.url)
400
+ elif args.command == "all":
401
+ success = run_all_tests(args.url)
402
+ else:
403
+ parser.print_help()
404
+ sys.exit(1)
405
+
406
+ sys.exit(0 if success else 1)
407
+
408
+
409
+ if __name__ == "__main__":
410
+ main()
model/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .kronos import KronosTokenizer, Kronos, KronosPredictor
2
+
3
+ model_dict = {
4
+ 'kronos_tokenizer': KronosTokenizer,
5
+ 'kronos': Kronos,
6
+ 'kronos_predictor': KronosPredictor
7
+ }
8
+
9
+
10
+ def get_model_class(model_name):
11
+ if model_name in model_dict:
12
+ return model_dict[model_name]
13
+ else:
14
+ print(f"Model {model_name} not found in model_dict")
15
+ raise NotImplementedError
16
+
17
+
model/kronos.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+ import sys
6
+
7
+ from tqdm import trange
8
+
9
+ sys.path.append("../")
10
+ from model.module import *
11
+
12
+
13
+ class KronosTokenizer(nn.Module, PyTorchModelHubMixin):
14
+ """
15
+ KronosTokenizer module for tokenizing input data using a hybrid quantization approach.
16
+
17
+ This tokenizer utilizes a combination of encoder and decoder Transformer blocks
18
+ along with the Binary Spherical Quantization (BSQuantizer) to compress and decompress input data.
19
+
20
+ Args:
21
+ d_in (int): Input dimension.
22
+ d_model (int): Model dimension.
23
+ n_heads (int): Number of attention heads.
24
+ ff_dim (int): Feed-forward dimension.
25
+ n_enc_layers (int): Number of encoder layers.
26
+ n_dec_layers (int): Number of decoder layers.
27
+ ffn_dropout_p (float): Dropout probability for feed-forward networks.
28
+ attn_dropout_p (float): Dropout probability for attention mechanisms.
29
+ resid_dropout_p (float): Dropout probability for residual connections.
30
+ s1_bits (int): Number of bits for the pre token in BSQuantizer.
31
+ s2_bits (int): Number of bits for the post token in BSQuantizer.
32
+ beta (float): Beta parameter for BSQuantizer.
33
+ gamma0 (float): Gamma0 parameter for BSQuantizer.
34
+ gamma (float): Gamma parameter for BSQuantizer.
35
+ zeta (float): Zeta parameter for BSQuantizer.
36
+ group_size (int): Group size parameter for BSQuantizer.
37
+
38
+ """
39
+
40
+ def __init__(self, d_in, d_model, n_heads, ff_dim, n_enc_layers, n_dec_layers, ffn_dropout_p, attn_dropout_p, resid_dropout_p, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size):
41
+
42
+ super().__init__()
43
+ self.d_in = d_in
44
+ self.d_model = d_model
45
+ self.n_heads = n_heads
46
+ self.ff_dim = ff_dim
47
+ self.enc_layers = n_enc_layers
48
+ self.dec_layers = n_dec_layers
49
+ self.ffn_dropout_p = ffn_dropout_p
50
+ self.attn_dropout_p = attn_dropout_p
51
+ self.resid_dropout_p = resid_dropout_p
52
+
53
+ self.s1_bits = s1_bits
54
+ self.s2_bits = s2_bits
55
+ self.codebook_dim = s1_bits + s2_bits # Total dimension of the codebook after quantization
56
+ self.embed = nn.Linear(self.d_in, self.d_model)
57
+ self.head = nn.Linear(self.d_model, self.d_in)
58
+
59
+ # Encoder Transformer Blocks
60
+ self.encoder = nn.ModuleList([
61
+ TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
62
+ for _ in range(self.enc_layers - 1)
63
+ ])
64
+ # Decoder Transformer Blocks
65
+ self.decoder = nn.ModuleList([
66
+ TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
67
+ for _ in range(self.dec_layers - 1)
68
+ ])
69
+ self.quant_embed = nn.Linear(in_features=self.d_model, out_features=self.codebook_dim) # Linear layer before quantization
70
+ self.post_quant_embed_pre = nn.Linear(in_features=self.s1_bits, out_features=self.d_model) # Linear layer after quantization (pre part - s1 bits)
71
+ self.post_quant_embed = nn.Linear(in_features=self.codebook_dim, out_features=self.d_model) # Linear layer after quantization (full codebook)
72
+ self.tokenizer = BSQuantizer(self.s1_bits, self.s2_bits, beta, gamma0, gamma, zeta, group_size) # BSQuantizer module
73
+
74
+ def forward(self, x):
75
+ """
76
+ Forward pass of the KronosTokenizer.
77
+
78
+ Args:
79
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in).
80
+
81
+ Returns:
82
+ tuple: A tuple containing:
83
+ - tuple: (z_pre, z) - Reconstructed outputs from decoder with s1_bits and full codebook respectively,
84
+ both of shape (batch_size, seq_len, d_in).
85
+ - torch.Tensor: bsq_loss - Loss from the BSQuantizer.
86
+ - torch.Tensor: quantized - Quantized representation from BSQuantizer.
87
+ - torch.Tensor: z_indices - Indices from the BSQuantizer.
88
+ """
89
+ z = self.embed(x)
90
+
91
+ for layer in self.encoder:
92
+ z = layer(z)
93
+
94
+ z = self.quant_embed(z) # (B, T, codebook)
95
+
96
+ bsq_loss, quantized, z_indices = self.tokenizer(z)
97
+
98
+ quantized_pre = quantized[:, :, :self.s1_bits] # Extract the first part of quantized representation (s1_bits)
99
+ z_pre = self.post_quant_embed_pre(quantized_pre)
100
+
101
+ z = self.post_quant_embed(quantized)
102
+
103
+ # Decoder layers (for pre part - s1 bits)
104
+ for layer in self.decoder:
105
+ z_pre = layer(z_pre)
106
+ z_pre = self.head(z_pre)
107
+
108
+ # Decoder layers (for full codebook)
109
+ for layer in self.decoder:
110
+ z = layer(z)
111
+ z = self.head(z)
112
+
113
+ return (z_pre, z), bsq_loss, quantized, z_indices
114
+
115
+ def indices_to_bits(self, x, half=False):
116
+ """
117
+ Converts indices to bit representations and scales them.
118
+
119
+ Args:
120
+ x (torch.Tensor): Indices tensor.
121
+ half (bool, optional): Whether to process only half of the codebook dimension. Defaults to False.
122
+
123
+ Returns:
124
+ torch.Tensor: Bit representation tensor.
125
+ """
126
+ if half:
127
+ x1 = x[0] # Assuming x is a tuple of indices if half is True
128
+ x2 = x[1]
129
+ mask = 2 ** torch.arange(self.codebook_dim//2, device=x1.device, dtype=torch.long) # Create a mask for bit extraction
130
+ x1 = (x1.unsqueeze(-1) & mask) != 0 # Extract bits for the first half
131
+ x2 = (x2.unsqueeze(-1) & mask) != 0 # Extract bits for the second half
132
+ x = torch.cat([x1, x2], dim=-1) # Concatenate the bit representations
133
+ else:
134
+ mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) # Create a mask for bit extraction
135
+ x = (x.unsqueeze(-1) & mask) != 0 # Extract bits
136
+
137
+ x = x.float() * 2 - 1 # Convert boolean to bipolar (-1, 1)
138
+ q_scale = 1. / (self.codebook_dim ** 0.5) # Scaling factor
139
+ x = x * q_scale
140
+ return x
141
+
142
+ def encode(self, x, half=False):
143
+ """
144
+ Encodes the input data into quantized indices.
145
+
146
+ Args:
147
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in).
148
+ half (bool, optional): Whether to use half quantization in BSQuantizer. Defaults to False.
149
+
150
+ Returns:
151
+ torch.Tensor: Quantized indices from BSQuantizer.
152
+ """
153
+ z = self.embed(x)
154
+ for layer in self.encoder:
155
+ z = layer(z)
156
+ z = self.quant_embed(z)
157
+
158
+ bsq_loss, quantized, z_indices = self.tokenizer(z, half)
159
+ return z_indices
160
+
161
+ def decode(self, x, half=False):
162
+ """
163
+ Decodes quantized indices back to the input data space.
164
+
165
+ Args:
166
+ x (torch.Tensor): Quantized indices tensor.
167
+ half (bool, optional): Whether the indices were generated with half quantization. Defaults to False.
168
+
169
+ Returns:
170
+ torch.Tensor: Reconstructed output tensor of shape (batch_size, seq_len, d_in).
171
+ """
172
+ quantized = self.indices_to_bits(x, half)
173
+ z = self.post_quant_embed(quantized)
174
+ for layer in self.decoder:
175
+ z = layer(z)
176
+ z = self.head(z)
177
+ return z
178
+
179
+
180
+ class Kronos(nn.Module, PyTorchModelHubMixin):
181
+ """
182
+ Kronos Model.
183
+
184
+ Args:
185
+ s1_bits (int): Number of bits for pre tokens.
186
+ s2_bits (int): Number of bits for post tokens.
187
+ n_layers (int): Number of Transformer blocks.
188
+ d_model (int): Dimension of the model's embeddings and hidden states.
189
+ n_heads (int): Number of attention heads in the MultiheadAttention layers.
190
+ ff_dim (int): Dimension of the feedforward network in the Transformer blocks.
191
+ ffn_dropout_p (float): Dropout probability for the feedforward network.
192
+ attn_dropout_p (float): Dropout probability for the attention layers.
193
+ resid_dropout_p (float): Dropout probability for residual connections.
194
+ token_dropout_p (float): Dropout probability for token embeddings.
195
+ learn_te (bool): Whether to use learnable temporal embeddings.
196
+ """
197
+
198
+ def __init__(self, s1_bits, s2_bits, n_layers, d_model, n_heads, ff_dim, ffn_dropout_p, attn_dropout_p, resid_dropout_p, token_dropout_p, learn_te):
199
+ super().__init__()
200
+ self.s1_bits = s1_bits
201
+ self.s2_bits = s2_bits
202
+ self.n_layers = n_layers
203
+ self.d_model = d_model
204
+ self.n_heads = n_heads
205
+ self.learn_te = learn_te
206
+ self.ff_dim = ff_dim
207
+ self.ffn_dropout_p = ffn_dropout_p
208
+ self.attn_dropout_p = attn_dropout_p
209
+ self.resid_dropout_p = resid_dropout_p
210
+ self.token_dropout_p = token_dropout_p
211
+
212
+ self.s1_vocab_size = 2 ** self.s1_bits
213
+ self.token_drop = nn.Dropout(self.token_dropout_p)
214
+ self.embedding = HierarchicalEmbedding(self.s1_bits, self.s2_bits, self.d_model)
215
+ self.time_emb = TemporalEmbedding(self.d_model, self.learn_te)
216
+ self.transformer = nn.ModuleList([
217
+ TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
218
+ for _ in range(self.n_layers)
219
+ ])
220
+ self.norm = RMSNorm(self.d_model)
221
+ self.dep_layer = DependencyAwareLayer(self.d_model)
222
+ self.head = DualHead(self.s1_bits, self.s2_bits, self.d_model)
223
+ self.apply(self._init_weights)
224
+
225
+ def _init_weights(self, module):
226
+
227
+ if isinstance(module, nn.Linear):
228
+ nn.init.xavier_normal_(module.weight)
229
+ if module.bias is not None:
230
+ nn.init.zeros_(module.bias)
231
+ elif isinstance(module, nn.Embedding):
232
+ nn.init.normal_(module.weight, mean=0, std=self.embedding.d_model ** -0.5)
233
+ elif isinstance(module, nn.LayerNorm):
234
+ nn.init.ones_(module.weight)
235
+ nn.init.zeros_(module.bias)
236
+ elif isinstance(module, RMSNorm):
237
+ nn.init.ones_(module.weight)
238
+
239
+ def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_forcing=False, s1_targets=None):
240
+ """
241
+ Args:
242
+ s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
243
+ s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len]
244
+ stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None.
245
+ padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
246
+ use_teacher_forcing (bool, optional): Whether to use teacher forcing for s1 decoding. Defaults to False.
247
+ s1_targets (torch.Tensor, optional): Target s1 token IDs for teacher forcing. Shape: [batch_size, seq_len]. Defaults to None.
248
+
249
+ Returns:
250
+ Tuple[torch.Tensor, torch.Tensor]:
251
+ - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size]
252
+ - s2_logits: Logits for s2 token predictions, conditioned on s1. Shape: [batch_size, seq_len, s2_vocab_size]
253
+ """
254
+ x = self.embedding([s1_ids, s2_ids])
255
+ if stamp is not None:
256
+ time_embedding = self.time_emb(stamp)
257
+ x = x + time_embedding
258
+ x = self.token_drop(x)
259
+
260
+ for layer in self.transformer:
261
+ x = layer(x, key_padding_mask=padding_mask)
262
+
263
+ x = self.norm(x)
264
+
265
+ s1_logits = self.head(x)
266
+
267
+ if use_teacher_forcing:
268
+ sibling_embed = self.embedding.emb_s1(s1_targets)
269
+ else:
270
+ s1_probs = F.softmax(s1_logits.detach(), dim=-1)
271
+ sample_s1_ids = torch.multinomial(s1_probs.view(-1, self.s1_vocab_size), 1).view(s1_ids.shape)
272
+ sibling_embed = self.embedding.emb_s1(sample_s1_ids)
273
+
274
+ x2 = self.dep_layer(x, sibling_embed, key_padding_mask=padding_mask) # Dependency Aware Layer: Condition on s1 embeddings
275
+ s2_logits = self.head.cond_forward(x2)
276
+ return s1_logits, s2_logits
277
+
278
+ def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None):
279
+ """
280
+ Decodes only the s1 tokens.
281
+
282
+ This method performs a forward pass to predict only s1 tokens. It returns the s1 logits
283
+ and the context representation from the Transformer, which can be used for subsequent s2 decoding.
284
+
285
+ Args:
286
+ s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
287
+ s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len]
288
+ stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None.
289
+ padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
290
+
291
+ Returns:
292
+ Tuple[torch.Tensor, torch.Tensor]:
293
+ - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size]
294
+ - context: Context representation from the Transformer. Shape: [batch_size, seq_len, d_model]
295
+ """
296
+ x = self.embedding([s1_ids, s2_ids])
297
+ if stamp is not None:
298
+ time_embedding = self.time_emb(stamp)
299
+ x = x + time_embedding
300
+ x = self.token_drop(x)
301
+
302
+ for layer in self.transformer:
303
+ x = layer(x, key_padding_mask=padding_mask)
304
+
305
+ x = self.norm(x)
306
+
307
+ s1_logits = self.head(x)
308
+ return s1_logits, x
309
+
310
+ def decode_s2(self, context, s1_ids, padding_mask=None):
311
+ """
312
+ Decodes the s2 tokens, conditioned on the context and s1 tokens.
313
+
314
+ This method decodes s2 tokens based on a pre-computed context representation (typically from `decode_s1`)
315
+ and the s1 token IDs. It uses the dependency-aware layer and the conditional s2 head to predict s2 tokens.
316
+
317
+ Args:
318
+ context (torch.Tensor): Context representation from the transformer (output of decode_s1).
319
+ Shape: [batch_size, seq_len, d_model]
320
+ s1_ids (torch.torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
321
+ padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
322
+
323
+ Returns:
324
+ torch.Tensor: s2 logits. Shape: [batch_size, seq_len, s2_vocab_size]
325
+ """
326
+ sibling_embed = self.embedding.emb_s1(s1_ids)
327
+ x2 = self.dep_layer(context, sibling_embed, key_padding_mask=padding_mask)
328
+ return self.head.cond_forward(x2)
329
+
330
+
331
+ def top_k_top_p_filtering(
332
+ logits,
333
+ top_k: int = 0,
334
+ top_p: float = 1.0,
335
+ filter_value: float = -float("Inf"),
336
+ min_tokens_to_keep: int = 1,
337
+ ):
338
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
339
+ Args:
340
+ logits: logits distribution shape (batch size, vocabulary size)
341
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
342
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
343
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
344
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
345
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
346
+ """
347
+ if top_k > 0:
348
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
349
+ # Remove all tokens with a probability less than the last token of the top-k
350
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
351
+ logits[indices_to_remove] = filter_value
352
+ return logits
353
+
354
+ if top_p < 1.0:
355
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
356
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
357
+
358
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
359
+ sorted_indices_to_remove = cumulative_probs > top_p
360
+ if min_tokens_to_keep > 1:
361
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
362
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
363
+ # Shift the indices to the right to keep also the first token above the threshold
364
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
365
+ sorted_indices_to_remove[..., 0] = 0
366
+
367
+ # scatter sorted tensors to original indexing
368
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
369
+ logits[indices_to_remove] = filter_value
370
+ return logits
371
+
372
+
373
+ def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_logits=True):
374
+ logits = logits / temperature
375
+ if top_k is not None or top_p is not None:
376
+ if top_k > 0 or top_p < 1.0:
377
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
378
+
379
+ probs = F.softmax(logits, dim=-1)
380
+
381
+ if not sample_logits:
382
+ _, x = top_k(probs, k=1, dim=-1)
383
+ else:
384
+ x = torch.multinomial(probs, num_samples=1)
385
+
386
+ return x
387
+
388
+
389
+ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False):
390
+ with torch.no_grad():
391
+ batch_size = x.size(0)
392
+ initial_seq_len = x.size(1)
393
+ x = torch.clip(x, -clip, clip)
394
+
395
+ device = x.device
396
+ x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x.size(1), x.size(2)).to(device)
397
+ x_stamp = x_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x_stamp.size(1), x_stamp.size(2)).to(device)
398
+ y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device)
399
+
400
+ x_token = tokenizer.encode(x, half=True)
401
+
402
+ def get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, pred_step):
403
+
404
+ if current_seq_len <= max_context - pred_step:
405
+ return torch.cat([x_stamp, y_stamp[:, :pred_step, :]], dim=1)
406
+ else:
407
+ start_idx = max_context - pred_step
408
+ return torch.cat([x_stamp[:, -start_idx:, :], y_stamp[:, :pred_step, :]], dim=1)
409
+
410
+ if verbose:
411
+ ran = trange
412
+ else:
413
+ ran = range
414
+ for i in ran(pred_len):
415
+ current_seq_len = initial_seq_len + i
416
+
417
+ if current_seq_len <= max_context:
418
+ input_tokens = x_token
419
+ else:
420
+ input_tokens = [t[:, -max_context:].contiguous() for t in x_token]
421
+
422
+ current_stamp = get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, i)
423
+
424
+ s1_logits, context = model.decode_s1(input_tokens[0], input_tokens[1], current_stamp)
425
+ s1_logits = s1_logits[:, -1, :]
426
+ sample_pre = sample_from_logits(s1_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True)
427
+
428
+ s2_logits = model.decode_s2(context, sample_pre)
429
+ s2_logits = s2_logits[:, -1, :]
430
+ sample_post = sample_from_logits(s2_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True)
431
+
432
+ x_token[0] = torch.cat([x_token[0], sample_pre], dim=1)
433
+ x_token[1] = torch.cat([x_token[1], sample_post], dim=1)
434
+
435
+ input_tokens = [t[:, -max_context:].contiguous() for t in x_token]
436
+ z = tokenizer.decode(input_tokens, half=True)
437
+ z = z.reshape(batch_size, sample_count, z.size(1), z.size(2))
438
+ preds = z.cpu().numpy()
439
+ # preds = np.mean(preds, axis=1)
440
+
441
+ return preds
442
+
443
+
444
+ def calc_time_stamps(x_timestamp):
445
+ time_df = pd.DataFrame()
446
+ time_df['minute'] = x_timestamp.dt.minute
447
+ time_df['hour'] = x_timestamp.dt.hour
448
+ time_df['weekday'] = x_timestamp.dt.weekday
449
+ time_df['day'] = x_timestamp.dt.day
450
+ time_df['month'] = x_timestamp.dt.month
451
+ return time_df
452
+
453
+
454
+ class KronosPredictor:
455
+
456
+ def __init__(self, model, tokenizer, device="cuda:0", max_context=512, clip=5):
457
+ self.tokenizer = tokenizer
458
+ self.model = model
459
+ self.max_context = max_context
460
+ self.clip = clip
461
+ self.price_cols = ['open', 'high', 'low', 'close']
462
+ self.vol_col = 'volume'
463
+ self.amt_vol = 'amount'
464
+ self.time_cols = ['minute', 'hour', 'weekday', 'day', 'month']
465
+ self.device = device
466
+
467
+ self.tokenizer = self.tokenizer.to(self.device)
468
+ self.model = self.model.to(self.device)
469
+
470
+ def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose):
471
+
472
+ x_tensor = torch.from_numpy(np.array(x).astype(np.float32)).to(self.device)
473
+ x_stamp_tensor = torch.from_numpy(np.array(x_stamp).astype(np.float32)).to(self.device)
474
+ y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device)
475
+
476
+ preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len,
477
+ self.clip, T, top_k, top_p, sample_count, verbose)
478
+ preds = preds[:, :, -pred_len:, :]
479
+ return preds
480
+
481
+ def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True):
482
+
483
+ if not isinstance(df, pd.DataFrame):
484
+ raise ValueError("Input must be a pandas DataFrame.")
485
+
486
+ if not all(col in df.columns for col in self.price_cols):
487
+ raise ValueError(f"Price columns {self.price_cols} not found in DataFrame.")
488
+
489
+ df = df.copy()
490
+ if self.vol_col not in df.columns:
491
+ df[self.vol_col] = 0.0 # Fill missing volume with zeros
492
+ df[self.amt_vol] = 0.0 # Fill missing amount with zeros
493
+ if self.amt_vol not in df.columns and self.vol_col in df.columns:
494
+ df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1)
495
+
496
+ if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any():
497
+ raise ValueError("Input DataFrame contains NaN values in price or volume columns.")
498
+
499
+ x_time_df = calc_time_stamps(x_timestamp)
500
+ y_time_df = calc_time_stamps(y_timestamp)
501
+
502
+ x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32)
503
+ x_stamp = x_time_df.values.astype(np.float32)
504
+ y_stamp = y_time_df.values.astype(np.float32)
505
+
506
+ x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
507
+
508
+ x = (x - x_mean) / (x_std + 1e-5)
509
+ x = np.clip(x, -self.clip, self.clip)
510
+
511
+ x = x[np.newaxis, :]
512
+ x_stamp = x_stamp[np.newaxis, :]
513
+ y_stamp = y_stamp[np.newaxis, :]
514
+
515
+ preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose)
516
+
517
+ preds = preds.squeeze(0)
518
+ preds = preds * (x_std[np.newaxis, :] + 1e-5) + x_mean[np.newaxis, :]
519
+
520
+ close_preds = preds[:, :, 3].swapaxes(0, 1)
521
+ volume_preds = preds[:, :, 4].swapaxes(0, 1)
522
+
523
+ close_df = pd.DataFrame(close_preds, columns=[f"pred-{i+1}" for i in range(sample_count)], index=y_timestamp)
524
+ volume_df = pd.DataFrame(volume_preds, columns=[f"pred-{i + 1}" for i in range(sample_count)], index=y_timestamp)
525
+
526
+ return close_df, volume_df
527
+
528
+ def predict_detail(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True):
529
+ """
530
+ Predict with full OHLCV output for all Monte Carlo samples.
531
+
532
+ Returns:
533
+ dict: Dictionary containing DataFrames for each price component:
534
+ - 'open': DataFrame with open price predictions
535
+ - 'high': DataFrame with high price predictions
536
+ - 'low': DataFrame with low price predictions
537
+ - 'close': DataFrame with close price predictions
538
+ - 'volume': DataFrame with volume predictions
539
+ """
540
+ if not isinstance(df, pd.DataFrame):
541
+ raise ValueError("Input must be a pandas DataFrame.")
542
+
543
+ if not all(col in df.columns for col in self.price_cols):
544
+ raise ValueError(f"Price columns {self.price_cols} not found in DataFrame.")
545
+
546
+ df = df.copy()
547
+ if self.vol_col not in df.columns:
548
+ df[self.vol_col] = 0.0
549
+ df[self.amt_vol] = 0.0
550
+ if self.amt_vol not in df.columns and self.vol_col in df.columns:
551
+ df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1)
552
+
553
+ if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any():
554
+ raise ValueError("Input DataFrame contains NaN values in price or volume columns.")
555
+
556
+ x_time_df = calc_time_stamps(x_timestamp)
557
+ y_time_df = calc_time_stamps(y_timestamp)
558
+
559
+ x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32)
560
+ x_stamp = x_time_df.values.astype(np.float32)
561
+ y_stamp = y_time_df.values.astype(np.float32)
562
+
563
+ x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
564
+
565
+ x = (x - x_mean) / (x_std + 1e-5)
566
+ x = np.clip(x, -self.clip, self.clip)
567
+
568
+ x = x[np.newaxis, :]
569
+ x_stamp = x_stamp[np.newaxis, :]
570
+ y_stamp = y_stamp[np.newaxis, :]
571
+
572
+ preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose)
573
+
574
+ preds = preds.squeeze(0)
575
+ preds = preds * (x_std[np.newaxis, :] + 1e-5) + x_mean[np.newaxis, :]
576
+
577
+ # Extract all OHLCV components: [sample_count, pred_len, 6]
578
+ # Columns: open(0), high(1), low(2), close(3), volume(4), amount(5)
579
+ col_names = [f"pred-{i+1}" for i in range(sample_count)]
580
+
581
+ result = {
582
+ 'open': pd.DataFrame(preds[:, :, 0].swapaxes(0, 1), columns=col_names, index=y_timestamp),
583
+ 'high': pd.DataFrame(preds[:, :, 1].swapaxes(0, 1), columns=col_names, index=y_timestamp),
584
+ 'low': pd.DataFrame(preds[:, :, 2].swapaxes(0, 1), columns=col_names, index=y_timestamp),
585
+ 'close': pd.DataFrame(preds[:, :, 3].swapaxes(0, 1), columns=col_names, index=y_timestamp),
586
+ 'volume': pd.DataFrame(preds[:, :, 4].swapaxes(0, 1), columns=col_names, index=y_timestamp),
587
+ }
588
+
589
+ return result
model/module.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from einops import rearrange, reduce
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.autograd import Function
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class DifferentiableEntropyFunction(Function):
11
+ @staticmethod
12
+ def forward(ctx, zq, basis, K, eps):
13
+ zb = (zq + 1) / 2
14
+ zi = ((zb * basis).sum(-1)).to(torch.int64)
15
+ cnt = torch.scatter_reduce(torch.zeros(2 ** K, device=zq.device, dtype=zq.dtype),
16
+ 0,
17
+ zi.flatten(),
18
+ torch.ones_like(zi.flatten()).to(zq.dtype),
19
+ 'sum')
20
+ prob = (cnt + eps) / (cnt + eps).sum()
21
+ H = -(prob * torch.log(prob)).sum()
22
+ ctx.save_for_backward(zq, zi, prob)
23
+ ctx.K = K
24
+ return H
25
+
26
+ @staticmethod
27
+ def backward(ctx, grad_output):
28
+ zq, zi, prob = ctx.saved_tensors
29
+ grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K
30
+ reord_grad = grad_array[zi.flatten()].reshape(zi.shape)
31
+ grad_input = reord_grad.unsqueeze(-1) * zq
32
+ return grad_input, None, None, None, None
33
+
34
+
35
+ def codebook_entropy(zq, basis, K, eps=1e-4):
36
+ return DifferentiableEntropyFunction.apply(zq, basis, K, eps)
37
+
38
+
39
+ class BinarySphericalQuantizer(nn.Module):
40
+ def __init__(self, embed_dim, beta, gamma0, gamma, zeta,
41
+ input_format='bchw',
42
+ soft_entropy=True, group_size=9,
43
+ persample_entropy_compute='analytical',
44
+ cb_entropy_compute='group',
45
+ l2_norm=True,
46
+ inv_temperature=1):
47
+ """
48
+ Paper link: https://arxiv.org/pdf/2406.07548.pdf
49
+ Here we use the official implementation of the BinarySphericalQuantizer.
50
+ """
51
+ super().__init__()
52
+ self.embed_dim = embed_dim
53
+ self.beta = beta # loss weight for commit loss
54
+ self.gamma0 = gamma0 # loss weight for entropy penalty
55
+ self.gamma = gamma # loss weight for entropy penalty
56
+ self.zeta = zeta # loss weight for entire entropy penalty
57
+ self.input_format = input_format
58
+ assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size"
59
+ self.num_groups = self.embed_dim // group_size
60
+ self.group_size = group_size
61
+ assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'"
62
+ assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'"
63
+ self.persample_entropy_compute = persample_entropy_compute
64
+ self.cb_entropy_compute = cb_entropy_compute
65
+ self.l2_norm = l2_norm
66
+ self.inv_temperature = inv_temperature
67
+
68
+ self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1))
69
+ self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1))
70
+
71
+ self.num_dimensions = 2 ** embed_dim
72
+ self.bits_per_index = embed_dim
73
+
74
+ # we only need to keep the codebook portion up to the group size
75
+ # because we approximate the H loss with this subcode
76
+ group_codes = torch.arange(2 ** self.group_size)
77
+ group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:]
78
+ self.register_buffer('group_codebook', group_codebook, persistent=False)
79
+
80
+ self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf
81
+
82
+ def quantize(self, z):
83
+ assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}"
84
+
85
+ zhat = torch.where(z > 0,
86
+ torch.tensor(1, dtype=z.dtype, device=z.device),
87
+ torch.tensor(-1, dtype=z.dtype, device=z.device))
88
+ return z + (zhat - z).detach()
89
+
90
+ def forward(self, z):
91
+ # if self.input_format == 'bchw':
92
+ # z = rearrange(z, 'b c h w -> b h w c')
93
+ zq = self.quantize(z)
94
+
95
+ indices = self.codes_to_indexes(zq.detach())
96
+ group_indices = self.codes_to_group_indexes(zq.detach())
97
+ if not self.training:
98
+ used_codes = torch.unique(indices, return_counts=False)
99
+ else:
100
+ used_codes = None
101
+
102
+ q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
103
+
104
+ if self.soft_entropy:
105
+ persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z)
106
+ entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
107
+ else:
108
+ zb_by_sample = ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32)
109
+ persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample)
110
+ cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim)
111
+ entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
112
+
113
+ zq = zq * q_scale
114
+
115
+ # commit loss
116
+ commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1))
117
+
118
+ # if self.input_format == 'bchw':
119
+ # zq = rearrange(zq, 'b h w c -> b c h w')
120
+
121
+ return (
122
+ zq,
123
+ commit_loss + self.zeta * entropy_penalty / self.inv_temperature,
124
+ {"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices,
125
+ "avg_prob": avg_prob}
126
+ )
127
+
128
+ def soft_entropy_loss(self, z):
129
+ # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size
130
+ # the sub-code is the last group_size bits of the full code
131
+ group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1)
132
+ divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size)
133
+
134
+ # we calculate the distance between the divided_z and the codebook for each subgroup
135
+ distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book)
136
+ prob = (-distance * self.inv_temperature).softmax(dim=-1)
137
+ if self.persample_entropy_compute == 'analytical':
138
+ if self.l2_norm:
139
+ p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature)
140
+ else:
141
+ p = torch.sigmoid(-4 * z * self.inv_temperature)
142
+ prob = torch.stack([p, 1 - p], dim=-1)
143
+ per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
144
+ else:
145
+ per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
146
+
147
+ # macro average of the probability of each subgroup
148
+ avg_prob = reduce(prob, '... g d ->g d', 'mean')
149
+ codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False)
150
+
151
+ # the approximation of the entropy is the sum of the entropy of each subgroup
152
+ return per_sample_entropy, codebook_entropy.sum(), avg_prob
153
+
154
+ def get_hard_per_sample_entropy(self, zb_by_sample):
155
+ probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1]
156
+ persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8)
157
+ persample_entropy = persample_entropy.sum(-1)
158
+ return persample_entropy.mean()
159
+
160
+ def codes_to_indexes(self, zhat):
161
+ """Converts a `code` to an index in the codebook.
162
+ Args:
163
+ zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
164
+ """
165
+ assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}"
166
+ return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64)
167
+
168
+ def codes_to_group_indexes(self, zhat):
169
+ """Converts a `code` to a list of indexes (in groups) in the codebook.
170
+ Args:
171
+ zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
172
+ """
173
+ zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size)
174
+ return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64)
175
+
176
+ def indexes_to_codes(self, indices):
177
+ """Inverse of `indexes_to_codes`."""
178
+ indices = indices.unsqueeze(-1)
179
+ codes_non_centered = torch.remainder(
180
+ torch.floor_divide(indices, self.basis), 2
181
+ )
182
+ return codes_non_centered * 2 - 1
183
+
184
+ def group_indexes_to_codes(self, group_indices):
185
+ """Inverse of `group_indexes_to_codes`."""
186
+ group_indices = group_indices.unsqueeze(-1)
187
+ codes_non_centered = torch.remainder(
188
+ torch.floor_divide(group_indices, self.group_basis), 2
189
+ )
190
+ codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)')
191
+ return codes_non_centered * 2 - 1
192
+
193
+ def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True):
194
+ if normalize:
195
+ probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True)
196
+ else:
197
+ probs = count
198
+ H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim)
199
+ return H
200
+
201
+ def get_group_codebook_entry(self, group_indices):
202
+ z_q = self.group_indexes_to_codes(group_indices)
203
+ q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
204
+ z_q = z_q * q_scale
205
+ if self.input_format == 'bchw':
206
+ h, w = int(z_q.shape[1] ** 0.5)
207
+ assert h * w == z_q.shape[1], 'Invalid sequence length'
208
+ z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
209
+ return z_q
210
+
211
+ def get_codebook_entry(self, indices):
212
+ z_q = self.indexes_to_codes(indices)
213
+ q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
214
+ z_q = z_q * q_scale
215
+ if self.input_format == 'bchw':
216
+ h, w = int(z_q.shape[1] ** 0.5)
217
+ assert h * w == z_q.shape[1], 'Invalid sequence length'
218
+ z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
219
+ return z_q
220
+
221
+
222
+ class BSQuantizer(nn.Module):
223
+
224
+ def __init__(self, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size):
225
+ super().__init__()
226
+ self.codebook_dim = s1_bits + s2_bits
227
+ self.s1_bits = s1_bits
228
+ self.s2_bits = s2_bits
229
+ self.bsq = BinarySphericalQuantizer(self.codebook_dim, beta, gamma0, gamma, zeta, group_size=group_size)
230
+
231
+ def bits_to_indices(self, bits):
232
+ bits = (bits >= 0).to(torch.long)
233
+ indices = 2 ** torch.arange(
234
+ 0,
235
+ bits.shape[-1],
236
+ 1,
237
+ dtype=torch.long,
238
+ device=bits.device,
239
+ )
240
+ return (bits * indices).sum(-1)
241
+
242
+ def forward(self, z, half=False):
243
+ z = F.normalize(z, dim=-1)
244
+ quantized, bsq_loss, metrics = self.bsq(z)
245
+ if half:
246
+ q_pre = quantized[:, :, :self.s1_bits]
247
+ q_post = quantized[:, :, self.s1_bits:]
248
+ z_indices = [self.bits_to_indices(q_pre), self.bits_to_indices(q_post)]
249
+ else:
250
+ z_indices = self.bits_to_indices(quantized)
251
+ return bsq_loss, quantized, z_indices
252
+
253
+
254
+ class RMSNorm(torch.nn.Module):
255
+ def __init__(self, dim: int, eps: float = 1e-5):
256
+ super().__init__()
257
+ self.eps = eps
258
+ self.weight = nn.Parameter(torch.ones(dim))
259
+
260
+ def _norm(self, x):
261
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
262
+
263
+ def forward(self, x):
264
+ output = self._norm(x.float()).type_as(x)
265
+ return output * self.weight
266
+
267
+
268
+ class FeedForward(nn.Module):
269
+ def __init__(self, d_model, ff_dim, ffn_dropout_p=0.0):
270
+ super().__init__()
271
+
272
+ self.w1 = nn.Linear(d_model, ff_dim, bias=False)
273
+ self.w3 = nn.Linear(d_model, ff_dim, bias=False)
274
+ self.w2 = nn.Linear(ff_dim, d_model, bias=False)
275
+ self.ffn_dropout = nn.Dropout(ffn_dropout_p)
276
+
277
+ def forward(self, x):
278
+ return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
279
+
280
+
281
+ class RotaryPositionalEmbedding(nn.Module):
282
+ def __init__(self, dim):
283
+ super().__init__()
284
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
285
+ self.register_buffer("inv_freq", inv_freq)
286
+ self.seq_len_cached = None
287
+ self.cos_cached = None
288
+ self.sin_cached = None
289
+
290
+ def _update_cos_sin_cache(self, x, seq_len):
291
+ if seq_len != self.seq_len_cached:
292
+ self.seq_len_cached = seq_len
293
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
294
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
295
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
296
+ self.cos_cached = emb.cos()[None, None, :, :]
297
+ self.sin_cached = emb.sin()[None, None, :, :]
298
+ return self.cos_cached, self.sin_cached
299
+
300
+ def forward(self, q, k):
301
+ cos, sin = self._update_cos_sin_cache(q, q.shape[-2])
302
+ return (
303
+ (q * cos) + (self._rotate_half(q) * sin),
304
+ (k * cos) + (self._rotate_half(k) * sin),
305
+ )
306
+
307
+ def _rotate_half(self, x):
308
+ x1, x2 = x.chunk(2, dim=-1)
309
+ return torch.cat((-x2, x1), dim=-1)
310
+
311
+
312
+ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, training=True) -> torch.Tensor:
313
+ L, S = query.size(-2), key.size(-2)
314
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
315
+ attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
316
+
317
+ if is_causal:
318
+ assert attn_mask is None
319
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).to(query.device)
320
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
321
+ attn_bias.to(query.dtype)
322
+
323
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
324
+ attn_weight += attn_bias
325
+
326
+ if attn_mask is not None:
327
+ attn_mask_bias = torch.zeros_like(attn_weight)
328
+ if attn_mask.dtype == torch.bool:
329
+ attn_mask_bias.masked_fill_(attn_mask, float("-inf"))
330
+ else:
331
+ attn_mask_bias += attn_mask
332
+ attn_weight += attn_mask_bias
333
+
334
+ attn_weight = torch.softmax(attn_weight, dim=-1)
335
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=training)
336
+ return attn_weight @ value
337
+
338
+
339
+ class MultiHeadAttentionWithRoPE(nn.Module):
340
+ def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0):
341
+ super().__init__()
342
+ self.d_model = d_model
343
+ self.n_heads = n_heads
344
+ self.head_dim = d_model // n_heads
345
+
346
+ self.q_proj = nn.Linear(d_model, d_model)
347
+ self.k_proj = nn.Linear(d_model, d_model)
348
+ self.v_proj = nn.Linear(d_model, d_model)
349
+ self.out_proj = nn.Linear(d_model, d_model)
350
+ self.rotary = RotaryPositionalEmbedding(self.head_dim)
351
+ self.attn_dropout_p = attn_dropout_p
352
+ self.resid_dropout = nn.Dropout(resid_dropout_p)
353
+
354
+ def forward(self, x, key_padding_mask=None):
355
+ batch_size, seq_len, _ = x.shape
356
+
357
+ q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
358
+ k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
359
+ v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
360
+
361
+ q, k = self.rotary(q, k)
362
+
363
+ if key_padding_mask is not None:
364
+ attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len]
365
+ attn_mask = attn_mask.expand(-1, self.n_heads, seq_len, -1) # [batch, n_heads, q_len, k_len]
366
+ else:
367
+ attn_mask = None
368
+
369
+ attn_output = scaled_dot_product_attention(
370
+ q, k, v,
371
+ attn_mask=attn_mask,
372
+ dropout_p=self.attn_dropout_p,
373
+ is_causal=True,
374
+ training=self.training
375
+ )
376
+
377
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
378
+ return self.resid_dropout(self.out_proj(attn_output))
379
+
380
+
381
+ class MultiHeadCrossAttentionWithRoPE(nn.Module):
382
+ def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout=0.0):
383
+ super().__init__()
384
+ self.d_model = d_model
385
+ self.n_heads = n_heads
386
+ self.head_dim = d_model // n_heads
387
+
388
+ self.q_proj = nn.Linear(d_model, d_model)
389
+ self.k_proj = nn.Linear(d_model, d_model)
390
+ self.v_proj = nn.Linear(d_model, d_model)
391
+ self.out_proj = nn.Linear(d_model, d_model)
392
+ self.rotary = RotaryPositionalEmbedding(self.head_dim)
393
+ self.attn_dropout_p = attn_dropout_p
394
+ self.resid_dropout = nn.Dropout(resid_dropout)
395
+
396
+ def forward(self, query, key, value, key_padding_mask=None):
397
+ batch_size, q_len, _ = query.shape
398
+ _, seq_len, _ = key.shape
399
+
400
+ q = self.q_proj(query).view(batch_size, q_len, self.n_heads, self.head_dim).transpose(1, 2)
401
+ k = self.k_proj(key).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
402
+ v = self.v_proj(value).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
403
+
404
+ q, k = self.rotary(q, k)
405
+
406
+ if key_padding_mask is not None:
407
+ attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
408
+ attn_mask = attn_mask.expand(-1, self.n_heads, q_len, -1)
409
+ else:
410
+ attn_mask = None
411
+
412
+ is_causal_flag = self.training
413
+
414
+ attn_output = scaled_dot_product_attention(
415
+ q, k, v,
416
+ attn_mask=attn_mask,
417
+ dropout_p=self.attn_dropout_p,
418
+ is_causal=is_causal_flag,
419
+ training=self.training
420
+ )
421
+
422
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model)
423
+ return self.resid_dropout(self.out_proj(attn_output))
424
+
425
+
426
+ class HierarchicalEmbedding(nn.Module):
427
+ def __init__(self, s1_bits, s2_bits, d_model=256):
428
+ super().__init__()
429
+ self.s1_bits = s1_bits
430
+ self.s2_bits = s2_bits
431
+
432
+ vocab_s1 = 2 ** s1_bits
433
+ vocab_s2 = 2 ** s2_bits
434
+
435
+ self.emb_s1 = nn.Embedding(vocab_s1, d_model)
436
+ self.emb_s2 = nn.Embedding(vocab_s2, d_model)
437
+ self.d_model = d_model
438
+ self.fusion_proj = nn.Linear(d_model * 2, d_model)
439
+
440
+ nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5)
441
+ nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5)
442
+
443
+ def forward(self, token_ids):
444
+ """Inputs:
445
+ token_ids: [batch_size, seq_len] token ID
446
+ Output: [batch_size, seq_len, d_model]
447
+ """
448
+ if isinstance(token_ids, tuple) or isinstance(token_ids, list):
449
+ s1_ids, s2_ids = token_ids
450
+ else:
451
+ s1_ids, s2_ids = self.split_token(token_ids, self.s2_bits)
452
+ s1_emb = self.emb_s1(s1_ids) * math.sqrt(self.d_model)
453
+ s2_emb = self.emb_s2(s2_ids) * math.sqrt(self.d_model)
454
+ return self.fusion_proj(torch.cat([s1_emb, s2_emb], dim=-1))
455
+
456
+
457
+ class DependencyAwareLayer(nn.Module):
458
+ def __init__(self, d_model, n_heads=4, attn_dropout_p=0.0, resid_dropout=0.0):
459
+ super().__init__()
460
+ self.cross_attn = MultiHeadCrossAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout)
461
+ self.norm = RMSNorm(d_model)
462
+
463
+ def forward(self, hidden_states, sibling_embed, key_padding_mask=None):
464
+ """hidden_states: [batch, seq_len, d_model]
465
+ sibling_embed: Embedding from another subtoken
466
+ """
467
+ attn_out = self.cross_attn(
468
+ query=sibling_embed,
469
+ key=hidden_states,
470
+ value=hidden_states,
471
+ key_padding_mask=key_padding_mask
472
+ )
473
+ return self.norm(hidden_states + attn_out)
474
+
475
+
476
+ class TransformerBlock(nn.Module):
477
+ def __init__(self, d_model, n_heads, ff_dim=1024, ffn_dropout_p=0.0, attn_dropout_p=0.0, resid_dropout_p=0.0):
478
+ super().__init__()
479
+ self.norm1 = RMSNorm(d_model)
480
+ self.self_attn = MultiHeadAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout_p)
481
+ self.norm2 = RMSNorm(d_model)
482
+ self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p)
483
+
484
+ def forward(self, x, key_padding_mask=None):
485
+ residual = x
486
+ x = self.norm1(x)
487
+ attn_out = self.self_attn(x, key_padding_mask=key_padding_mask)
488
+ x = residual + attn_out
489
+
490
+ residual = x
491
+ x = self.norm2(x)
492
+ ffn_out = self.ffn(x)
493
+ x = residual + ffn_out
494
+ return x
495
+
496
+
497
+ class DualHead(nn.Module):
498
+ def __init__(self, s1_bits, s2_bits, d_model):
499
+ super().__init__()
500
+ self.vocab_s1 = 2 ** s1_bits
501
+ self.vocab_s2 = 2 ** s2_bits
502
+ self.proj_s1 = nn.Linear(d_model, self.vocab_s1)
503
+ self.proj_s2 = nn.Linear(d_model, self.vocab_s2)
504
+
505
+ def compute_loss(self, s1_logits, s2_logits, s1_targets, s2_targets, padding_mask=None):
506
+ if padding_mask is not None:
507
+ valid_mask = (padding_mask == 0)
508
+ s1_logits = s1_logits[valid_mask]
509
+ s2_logits = s2_logits[valid_mask]
510
+ s1_targets = s1_targets[valid_mask]
511
+ s2_targets = s2_targets[valid_mask]
512
+ ce_s1 = F.cross_entropy(s1_logits, s1_targets)
513
+ ce_s2 = F.cross_entropy(s2_logits, s2_targets)
514
+ else:
515
+ ce_s1 = F.cross_entropy(s1_logits.reshape(-1, self.vocab_s1), s1_targets.reshape(-1))
516
+ ce_s2 = F.cross_entropy(s2_logits.reshape(-1, self.vocab_s2), s2_targets.reshape(-1))
517
+ ce_loss = (ce_s1 + ce_s2) / 2
518
+ return ce_loss, ce_s1, ce_s2
519
+
520
+ def forward(self, x):
521
+ return self.proj_s1(x)
522
+
523
+ def cond_forward(self, x2):
524
+ return self.proj_s2(x2)
525
+
526
+
527
+ class FixedEmbedding(nn.Module):
528
+ def __init__(self, c_in, d_model):
529
+ super(FixedEmbedding, self).__init__()
530
+
531
+ w = torch.zeros(c_in, d_model).float()
532
+ w.require_grad = False
533
+
534
+ position = torch.arange(0, c_in).float().unsqueeze(1)
535
+ div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
536
+
537
+ w[:, 0::2] = torch.sin(position * div_term)
538
+ w[:, 1::2] = torch.cos(position * div_term)
539
+
540
+ self.emb = nn.Embedding(c_in, d_model)
541
+ self.emb.weight = nn.Parameter(w, requires_grad=False)
542
+
543
+ def forward(self, x):
544
+ return self.emb(x).detach()
545
+
546
+
547
+ class TemporalEmbedding(nn.Module):
548
+ def __init__(self, d_model, learn_pe):
549
+ super(TemporalEmbedding, self).__init__()
550
+
551
+ minute_size = 60
552
+ hour_size = 24
553
+ weekday_size = 7
554
+ day_size = 32
555
+ month_size = 13
556
+
557
+ Embed = FixedEmbedding if not learn_pe else nn.Embedding
558
+ self.minute_embed = Embed(minute_size, d_model)
559
+ self.hour_embed = Embed(hour_size, d_model)
560
+ self.weekday_embed = Embed(weekday_size, d_model)
561
+ self.day_embed = Embed(day_size, d_model)
562
+ self.month_embed = Embed(month_size, d_model)
563
+
564
+ def forward(self, x):
565
+ x = x.long()
566
+
567
+ minute_x = self.minute_embed(x[:, :, 0])
568
+ hour_x = self.hour_embed(x[:, :, 1])
569
+ weekday_x = self.weekday_embed(x[:, :, 2])
570
+ day_x = self.day_embed(x[:, :, 3])
571
+ month_x = self.month_embed(x[:, :, 4])
572
+
573
+ return hour_x + weekday_x + day_x + month_x + minute_x
574
+
575
+
576
+
577
+
578
+
579
+
580
+
models/predictor/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
models/predictor/config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_dropout_p": 0.0,
3
+ "d_model": 256,
4
+ "ff_dim": 512,
5
+ "ffn_dropout_p": 0.2,
6
+ "learn_te": true,
7
+ "n_heads": 4,
8
+ "n_layers": 4,
9
+ "resid_dropout_p": 0.2,
10
+ "s1_bits": 10,
11
+ "s2_bits": 10,
12
+ "token_dropout_p": 0.0
13
+ }