kronos-api / data_fetcher.py
fengwm
更新 README 和代码,数据源切换为 Tushare,优化请求重试机制
23022b9
from datetime import datetime, timedelta
import os
import threading
import time
import pandas as pd
import tushare as ts
_TUSHARE_TOKEN = os.environ.get("TUSHARE_TOKEN", "").strip()
_TS_RETRY_COUNT = max(1, int(os.environ.get("TS_RETRY_COUNT", "3")))
_TS_RETRY_BASE_SLEEP = float(os.environ.get("TS_RETRY_BASE_SLEEP", "0.8"))
_PRO = None
_PRO_LOCK = threading.Lock()
def _get_pro():
global _PRO
with _PRO_LOCK:
if _PRO is not None:
return _PRO
if not _TUSHARE_TOKEN:
raise RuntimeError("TUSHARE_TOKEN is not set")
ts.set_token(_TUSHARE_TOKEN)
_PRO = ts.pro_api()
return _PRO
def _normalize_to_ts_code(raw_symbol: str) -> str:
"""
Normalize user `symbol` input to Tushare `ts_code`.
Accepted examples:
- "603777" -> "603777.SH"
- "300065" -> "300065.SZ"
- "430047" -> "430047.BJ"
- "000063.SZ" -> "000063.SZ"
"""
symbol = raw_symbol.strip().upper()
if "." in symbol:
code, market = symbol.split(".", 1)
if len(code) != 6 or not code.isdigit() or market not in {"SH", "SZ", "BJ"}:
raise ValueError(
f"Invalid symbol {raw_symbol!r}; expected e.g. '603777' or '000063.SZ'."
)
return f"{code}.{market}"
if len(symbol) != 6 or not symbol.isdigit():
raise ValueError(
f"Invalid symbol {raw_symbol!r}; expected 6 digits or code with suffix."
)
if symbol.startswith("6"):
market = "SH"
elif symbol.startswith(("0", "3")):
market = "SZ"
elif symbol.startswith(("4", "8")):
market = "BJ"
else:
raise ValueError(f"Cannot infer exchange suffix for symbol={raw_symbol!r}")
return f"{symbol}.{market}"
def _retry_call(fn, *, call_name: str):
last_exc: Exception | None = None
for attempt in range(1, _TS_RETRY_COUNT + 1):
try:
return fn()
except Exception as exc: # pragma: no cover - external IO
last_exc = exc
if attempt >= _TS_RETRY_COUNT:
break
time.sleep(_TS_RETRY_BASE_SLEEP * (2 ** (attempt - 1)))
raise RuntimeError(
f"Tushare call failed after {_TS_RETRY_COUNT} attempts ({call_name}): {last_exc}"
) from last_exc
def fetch_stock_data(
symbol: str, lookback: int
) -> tuple[pd.DataFrame, pd.Series, str]:
"""
Returns:
x_df : DataFrame with columns [open, high, low, close, volume, amount]
x_timestamp : pd.Series[datetime], aligned to x_df
last_trade_date: str "YYYYMMDD", the most recent bar date
"""
ts_code = _normalize_to_ts_code(symbol)
end_date = datetime.today().strftime("%Y%m%d")
# 2x buffer to account for weekends/holidays.
start_date = (datetime.today() - timedelta(days=lookback * 2)).strftime("%Y%m%d")
pro = _get_pro()
df = _retry_call(
lambda: ts.pro_bar(
ts_code=ts_code,
adj="qfq",
start_date=start_date,
end_date=end_date,
asset="E",
),
call_name=f"pro_bar(ts_code={ts_code})",
)
if df is None or df.empty:
raise ValueError(f"No data returned for symbol={symbol!r} (ts_code={ts_code})")
df = df.sort_values("trade_date").reset_index(drop=True)
df = df.rename(columns={"vol": "volume"})
df["timestamps"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
# Keep the most recent `lookback` bars
df = df.tail(lookback).reset_index(drop=True)
x_df = df[["open", "high", "low", "close", "volume", "amount"]].copy()
x_timestamp = df["timestamps"].copy()
last_trade_date = str(df["trade_date"].iloc[-1])
return x_df, x_timestamp, last_trade_date
def get_future_trading_dates(last_trade_date: str, pred_len: int) -> pd.Series:
"""
Return a pd.Series of `pred_len` future SSE trading dates (datetime) that
follow `last_trade_date` (format: YYYYMMDD).
"""
last_dt = datetime.strptime(last_trade_date, "%Y%m%d")
# 3x buffer so we always have enough dates even over a long holiday
end_dt = last_dt + timedelta(days=pred_len * 3)
pro = _get_pro()
cal = _retry_call(
lambda: pro.trade_cal(
exchange="SSE",
start_date=(last_dt + timedelta(days=1)).strftime("%Y%m%d"),
end_date=end_dt.strftime("%Y%m%d"),
is_open="1",
),
call_name="trade_cal(exchange=SSE)",
)
cal = cal.sort_values("cal_date")
dates = pd.to_datetime(cal["cal_date"].values[:pred_len], format="%Y%m%d")
if len(dates) < pred_len:
raise ValueError(
f"Could only obtain {len(dates)} future trading dates; "
f"increase buffer or check Tushare calendar coverage."
)
return pd.Series(dates)