Spaces:
Running
Running
fengwm commited on
Commit ·
23022b9
1
Parent(s): 72f5e10
更新 README 和代码,数据源切换为 Tushare,优化请求重试机制
Browse files- README.md +5 -2
- data_fetcher.py +89 -109
- requirements.txt +1 -1
README.md
CHANGED
|
@@ -11,7 +11,7 @@ pinned: false
|
|
| 11 |
|
| 12 |
基于[清华大学 Kronos 金融 K 线基础大模型](https://arxiv.org/abs/2508.02739)的 A 股概率预测 REST API。
|
| 13 |
|
| 14 |
-
- **数据源**:
|
| 15 |
- **推理方式**:蒙特卡洛分批采样,输出预测方向、置信度及 95% 交易区间
|
| 16 |
- **异步任务**:POST 提交 → 返回 `task_id` → GET 轮询结果
|
| 17 |
|
|
@@ -20,7 +20,7 @@ pinned: false
|
|
| 20 |
## 更新说明(2026-03-16)
|
| 21 |
|
| 22 |
- 请求/响应主字段由 `ts_code` 统一为 `symbol`
|
| 23 |
-
- 数据源切换
|
| 24 |
- 方向概率 `direction.probability` 定义为“预测区间内看涨概率(0–1)”
|
| 25 |
- 推理路径改为分批采样(可通过 `MC_BATCH_SIZE` 调整批大小)
|
| 26 |
- 新增阶段耗时日志(`fetch/calendar/infer/build/cache/total`)
|
|
@@ -311,8 +311,11 @@ curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache"
|
|
| 311 |
|
| 312 |
| 环境变量 | 默认值 | 说明 |
|
| 313 |
|---|---|---|
|
|
|
|
| 314 |
| `KRONOS_DIR` | `/app/Kronos` | Kronos 源码目录 |
|
| 315 |
| `MC_BATCH_SIZE` | `8` | 蒙特卡洛分批采样大小(越大通常越快,但占用显存/内存更高) |
|
|
|
|
|
|
|
| 316 |
|
| 317 |
---
|
| 318 |
|
|
|
|
| 11 |
|
| 12 |
基于[清华大学 Kronos 金融 K 线基础大模型](https://arxiv.org/abs/2508.02739)的 A 股概率预测 REST API。
|
| 13 |
|
| 14 |
+
- **数据源**:Tushare Pro A 股日线,前复权(qfq)
|
| 15 |
- **推理方式**:蒙特卡洛分批采样,输出预测方向、置信度及 95% 交易区间
|
| 16 |
- **异步任务**:POST 提交 → 返回 `task_id` → GET 轮询结果
|
| 17 |
|
|
|
|
| 20 |
## 更新说明(2026-03-16)
|
| 21 |
|
| 22 |
- 请求/响应主字段由 `ts_code` 统一为 `symbol`
|
| 23 |
+
- 数据源切换回 Tushare Pro(`pro_bar`,前复权 `qfq`)
|
| 24 |
- 方向概率 `direction.probability` 定义为“预测区间内看涨概率(0–1)”
|
| 25 |
- 推理路径改为分批采样(可通过 `MC_BATCH_SIZE` 调整批大小)
|
| 26 |
- 新增阶段耗时日志(`fetch/calendar/infer/build/cache/total`)
|
|
|
|
| 311 |
|
| 312 |
| 环境变量 | 默认值 | 说明 |
|
| 313 |
|---|---|---|
|
| 314 |
+
| `TUSHARE_TOKEN` | — | Tushare 访问令牌(必填) |
|
| 315 |
| `KRONOS_DIR` | `/app/Kronos` | Kronos 源码目录 |
|
| 316 |
| `MC_BATCH_SIZE` | `8` | 蒙特卡洛分批采样大小(越大通常越快,但占用显存/内存更高) |
|
| 317 |
+
| `TS_RETRY_COUNT` | `3` | Tushare 请求重试次数 |
|
| 318 |
+
| `TS_RETRY_BASE_SLEEP` | `0.8` | Tushare 重试退避基准秒数(指数退避) |
|
| 319 |
|
| 320 |
---
|
| 321 |
|
data_fetcher.py
CHANGED
|
@@ -3,88 +3,80 @@ import os
|
|
| 3 |
import threading
|
| 4 |
import time
|
| 5 |
|
| 6 |
-
import akshare as ak
|
| 7 |
import pandas as pd
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
_TRADE_CALENDAR_LOCK = threading.Lock()
|
| 13 |
-
_AK_RETRY_COUNT = max(1, int(os.environ.get("AK_RETRY_COUNT", "3")))
|
| 14 |
-
_AK_RETRY_BASE_SLEEP = float(os.environ.get("AK_RETRY_BASE_SLEEP", "0.8"))
|
| 15 |
-
_AK_TIMEOUT = float(os.environ.get("AK_TIMEOUT", "15"))
|
| 16 |
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
"""
|
| 20 |
-
|
| 21 |
-
`ak.stock_zh_a_hist`.
|
| 22 |
|
| 23 |
Accepted examples:
|
| 24 |
-
- "603777"
|
| 25 |
-
- "
|
| 26 |
-
- "
|
|
|
|
| 27 |
"""
|
| 28 |
symbol = raw_symbol.strip().upper()
|
| 29 |
if "." in symbol:
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
if len(symbol) != 6 or not symbol.isdigit():
|
| 32 |
raise ValueError(
|
| 33 |
-
f"Invalid
|
| 34 |
-
"or Tushare-style code like '600900.SH'."
|
| 35 |
)
|
| 36 |
-
return symbol
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
last_exc: Exception | None = None
|
| 44 |
-
for attempt in range(1,
|
| 45 |
try:
|
| 46 |
return fn()
|
| 47 |
except Exception as exc: # pragma: no cover - external IO
|
| 48 |
last_exc = exc
|
| 49 |
-
if attempt >=
|
| 50 |
break
|
| 51 |
-
time.sleep(
|
| 52 |
raise RuntimeError(
|
| 53 |
-
f"
|
| 54 |
) from last_exc
|
| 55 |
|
| 56 |
|
| 57 |
-
def _get_trade_calendar_cached() -> pd.DatetimeIndex:
|
| 58 |
-
"""
|
| 59 |
-
Fetch and cache exchange trading dates in-process to avoid repeated
|
| 60 |
-
network calls on each request.
|
| 61 |
-
"""
|
| 62 |
-
global _TRADE_CALENDAR_CACHE, _TRADE_CALENDAR_CACHED_AT
|
| 63 |
-
|
| 64 |
-
now = datetime.now()
|
| 65 |
-
with _TRADE_CALENDAR_LOCK:
|
| 66 |
-
if (
|
| 67 |
-
_TRADE_CALENDAR_CACHE is not None
|
| 68 |
-
and _TRADE_CALENDAR_CACHED_AT is not None
|
| 69 |
-
and (now - _TRADE_CALENDAR_CACHED_AT) < _TRADE_CALENDAR_CACHE_TTL
|
| 70 |
-
):
|
| 71 |
-
return _TRADE_CALENDAR_CACHE
|
| 72 |
-
|
| 73 |
-
cal = _retry_ak_call(
|
| 74 |
-
lambda: ak.tool_trade_date_hist_sina(),
|
| 75 |
-
call_name="tool_trade_date_hist_sina",
|
| 76 |
-
)
|
| 77 |
-
cal_col = "trade_date" if "trade_date" in cal.columns else "日期"
|
| 78 |
-
all_dates = pd.to_datetime(cal[cal_col]).sort_values().drop_duplicates()
|
| 79 |
-
cached = pd.DatetimeIndex(all_dates)
|
| 80 |
-
|
| 81 |
-
with _TRADE_CALENDAR_LOCK:
|
| 82 |
-
_TRADE_CALENDAR_CACHE = cached
|
| 83 |
-
_TRADE_CALENDAR_CACHED_AT = now
|
| 84 |
-
|
| 85 |
-
return cached
|
| 86 |
-
|
| 87 |
-
|
| 88 |
def fetch_stock_data(
|
| 89 |
symbol: str, lookback: int
|
| 90 |
) -> tuple[pd.DataFrame, pd.Series, str]:
|
|
@@ -94,55 +86,36 @@ def fetch_stock_data(
|
|
| 94 |
x_timestamp : pd.Series[datetime], aligned to x_df
|
| 95 |
last_trade_date: str "YYYYMMDD", the most recent bar date
|
| 96 |
"""
|
| 97 |
-
|
| 98 |
end_date = datetime.today().strftime("%Y%m%d")
|
| 99 |
-
#
|
| 100 |
-
start_date = (datetime.today() - timedelta(days=lookback *
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
| 106 |
start_date=start_date,
|
| 107 |
end_date=end_date,
|
| 108 |
-
|
| 109 |
-
timeout=_AK_TIMEOUT,
|
| 110 |
),
|
| 111 |
-
call_name=f"
|
| 112 |
)
|
| 113 |
|
| 114 |
if df is None or df.empty:
|
| 115 |
-
raise ValueError(f"No data returned for symbol={symbol!r}")
|
| 116 |
-
|
| 117 |
-
df = df.rename(
|
| 118 |
-
columns={
|
| 119 |
-
"日期": "trade_date",
|
| 120 |
-
"开盘": "open",
|
| 121 |
-
"最高": "high",
|
| 122 |
-
"最低": "low",
|
| 123 |
-
"收盘": "close",
|
| 124 |
-
"成交量": "volume",
|
| 125 |
-
"成交额": "amount",
|
| 126 |
-
}
|
| 127 |
-
)
|
| 128 |
-
required_cols = ["trade_date", "open", "high", "low", "close", "volume", "amount"]
|
| 129 |
-
missing = [c for c in required_cols if c not in df.columns]
|
| 130 |
-
if missing:
|
| 131 |
-
raise ValueError(f"AkShare response missing columns: {missing}")
|
| 132 |
-
|
| 133 |
-
df["trade_date"] = pd.to_datetime(df["trade_date"])
|
| 134 |
-
for col in ["open", "high", "low", "close", "volume", "amount"]:
|
| 135 |
-
df[col] = pd.to_numeric(df[col], errors="coerce")
|
| 136 |
-
df = df.dropna(subset=["trade_date", "open", "high", "low", "close", "volume", "amount"])
|
| 137 |
df = df.sort_values("trade_date").reset_index(drop=True)
|
| 138 |
-
df
|
|
|
|
| 139 |
|
| 140 |
# Keep the most recent `lookback` bars
|
| 141 |
df = df.tail(lookback).reset_index(drop=True)
|
| 142 |
|
| 143 |
x_df = df[["open", "high", "low", "close", "volume", "amount"]].copy()
|
| 144 |
x_timestamp = df["timestamps"].copy()
|
| 145 |
-
last_trade_date = df["trade_date"].iloc[-1]
|
| 146 |
|
| 147 |
return x_df, x_timestamp, last_trade_date
|
| 148 |
|
|
@@ -153,20 +126,27 @@ def get_future_trading_dates(last_trade_date: str, pred_len: int) -> pd.Series:
|
|
| 153 |
follow `last_trade_date` (format: YYYYMMDD).
|
| 154 |
"""
|
| 155 |
last_dt = datetime.strptime(last_trade_date, "%Y%m%d")
|
| 156 |
-
dates
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import threading
|
| 4 |
import time
|
| 5 |
|
|
|
|
| 6 |
import pandas as pd
|
| 7 |
+
import tushare as ts
|
| 8 |
|
| 9 |
+
_TUSHARE_TOKEN = os.environ.get("TUSHARE_TOKEN", "").strip()
|
| 10 |
+
_TS_RETRY_COUNT = max(1, int(os.environ.get("TS_RETRY_COUNT", "3")))
|
| 11 |
+
_TS_RETRY_BASE_SLEEP = float(os.environ.get("TS_RETRY_BASE_SLEEP", "0.8"))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
_PRO = None
|
| 14 |
+
_PRO_LOCK = threading.Lock()
|
| 15 |
|
| 16 |
+
|
| 17 |
+
def _get_pro():
|
| 18 |
+
global _PRO
|
| 19 |
+
with _PRO_LOCK:
|
| 20 |
+
if _PRO is not None:
|
| 21 |
+
return _PRO
|
| 22 |
+
if not _TUSHARE_TOKEN:
|
| 23 |
+
raise RuntimeError("TUSHARE_TOKEN is not set")
|
| 24 |
+
ts.set_token(_TUSHARE_TOKEN)
|
| 25 |
+
_PRO = ts.pro_api()
|
| 26 |
+
return _PRO
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _normalize_to_ts_code(raw_symbol: str) -> str:
|
| 30 |
"""
|
| 31 |
+
Normalize user `symbol` input to Tushare `ts_code`.
|
|
|
|
| 32 |
|
| 33 |
Accepted examples:
|
| 34 |
+
- "603777" -> "603777.SH"
|
| 35 |
+
- "300065" -> "300065.SZ"
|
| 36 |
+
- "430047" -> "430047.BJ"
|
| 37 |
+
- "000063.SZ" -> "000063.SZ"
|
| 38 |
"""
|
| 39 |
symbol = raw_symbol.strip().upper()
|
| 40 |
if "." in symbol:
|
| 41 |
+
code, market = symbol.split(".", 1)
|
| 42 |
+
if len(code) != 6 or not code.isdigit() or market not in {"SH", "SZ", "BJ"}:
|
| 43 |
+
raise ValueError(
|
| 44 |
+
f"Invalid symbol {raw_symbol!r}; expected e.g. '603777' or '000063.SZ'."
|
| 45 |
+
)
|
| 46 |
+
return f"{code}.{market}"
|
| 47 |
+
|
| 48 |
if len(symbol) != 6 or not symbol.isdigit():
|
| 49 |
raise ValueError(
|
| 50 |
+
f"Invalid symbol {raw_symbol!r}; expected 6 digits or code with suffix."
|
|
|
|
| 51 |
)
|
|
|
|
| 52 |
|
| 53 |
+
if symbol.startswith("6"):
|
| 54 |
+
market = "SH"
|
| 55 |
+
elif symbol.startswith(("0", "3")):
|
| 56 |
+
market = "SZ"
|
| 57 |
+
elif symbol.startswith(("4", "8")):
|
| 58 |
+
market = "BJ"
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError(f"Cannot infer exchange suffix for symbol={raw_symbol!r}")
|
| 61 |
|
| 62 |
+
return f"{symbol}.{market}"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _retry_call(fn, *, call_name: str):
|
| 66 |
last_exc: Exception | None = None
|
| 67 |
+
for attempt in range(1, _TS_RETRY_COUNT + 1):
|
| 68 |
try:
|
| 69 |
return fn()
|
| 70 |
except Exception as exc: # pragma: no cover - external IO
|
| 71 |
last_exc = exc
|
| 72 |
+
if attempt >= _TS_RETRY_COUNT:
|
| 73 |
break
|
| 74 |
+
time.sleep(_TS_RETRY_BASE_SLEEP * (2 ** (attempt - 1)))
|
| 75 |
raise RuntimeError(
|
| 76 |
+
f"Tushare call failed after {_TS_RETRY_COUNT} attempts ({call_name}): {last_exc}"
|
| 77 |
) from last_exc
|
| 78 |
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
def fetch_stock_data(
|
| 81 |
symbol: str, lookback: int
|
| 82 |
) -> tuple[pd.DataFrame, pd.Series, str]:
|
|
|
|
| 86 |
x_timestamp : pd.Series[datetime], aligned to x_df
|
| 87 |
last_trade_date: str "YYYYMMDD", the most recent bar date
|
| 88 |
"""
|
| 89 |
+
ts_code = _normalize_to_ts_code(symbol)
|
| 90 |
end_date = datetime.today().strftime("%Y%m%d")
|
| 91 |
+
# 2x buffer to account for weekends/holidays.
|
| 92 |
+
start_date = (datetime.today() - timedelta(days=lookback * 2)).strftime("%Y%m%d")
|
| 93 |
+
|
| 94 |
+
pro = _get_pro()
|
| 95 |
+
df = _retry_call(
|
| 96 |
+
lambda: ts.pro_bar(
|
| 97 |
+
ts_code=ts_code,
|
| 98 |
+
adj="qfq",
|
| 99 |
start_date=start_date,
|
| 100 |
end_date=end_date,
|
| 101 |
+
asset="E",
|
|
|
|
| 102 |
),
|
| 103 |
+
call_name=f"pro_bar(ts_code={ts_code})",
|
| 104 |
)
|
| 105 |
|
| 106 |
if df is None or df.empty:
|
| 107 |
+
raise ValueError(f"No data returned for symbol={symbol!r} (ts_code={ts_code})")
|
| 108 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
df = df.sort_values("trade_date").reset_index(drop=True)
|
| 110 |
+
df = df.rename(columns={"vol": "volume"})
|
| 111 |
+
df["timestamps"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
|
| 112 |
|
| 113 |
# Keep the most recent `lookback` bars
|
| 114 |
df = df.tail(lookback).reset_index(drop=True)
|
| 115 |
|
| 116 |
x_df = df[["open", "high", "low", "close", "volume", "amount"]].copy()
|
| 117 |
x_timestamp = df["timestamps"].copy()
|
| 118 |
+
last_trade_date = str(df["trade_date"].iloc[-1])
|
| 119 |
|
| 120 |
return x_df, x_timestamp, last_trade_date
|
| 121 |
|
|
|
|
| 126 |
follow `last_trade_date` (format: YYYYMMDD).
|
| 127 |
"""
|
| 128 |
last_dt = datetime.strptime(last_trade_date, "%Y%m%d")
|
| 129 |
+
# 3x buffer so we always have enough dates even over a long holiday
|
| 130 |
+
end_dt = last_dt + timedelta(days=pred_len * 3)
|
| 131 |
+
|
| 132 |
+
pro = _get_pro()
|
| 133 |
+
cal = _retry_call(
|
| 134 |
+
lambda: pro.trade_cal(
|
| 135 |
+
exchange="SSE",
|
| 136 |
+
start_date=(last_dt + timedelta(days=1)).strftime("%Y%m%d"),
|
| 137 |
+
end_date=end_dt.strftime("%Y%m%d"),
|
| 138 |
+
is_open="1",
|
| 139 |
+
),
|
| 140 |
+
call_name="trade_cal(exchange=SSE)",
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
cal = cal.sort_values("cal_date")
|
| 144 |
+
dates = pd.to_datetime(cal["cal_date"].values[:pred_len], format="%Y%m%d")
|
| 145 |
+
|
| 146 |
+
if len(dates) < pred_len:
|
| 147 |
+
raise ValueError(
|
| 148 |
+
f"Could only obtain {len(dates)} future trading dates; "
|
| 149 |
+
f"increase buffer or check Tushare calendar coverage."
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
return pd.Series(dates)
|
requirements.txt
CHANGED
|
@@ -9,4 +9,4 @@ huggingface_hub==0.33.1
|
|
| 9 |
matplotlib==3.9.3
|
| 10 |
tqdm==4.67.1
|
| 11 |
safetensors==0.6.2
|
| 12 |
-
|
|
|
|
| 9 |
matplotlib==3.9.3
|
| 10 |
tqdm==4.67.1
|
| 11 |
safetensors==0.6.2
|
| 12 |
+
tushare
|