fengwm commited on
Commit
23022b9
·
1 Parent(s): 72f5e10

更新 README 和代码,数据源切换为 Tushare,优化请求重试机制

Browse files
Files changed (3) hide show
  1. README.md +5 -2
  2. data_fetcher.py +89 -109
  3. requirements.txt +1 -1
README.md CHANGED
@@ -11,7 +11,7 @@ pinned: false
11
 
12
  基于[清华大学 Kronos 金融 K 线基础大模型](https://arxiv.org/abs/2508.02739)的 A 股概率预测 REST API。
13
 
14
- - **数据源**:AkShare 东方财富 A 股日线,前复权(qfq)
15
  - **推理方式**:蒙特卡洛分批采样,输出预测方向、置信度及 95% 交易区间
16
  - **异步任务**:POST 提交 → 返回 `task_id` → GET 轮询结果
17
 
@@ -20,7 +20,7 @@ pinned: false
20
  ## 更新说明(2026-03-16)
21
 
22
  - 请求/响应主字段由 `ts_code` 统一为 `symbol`
23
- - 数据源切换 AkShare 东财日线接口(`stock_zh_a_hist`,前复权 `qfq`)
24
  - 方向概率 `direction.probability` 定义为“预测区间内看涨概率(0–1)”
25
  - 推理路径改为分批采样(可通过 `MC_BATCH_SIZE` 调整批大小)
26
  - 新增阶段耗时日志(`fetch/calendar/infer/build/cache/total`)
@@ -311,8 +311,11 @@ curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache"
311
 
312
  | 环境变量 | 默认值 | 说明 |
313
  |---|---|---|
 
314
  | `KRONOS_DIR` | `/app/Kronos` | Kronos 源码目录 |
315
  | `MC_BATCH_SIZE` | `8` | 蒙特卡洛分批采样大小(越大通常越快,但占用显存/内存更高) |
 
 
316
 
317
  ---
318
 
 
11
 
12
  基于[清华大学 Kronos 金融 K 线基础大模型](https://arxiv.org/abs/2508.02739)的 A 股概率预测 REST API。
13
 
14
+ - **数据源**:Tushare Pro A 股日线,前复权(qfq)
15
  - **推理方式**:蒙特卡洛分批采样,输出预测方向、置信度及 95% 交易区间
16
  - **异步任务**:POST 提交 → 返回 `task_id` → GET 轮询结果
17
 
 
20
  ## 更新说明(2026-03-16)
21
 
22
  - 请求/响应主字段由 `ts_code` 统一为 `symbol`
23
+ - 数据源切换 Tushare Pro(`pro_bar`,前复权 `qfq`)
24
  - 方向概率 `direction.probability` 定义为“预测区间内看涨概率(0–1)”
25
  - 推理路径改为分批采样(可通过 `MC_BATCH_SIZE` 调整批大小)
26
  - 新增阶段耗时日志(`fetch/calendar/infer/build/cache/total`)
 
311
 
312
  | 环境变量 | 默认值 | 说明 |
313
  |---|---|---|
314
+ | `TUSHARE_TOKEN` | — | Tushare 访问令牌(必填) |
315
  | `KRONOS_DIR` | `/app/Kronos` | Kronos 源码目录 |
316
  | `MC_BATCH_SIZE` | `8` | 蒙特卡洛分批采样大小(越大通常越快,但占用显存/内存更高) |
317
+ | `TS_RETRY_COUNT` | `3` | Tushare 请求重试次数 |
318
+ | `TS_RETRY_BASE_SLEEP` | `0.8` | Tushare 重试退避基准秒数(指数退避) |
319
 
320
  ---
321
 
data_fetcher.py CHANGED
@@ -3,88 +3,80 @@ import os
3
  import threading
4
  import time
5
 
6
- import akshare as ak
7
  import pandas as pd
 
8
 
9
- _TRADE_CALENDAR_CACHE: pd.DatetimeIndex | None = None
10
- _TRADE_CALENDAR_CACHED_AT: datetime | None = None
11
- _TRADE_CALENDAR_CACHE_TTL = timedelta(hours=12)
12
- _TRADE_CALENDAR_LOCK = threading.Lock()
13
- _AK_RETRY_COUNT = max(1, int(os.environ.get("AK_RETRY_COUNT", "3")))
14
- _AK_RETRY_BASE_SLEEP = float(os.environ.get("AK_RETRY_BASE_SLEEP", "0.8"))
15
- _AK_TIMEOUT = float(os.environ.get("AK_TIMEOUT", "15"))
16
 
 
 
17
 
18
- def _normalize_symbol(raw_symbol: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  """
20
- Convert user input into the 6-digit stock code expected by
21
- `ak.stock_zh_a_hist`.
22
 
23
  Accepted examples:
24
- - "603777"
25
- - "600900.SH"
26
- - "000063.SZ"
 
27
  """
28
  symbol = raw_symbol.strip().upper()
29
  if "." in symbol:
30
- symbol = symbol.split(".", 1)[0]
 
 
 
 
 
 
31
  if len(symbol) != 6 or not symbol.isdigit():
32
  raise ValueError(
33
- f"Invalid stock code {raw_symbol!r}; expected 6 digits like '603777' "
34
- "or Tushare-style code like '600900.SH'."
35
  )
36
- return symbol
37
 
 
 
 
 
 
 
 
 
38
 
39
- def _retry_ak_call(fn, *, call_name: str):
40
- """
41
- Retry wrapper for AkShare calls to handle transient network disconnects.
42
- """
43
  last_exc: Exception | None = None
44
- for attempt in range(1, _AK_RETRY_COUNT + 1):
45
  try:
46
  return fn()
47
  except Exception as exc: # pragma: no cover - external IO
48
  last_exc = exc
49
- if attempt >= _AK_RETRY_COUNT:
50
  break
51
- time.sleep(_AK_RETRY_BASE_SLEEP * (2 ** (attempt - 1)))
52
  raise RuntimeError(
53
- f"AkShare call failed after {_AK_RETRY_COUNT} attempts ({call_name}): {last_exc}"
54
  ) from last_exc
55
 
56
 
57
- def _get_trade_calendar_cached() -> pd.DatetimeIndex:
58
- """
59
- Fetch and cache exchange trading dates in-process to avoid repeated
60
- network calls on each request.
61
- """
62
- global _TRADE_CALENDAR_CACHE, _TRADE_CALENDAR_CACHED_AT
63
-
64
- now = datetime.now()
65
- with _TRADE_CALENDAR_LOCK:
66
- if (
67
- _TRADE_CALENDAR_CACHE is not None
68
- and _TRADE_CALENDAR_CACHED_AT is not None
69
- and (now - _TRADE_CALENDAR_CACHED_AT) < _TRADE_CALENDAR_CACHE_TTL
70
- ):
71
- return _TRADE_CALENDAR_CACHE
72
-
73
- cal = _retry_ak_call(
74
- lambda: ak.tool_trade_date_hist_sina(),
75
- call_name="tool_trade_date_hist_sina",
76
- )
77
- cal_col = "trade_date" if "trade_date" in cal.columns else "日期"
78
- all_dates = pd.to_datetime(cal[cal_col]).sort_values().drop_duplicates()
79
- cached = pd.DatetimeIndex(all_dates)
80
-
81
- with _TRADE_CALENDAR_LOCK:
82
- _TRADE_CALENDAR_CACHE = cached
83
- _TRADE_CALENDAR_CACHED_AT = now
84
-
85
- return cached
86
-
87
-
88
  def fetch_stock_data(
89
  symbol: str, lookback: int
90
  ) -> tuple[pd.DataFrame, pd.Series, str]:
@@ -94,55 +86,36 @@ def fetch_stock_data(
94
  x_timestamp : pd.Series[datetime], aligned to x_df
95
  last_trade_date: str "YYYYMMDD", the most recent bar date
96
  """
97
- normalized_symbol = _normalize_symbol(symbol)
98
  end_date = datetime.today().strftime("%Y%m%d")
99
- # 4x buffer to account for weekends/long holidays.
100
- start_date = (datetime.today() - timedelta(days=lookback * 4)).strftime("%Y%m%d")
101
-
102
- df = _retry_ak_call(
103
- lambda: ak.stock_zh_a_hist(
104
- symbol=normalized_symbol,
105
- period="daily",
 
106
  start_date=start_date,
107
  end_date=end_date,
108
- adjust="qfq",
109
- timeout=_AK_TIMEOUT,
110
  ),
111
- call_name=f"stock_zh_a_hist(symbol={normalized_symbol})",
112
  )
113
 
114
  if df is None or df.empty:
115
- raise ValueError(f"No data returned for symbol={symbol!r}")
116
-
117
- df = df.rename(
118
- columns={
119
- "日期": "trade_date",
120
- "开盘": "open",
121
- "最高": "high",
122
- "最低": "low",
123
- "收盘": "close",
124
- "成交量": "volume",
125
- "成交额": "amount",
126
- }
127
- )
128
- required_cols = ["trade_date", "open", "high", "low", "close", "volume", "amount"]
129
- missing = [c for c in required_cols if c not in df.columns]
130
- if missing:
131
- raise ValueError(f"AkShare response missing columns: {missing}")
132
-
133
- df["trade_date"] = pd.to_datetime(df["trade_date"])
134
- for col in ["open", "high", "low", "close", "volume", "amount"]:
135
- df[col] = pd.to_numeric(df[col], errors="coerce")
136
- df = df.dropna(subset=["trade_date", "open", "high", "low", "close", "volume", "amount"])
137
  df = df.sort_values("trade_date").reset_index(drop=True)
138
- df["timestamps"] = df["trade_date"]
 
139
 
140
  # Keep the most recent `lookback` bars
141
  df = df.tail(lookback).reset_index(drop=True)
142
 
143
  x_df = df[["open", "high", "low", "close", "volume", "amount"]].copy()
144
  x_timestamp = df["timestamps"].copy()
145
- last_trade_date = df["trade_date"].iloc[-1].strftime("%Y%m%d")
146
 
147
  return x_df, x_timestamp, last_trade_date
148
 
@@ -153,20 +126,27 @@ def get_future_trading_dates(last_trade_date: str, pred_len: int) -> pd.Series:
153
  follow `last_trade_date` (format: YYYYMMDD).
154
  """
155
  last_dt = datetime.strptime(last_trade_date, "%Y%m%d")
156
- dates: list[pd.Timestamp] = []
157
-
158
- # Prefer real exchange trade dates from AkShare.
159
- try:
160
- all_dates = _get_trade_calendar_cached()
161
- dates.extend([d for d in all_dates if d > pd.Timestamp(last_dt)][:pred_len])
162
- except Exception:
163
- # If calendar fetch fails, fall back to weekday-based dates.
164
- pass
165
-
166
- candidate = last_dt + timedelta(days=1)
167
- while len(dates) < pred_len:
168
- if candidate.weekday() < 5:
169
- dates.append(pd.Timestamp(candidate))
170
- candidate += timedelta(days=1)
171
-
172
- return pd.Series(pd.DatetimeIndex(dates[:pred_len]))
 
 
 
 
 
 
 
 
3
  import threading
4
  import time
5
 
 
6
  import pandas as pd
7
+ import tushare as ts
8
 
9
+ _TUSHARE_TOKEN = os.environ.get("TUSHARE_TOKEN", "").strip()
10
+ _TS_RETRY_COUNT = max(1, int(os.environ.get("TS_RETRY_COUNT", "3")))
11
+ _TS_RETRY_BASE_SLEEP = float(os.environ.get("TS_RETRY_BASE_SLEEP", "0.8"))
 
 
 
 
12
 
13
+ _PRO = None
14
+ _PRO_LOCK = threading.Lock()
15
 
16
+
17
+ def _get_pro():
18
+ global _PRO
19
+ with _PRO_LOCK:
20
+ if _PRO is not None:
21
+ return _PRO
22
+ if not _TUSHARE_TOKEN:
23
+ raise RuntimeError("TUSHARE_TOKEN is not set")
24
+ ts.set_token(_TUSHARE_TOKEN)
25
+ _PRO = ts.pro_api()
26
+ return _PRO
27
+
28
+
29
+ def _normalize_to_ts_code(raw_symbol: str) -> str:
30
  """
31
+ Normalize user `symbol` input to Tushare `ts_code`.
 
32
 
33
  Accepted examples:
34
+ - "603777" -> "603777.SH"
35
+ - "300065" -> "300065.SZ"
36
+ - "430047" -> "430047.BJ"
37
+ - "000063.SZ" -> "000063.SZ"
38
  """
39
  symbol = raw_symbol.strip().upper()
40
  if "." in symbol:
41
+ code, market = symbol.split(".", 1)
42
+ if len(code) != 6 or not code.isdigit() or market not in {"SH", "SZ", "BJ"}:
43
+ raise ValueError(
44
+ f"Invalid symbol {raw_symbol!r}; expected e.g. '603777' or '000063.SZ'."
45
+ )
46
+ return f"{code}.{market}"
47
+
48
  if len(symbol) != 6 or not symbol.isdigit():
49
  raise ValueError(
50
+ f"Invalid symbol {raw_symbol!r}; expected 6 digits or code with suffix."
 
51
  )
 
52
 
53
+ if symbol.startswith("6"):
54
+ market = "SH"
55
+ elif symbol.startswith(("0", "3")):
56
+ market = "SZ"
57
+ elif symbol.startswith(("4", "8")):
58
+ market = "BJ"
59
+ else:
60
+ raise ValueError(f"Cannot infer exchange suffix for symbol={raw_symbol!r}")
61
 
62
+ return f"{symbol}.{market}"
63
+
64
+
65
+ def _retry_call(fn, *, call_name: str):
66
  last_exc: Exception | None = None
67
+ for attempt in range(1, _TS_RETRY_COUNT + 1):
68
  try:
69
  return fn()
70
  except Exception as exc: # pragma: no cover - external IO
71
  last_exc = exc
72
+ if attempt >= _TS_RETRY_COUNT:
73
  break
74
+ time.sleep(_TS_RETRY_BASE_SLEEP * (2 ** (attempt - 1)))
75
  raise RuntimeError(
76
+ f"Tushare call failed after {_TS_RETRY_COUNT} attempts ({call_name}): {last_exc}"
77
  ) from last_exc
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def fetch_stock_data(
81
  symbol: str, lookback: int
82
  ) -> tuple[pd.DataFrame, pd.Series, str]:
 
86
  x_timestamp : pd.Series[datetime], aligned to x_df
87
  last_trade_date: str "YYYYMMDD", the most recent bar date
88
  """
89
+ ts_code = _normalize_to_ts_code(symbol)
90
  end_date = datetime.today().strftime("%Y%m%d")
91
+ # 2x buffer to account for weekends/holidays.
92
+ start_date = (datetime.today() - timedelta(days=lookback * 2)).strftime("%Y%m%d")
93
+
94
+ pro = _get_pro()
95
+ df = _retry_call(
96
+ lambda: ts.pro_bar(
97
+ ts_code=ts_code,
98
+ adj="qfq",
99
  start_date=start_date,
100
  end_date=end_date,
101
+ asset="E",
 
102
  ),
103
+ call_name=f"pro_bar(ts_code={ts_code})",
104
  )
105
 
106
  if df is None or df.empty:
107
+ raise ValueError(f"No data returned for symbol={symbol!r} (ts_code={ts_code})")
108
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  df = df.sort_values("trade_date").reset_index(drop=True)
110
+ df = df.rename(columns={"vol": "volume"})
111
+ df["timestamps"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
112
 
113
  # Keep the most recent `lookback` bars
114
  df = df.tail(lookback).reset_index(drop=True)
115
 
116
  x_df = df[["open", "high", "low", "close", "volume", "amount"]].copy()
117
  x_timestamp = df["timestamps"].copy()
118
+ last_trade_date = str(df["trade_date"].iloc[-1])
119
 
120
  return x_df, x_timestamp, last_trade_date
121
 
 
126
  follow `last_trade_date` (format: YYYYMMDD).
127
  """
128
  last_dt = datetime.strptime(last_trade_date, "%Y%m%d")
129
+ # 3x buffer so we always have enough dates even over a long holiday
130
+ end_dt = last_dt + timedelta(days=pred_len * 3)
131
+
132
+ pro = _get_pro()
133
+ cal = _retry_call(
134
+ lambda: pro.trade_cal(
135
+ exchange="SSE",
136
+ start_date=(last_dt + timedelta(days=1)).strftime("%Y%m%d"),
137
+ end_date=end_dt.strftime("%Y%m%d"),
138
+ is_open="1",
139
+ ),
140
+ call_name="trade_cal(exchange=SSE)",
141
+ )
142
+
143
+ cal = cal.sort_values("cal_date")
144
+ dates = pd.to_datetime(cal["cal_date"].values[:pred_len], format="%Y%m%d")
145
+
146
+ if len(dates) < pred_len:
147
+ raise ValueError(
148
+ f"Could only obtain {len(dates)} future trading dates; "
149
+ f"increase buffer or check Tushare calendar coverage."
150
+ )
151
+
152
+ return pd.Series(dates)
requirements.txt CHANGED
@@ -9,4 +9,4 @@ huggingface_hub==0.33.1
9
  matplotlib==3.9.3
10
  tqdm==4.67.1
11
  safetensors==0.6.2
12
- akshare
 
9
  matplotlib==3.9.3
10
  tqdm==4.67.1
11
  safetensors==0.6.2
12
+ tushare