Spaces:
Running
Running
| from datetime import datetime, timedelta | |
| import os | |
| import threading | |
| import time | |
| import pandas as pd | |
| import tushare as ts | |
| _TUSHARE_TOKEN = os.environ.get("TUSHARE_TOKEN", "").strip() | |
| _TS_RETRY_COUNT = max(1, int(os.environ.get("TS_RETRY_COUNT", "3"))) | |
| _TS_RETRY_BASE_SLEEP = float(os.environ.get("TS_RETRY_BASE_SLEEP", "0.8")) | |
| _PRO = None | |
| _PRO_LOCK = threading.Lock() | |
| def _get_pro(): | |
| global _PRO | |
| with _PRO_LOCK: | |
| if _PRO is not None: | |
| return _PRO | |
| if not _TUSHARE_TOKEN: | |
| raise RuntimeError("TUSHARE_TOKEN is not set") | |
| ts.set_token(_TUSHARE_TOKEN) | |
| _PRO = ts.pro_api() | |
| return _PRO | |
| def _normalize_to_ts_code(raw_symbol: str) -> str: | |
| """ | |
| Normalize user `symbol` input to Tushare `ts_code`. | |
| Accepted examples: | |
| - "603777" -> "603777.SH" | |
| - "300065" -> "300065.SZ" | |
| - "430047" -> "430047.BJ" | |
| - "000063.SZ" -> "000063.SZ" | |
| """ | |
| symbol = raw_symbol.strip().upper() | |
| if "." in symbol: | |
| code, market = symbol.split(".", 1) | |
| if len(code) != 6 or not code.isdigit() or market not in {"SH", "SZ", "BJ"}: | |
| raise ValueError( | |
| f"Invalid symbol {raw_symbol!r}; expected e.g. '603777' or '000063.SZ'." | |
| ) | |
| return f"{code}.{market}" | |
| if len(symbol) != 6 or not symbol.isdigit(): | |
| raise ValueError( | |
| f"Invalid symbol {raw_symbol!r}; expected 6 digits or code with suffix." | |
| ) | |
| if symbol.startswith("6"): | |
| market = "SH" | |
| elif symbol.startswith(("0", "3")): | |
| market = "SZ" | |
| elif symbol.startswith(("4", "8")): | |
| market = "BJ" | |
| else: | |
| raise ValueError(f"Cannot infer exchange suffix for symbol={raw_symbol!r}") | |
| return f"{symbol}.{market}" | |
| def _retry_call(fn, *, call_name: str): | |
| last_exc: Exception | None = None | |
| for attempt in range(1, _TS_RETRY_COUNT + 1): | |
| try: | |
| return fn() | |
| except Exception as exc: # pragma: no cover - external IO | |
| last_exc = exc | |
| if attempt >= _TS_RETRY_COUNT: | |
| break | |
| time.sleep(_TS_RETRY_BASE_SLEEP * (2 ** (attempt - 1))) | |
| raise RuntimeError( | |
| f"Tushare call failed after {_TS_RETRY_COUNT} attempts ({call_name}): {last_exc}" | |
| ) from last_exc | |
| def fetch_stock_data( | |
| symbol: str, lookback: int | |
| ) -> tuple[pd.DataFrame, pd.Series, str]: | |
| """ | |
| Returns: | |
| x_df : DataFrame with columns [open, high, low, close, volume, amount] | |
| x_timestamp : pd.Series[datetime], aligned to x_df | |
| last_trade_date: str "YYYYMMDD", the most recent bar date | |
| """ | |
| ts_code = _normalize_to_ts_code(symbol) | |
| end_date = datetime.today().strftime("%Y%m%d") | |
| # 2x buffer to account for weekends/holidays. | |
| start_date = (datetime.today() - timedelta(days=lookback * 2)).strftime("%Y%m%d") | |
| pro = _get_pro() | |
| df = _retry_call( | |
| lambda: ts.pro_bar( | |
| ts_code=ts_code, | |
| adj="qfq", | |
| start_date=start_date, | |
| end_date=end_date, | |
| asset="E", | |
| ), | |
| call_name=f"pro_bar(ts_code={ts_code})", | |
| ) | |
| if df is None or df.empty: | |
| raise ValueError(f"No data returned for symbol={symbol!r} (ts_code={ts_code})") | |
| df = df.sort_values("trade_date").reset_index(drop=True) | |
| df = df.rename(columns={"vol": "volume"}) | |
| df["timestamps"] = pd.to_datetime(df["trade_date"], format="%Y%m%d") | |
| # Keep the most recent `lookback` bars | |
| df = df.tail(lookback).reset_index(drop=True) | |
| x_df = df[["open", "high", "low", "close", "volume", "amount"]].copy() | |
| x_timestamp = df["timestamps"].copy() | |
| last_trade_date = str(df["trade_date"].iloc[-1]) | |
| return x_df, x_timestamp, last_trade_date | |
| def get_future_trading_dates(last_trade_date: str, pred_len: int) -> pd.Series: | |
| """ | |
| Return a pd.Series of `pred_len` future SSE trading dates (datetime) that | |
| follow `last_trade_date` (format: YYYYMMDD). | |
| """ | |
| last_dt = datetime.strptime(last_trade_date, "%Y%m%d") | |
| # 3x buffer so we always have enough dates even over a long holiday | |
| end_dt = last_dt + timedelta(days=pred_len * 3) | |
| pro = _get_pro() | |
| cal = _retry_call( | |
| lambda: pro.trade_cal( | |
| exchange="SSE", | |
| start_date=(last_dt + timedelta(days=1)).strftime("%Y%m%d"), | |
| end_date=end_dt.strftime("%Y%m%d"), | |
| is_open="1", | |
| ), | |
| call_name="trade_cal(exchange=SSE)", | |
| ) | |
| cal = cal.sort_values("cal_date") | |
| dates = pd.to_datetime(cal["cal_date"].values[:pred_len], format="%Y%m%d") | |
| if len(dates) < pred_len: | |
| raise ValueError( | |
| f"Could only obtain {len(dates)} future trading dates; " | |
| f"increase buffer or check Tushare calendar coverage." | |
| ) | |
| return pd.Series(dates) | |