from datetime import datetime, timedelta import os import threading import time import pandas as pd import tushare as ts _TUSHARE_TOKEN = os.environ.get("TUSHARE_TOKEN", "").strip() _TS_RETRY_COUNT = max(1, int(os.environ.get("TS_RETRY_COUNT", "3"))) _TS_RETRY_BASE_SLEEP = float(os.environ.get("TS_RETRY_BASE_SLEEP", "0.8")) _PRO = None _PRO_LOCK = threading.Lock() def _get_pro(): global _PRO with _PRO_LOCK: if _PRO is not None: return _PRO if not _TUSHARE_TOKEN: raise RuntimeError("TUSHARE_TOKEN is not set") ts.set_token(_TUSHARE_TOKEN) _PRO = ts.pro_api() return _PRO def _normalize_to_ts_code(raw_symbol: str) -> str: """ Normalize user `symbol` input to Tushare `ts_code`. Accepted examples: - "603777" -> "603777.SH" - "300065" -> "300065.SZ" - "430047" -> "430047.BJ" - "000063.SZ" -> "000063.SZ" """ symbol = raw_symbol.strip().upper() if "." in symbol: code, market = symbol.split(".", 1) if len(code) != 6 or not code.isdigit() or market not in {"SH", "SZ", "BJ"}: raise ValueError( f"Invalid symbol {raw_symbol!r}; expected e.g. '603777' or '000063.SZ'." ) return f"{code}.{market}" if len(symbol) != 6 or not symbol.isdigit(): raise ValueError( f"Invalid symbol {raw_symbol!r}; expected 6 digits or code with suffix." ) if symbol.startswith("6"): market = "SH" elif symbol.startswith(("0", "3")): market = "SZ" elif symbol.startswith(("4", "8")): market = "BJ" else: raise ValueError(f"Cannot infer exchange suffix for symbol={raw_symbol!r}") return f"{symbol}.{market}" def _retry_call(fn, *, call_name: str): last_exc: Exception | None = None for attempt in range(1, _TS_RETRY_COUNT + 1): try: return fn() except Exception as exc: # pragma: no cover - external IO last_exc = exc if attempt >= _TS_RETRY_COUNT: break time.sleep(_TS_RETRY_BASE_SLEEP * (2 ** (attempt - 1))) raise RuntimeError( f"Tushare call failed after {_TS_RETRY_COUNT} attempts ({call_name}): {last_exc}" ) from last_exc def fetch_stock_data( symbol: str, lookback: int ) -> tuple[pd.DataFrame, pd.Series, str]: """ Returns: x_df : DataFrame with columns [open, high, low, close, volume, amount] x_timestamp : pd.Series[datetime], aligned to x_df last_trade_date: str "YYYYMMDD", the most recent bar date """ ts_code = _normalize_to_ts_code(symbol) end_date = datetime.today().strftime("%Y%m%d") # 2x buffer to account for weekends/holidays. start_date = (datetime.today() - timedelta(days=lookback * 2)).strftime("%Y%m%d") pro = _get_pro() df = _retry_call( lambda: ts.pro_bar( ts_code=ts_code, adj="qfq", start_date=start_date, end_date=end_date, asset="E", ), call_name=f"pro_bar(ts_code={ts_code})", ) if df is None or df.empty: raise ValueError(f"No data returned for symbol={symbol!r} (ts_code={ts_code})") df = df.sort_values("trade_date").reset_index(drop=True) df = df.rename(columns={"vol": "volume"}) df["timestamps"] = pd.to_datetime(df["trade_date"], format="%Y%m%d") # Keep the most recent `lookback` bars df = df.tail(lookback).reset_index(drop=True) x_df = df[["open", "high", "low", "close", "volume", "amount"]].copy() x_timestamp = df["timestamps"].copy() last_trade_date = str(df["trade_date"].iloc[-1]) return x_df, x_timestamp, last_trade_date def get_future_trading_dates(last_trade_date: str, pred_len: int) -> pd.Series: """ Return a pd.Series of `pred_len` future SSE trading dates (datetime) that follow `last_trade_date` (format: YYYYMMDD). """ last_dt = datetime.strptime(last_trade_date, "%Y%m%d") # 3x buffer so we always have enough dates even over a long holiday end_dt = last_dt + timedelta(days=pred_len * 3) pro = _get_pro() cal = _retry_call( lambda: pro.trade_cal( exchange="SSE", start_date=(last_dt + timedelta(days=1)).strftime("%Y%m%d"), end_date=end_dt.strftime("%Y%m%d"), is_open="1", ), call_name="trade_cal(exchange=SSE)", ) cal = cal.sort_values("cal_date") dates = pd.to_datetime(cal["cal_date"].values[:pred_len], format="%Y%m%d") if len(dates) < pred_len: raise ValueError( f"Could only obtain {len(dates)} future trading dates; " f"increase buffer or check Tushare calendar coverage." ) return pd.Series(dates)