Spaces:
Running
Running
File size: 4,839 Bytes
72a9562 72f5e10 2a8a0f5 72f5e10 72a9562 23022b9 72a9562 23022b9 72a9562 23022b9 2a8a0f5 23022b9 2a8a0f5 23022b9 2a8a0f5 23022b9 2a8a0f5 23022b9 2a8a0f5 23022b9 2a8a0f5 23022b9 2a8a0f5 23022b9 72f5e10 23022b9 72f5e10 23022b9 72f5e10 23022b9 72f5e10 23022b9 72f5e10 72a9562 2a8a0f5 72a9562 23022b9 72a9562 23022b9 72f5e10 23022b9 72f5e10 23022b9 72a9562 23022b9 72a9562 23022b9 72a9562 23022b9 72a9562 23022b9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | from datetime import datetime, timedelta
import os
import threading
import time
import pandas as pd
import tushare as ts
_TUSHARE_TOKEN = os.environ.get("TUSHARE_TOKEN", "").strip()
_TS_RETRY_COUNT = max(1, int(os.environ.get("TS_RETRY_COUNT", "3")))
_TS_RETRY_BASE_SLEEP = float(os.environ.get("TS_RETRY_BASE_SLEEP", "0.8"))
_PRO = None
_PRO_LOCK = threading.Lock()
def _get_pro():
global _PRO
with _PRO_LOCK:
if _PRO is not None:
return _PRO
if not _TUSHARE_TOKEN:
raise RuntimeError("TUSHARE_TOKEN is not set")
ts.set_token(_TUSHARE_TOKEN)
_PRO = ts.pro_api()
return _PRO
def _normalize_to_ts_code(raw_symbol: str) -> str:
"""
Normalize user `symbol` input to Tushare `ts_code`.
Accepted examples:
- "603777" -> "603777.SH"
- "300065" -> "300065.SZ"
- "430047" -> "430047.BJ"
- "000063.SZ" -> "000063.SZ"
"""
symbol = raw_symbol.strip().upper()
if "." in symbol:
code, market = symbol.split(".", 1)
if len(code) != 6 or not code.isdigit() or market not in {"SH", "SZ", "BJ"}:
raise ValueError(
f"Invalid symbol {raw_symbol!r}; expected e.g. '603777' or '000063.SZ'."
)
return f"{code}.{market}"
if len(symbol) != 6 or not symbol.isdigit():
raise ValueError(
f"Invalid symbol {raw_symbol!r}; expected 6 digits or code with suffix."
)
if symbol.startswith("6"):
market = "SH"
elif symbol.startswith(("0", "3")):
market = "SZ"
elif symbol.startswith(("4", "8")):
market = "BJ"
else:
raise ValueError(f"Cannot infer exchange suffix for symbol={raw_symbol!r}")
return f"{symbol}.{market}"
def _retry_call(fn, *, call_name: str):
last_exc: Exception | None = None
for attempt in range(1, _TS_RETRY_COUNT + 1):
try:
return fn()
except Exception as exc: # pragma: no cover - external IO
last_exc = exc
if attempt >= _TS_RETRY_COUNT:
break
time.sleep(_TS_RETRY_BASE_SLEEP * (2 ** (attempt - 1)))
raise RuntimeError(
f"Tushare call failed after {_TS_RETRY_COUNT} attempts ({call_name}): {last_exc}"
) from last_exc
def fetch_stock_data(
symbol: str, lookback: int
) -> tuple[pd.DataFrame, pd.Series, str]:
"""
Returns:
x_df : DataFrame with columns [open, high, low, close, volume, amount]
x_timestamp : pd.Series[datetime], aligned to x_df
last_trade_date: str "YYYYMMDD", the most recent bar date
"""
ts_code = _normalize_to_ts_code(symbol)
end_date = datetime.today().strftime("%Y%m%d")
# 2x buffer to account for weekends/holidays.
start_date = (datetime.today() - timedelta(days=lookback * 2)).strftime("%Y%m%d")
pro = _get_pro()
df = _retry_call(
lambda: ts.pro_bar(
ts_code=ts_code,
adj="qfq",
start_date=start_date,
end_date=end_date,
asset="E",
),
call_name=f"pro_bar(ts_code={ts_code})",
)
if df is None or df.empty:
raise ValueError(f"No data returned for symbol={symbol!r} (ts_code={ts_code})")
df = df.sort_values("trade_date").reset_index(drop=True)
df = df.rename(columns={"vol": "volume"})
df["timestamps"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
# Keep the most recent `lookback` bars
df = df.tail(lookback).reset_index(drop=True)
x_df = df[["open", "high", "low", "close", "volume", "amount"]].copy()
x_timestamp = df["timestamps"].copy()
last_trade_date = str(df["trade_date"].iloc[-1])
return x_df, x_timestamp, last_trade_date
def get_future_trading_dates(last_trade_date: str, pred_len: int) -> pd.Series:
"""
Return a pd.Series of `pred_len` future SSE trading dates (datetime) that
follow `last_trade_date` (format: YYYYMMDD).
"""
last_dt = datetime.strptime(last_trade_date, "%Y%m%d")
# 3x buffer so we always have enough dates even over a long holiday
end_dt = last_dt + timedelta(days=pred_len * 3)
pro = _get_pro()
cal = _retry_call(
lambda: pro.trade_cal(
exchange="SSE",
start_date=(last_dt + timedelta(days=1)).strftime("%Y%m%d"),
end_date=end_dt.strftime("%Y%m%d"),
is_open="1",
),
call_name="trade_cal(exchange=SSE)",
)
cal = cal.sort_values("cal_date")
dates = pd.to_datetime(cal["cal_date"].values[:pred_len], format="%Y%m%d")
if len(dates) < pred_len:
raise ValueError(
f"Could only obtain {len(dates)} future trading dates; "
f"increase buffer or check Tushare calendar coverage."
)
return pd.Series(dates)
|