fengwm commited on
Commit
2a8a0f5
·
1 Parent(s): 292fb60

更新 README,切换数据源为 AkShare,统一请求字段为 symbol,优化缓存机制,新增性能日志

Browse files
Files changed (5) hide show
  1. README.md +42 -15
  2. app.py +50 -15
  3. data_fetcher.py +103 -38
  4. predictor.py +63 -14
  5. requirements.txt +1 -1
README.md CHANGED
@@ -11,12 +11,22 @@ pinned: false
11
 
12
  基于[清华大学 Kronos 金融 K 线基础大模型](https://arxiv.org/abs/2508.02739)的 A 股概率预测 REST API。
13
 
14
- - **数据源**:Tushare Pro,前复权(qfq)
15
- - **推理方式**:蒙特卡洛多次采样,输出预测方向、置信度及 95% 交易区间
16
  - **异步任务**:POST 提交 → 返回 `task_id` → GET 轮询结果
17
 
18
  ---
19
 
 
 
 
 
 
 
 
 
 
 
20
  ## 模型信息
21
 
22
  | 项目 | 值 |
@@ -38,7 +48,7 @@ pinned: false
38
  curl -X POST "https://yingfeng64-kronos-api.hf.space/api/v1/predict" \
39
  -H "Content-Type: application/json" \
40
  -d '{
41
- "ts_code": "000063.SZ",
42
  "lookback": 512,
43
  "pred_len": 5,
44
  "sample_count": 30,
@@ -66,7 +76,7 @@ import time, requests
66
  BASE = "https://yingfeng64-kronos-api.hf.space"
67
 
68
  resp = requests.post(f"{BASE}/api/v1/predict", json={
69
- "ts_code": "000063.SZ",
70
  "lookback": 512,
71
  "pred_len": 5,
72
  "sample_count": 30,
@@ -92,7 +102,7 @@ print(r["result"])
92
 
93
  | 字段 | 类型 | 默认值 | 范围 | 说明 |
94
  |---|---|---|---|---|
95
- | `ts_code` | string | — | — | Tushare代码, `"000063.SZ"` |
96
  | `lookback` | int | 512 | 20–512 | 回看历史 K 线根数 |
97
  | `pred_len` | int | 5 | 1–60 | 预测未来交易日数(建议 ≤ 30) |
98
  | `sample_count` | int | 30 | 1–100 | MC 蒙特卡洛采样次数 |
@@ -103,7 +113,7 @@ print(r["result"])
103
 
104
  ```json
105
  {
106
- "ts_code": "000063.SZ",
107
  "base_date": "2026-03-13",
108
  "pred_len": 5,
109
  "confidence": 95,
@@ -176,9 +186,9 @@ curl -X POST "https://yingfeng64-kronos-api.hf.space/api/v1/predict/batch" \
176
  -H "Content-Type: application/json" \
177
  -d '{
178
  "requests": [
179
- {"ts_code": "000063.SZ", "pred_len": 5, "sample_count": 30},
180
- {"ts_code": "600900.SH", "pred_len": 5, "sample_count": 30},
181
- {"ts_code": "000001.SZ", "pred_len": 5, "sample_count": 30}
182
  ]
183
  }'
184
  ```
@@ -227,11 +237,11 @@ curl -X POST "https://yingfeng64-kronos-api.hf.space/api/v1/predict/batch" \
227
 
228
  | 参数 | 说明 |
229
  |---|---|
230
- | `ts_code`(可选) | 只返回该股票的缓存,不传则返回全部 |
231
 
232
  ```bash
233
  # 查某只股票
234
- curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache?ts_code=000063.SZ"
235
 
236
  # 查全部
237
  curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache"
@@ -242,7 +252,7 @@ curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache"
242
  "count": 1,
243
  "entries": [
244
  {
245
- "ts_code": "000063.SZ",
246
  "lookback": 512,
247
  "pred_len": 5,
248
  "sample_count": 30,
@@ -271,8 +281,8 @@ curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache"
271
  | 字段 | 含义 |
272
  |---|---|
273
  | `base_date` | 预测所基于的最后一个历史 K 线日期 |
274
- | `direction.signal` | `"bullish"` / `"bearish"`,MC 样本中末日收盘价 > 当前收盘的比例决定 |
275
- | `direction.probability` | 看涨概率(0–1) |
276
  | `trading_low` | 该日预测最低价的 q2.5 分位数(95% 交易区间下沿) |
277
  | `trading_high` | 该日预测最高价的 q97.5 分位数(95% 交易区间上沿) |
278
  | `uncertainty` | `(trading_high − trading_low) / last_close`,无量纲不确定性 |
@@ -284,7 +294,7 @@ curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache"
284
 
285
  ## 缓存机制
286
 
287
- 缓存 key 由 `(ts_code, lookback, pred_len, sample_count, mode, include_volume)` 六元组构成,失效时机为下一个 A 股交易日收盘(15:00 CST)。
288
 
289
  | 请求时间(CST) | 缓存过期时间 |
290
  |---|---|
@@ -297,6 +307,23 @@ curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache"
297
 
298
  ---
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  ## 性能参考
301
 
302
  | 环境 | 单次采样 | 30 次 MC 总耗时 |
 
11
 
12
  基于[清华大学 Kronos 金融 K 线基础大模型](https://arxiv.org/abs/2508.02739)的 A 股概率预测 REST API。
13
 
14
+ - **数据源**:AkShare 东方财富 A 股日线,前复权(qfq)
15
+ - **推理方式**:蒙特卡洛分批采样,输出预测方向、置信度及 95% 交易区间
16
  - **异步任务**:POST 提交 → 返回 `task_id` → GET 轮询结果
17
 
18
  ---
19
 
20
+ ## 更新说明(2026-03-16)
21
+
22
+ - 请求/响应主字段由 `ts_code` 统一为 `symbol`
23
+ - 数据源切换为 AkShare 东财日线接口(`stock_zh_a_hist`,前复权 `qfq`)
24
+ - 方向概率 `direction.probability` 定义为“预测区间内看涨概率(0–1)”
25
+ - 推理路径改为分批采样(可通过 `MC_BATCH_SIZE` 调整批大小)
26
+ - 新增阶段耗时日志(`fetch/calendar/infer/build/cache/total`)
27
+
28
+ ---
29
+
30
  ## 模型信息
31
 
32
  | 项目 | 值 |
 
48
  curl -X POST "https://yingfeng64-kronos-api.hf.space/api/v1/predict" \
49
  -H "Content-Type: application/json" \
50
  -d '{
51
+ "symbol": "000063.SZ",
52
  "lookback": 512,
53
  "pred_len": 5,
54
  "sample_count": 30,
 
76
  BASE = "https://yingfeng64-kronos-api.hf.space"
77
 
78
  resp = requests.post(f"{BASE}/api/v1/predict", json={
79
+ "symbol": "000063.SZ",
80
  "lookback": 512,
81
  "pred_len": 5,
82
  "sample_count": 30,
 
102
 
103
  | 字段 | 类型 | 默认值 | 范围 | 说明 |
104
  |---|---|---|---|---|
105
+ | `symbol` | string | — | — | A 股代码,支持 `"603777"` 或 `"000063.SZ"` |
106
  | `lookback` | int | 512 | 20–512 | 回看历史 K 线根数 |
107
  | `pred_len` | int | 5 | 1–60 | 预测未来交易日数(建议 ≤ 30) |
108
  | `sample_count` | int | 30 | 1–100 | MC 蒙特卡洛采样次数 |
 
113
 
114
  ```json
115
  {
116
+ "symbol": "000063.SZ",
117
  "base_date": "2026-03-13",
118
  "pred_len": 5,
119
  "confidence": 95,
 
186
  -H "Content-Type: application/json" \
187
  -d '{
188
  "requests": [
189
+ {"symbol": "000063", "pred_len": 5, "sample_count": 30},
190
+ {"symbol": "600900", "pred_len": 5, "sample_count": 30},
191
+ {"symbol": "000001", "pred_len": 5, "sample_count": 30}
192
  ]
193
  }'
194
  ```
 
237
 
238
  | 参数 | 说明 |
239
  |---|---|
240
+ | `symbol`(可选) | 只返回该股票的缓存,不传则返回全部 |
241
 
242
  ```bash
243
  # 查某只股票
244
+ curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache?symbol=000063"
245
 
246
  # 查全部
247
  curl "https://yingfeng64-kronos-api.hf.space/api/v1/cache"
 
252
  "count": 1,
253
  "entries": [
254
  {
255
+ "symbol": "000063",
256
  "lookback": 512,
257
  "pred_len": 5,
258
  "sample_count": 30,
 
281
  | 字段 | 含义 |
282
  |---|---|
283
  | `base_date` | 预测所基于的最后一个历史 K 线日期 |
284
+ | `direction.signal` | `"bullish"` / `"bearish"`, `direction.probability >= 0.5` 决定 |
285
+ | `direction.probability` | 预测区间内看涨概率(0–1) |
286
  | `trading_low` | 该日预测最低价的 q2.5 分位数(95% 交易区间下沿) |
287
  | `trading_high` | 该日预测最高价的 q97.5 分位数(95% 交易区间上沿) |
288
  | `uncertainty` | `(trading_high − trading_low) / last_close`,无量纲不确定性 |
 
294
 
295
  ## 缓存机制
296
 
297
+ 缓存 key 由 `(symbol, lookback, pred_len, sample_count, mode, include_volume)` 六元组构成,失效时机为下一个 A 股交易日收盘(15:00 CST)。
298
 
299
  | 请求时间(CST) | 缓存过期时间 |
300
  |---|---|
 
307
 
308
  ---
309
 
310
+ ## 运行配置
311
+
312
+ | 环境变量 | 默认值 | 说明 |
313
+ |---|---|---|
314
+ | `KRONOS_DIR` | `/app/Kronos` | Kronos 源码目录 |
315
+ | `MC_BATCH_SIZE` | `8` | 蒙特卡洛分批采样大小(越大通常越快,但占用显存/内存更高) |
316
+
317
+ ---
318
+
319
+ ## 可观测性
320
+
321
+ 服务会在 `INFO` 日志输出预测阶段耗时,示例:
322
+
323
+ ```text
324
+ Task <task_id> timing symbol=300065.SZ fetch=...ms calendar=...ms infer=...ms build=...ms cache=...ms total=...ms
325
+ ```
326
+
327
  ## 性能参考
328
 
329
  | 环境 | 单次采样 | 30 次 MC 总耗时 |
app.py CHANGED
@@ -15,6 +15,7 @@ import uuid
15
  from concurrent.futures import ThreadPoolExecutor
16
  from contextlib import asynccontextmanager
17
  from datetime import datetime, time, timedelta, timezone
 
18
  from typing import Literal, List
19
 
20
  import pandas as pd
@@ -38,8 +39,8 @@ def _next_cache_expiry() -> datetime:
38
  Return the UTC datetime of the NEXT A-share market close (15:00 CST on a
39
  weekday), which is when new candle data becomes available and the cache
40
  should be invalidated.
41
- Chinese public holidays are intentionally ignored: on those days Tushare
42
- returns the same last bar, so a cache hit is harmless.
43
  """
44
  now_cst = datetime.now(_CST)
45
  today_close = now_cst.replace(hour=15, minute=0, second=0, microsecond=0)
@@ -58,13 +59,13 @@ def _next_cache_expiry() -> datetime:
58
 
59
 
60
  # ── Result cache ──────────────────────────────────────────────────────────────
61
- # key : (ts_code, lookback, pred_len, sample_count, mode, include_volume)
62
  # value : {"result": dict, "expires_at": datetime(UTC), "cached_at": datetime(UTC)}
63
  _cache: dict[tuple, dict] = {}
64
 
65
 
66
  def _cache_key(req: "PredictRequest") -> tuple:
67
- return (req.ts_code, req.lookback, req.pred_len,
68
  req.sample_count, req.mode, req.include_volume)
69
 
70
 
@@ -84,7 +85,7 @@ def _set_cache(req: "PredictRequest", result: dict) -> None:
84
  }
85
  logger.info(
86
  "Cached %s, expires at %s CST",
87
- req.ts_code,
88
  _cache[_cache_key(req)]["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M"),
89
  )
90
 
@@ -124,7 +125,11 @@ app.add_middleware(
124
 
125
  # ── Request / Response schemas ────────────────────────────────────────────────
126
  class PredictRequest(BaseModel):
127
- ts_code: str = Field(..., examples=["600900.SH"], description="Tushare 股票代码")
 
 
 
 
128
  lookback: int = Field(
129
  default=512,
130
  ge=20,
@@ -180,7 +185,7 @@ def _build_response(req: PredictRequest, base_date: str, pred_mean, ci,
180
  bands.append(band)
181
 
182
  result: dict = {
183
- "ts_code": req.ts_code,
184
  "base_date": base_date,
185
  "pred_len": req.pred_len,
186
  "confidence": 95,
@@ -214,12 +219,18 @@ def _build_response(req: PredictRequest, base_date: str, pred_mean, ci,
214
 
215
  # ── Background task ───────────────────────────────────────────────────────────
216
  def _run_prediction(task_id: str, req: PredictRequest) -> None:
 
217
  try:
218
  # ── Cache check ───────────────────────────────────────────────────────
219
  cache_entry = _get_cached(req)
220
  if cache_entry is not None:
221
- logger.info("Cache hit for %s (expires %s CST)", req.ts_code,
222
- cache_entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M"))
 
 
 
 
 
223
  _tasks[task_id] = {
224
  "status": "done",
225
  "result": {**cache_entry["result"], "cached": True,
@@ -229,26 +240,37 @@ def _run_prediction(task_id: str, req: PredictRequest) -> None:
229
  return
230
 
231
  # ── Full inference ────────────────────────────────────────────────────
 
232
  x_df, x_timestamp, last_trade_date = data_fetcher.fetch_stock_data(
233
- req.ts_code, req.lookback
234
  )
 
 
 
235
  y_timestamp = data_fetcher.get_future_trading_dates(last_trade_date, req.pred_len)
 
236
 
 
237
  pred_mean, ci, trading_low, trading_high, direction_prob, last_close = (
238
  pred_module.run_mc_prediction(
239
  x_df, x_timestamp, y_timestamp, req.pred_len, req.sample_count
240
  )
241
  )
 
242
 
 
243
  base_date = str(pd.to_datetime(last_trade_date, format="%Y%m%d").date())
244
  result = _build_response(
245
  req, base_date, pred_mean, ci,
246
  trading_low, trading_high, direction_prob, last_close, y_timestamp,
247
  )
 
248
 
249
  # ── Store in cache ────────────────────────────────────────────────────
 
250
  _set_cache(req, result)
251
  cache_entry = _cache[_cache_key(req)]
 
252
 
253
  _tasks[task_id] = {
254
  "status": "done",
@@ -256,8 +278,21 @@ def _run_prediction(task_id: str, req: PredictRequest) -> None:
256
  "cache_expires_at": cache_entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M:%S %Z")},
257
  "error": None,
258
  }
 
 
 
 
 
 
 
 
 
 
 
 
259
  except Exception as exc:
260
- logger.exception("Task %s failed", task_id)
 
261
  _tasks[task_id] = {"status": "failed", "result": None, "error": str(exc)}
262
 
263
 
@@ -297,22 +332,22 @@ async def get_predict_result(task_id: str):
297
 
298
 
299
  @app.get("/api/v1/cache", summary="查看缓存状态")
300
- async def get_cache(ts_code: str | None = None):
301
  """
302
  列出有效的缓存条目及其过期时间。
303
 
304
  - 不传参数:返回全部
305
- - `?ts_code=000063.SZ`:只返回该股票的所有参数组合
306
  """
307
  now_utc = datetime.now(timezone.utc)
308
  entries = []
309
  for key, entry in _cache.items():
310
- if ts_code and key[0] != ts_code:
311
  continue
312
  remaining = (entry["expires_at"] - now_utc).total_seconds()
313
  if remaining > 0:
314
  entries.append({
315
- "ts_code": key[0],
316
  "lookback": key[1],
317
  "pred_len": key[2],
318
  "sample_count": key[3],
 
15
  from concurrent.futures import ThreadPoolExecutor
16
  from contextlib import asynccontextmanager
17
  from datetime import datetime, time, timedelta, timezone
18
+ from time import perf_counter
19
  from typing import Literal, List
20
 
21
  import pandas as pd
 
39
  Return the UTC datetime of the NEXT A-share market close (15:00 CST on a
40
  weekday), which is when new candle data becomes available and the cache
41
  should be invalidated.
42
+ Chinese public holidays are intentionally ignored: on those days market
43
+ data does not advance, so a cache hit is harmless.
44
  """
45
  now_cst = datetime.now(_CST)
46
  today_close = now_cst.replace(hour=15, minute=0, second=0, microsecond=0)
 
59
 
60
 
61
  # ── Result cache ──────────────────────────────────────────────────────────────
62
+ # key : (symbol, lookback, pred_len, sample_count, mode, include_volume)
63
  # value : {"result": dict, "expires_at": datetime(UTC), "cached_at": datetime(UTC)}
64
  _cache: dict[tuple, dict] = {}
65
 
66
 
67
  def _cache_key(req: "PredictRequest") -> tuple:
68
+ return (req.symbol, req.lookback, req.pred_len,
69
  req.sample_count, req.mode, req.include_volume)
70
 
71
 
 
85
  }
86
  logger.info(
87
  "Cached %s, expires at %s CST",
88
+ req.symbol,
89
  _cache[_cache_key(req)]["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M"),
90
  )
91
 
 
125
 
126
  # ── Request / Response schemas ────────────────────────────────────────────────
127
  class PredictRequest(BaseModel):
128
+ symbol: str = Field(
129
+ ...,
130
+ examples=["603777", "600900.SH"],
131
+ description="A 股代码;支持 6 位代码或带市场后缀(如 600900.SH)",
132
+ )
133
  lookback: int = Field(
134
  default=512,
135
  ge=20,
 
185
  bands.append(band)
186
 
187
  result: dict = {
188
+ "symbol": req.symbol,
189
  "base_date": base_date,
190
  "pred_len": req.pred_len,
191
  "confidence": 95,
 
219
 
220
  # ── Background task ───────────────────────────────────────────────────────────
221
  def _run_prediction(task_id: str, req: PredictRequest) -> None:
222
+ t_total_start = perf_counter()
223
  try:
224
  # ── Cache check ───────────────────────────────────────────────────────
225
  cache_entry = _get_cached(req)
226
  if cache_entry is not None:
227
+ total_ms = (perf_counter() - t_total_start) * 1000
228
+ logger.info(
229
+ "Cache hit for %s (expires %s CST, total=%.1fms)",
230
+ req.symbol,
231
+ cache_entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M"),
232
+ total_ms,
233
+ )
234
  _tasks[task_id] = {
235
  "status": "done",
236
  "result": {**cache_entry["result"], "cached": True,
 
240
  return
241
 
242
  # ── Full inference ────────────────────────────────────────────────────
243
+ t_fetch_start = perf_counter()
244
  x_df, x_timestamp, last_trade_date = data_fetcher.fetch_stock_data(
245
+ req.symbol, req.lookback
246
  )
247
+ fetch_ms = (perf_counter() - t_fetch_start) * 1000
248
+
249
+ t_calendar_start = perf_counter()
250
  y_timestamp = data_fetcher.get_future_trading_dates(last_trade_date, req.pred_len)
251
+ calendar_ms = (perf_counter() - t_calendar_start) * 1000
252
 
253
+ t_infer_start = perf_counter()
254
  pred_mean, ci, trading_low, trading_high, direction_prob, last_close = (
255
  pred_module.run_mc_prediction(
256
  x_df, x_timestamp, y_timestamp, req.pred_len, req.sample_count
257
  )
258
  )
259
+ infer_ms = (perf_counter() - t_infer_start) * 1000
260
 
261
+ t_build_start = perf_counter()
262
  base_date = str(pd.to_datetime(last_trade_date, format="%Y%m%d").date())
263
  result = _build_response(
264
  req, base_date, pred_mean, ci,
265
  trading_low, trading_high, direction_prob, last_close, y_timestamp,
266
  )
267
+ build_ms = (perf_counter() - t_build_start) * 1000
268
 
269
  # ── Store in cache ────────────────────────────────────────────────────
270
+ t_cache_start = perf_counter()
271
  _set_cache(req, result)
272
  cache_entry = _cache[_cache_key(req)]
273
+ cache_ms = (perf_counter() - t_cache_start) * 1000
274
 
275
  _tasks[task_id] = {
276
  "status": "done",
 
278
  "cache_expires_at": cache_entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M:%S %Z")},
279
  "error": None,
280
  }
281
+ total_ms = (perf_counter() - t_total_start) * 1000
282
+ logger.info(
283
+ "Task %s timing symbol=%s fetch=%.1fms calendar=%.1fms infer=%.1fms build=%.1fms cache=%.1fms total=%.1fms",
284
+ task_id,
285
+ req.symbol,
286
+ fetch_ms,
287
+ calendar_ms,
288
+ infer_ms,
289
+ build_ms,
290
+ cache_ms,
291
+ total_ms,
292
+ )
293
  except Exception as exc:
294
+ total_ms = (perf_counter() - t_total_start) * 1000
295
+ logger.exception("Task %s failed after %.1fms", task_id, total_ms)
296
  _tasks[task_id] = {"status": "failed", "result": None, "error": str(exc)}
297
 
298
 
 
332
 
333
 
334
  @app.get("/api/v1/cache", summary="查看缓存状态")
335
+ async def get_cache(symbol: str | None = None):
336
  """
337
  列出有效的缓存条目及其过期时间。
338
 
339
  - 不传参数:返回全部
340
+ - `?symbol=000063.SZ`:只返回该股票的所有参数组合
341
  """
342
  now_utc = datetime.now(timezone.utc)
343
  entries = []
344
  for key, entry in _cache.items():
345
+ if symbol and key[0] != symbol:
346
  continue
347
  remaining = (entry["expires_at"] - now_utc).total_seconds()
348
  if remaining > 0:
349
  entries.append({
350
+ "symbol": key[0],
351
  "lookback": key[1],
352
  "pred_len": key[2],
353
  "sample_count": key[3],
data_fetcher.py CHANGED
@@ -1,19 +1,66 @@
1
- import os
2
  from datetime import datetime, timedelta
 
3
 
 
4
  import pandas as pd
5
- import tushare as ts
6
 
7
- TUSHARE_TOKEN = os.environ.get(
8
- "TUSHARE_TOKEN",
9
- )
 
10
 
11
- ts.set_token(TUSHARE_TOKEN)
12
- _pro = ts.pro_api()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  def fetch_stock_data(
16
- ts_code: str, lookback: int
17
  ) -> tuple[pd.DataFrame, pd.Series, str]:
18
  """
19
  Returns:
@@ -21,31 +68,51 @@ def fetch_stock_data(
21
  x_timestamp : pd.Series[datetime], aligned to x_df
22
  last_trade_date: str "YYYYMMDD", the most recent bar date
23
  """
 
24
  end_date = datetime.today().strftime("%Y%m%d")
25
- # buffer to account for weekends/holidays
26
- start_date = (datetime.today() - timedelta(days=lookback * 2)).strftime("%Y%m%d")
27
 
28
- df = ts.pro_bar(
29
- ts_code=ts_code,
30
- adj="qfq",
31
  start_date=start_date,
32
  end_date=end_date,
33
- asset="E",
34
  )
35
 
36
  if df is None or df.empty:
37
- raise ValueError(f"No data returned for ts_code={ts_code!r}")
38
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  df = df.sort_values("trade_date").reset_index(drop=True)
40
- df = df.rename(columns={"vol": "volume"})
41
- df["timestamps"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
42
 
43
  # Keep the most recent `lookback` bars
44
  df = df.tail(lookback).reset_index(drop=True)
45
 
46
  x_df = df[["open", "high", "low", "close", "volume", "amount"]].copy()
47
  x_timestamp = df["timestamps"].copy()
48
- last_trade_date = df["trade_date"].iloc[-1]
49
 
50
  return x_df, x_timestamp, last_trade_date
51
 
@@ -56,22 +123,20 @@ def get_future_trading_dates(last_trade_date: str, pred_len: int) -> pd.Series:
56
  follow `last_trade_date` (format: YYYYMMDD).
57
  """
58
  last_dt = datetime.strptime(last_trade_date, "%Y%m%d")
59
- # 3× buffer so we always have enough dates even over a long holiday
60
- end_dt = last_dt + timedelta(days=pred_len * 3)
61
-
62
- cal = _pro.trade_cal(
63
- exchange="SSE",
64
- start_date=(last_dt + timedelta(days=1)).strftime("%Y%m%d"),
65
- end_date=end_dt.strftime("%Y%m%d"),
66
- is_open="1",
67
- )
68
- cal = cal.sort_values("cal_date")
69
- dates = pd.to_datetime(cal["cal_date"].values[:pred_len], format="%Y%m%d")
70
-
71
- if len(dates) < pred_len:
72
- raise ValueError(
73
- f"Could only obtain {len(dates)} future trading dates; "
74
- f"increase buffer or check Tushare calendar coverage."
75
- )
76
-
77
- return pd.Series(dates)
 
 
1
  from datetime import datetime, timedelta
2
+ import threading
3
 
4
+ import akshare as ak
5
  import pandas as pd
 
6
 
7
+ _TRADE_CALENDAR_CACHE: pd.DatetimeIndex | None = None
8
+ _TRADE_CALENDAR_CACHED_AT: datetime | None = None
9
+ _TRADE_CALENDAR_CACHE_TTL = timedelta(hours=12)
10
+ _TRADE_CALENDAR_LOCK = threading.Lock()
11
 
12
+
13
+ def _normalize_symbol(raw_symbol: str) -> str:
14
+ """
15
+ Convert user input into the 6-digit stock code expected by
16
+ `ak.stock_zh_a_hist`.
17
+
18
+ Accepted examples:
19
+ - "603777"
20
+ - "600900.SH"
21
+ - "000063.SZ"
22
+ """
23
+ symbol = raw_symbol.strip().upper()
24
+ if "." in symbol:
25
+ symbol = symbol.split(".", 1)[0]
26
+ if len(symbol) != 6 or not symbol.isdigit():
27
+ raise ValueError(
28
+ f"Invalid stock code {raw_symbol!r}; expected 6 digits like '603777' "
29
+ "or Tushare-style code like '600900.SH'."
30
+ )
31
+ return symbol
32
+
33
+
34
+ def _get_trade_calendar_cached() -> pd.DatetimeIndex:
35
+ """
36
+ Fetch and cache exchange trading dates in-process to avoid repeated
37
+ network calls on each request.
38
+ """
39
+ global _TRADE_CALENDAR_CACHE, _TRADE_CALENDAR_CACHED_AT
40
+
41
+ now = datetime.now()
42
+ with _TRADE_CALENDAR_LOCK:
43
+ if (
44
+ _TRADE_CALENDAR_CACHE is not None
45
+ and _TRADE_CALENDAR_CACHED_AT is not None
46
+ and (now - _TRADE_CALENDAR_CACHED_AT) < _TRADE_CALENDAR_CACHE_TTL
47
+ ):
48
+ return _TRADE_CALENDAR_CACHE
49
+
50
+ cal = ak.tool_trade_date_hist_sina()
51
+ cal_col = "trade_date" if "trade_date" in cal.columns else "日期"
52
+ all_dates = pd.to_datetime(cal[cal_col]).sort_values().drop_duplicates()
53
+ cached = pd.DatetimeIndex(all_dates)
54
+
55
+ with _TRADE_CALENDAR_LOCK:
56
+ _TRADE_CALENDAR_CACHE = cached
57
+ _TRADE_CALENDAR_CACHED_AT = now
58
+
59
+ return cached
60
 
61
 
62
  def fetch_stock_data(
63
+ symbol: str, lookback: int
64
  ) -> tuple[pd.DataFrame, pd.Series, str]:
65
  """
66
  Returns:
 
68
  x_timestamp : pd.Series[datetime], aligned to x_df
69
  last_trade_date: str "YYYYMMDD", the most recent bar date
70
  """
71
+ normalized_symbol = _normalize_symbol(symbol)
72
  end_date = datetime.today().strftime("%Y%m%d")
73
+ # 4x buffer to account for weekends/long holidays.
74
+ start_date = (datetime.today() - timedelta(days=lookback * 4)).strftime("%Y%m%d")
75
 
76
+ df = ak.stock_zh_a_hist(
77
+ symbol=normalized_symbol,
78
+ period="daily",
79
  start_date=start_date,
80
  end_date=end_date,
81
+ adjust="qfq",
82
  )
83
 
84
  if df is None or df.empty:
85
+ raise ValueError(f"No data returned for symbol={symbol!r}")
86
+
87
+ df = df.rename(
88
+ columns={
89
+ "日期": "trade_date",
90
+ "开盘": "open",
91
+ "最高": "high",
92
+ "最低": "low",
93
+ "收盘": "close",
94
+ "成交量": "volume",
95
+ "成交额": "amount",
96
+ }
97
+ )
98
+ required_cols = ["trade_date", "open", "high", "low", "close", "volume", "amount"]
99
+ missing = [c for c in required_cols if c not in df.columns]
100
+ if missing:
101
+ raise ValueError(f"AkShare response missing columns: {missing}")
102
+
103
+ df["trade_date"] = pd.to_datetime(df["trade_date"])
104
+ for col in ["open", "high", "low", "close", "volume", "amount"]:
105
+ df[col] = pd.to_numeric(df[col], errors="coerce")
106
+ df = df.dropna(subset=["trade_date", "open", "high", "low", "close", "volume", "amount"])
107
  df = df.sort_values("trade_date").reset_index(drop=True)
108
+ df["timestamps"] = df["trade_date"]
 
109
 
110
  # Keep the most recent `lookback` bars
111
  df = df.tail(lookback).reset_index(drop=True)
112
 
113
  x_df = df[["open", "high", "low", "close", "volume", "amount"]].copy()
114
  x_timestamp = df["timestamps"].copy()
115
+ last_trade_date = df["trade_date"].iloc[-1].strftime("%Y%m%d")
116
 
117
  return x_df, x_timestamp, last_trade_date
118
 
 
123
  follow `last_trade_date` (format: YYYYMMDD).
124
  """
125
  last_dt = datetime.strptime(last_trade_date, "%Y%m%d")
126
+ dates: list[pd.Timestamp] = []
127
+
128
+ # Prefer real exchange trade dates from AkShare.
129
+ try:
130
+ all_dates = _get_trade_calendar_cached()
131
+ dates.extend([d for d in all_dates if d > pd.Timestamp(last_dt)][:pred_len])
132
+ except Exception:
133
+ # If calendar fetch fails, fall back to weekday-based dates.
134
+ pass
135
+
136
+ candidate = last_dt + timedelta(days=1)
137
+ while len(dates) < pred_len:
138
+ if candidate.weekday() < 5:
139
+ dates.append(pd.Timestamp(candidate))
140
+ candidate += timedelta(days=1)
141
+
142
+ return pd.Series(pd.DatetimeIndex(dates[:pred_len]))
 
 
predictor.py CHANGED
@@ -24,6 +24,7 @@ logger = logging.getLogger(__name__)
24
  KRONOS_DIR = os.environ.get("KRONOS_DIR", "/app/Kronos")
25
  MODEL_ID = "NeoQuasar/Kronos-base"
26
  TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-base"
 
27
 
28
 
29
  # ── Bootstrap Kronos source ──────────────────────────────────────────────────
@@ -66,6 +67,37 @@ def get_predictor() -> KronosPredictor:
66
  return _predictor
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  # ── Monte-Carlo prediction ────────────────────────────────────────────────────
70
  def run_mc_prediction(
71
  x_df: pd.DataFrame,
@@ -83,45 +115,62 @@ def run_mc_prediction(
83
  ci : dict[field]["low"/"high"] → ndarray(pred_len,), 95% CI
84
  trading_low : ndarray(pred_len,), q2.5 of predicted_low
85
  trading_high : ndarray(pred_len,), q97.5 of predicted_high
86
- direction_prob : float ∈ [0,1], fraction of samples where final close > last close
87
  last_close : float, closing price of the last historical bar
88
  """
89
  predictor = get_predictor()
90
- samples = []
 
 
91
 
92
- for _ in range(sample_count):
 
93
  with _infer_lock:
94
- s = predictor.predict(
95
  df=x_df,
96
  x_timestamp=x_timestamp,
97
  y_timestamp=y_timestamp,
98
  pred_len=pred_len,
99
  T=0.8,
100
  top_p=0.9,
101
- sample_count=1,
102
  verbose=False,
103
  )
104
- samples.append(s)
 
 
 
 
 
 
 
 
 
 
105
 
106
  pred_mean = pd.concat(samples).groupby(level=0).mean()
107
-
108
- def stack(field: str) -> np.ndarray:
109
- return np.stack([s[field].values for s in samples]) # (sample_count, pred_len)
 
110
 
111
  alpha = 2.5 # → 95 % CI
112
  ci = {
113
  field: {
114
- "low": np.percentile(stack(field), alpha, axis=0),
115
- "high": np.percentile(stack(field), 100 - alpha, axis=0),
116
  }
117
- for field in ["open", "high", "low", "close", "volume"]
118
  }
119
 
120
  trading_low = ci["low"]["low"] # q2.5 of the predicted daily low
121
  trading_high = ci["high"]["high"] # q97.5 of the predicted daily high
122
 
123
  last_close = float(x_df["close"].iloc[-1])
124
- bull_count = sum(float(s["close"].iloc[-1]) > last_close for s in samples)
125
- direction_prob = bull_count / sample_count
 
 
 
126
 
127
  return pred_mean, ci, trading_low, trading_high, direction_prob, last_close
 
24
  KRONOS_DIR = os.environ.get("KRONOS_DIR", "/app/Kronos")
25
  MODEL_ID = "NeoQuasar/Kronos-base"
26
  TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-base"
27
+ MC_BATCH_SIZE = max(1, int(os.environ.get("MC_BATCH_SIZE", "8")))
28
 
29
 
30
  # ── Bootstrap Kronos source ──────────────────────────────────────────────────
 
67
  return _predictor
68
 
69
 
70
+ def _split_batched_output(
71
+ pred_output,
72
+ expected_count: int,
73
+ pred_len: int,
74
+ ) -> list[pd.DataFrame]:
75
+ """
76
+ Normalize predictor output into `expected_count` DataFrame samples.
77
+ Supports single-sample DataFrame and common batched return shapes.
78
+ """
79
+ if isinstance(pred_output, pd.DataFrame):
80
+ if expected_count == 1:
81
+ return [pred_output]
82
+ if isinstance(pred_output.index, pd.MultiIndex):
83
+ grouped = [g.droplevel(0) for _, g in pred_output.groupby(level=0, sort=False)]
84
+ if len(grouped) == expected_count:
85
+ return grouped
86
+ if len(pred_output) == expected_count * pred_len:
87
+ return [
88
+ pred_output.iloc[i * pred_len:(i + 1) * pred_len].copy()
89
+ for i in range(expected_count)
90
+ ]
91
+ if isinstance(pred_output, (list, tuple)):
92
+ if len(pred_output) == expected_count and all(
93
+ isinstance(item, pd.DataFrame) for item in pred_output
94
+ ):
95
+ return list(pred_output)
96
+ if expected_count == 1 and len(pred_output) == 1 and isinstance(pred_output[0], pd.DataFrame):
97
+ return [pred_output[0]]
98
+ raise ValueError("Unsupported predict() output format for batched sampling")
99
+
100
+
101
  # ── Monte-Carlo prediction ────────────────────────────────────────────────────
102
  def run_mc_prediction(
103
  x_df: pd.DataFrame,
 
115
  ci : dict[field]["low"/"high"] → ndarray(pred_len,), 95% CI
116
  trading_low : ndarray(pred_len,), q2.5 of predicted_low
117
  trading_high : ndarray(pred_len,), q97.5 of predicted_high
118
+ direction_prob : float ∈ [0,1], horizon-level bullish probability
119
  last_close : float, closing price of the last historical bar
120
  """
121
  predictor = get_predictor()
122
+ samples: list[pd.DataFrame] = []
123
+ supports_batched_sampling = True
124
+ remaining = sample_count
125
 
126
+ while remaining > 0:
127
+ batch_n = min(remaining, MC_BATCH_SIZE if supports_batched_sampling else 1)
128
  with _infer_lock:
129
+ pred_output = predictor.predict(
130
  df=x_df,
131
  x_timestamp=x_timestamp,
132
  y_timestamp=y_timestamp,
133
  pred_len=pred_len,
134
  T=0.8,
135
  top_p=0.9,
136
+ sample_count=batch_n,
137
  verbose=False,
138
  )
139
+ try:
140
+ batch_samples = _split_batched_output(pred_output, batch_n, pred_len)
141
+ except ValueError:
142
+ if batch_n > 1:
143
+ # Fallback for predictor implementations that do not support
144
+ # returning per-sample outputs for sample_count>1.
145
+ supports_batched_sampling = False
146
+ continue
147
+ raise
148
+ samples.extend(batch_samples)
149
+ remaining -= batch_n
150
 
151
  pred_mean = pd.concat(samples).groupby(level=0).mean()
152
+ stacked = {
153
+ field: np.stack([s[field].values for s in samples]) # (sample_count, pred_len)
154
+ for field in ["open", "high", "low", "close", "volume"]
155
+ }
156
 
157
  alpha = 2.5 # → 95 % CI
158
  ci = {
159
  field: {
160
+ "low": np.percentile(stacked[field], alpha, axis=0),
161
+ "high": np.percentile(stacked[field], 100 - alpha, axis=0),
162
  }
163
+ for field in stacked
164
  }
165
 
166
  trading_low = ci["low"]["low"] # q2.5 of the predicted daily low
167
  trading_high = ci["high"]["high"] # q97.5 of the predicted daily high
168
 
169
  last_close = float(x_df["close"].iloc[-1])
170
+ close_paths = stacked["close"] # (sample_count, pred_len)
171
+ # Use all future points to estimate horizon bullish probability.
172
+ bull_count = int((close_paths > last_close).sum())
173
+ total_points = int(close_paths.size)
174
+ direction_prob = bull_count / total_points
175
 
176
  return pred_mean, ci, trading_low, trading_high, direction_prob, last_close
requirements.txt CHANGED
@@ -9,4 +9,4 @@ huggingface_hub==0.33.1
9
  matplotlib==3.9.3
10
  tqdm==4.67.1
11
  safetensors==0.6.2
12
- tushare
 
9
  matplotlib==3.9.3
10
  tqdm==4.67.1
11
  safetensors==0.6.2
12
+ akshare