Spaces:
Running
Running
fengwm commited on
Commit ·
2a8a0f5
1
Parent(s): 292fb60
更新 README,切换数据源为 AkShare,统一请求字段为 symbol,优化缓存机制,新增性能日志
Browse files- README.md +42 -15
- app.py +50 -15
- data_fetcher.py +103 -38
- predictor.py +63 -14
- requirements.txt +1 -1
README.md
CHANGED
|
@@ -11,12 +11,22 @@ pinned: false
|
|
| 11 |
|
| 12 |
基于[清华大学 Kronos 金融 K 线基础大模型](https://arxiv.org/abs/2508.02739)的 A 股概率预测 REST API。
|
| 13 |
|
| 14 |
-
- **数据源**:
|
| 15 |
-
- **推理方式**:蒙特卡洛
|
| 16 |
- **异步任务**:POST 提交 → 返回 `task_id` → GET 轮询结果
|
| 17 |
|
| 18 |
---
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
## 模型信息
|
| 21 |
|
| 22 |
| 项目 | 值 |
|
|
@@ -38,7 +48,7 @@ pinned: false
|
|
| 38 |
curl -X POST "https://yingfeng64-kronos-api.hf.space/api/v1/predict" \
|
| 39 |
-H "Content-Type: application/json" \
|
| 40 |
-d '{
|
| 41 |
-
"
|
| 42 |
"lookback": 512,
|
| 43 |
"pred_len": 5,
|
| 44 |
"sample_count": 30,
|
|
@@ -66,7 +76,7 @@ import time, requests
|
|
| 66 |
BASE = "https://yingfeng64-kronos-api.hf.space"
|
| 67 |
|
| 68 |
resp = requests.post(f"{BASE}/api/v1/predict", json={
|
| 69 |
-
"
|
| 70 |
"lookback": 512,
|
| 71 |
"pred_len": 5,
|
| 72 |
"sample_count": 30,
|
|
@@ -92,7 +102,7 @@ print(r["result"])
|
|
| 92 |
|
| 93 |
| 字段 | 类型 | 默认值 | 范围 | 说明 |
|
| 94 |
|---|---|---|---|---|
|
| 95 |
-
| `
|
| 96 |
| `lookback` | int | 512 | 20–512 | 回看历史 K 线根数 |
|
| 97 |
| `pred_len` | int | 5 | 1–60 | 预测未来交易日数(建议 ≤ 30) |
|
| 98 |
| `sample_count` | int | 30 | 1–100 | MC 蒙特卡洛采样次数 |
|
|
@@ -103,7 +113,7 @@ print(r["result"])
|
|
| 103 |
|
| 104 |
```json
|
| 105 |
{
|
| 106 |
-
"
|
| 107 |
"base_date": "2026-03-13",
|
| 108 |
"pred_len": 5,
|
| 109 |
"confidence": 95,
|
|
@@ -176,9 +186,9 @@ curl -X POST "https://yingfeng64-kronos-api.hf.space/api/v1/predict/batch" \
|
|
| 176 |
-H "Content-Type: application/json" \
|
| 177 |
-d '{
|
| 178 |
"requests": [
|
| 179 |
-
{"
|
| 180 |
-
{"
|
| 181 |
-
{"
|
| 182 |
]
|
| 183 |
}'
|
| 184 |
```
|
|
@@ -227,11 +237,11 @@ curl -X POST "https://yingfeng64-kronos-api.hf.space/api/v1/predict/batch" \
|
|
| 227 |
|
| 228 |
| 参数 | 说明 |
|
| 229 |
|---|---|
|
| 230 |
-
| `
|
| 231 |
|
| 232 |
```bash
|
| 233 |
# 查某只股票
|
| 234 |
-
curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache?
|
| 235 |
|
| 236 |
# 查全部
|
| 237 |
curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache"
|
|
@@ -242,7 +252,7 @@ curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache"
|
|
| 242 |
"count": 1,
|
| 243 |
"entries": [
|
| 244 |
{
|
| 245 |
-
"
|
| 246 |
"lookback": 512,
|
| 247 |
"pred_len": 5,
|
| 248 |
"sample_count": 30,
|
|
@@ -271,8 +281,8 @@ curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache"
|
|
| 271 |
| 字段 | 含义 |
|
| 272 |
|---|---|
|
| 273 |
| `base_date` | 预测所基于的最后一个历史 K 线日期 |
|
| 274 |
-
| `direction.signal` | `"bullish"` / `"bearish"`,
|
| 275 |
-
| `direction.probability` | 看涨概率(0–1) |
|
| 276 |
| `trading_low` | 该日预测最低价的 q2.5 分位数(95% 交易区间下沿) |
|
| 277 |
| `trading_high` | 该日预测最高价的 q97.5 分位数(95% 交易区间上沿) |
|
| 278 |
| `uncertainty` | `(trading_high − trading_low) / last_close`,无量纲不确定性 |
|
|
@@ -284,7 +294,7 @@ curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache"
|
|
| 284 |
|
| 285 |
## 缓存机制
|
| 286 |
|
| 287 |
-
缓存 key 由 `(
|
| 288 |
|
| 289 |
| 请求时间(CST) | 缓存过期时间 |
|
| 290 |
|---|---|
|
|
@@ -297,6 +307,23 @@ curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache"
|
|
| 297 |
|
| 298 |
---
|
| 299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
## 性能参考
|
| 301 |
|
| 302 |
| 环境 | 单次采样 | 30 次 MC 总耗时 |
|
|
|
|
| 11 |
|
| 12 |
基于[清华大学 Kronos 金融 K 线基础大模型](https://arxiv.org/abs/2508.02739)的 A 股概率预测 REST API。
|
| 13 |
|
| 14 |
+
- **数据源**:AkShare 东方财富 A 股日线,前复权(qfq)
|
| 15 |
+
- **推理方式**:蒙特卡洛分批采样,输出预测方向、置信度及 95% 交易区间
|
| 16 |
- **异步任务**:POST 提交 → 返回 `task_id` → GET 轮询结果
|
| 17 |
|
| 18 |
---
|
| 19 |
|
| 20 |
+
## 更新说明(2026-03-16)
|
| 21 |
+
|
| 22 |
+
- 请求/响应主字段由 `ts_code` 统一为 `symbol`
|
| 23 |
+
- 数据源切换为 AkShare 东财日线接口(`stock_zh_a_hist`,前复权 `qfq`)
|
| 24 |
+
- 方向概率 `direction.probability` 定义为“预测区间内看涨概率(0–1)”
|
| 25 |
+
- 推理路径改为分批采样(可通过 `MC_BATCH_SIZE` 调整批大小)
|
| 26 |
+
- 新增阶段耗时日志(`fetch/calendar/infer/build/cache/total`)
|
| 27 |
+
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
## 模型信息
|
| 31 |
|
| 32 |
| 项目 | 值 |
|
|
|
|
| 48 |
curl -X POST "https://yingfeng64-kronos-api.hf.space/api/v1/predict" \
|
| 49 |
-H "Content-Type: application/json" \
|
| 50 |
-d '{
|
| 51 |
+
"symbol": "000063.SZ",
|
| 52 |
"lookback": 512,
|
| 53 |
"pred_len": 5,
|
| 54 |
"sample_count": 30,
|
|
|
|
| 76 |
BASE = "https://yingfeng64-kronos-api.hf.space"
|
| 77 |
|
| 78 |
resp = requests.post(f"{BASE}/api/v1/predict", json={
|
| 79 |
+
"symbol": "000063.SZ",
|
| 80 |
"lookback": 512,
|
| 81 |
"pred_len": 5,
|
| 82 |
"sample_count": 30,
|
|
|
|
| 102 |
|
| 103 |
| 字段 | 类型 | 默认值 | 范围 | 说明 |
|
| 104 |
|---|---|---|---|---|
|
| 105 |
+
| `symbol` | string | — | — | A 股代码,支持 `"603777"` 或 `"000063.SZ"` |
|
| 106 |
| `lookback` | int | 512 | 20–512 | 回看历史 K 线根数 |
|
| 107 |
| `pred_len` | int | 5 | 1–60 | 预测未来交易日数(建议 ≤ 30) |
|
| 108 |
| `sample_count` | int | 30 | 1–100 | MC 蒙特卡洛采样次数 |
|
|
|
|
| 113 |
|
| 114 |
```json
|
| 115 |
{
|
| 116 |
+
"symbol": "000063.SZ",
|
| 117 |
"base_date": "2026-03-13",
|
| 118 |
"pred_len": 5,
|
| 119 |
"confidence": 95,
|
|
|
|
| 186 |
-H "Content-Type: application/json" \
|
| 187 |
-d '{
|
| 188 |
"requests": [
|
| 189 |
+
{"symbol": "000063", "pred_len": 5, "sample_count": 30},
|
| 190 |
+
{"symbol": "600900", "pred_len": 5, "sample_count": 30},
|
| 191 |
+
{"symbol": "000001", "pred_len": 5, "sample_count": 30}
|
| 192 |
]
|
| 193 |
}'
|
| 194 |
```
|
|
|
|
| 237 |
|
| 238 |
| 参数 | 说明 |
|
| 239 |
|---|---|
|
| 240 |
+
| `symbol`(可选) | 只返回该股票的缓存,不传则返回全部 |
|
| 241 |
|
| 242 |
```bash
|
| 243 |
# 查某只股票
|
| 244 |
+
curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache?symbol=000063"
|
| 245 |
|
| 246 |
# 查全部
|
| 247 |
curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache"
|
|
|
|
| 252 |
"count": 1,
|
| 253 |
"entries": [
|
| 254 |
{
|
| 255 |
+
"symbol": "000063",
|
| 256 |
"lookback": 512,
|
| 257 |
"pred_len": 5,
|
| 258 |
"sample_count": 30,
|
|
|
|
| 281 |
| 字段 | 含义 |
|
| 282 |
|---|---|
|
| 283 |
| `base_date` | 预测所基于的最后一个历史 K 线日期 |
|
| 284 |
+
| `direction.signal` | `"bullish"` / `"bearish"`,由 `direction.probability >= 0.5` 决定 |
|
| 285 |
+
| `direction.probability` | 预测区间内看涨概率(0–1) |
|
| 286 |
| `trading_low` | 该日预测最低价的 q2.5 分位数(95% 交易区间下沿) |
|
| 287 |
| `trading_high` | 该日预测最高价的 q97.5 分位数(95% 交易区间上沿) |
|
| 288 |
| `uncertainty` | `(trading_high − trading_low) / last_close`,无量纲不确定性 |
|
|
|
|
| 294 |
|
| 295 |
## 缓存机制
|
| 296 |
|
| 297 |
+
缓存 key 由 `(symbol, lookback, pred_len, sample_count, mode, include_volume)` 六元组构成,失效时机为下一个 A 股交易日收盘(15:00 CST)。
|
| 298 |
|
| 299 |
| 请求时间(CST) | 缓存过期时间 |
|
| 300 |
|---|---|
|
|
|
|
| 307 |
|
| 308 |
---
|
| 309 |
|
| 310 |
+
## 运行配置
|
| 311 |
+
|
| 312 |
+
| 环境变量 | 默认值 | 说明 |
|
| 313 |
+
|---|---|---|
|
| 314 |
+
| `KRONOS_DIR` | `/app/Kronos` | Kronos 源码目录 |
|
| 315 |
+
| `MC_BATCH_SIZE` | `8` | 蒙特卡洛分批采样大小(越大通常越快,但占用显存/内存更高) |
|
| 316 |
+
|
| 317 |
+
---
|
| 318 |
+
|
| 319 |
+
## 可观测性
|
| 320 |
+
|
| 321 |
+
服务会在 `INFO` 日志输出预测阶段耗时,示例:
|
| 322 |
+
|
| 323 |
+
```text
|
| 324 |
+
Task <task_id> timing symbol=300065.SZ fetch=...ms calendar=...ms infer=...ms build=...ms cache=...ms total=...ms
|
| 325 |
+
```
|
| 326 |
+
|
| 327 |
## 性能参考
|
| 328 |
|
| 329 |
| 环境 | 单次采样 | 30 次 MC 总耗时 |
|
app.py
CHANGED
|
@@ -15,6 +15,7 @@ import uuid
|
|
| 15 |
from concurrent.futures import ThreadPoolExecutor
|
| 16 |
from contextlib import asynccontextmanager
|
| 17 |
from datetime import datetime, time, timedelta, timezone
|
|
|
|
| 18 |
from typing import Literal, List
|
| 19 |
|
| 20 |
import pandas as pd
|
|
@@ -38,8 +39,8 @@ def _next_cache_expiry() -> datetime:
|
|
| 38 |
Return the UTC datetime of the NEXT A-share market close (15:00 CST on a
|
| 39 |
weekday), which is when new candle data becomes available and the cache
|
| 40 |
should be invalidated.
|
| 41 |
-
Chinese public holidays are intentionally ignored: on those days
|
| 42 |
-
|
| 43 |
"""
|
| 44 |
now_cst = datetime.now(_CST)
|
| 45 |
today_close = now_cst.replace(hour=15, minute=0, second=0, microsecond=0)
|
|
@@ -58,13 +59,13 @@ def _next_cache_expiry() -> datetime:
|
|
| 58 |
|
| 59 |
|
| 60 |
# ── Result cache ──────────────────────────────────────────────────────────────
|
| 61 |
-
# key : (
|
| 62 |
# value : {"result": dict, "expires_at": datetime(UTC), "cached_at": datetime(UTC)}
|
| 63 |
_cache: dict[tuple, dict] = {}
|
| 64 |
|
| 65 |
|
| 66 |
def _cache_key(req: "PredictRequest") -> tuple:
|
| 67 |
-
return (req.
|
| 68 |
req.sample_count, req.mode, req.include_volume)
|
| 69 |
|
| 70 |
|
|
@@ -84,7 +85,7 @@ def _set_cache(req: "PredictRequest", result: dict) -> None:
|
|
| 84 |
}
|
| 85 |
logger.info(
|
| 86 |
"Cached %s, expires at %s CST",
|
| 87 |
-
req.
|
| 88 |
_cache[_cache_key(req)]["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M"),
|
| 89 |
)
|
| 90 |
|
|
@@ -124,7 +125,11 @@ app.add_middleware(
|
|
| 124 |
|
| 125 |
# ── Request / Response schemas ────────────────────────────────────────────────
|
| 126 |
class PredictRequest(BaseModel):
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
lookback: int = Field(
|
| 129 |
default=512,
|
| 130 |
ge=20,
|
|
@@ -180,7 +185,7 @@ def _build_response(req: PredictRequest, base_date: str, pred_mean, ci,
|
|
| 180 |
bands.append(band)
|
| 181 |
|
| 182 |
result: dict = {
|
| 183 |
-
"
|
| 184 |
"base_date": base_date,
|
| 185 |
"pred_len": req.pred_len,
|
| 186 |
"confidence": 95,
|
|
@@ -214,12 +219,18 @@ def _build_response(req: PredictRequest, base_date: str, pred_mean, ci,
|
|
| 214 |
|
| 215 |
# ── Background task ───────────────────────────────────────────────────────────
|
| 216 |
def _run_prediction(task_id: str, req: PredictRequest) -> None:
|
|
|
|
| 217 |
try:
|
| 218 |
# ── Cache check ───────────────────────────────────────────────────────
|
| 219 |
cache_entry = _get_cached(req)
|
| 220 |
if cache_entry is not None:
|
| 221 |
-
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
_tasks[task_id] = {
|
| 224 |
"status": "done",
|
| 225 |
"result": {**cache_entry["result"], "cached": True,
|
|
@@ -229,26 +240,37 @@ def _run_prediction(task_id: str, req: PredictRequest) -> None:
|
|
| 229 |
return
|
| 230 |
|
| 231 |
# ── Full inference ────────────────────────────────────────────────────
|
|
|
|
| 232 |
x_df, x_timestamp, last_trade_date = data_fetcher.fetch_stock_data(
|
| 233 |
-
req.
|
| 234 |
)
|
|
|
|
|
|
|
|
|
|
| 235 |
y_timestamp = data_fetcher.get_future_trading_dates(last_trade_date, req.pred_len)
|
|
|
|
| 236 |
|
|
|
|
| 237 |
pred_mean, ci, trading_low, trading_high, direction_prob, last_close = (
|
| 238 |
pred_module.run_mc_prediction(
|
| 239 |
x_df, x_timestamp, y_timestamp, req.pred_len, req.sample_count
|
| 240 |
)
|
| 241 |
)
|
|
|
|
| 242 |
|
|
|
|
| 243 |
base_date = str(pd.to_datetime(last_trade_date, format="%Y%m%d").date())
|
| 244 |
result = _build_response(
|
| 245 |
req, base_date, pred_mean, ci,
|
| 246 |
trading_low, trading_high, direction_prob, last_close, y_timestamp,
|
| 247 |
)
|
|
|
|
| 248 |
|
| 249 |
# ── Store in cache ────────────────────────────────────────────────────
|
|
|
|
| 250 |
_set_cache(req, result)
|
| 251 |
cache_entry = _cache[_cache_key(req)]
|
|
|
|
| 252 |
|
| 253 |
_tasks[task_id] = {
|
| 254 |
"status": "done",
|
|
@@ -256,8 +278,21 @@ def _run_prediction(task_id: str, req: PredictRequest) -> None:
|
|
| 256 |
"cache_expires_at": cache_entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M:%S %Z")},
|
| 257 |
"error": None,
|
| 258 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
except Exception as exc:
|
| 260 |
-
|
|
|
|
| 261 |
_tasks[task_id] = {"status": "failed", "result": None, "error": str(exc)}
|
| 262 |
|
| 263 |
|
|
@@ -297,22 +332,22 @@ async def get_predict_result(task_id: str):
|
|
| 297 |
|
| 298 |
|
| 299 |
@app.get("/api/v1/cache", summary="查看缓存状态")
|
| 300 |
-
async def get_cache(
|
| 301 |
"""
|
| 302 |
列出有效的缓存条目及其过期时间。
|
| 303 |
|
| 304 |
- 不传参数:返回全部
|
| 305 |
-
- `?
|
| 306 |
"""
|
| 307 |
now_utc = datetime.now(timezone.utc)
|
| 308 |
entries = []
|
| 309 |
for key, entry in _cache.items():
|
| 310 |
-
if
|
| 311 |
continue
|
| 312 |
remaining = (entry["expires_at"] - now_utc).total_seconds()
|
| 313 |
if remaining > 0:
|
| 314 |
entries.append({
|
| 315 |
-
"
|
| 316 |
"lookback": key[1],
|
| 317 |
"pred_len": key[2],
|
| 318 |
"sample_count": key[3],
|
|
|
|
| 15 |
from concurrent.futures import ThreadPoolExecutor
|
| 16 |
from contextlib import asynccontextmanager
|
| 17 |
from datetime import datetime, time, timedelta, timezone
|
| 18 |
+
from time import perf_counter
|
| 19 |
from typing import Literal, List
|
| 20 |
|
| 21 |
import pandas as pd
|
|
|
|
| 39 |
Return the UTC datetime of the NEXT A-share market close (15:00 CST on a
|
| 40 |
weekday), which is when new candle data becomes available and the cache
|
| 41 |
should be invalidated.
|
| 42 |
+
Chinese public holidays are intentionally ignored: on those days market
|
| 43 |
+
data does not advance, so a cache hit is harmless.
|
| 44 |
"""
|
| 45 |
now_cst = datetime.now(_CST)
|
| 46 |
today_close = now_cst.replace(hour=15, minute=0, second=0, microsecond=0)
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
# ── Result cache ──────────────────────────────────────────────────────────────
|
| 62 |
+
# key : (symbol, lookback, pred_len, sample_count, mode, include_volume)
|
| 63 |
# value : {"result": dict, "expires_at": datetime(UTC), "cached_at": datetime(UTC)}
|
| 64 |
_cache: dict[tuple, dict] = {}
|
| 65 |
|
| 66 |
|
| 67 |
def _cache_key(req: "PredictRequest") -> tuple:
|
| 68 |
+
return (req.symbol, req.lookback, req.pred_len,
|
| 69 |
req.sample_count, req.mode, req.include_volume)
|
| 70 |
|
| 71 |
|
|
|
|
| 85 |
}
|
| 86 |
logger.info(
|
| 87 |
"Cached %s, expires at %s CST",
|
| 88 |
+
req.symbol,
|
| 89 |
_cache[_cache_key(req)]["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M"),
|
| 90 |
)
|
| 91 |
|
|
|
|
| 125 |
|
| 126 |
# ── Request / Response schemas ────────────────────────────────────────────────
|
| 127 |
class PredictRequest(BaseModel):
|
| 128 |
+
symbol: str = Field(
|
| 129 |
+
...,
|
| 130 |
+
examples=["603777", "600900.SH"],
|
| 131 |
+
description="A 股代码;支持 6 位代码或带市场后缀(如 600900.SH)",
|
| 132 |
+
)
|
| 133 |
lookback: int = Field(
|
| 134 |
default=512,
|
| 135 |
ge=20,
|
|
|
|
| 185 |
bands.append(band)
|
| 186 |
|
| 187 |
result: dict = {
|
| 188 |
+
"symbol": req.symbol,
|
| 189 |
"base_date": base_date,
|
| 190 |
"pred_len": req.pred_len,
|
| 191 |
"confidence": 95,
|
|
|
|
| 219 |
|
| 220 |
# ── Background task ───────────────────────────────────────────────────────────
|
| 221 |
def _run_prediction(task_id: str, req: PredictRequest) -> None:
|
| 222 |
+
t_total_start = perf_counter()
|
| 223 |
try:
|
| 224 |
# ── Cache check ───────────────────────────────────────────────────────
|
| 225 |
cache_entry = _get_cached(req)
|
| 226 |
if cache_entry is not None:
|
| 227 |
+
total_ms = (perf_counter() - t_total_start) * 1000
|
| 228 |
+
logger.info(
|
| 229 |
+
"Cache hit for %s (expires %s CST, total=%.1fms)",
|
| 230 |
+
req.symbol,
|
| 231 |
+
cache_entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M"),
|
| 232 |
+
total_ms,
|
| 233 |
+
)
|
| 234 |
_tasks[task_id] = {
|
| 235 |
"status": "done",
|
| 236 |
"result": {**cache_entry["result"], "cached": True,
|
|
|
|
| 240 |
return
|
| 241 |
|
| 242 |
# ── Full inference ────────────────────────────────────────────────────
|
| 243 |
+
t_fetch_start = perf_counter()
|
| 244 |
x_df, x_timestamp, last_trade_date = data_fetcher.fetch_stock_data(
|
| 245 |
+
req.symbol, req.lookback
|
| 246 |
)
|
| 247 |
+
fetch_ms = (perf_counter() - t_fetch_start) * 1000
|
| 248 |
+
|
| 249 |
+
t_calendar_start = perf_counter()
|
| 250 |
y_timestamp = data_fetcher.get_future_trading_dates(last_trade_date, req.pred_len)
|
| 251 |
+
calendar_ms = (perf_counter() - t_calendar_start) * 1000
|
| 252 |
|
| 253 |
+
t_infer_start = perf_counter()
|
| 254 |
pred_mean, ci, trading_low, trading_high, direction_prob, last_close = (
|
| 255 |
pred_module.run_mc_prediction(
|
| 256 |
x_df, x_timestamp, y_timestamp, req.pred_len, req.sample_count
|
| 257 |
)
|
| 258 |
)
|
| 259 |
+
infer_ms = (perf_counter() - t_infer_start) * 1000
|
| 260 |
|
| 261 |
+
t_build_start = perf_counter()
|
| 262 |
base_date = str(pd.to_datetime(last_trade_date, format="%Y%m%d").date())
|
| 263 |
result = _build_response(
|
| 264 |
req, base_date, pred_mean, ci,
|
| 265 |
trading_low, trading_high, direction_prob, last_close, y_timestamp,
|
| 266 |
)
|
| 267 |
+
build_ms = (perf_counter() - t_build_start) * 1000
|
| 268 |
|
| 269 |
# ── Store in cache ────────────────────────────────────────────────────
|
| 270 |
+
t_cache_start = perf_counter()
|
| 271 |
_set_cache(req, result)
|
| 272 |
cache_entry = _cache[_cache_key(req)]
|
| 273 |
+
cache_ms = (perf_counter() - t_cache_start) * 1000
|
| 274 |
|
| 275 |
_tasks[task_id] = {
|
| 276 |
"status": "done",
|
|
|
|
| 278 |
"cache_expires_at": cache_entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M:%S %Z")},
|
| 279 |
"error": None,
|
| 280 |
}
|
| 281 |
+
total_ms = (perf_counter() - t_total_start) * 1000
|
| 282 |
+
logger.info(
|
| 283 |
+
"Task %s timing symbol=%s fetch=%.1fms calendar=%.1fms infer=%.1fms build=%.1fms cache=%.1fms total=%.1fms",
|
| 284 |
+
task_id,
|
| 285 |
+
req.symbol,
|
| 286 |
+
fetch_ms,
|
| 287 |
+
calendar_ms,
|
| 288 |
+
infer_ms,
|
| 289 |
+
build_ms,
|
| 290 |
+
cache_ms,
|
| 291 |
+
total_ms,
|
| 292 |
+
)
|
| 293 |
except Exception as exc:
|
| 294 |
+
total_ms = (perf_counter() - t_total_start) * 1000
|
| 295 |
+
logger.exception("Task %s failed after %.1fms", task_id, total_ms)
|
| 296 |
_tasks[task_id] = {"status": "failed", "result": None, "error": str(exc)}
|
| 297 |
|
| 298 |
|
|
|
|
| 332 |
|
| 333 |
|
| 334 |
@app.get("/api/v1/cache", summary="查看缓存状态")
|
| 335 |
+
async def get_cache(symbol: str | None = None):
|
| 336 |
"""
|
| 337 |
列出有效的缓存条目及其过期时间。
|
| 338 |
|
| 339 |
- 不传参数:返回全部
|
| 340 |
+
- `?symbol=000063.SZ`:只返回该股票的所有参数组合
|
| 341 |
"""
|
| 342 |
now_utc = datetime.now(timezone.utc)
|
| 343 |
entries = []
|
| 344 |
for key, entry in _cache.items():
|
| 345 |
+
if symbol and key[0] != symbol:
|
| 346 |
continue
|
| 347 |
remaining = (entry["expires_at"] - now_utc).total_seconds()
|
| 348 |
if remaining > 0:
|
| 349 |
entries.append({
|
| 350 |
+
"symbol": key[0],
|
| 351 |
"lookback": key[1],
|
| 352 |
"pred_len": key[2],
|
| 353 |
"sample_count": key[3],
|
data_fetcher.py
CHANGED
|
@@ -1,19 +1,66 @@
|
|
| 1 |
-
import os
|
| 2 |
from datetime import datetime, timedelta
|
|
|
|
| 3 |
|
|
|
|
| 4 |
import pandas as pd
|
| 5 |
-
import tushare as ts
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
)
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
def fetch_stock_data(
|
| 16 |
-
|
| 17 |
) -> tuple[pd.DataFrame, pd.Series, str]:
|
| 18 |
"""
|
| 19 |
Returns:
|
|
@@ -21,31 +68,51 @@ def fetch_stock_data(
|
|
| 21 |
x_timestamp : pd.Series[datetime], aligned to x_df
|
| 22 |
last_trade_date: str "YYYYMMDD", the most recent bar date
|
| 23 |
"""
|
|
|
|
| 24 |
end_date = datetime.today().strftime("%Y%m%d")
|
| 25 |
-
#
|
| 26 |
-
start_date = (datetime.today() - timedelta(days=lookback *
|
| 27 |
|
| 28 |
-
df =
|
| 29 |
-
|
| 30 |
-
|
| 31 |
start_date=start_date,
|
| 32 |
end_date=end_date,
|
| 33 |
-
|
| 34 |
)
|
| 35 |
|
| 36 |
if df is None or df.empty:
|
| 37 |
-
raise ValueError(f"No data returned for
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
df = df.sort_values("trade_date").reset_index(drop=True)
|
| 40 |
-
df = df
|
| 41 |
-
df["timestamps"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
|
| 42 |
|
| 43 |
# Keep the most recent `lookback` bars
|
| 44 |
df = df.tail(lookback).reset_index(drop=True)
|
| 45 |
|
| 46 |
x_df = df[["open", "high", "low", "close", "volume", "amount"]].copy()
|
| 47 |
x_timestamp = df["timestamps"].copy()
|
| 48 |
-
last_trade_date = df["trade_date"].iloc[-1]
|
| 49 |
|
| 50 |
return x_df, x_timestamp, last_trade_date
|
| 51 |
|
|
@@ -56,22 +123,20 @@ def get_future_trading_dates(last_trade_date: str, pred_len: int) -> pd.Series:
|
|
| 56 |
follow `last_trade_date` (format: YYYYMMDD).
|
| 57 |
"""
|
| 58 |
last_dt = datetime.strptime(last_trade_date, "%Y%m%d")
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
return pd.Series(dates)
|
|
|
|
|
|
|
| 1 |
from datetime import datetime, timedelta
|
| 2 |
+
import threading
|
| 3 |
|
| 4 |
+
import akshare as ak
|
| 5 |
import pandas as pd
|
|
|
|
| 6 |
|
| 7 |
+
_TRADE_CALENDAR_CACHE: pd.DatetimeIndex | None = None
|
| 8 |
+
_TRADE_CALENDAR_CACHED_AT: datetime | None = None
|
| 9 |
+
_TRADE_CALENDAR_CACHE_TTL = timedelta(hours=12)
|
| 10 |
+
_TRADE_CALENDAR_LOCK = threading.Lock()
|
| 11 |
|
| 12 |
+
|
| 13 |
+
def _normalize_symbol(raw_symbol: str) -> str:
|
| 14 |
+
"""
|
| 15 |
+
Convert user input into the 6-digit stock code expected by
|
| 16 |
+
`ak.stock_zh_a_hist`.
|
| 17 |
+
|
| 18 |
+
Accepted examples:
|
| 19 |
+
- "603777"
|
| 20 |
+
- "600900.SH"
|
| 21 |
+
- "000063.SZ"
|
| 22 |
+
"""
|
| 23 |
+
symbol = raw_symbol.strip().upper()
|
| 24 |
+
if "." in symbol:
|
| 25 |
+
symbol = symbol.split(".", 1)[0]
|
| 26 |
+
if len(symbol) != 6 or not symbol.isdigit():
|
| 27 |
+
raise ValueError(
|
| 28 |
+
f"Invalid stock code {raw_symbol!r}; expected 6 digits like '603777' "
|
| 29 |
+
"or Tushare-style code like '600900.SH'."
|
| 30 |
+
)
|
| 31 |
+
return symbol
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _get_trade_calendar_cached() -> pd.DatetimeIndex:
|
| 35 |
+
"""
|
| 36 |
+
Fetch and cache exchange trading dates in-process to avoid repeated
|
| 37 |
+
network calls on each request.
|
| 38 |
+
"""
|
| 39 |
+
global _TRADE_CALENDAR_CACHE, _TRADE_CALENDAR_CACHED_AT
|
| 40 |
+
|
| 41 |
+
now = datetime.now()
|
| 42 |
+
with _TRADE_CALENDAR_LOCK:
|
| 43 |
+
if (
|
| 44 |
+
_TRADE_CALENDAR_CACHE is not None
|
| 45 |
+
and _TRADE_CALENDAR_CACHED_AT is not None
|
| 46 |
+
and (now - _TRADE_CALENDAR_CACHED_AT) < _TRADE_CALENDAR_CACHE_TTL
|
| 47 |
+
):
|
| 48 |
+
return _TRADE_CALENDAR_CACHE
|
| 49 |
+
|
| 50 |
+
cal = ak.tool_trade_date_hist_sina()
|
| 51 |
+
cal_col = "trade_date" if "trade_date" in cal.columns else "日期"
|
| 52 |
+
all_dates = pd.to_datetime(cal[cal_col]).sort_values().drop_duplicates()
|
| 53 |
+
cached = pd.DatetimeIndex(all_dates)
|
| 54 |
+
|
| 55 |
+
with _TRADE_CALENDAR_LOCK:
|
| 56 |
+
_TRADE_CALENDAR_CACHE = cached
|
| 57 |
+
_TRADE_CALENDAR_CACHED_AT = now
|
| 58 |
+
|
| 59 |
+
return cached
|
| 60 |
|
| 61 |
|
| 62 |
def fetch_stock_data(
|
| 63 |
+
symbol: str, lookback: int
|
| 64 |
) -> tuple[pd.DataFrame, pd.Series, str]:
|
| 65 |
"""
|
| 66 |
Returns:
|
|
|
|
| 68 |
x_timestamp : pd.Series[datetime], aligned to x_df
|
| 69 |
last_trade_date: str "YYYYMMDD", the most recent bar date
|
| 70 |
"""
|
| 71 |
+
normalized_symbol = _normalize_symbol(symbol)
|
| 72 |
end_date = datetime.today().strftime("%Y%m%d")
|
| 73 |
+
# 4x buffer to account for weekends/long holidays.
|
| 74 |
+
start_date = (datetime.today() - timedelta(days=lookback * 4)).strftime("%Y%m%d")
|
| 75 |
|
| 76 |
+
df = ak.stock_zh_a_hist(
|
| 77 |
+
symbol=normalized_symbol,
|
| 78 |
+
period="daily",
|
| 79 |
start_date=start_date,
|
| 80 |
end_date=end_date,
|
| 81 |
+
adjust="qfq",
|
| 82 |
)
|
| 83 |
|
| 84 |
if df is None or df.empty:
|
| 85 |
+
raise ValueError(f"No data returned for symbol={symbol!r}")
|
| 86 |
+
|
| 87 |
+
df = df.rename(
|
| 88 |
+
columns={
|
| 89 |
+
"日期": "trade_date",
|
| 90 |
+
"开盘": "open",
|
| 91 |
+
"最高": "high",
|
| 92 |
+
"最低": "low",
|
| 93 |
+
"收盘": "close",
|
| 94 |
+
"成交量": "volume",
|
| 95 |
+
"成交额": "amount",
|
| 96 |
+
}
|
| 97 |
+
)
|
| 98 |
+
required_cols = ["trade_date", "open", "high", "low", "close", "volume", "amount"]
|
| 99 |
+
missing = [c for c in required_cols if c not in df.columns]
|
| 100 |
+
if missing:
|
| 101 |
+
raise ValueError(f"AkShare response missing columns: {missing}")
|
| 102 |
+
|
| 103 |
+
df["trade_date"] = pd.to_datetime(df["trade_date"])
|
| 104 |
+
for col in ["open", "high", "low", "close", "volume", "amount"]:
|
| 105 |
+
df[col] = pd.to_numeric(df[col], errors="coerce")
|
| 106 |
+
df = df.dropna(subset=["trade_date", "open", "high", "low", "close", "volume", "amount"])
|
| 107 |
df = df.sort_values("trade_date").reset_index(drop=True)
|
| 108 |
+
df["timestamps"] = df["trade_date"]
|
|
|
|
| 109 |
|
| 110 |
# Keep the most recent `lookback` bars
|
| 111 |
df = df.tail(lookback).reset_index(drop=True)
|
| 112 |
|
| 113 |
x_df = df[["open", "high", "low", "close", "volume", "amount"]].copy()
|
| 114 |
x_timestamp = df["timestamps"].copy()
|
| 115 |
+
last_trade_date = df["trade_date"].iloc[-1].strftime("%Y%m%d")
|
| 116 |
|
| 117 |
return x_df, x_timestamp, last_trade_date
|
| 118 |
|
|
|
|
| 123 |
follow `last_trade_date` (format: YYYYMMDD).
|
| 124 |
"""
|
| 125 |
last_dt = datetime.strptime(last_trade_date, "%Y%m%d")
|
| 126 |
+
dates: list[pd.Timestamp] = []
|
| 127 |
+
|
| 128 |
+
# Prefer real exchange trade dates from AkShare.
|
| 129 |
+
try:
|
| 130 |
+
all_dates = _get_trade_calendar_cached()
|
| 131 |
+
dates.extend([d for d in all_dates if d > pd.Timestamp(last_dt)][:pred_len])
|
| 132 |
+
except Exception:
|
| 133 |
+
# If calendar fetch fails, fall back to weekday-based dates.
|
| 134 |
+
pass
|
| 135 |
+
|
| 136 |
+
candidate = last_dt + timedelta(days=1)
|
| 137 |
+
while len(dates) < pred_len:
|
| 138 |
+
if candidate.weekday() < 5:
|
| 139 |
+
dates.append(pd.Timestamp(candidate))
|
| 140 |
+
candidate += timedelta(days=1)
|
| 141 |
+
|
| 142 |
+
return pd.Series(pd.DatetimeIndex(dates[:pred_len]))
|
|
|
|
|
|
predictor.py
CHANGED
|
@@ -24,6 +24,7 @@ logger = logging.getLogger(__name__)
|
|
| 24 |
KRONOS_DIR = os.environ.get("KRONOS_DIR", "/app/Kronos")
|
| 25 |
MODEL_ID = "NeoQuasar/Kronos-base"
|
| 26 |
TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-base"
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
# ── Bootstrap Kronos source ──────────────────────────────────────────────────
|
|
@@ -66,6 +67,37 @@ def get_predictor() -> KronosPredictor:
|
|
| 66 |
return _predictor
|
| 67 |
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
# ── Monte-Carlo prediction ────────────────────────────────────────────────────
|
| 70 |
def run_mc_prediction(
|
| 71 |
x_df: pd.DataFrame,
|
|
@@ -83,45 +115,62 @@ def run_mc_prediction(
|
|
| 83 |
ci : dict[field]["low"/"high"] → ndarray(pred_len,), 95% CI
|
| 84 |
trading_low : ndarray(pred_len,), q2.5 of predicted_low
|
| 85 |
trading_high : ndarray(pred_len,), q97.5 of predicted_high
|
| 86 |
-
direction_prob : float ∈ [0,1],
|
| 87 |
last_close : float, closing price of the last historical bar
|
| 88 |
"""
|
| 89 |
predictor = get_predictor()
|
| 90 |
-
samples = []
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
|
|
|
|
| 93 |
with _infer_lock:
|
| 94 |
-
|
| 95 |
df=x_df,
|
| 96 |
x_timestamp=x_timestamp,
|
| 97 |
y_timestamp=y_timestamp,
|
| 98 |
pred_len=pred_len,
|
| 99 |
T=0.8,
|
| 100 |
top_p=0.9,
|
| 101 |
-
sample_count=
|
| 102 |
verbose=False,
|
| 103 |
)
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
pred_mean = pd.concat(samples).groupby(level=0).mean()
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
| 110 |
|
| 111 |
alpha = 2.5 # → 95 % CI
|
| 112 |
ci = {
|
| 113 |
field: {
|
| 114 |
-
"low": np.percentile(
|
| 115 |
-
"high": np.percentile(
|
| 116 |
}
|
| 117 |
-
for field in
|
| 118 |
}
|
| 119 |
|
| 120 |
trading_low = ci["low"]["low"] # q2.5 of the predicted daily low
|
| 121 |
trading_high = ci["high"]["high"] # q97.5 of the predicted daily high
|
| 122 |
|
| 123 |
last_close = float(x_df["close"].iloc[-1])
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
return pred_mean, ci, trading_low, trading_high, direction_prob, last_close
|
|
|
|
| 24 |
KRONOS_DIR = os.environ.get("KRONOS_DIR", "/app/Kronos")
|
| 25 |
MODEL_ID = "NeoQuasar/Kronos-base"
|
| 26 |
TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-base"
|
| 27 |
+
MC_BATCH_SIZE = max(1, int(os.environ.get("MC_BATCH_SIZE", "8")))
|
| 28 |
|
| 29 |
|
| 30 |
# ── Bootstrap Kronos source ──────────────────────────────────────────────────
|
|
|
|
| 67 |
return _predictor
|
| 68 |
|
| 69 |
|
| 70 |
+
def _split_batched_output(
|
| 71 |
+
pred_output,
|
| 72 |
+
expected_count: int,
|
| 73 |
+
pred_len: int,
|
| 74 |
+
) -> list[pd.DataFrame]:
|
| 75 |
+
"""
|
| 76 |
+
Normalize predictor output into `expected_count` DataFrame samples.
|
| 77 |
+
Supports single-sample DataFrame and common batched return shapes.
|
| 78 |
+
"""
|
| 79 |
+
if isinstance(pred_output, pd.DataFrame):
|
| 80 |
+
if expected_count == 1:
|
| 81 |
+
return [pred_output]
|
| 82 |
+
if isinstance(pred_output.index, pd.MultiIndex):
|
| 83 |
+
grouped = [g.droplevel(0) for _, g in pred_output.groupby(level=0, sort=False)]
|
| 84 |
+
if len(grouped) == expected_count:
|
| 85 |
+
return grouped
|
| 86 |
+
if len(pred_output) == expected_count * pred_len:
|
| 87 |
+
return [
|
| 88 |
+
pred_output.iloc[i * pred_len:(i + 1) * pred_len].copy()
|
| 89 |
+
for i in range(expected_count)
|
| 90 |
+
]
|
| 91 |
+
if isinstance(pred_output, (list, tuple)):
|
| 92 |
+
if len(pred_output) == expected_count and all(
|
| 93 |
+
isinstance(item, pd.DataFrame) for item in pred_output
|
| 94 |
+
):
|
| 95 |
+
return list(pred_output)
|
| 96 |
+
if expected_count == 1 and len(pred_output) == 1 and isinstance(pred_output[0], pd.DataFrame):
|
| 97 |
+
return [pred_output[0]]
|
| 98 |
+
raise ValueError("Unsupported predict() output format for batched sampling")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
# ── Monte-Carlo prediction ────────────────────────────────────────────────────
|
| 102 |
def run_mc_prediction(
|
| 103 |
x_df: pd.DataFrame,
|
|
|
|
| 115 |
ci : dict[field]["low"/"high"] → ndarray(pred_len,), 95% CI
|
| 116 |
trading_low : ndarray(pred_len,), q2.5 of predicted_low
|
| 117 |
trading_high : ndarray(pred_len,), q97.5 of predicted_high
|
| 118 |
+
direction_prob : float ∈ [0,1], horizon-level bullish probability
|
| 119 |
last_close : float, closing price of the last historical bar
|
| 120 |
"""
|
| 121 |
predictor = get_predictor()
|
| 122 |
+
samples: list[pd.DataFrame] = []
|
| 123 |
+
supports_batched_sampling = True
|
| 124 |
+
remaining = sample_count
|
| 125 |
|
| 126 |
+
while remaining > 0:
|
| 127 |
+
batch_n = min(remaining, MC_BATCH_SIZE if supports_batched_sampling else 1)
|
| 128 |
with _infer_lock:
|
| 129 |
+
pred_output = predictor.predict(
|
| 130 |
df=x_df,
|
| 131 |
x_timestamp=x_timestamp,
|
| 132 |
y_timestamp=y_timestamp,
|
| 133 |
pred_len=pred_len,
|
| 134 |
T=0.8,
|
| 135 |
top_p=0.9,
|
| 136 |
+
sample_count=batch_n,
|
| 137 |
verbose=False,
|
| 138 |
)
|
| 139 |
+
try:
|
| 140 |
+
batch_samples = _split_batched_output(pred_output, batch_n, pred_len)
|
| 141 |
+
except ValueError:
|
| 142 |
+
if batch_n > 1:
|
| 143 |
+
# Fallback for predictor implementations that do not support
|
| 144 |
+
# returning per-sample outputs for sample_count>1.
|
| 145 |
+
supports_batched_sampling = False
|
| 146 |
+
continue
|
| 147 |
+
raise
|
| 148 |
+
samples.extend(batch_samples)
|
| 149 |
+
remaining -= batch_n
|
| 150 |
|
| 151 |
pred_mean = pd.concat(samples).groupby(level=0).mean()
|
| 152 |
+
stacked = {
|
| 153 |
+
field: np.stack([s[field].values for s in samples]) # (sample_count, pred_len)
|
| 154 |
+
for field in ["open", "high", "low", "close", "volume"]
|
| 155 |
+
}
|
| 156 |
|
| 157 |
alpha = 2.5 # → 95 % CI
|
| 158 |
ci = {
|
| 159 |
field: {
|
| 160 |
+
"low": np.percentile(stacked[field], alpha, axis=0),
|
| 161 |
+
"high": np.percentile(stacked[field], 100 - alpha, axis=0),
|
| 162 |
}
|
| 163 |
+
for field in stacked
|
| 164 |
}
|
| 165 |
|
| 166 |
trading_low = ci["low"]["low"] # q2.5 of the predicted daily low
|
| 167 |
trading_high = ci["high"]["high"] # q97.5 of the predicted daily high
|
| 168 |
|
| 169 |
last_close = float(x_df["close"].iloc[-1])
|
| 170 |
+
close_paths = stacked["close"] # (sample_count, pred_len)
|
| 171 |
+
# Use all future points to estimate horizon bullish probability.
|
| 172 |
+
bull_count = int((close_paths > last_close).sum())
|
| 173 |
+
total_points = int(close_paths.size)
|
| 174 |
+
direction_prob = bull_count / total_points
|
| 175 |
|
| 176 |
return pred_mean, ci, trading_low, trading_high, direction_prob, last_close
|
requirements.txt
CHANGED
|
@@ -9,4 +9,4 @@ huggingface_hub==0.33.1
|
|
| 9 |
matplotlib==3.9.3
|
| 10 |
tqdm==4.67.1
|
| 11 |
safetensors==0.6.2
|
| 12 |
-
|
|
|
|
| 9 |
matplotlib==3.9.3
|
| 10 |
tqdm==4.67.1
|
| 11 |
safetensors==0.6.2
|
| 12 |
+
akshare
|