File size: 4,839 Bytes
72a9562
72f5e10
2a8a0f5
72f5e10
72a9562
 
23022b9
72a9562
23022b9
 
 
72a9562
23022b9
 
2a8a0f5
23022b9
 
 
 
 
 
 
 
 
 
 
 
 
 
2a8a0f5
23022b9
2a8a0f5
 
23022b9
 
 
 
2a8a0f5
 
 
23022b9
 
 
 
 
 
 
2a8a0f5
 
23022b9
2a8a0f5
 
23022b9
 
 
 
 
 
 
 
2a8a0f5
23022b9
 
 
 
72f5e10
23022b9
72f5e10
 
 
 
23022b9
72f5e10
23022b9
72f5e10
23022b9
72f5e10
 
 
72a9562
2a8a0f5
72a9562
 
 
 
 
 
 
23022b9
72a9562
23022b9
 
 
 
 
 
 
 
72f5e10
 
23022b9
72f5e10
23022b9
72a9562
 
 
23022b9
 
72a9562
23022b9
 
72a9562
 
 
 
 
 
23022b9
72a9562
 
 
 
 
 
 
 
 
 
23022b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from datetime import datetime, timedelta
import os
import threading
import time

import pandas as pd
import tushare as ts

_TUSHARE_TOKEN = os.environ.get("TUSHARE_TOKEN", "").strip()
_TS_RETRY_COUNT = max(1, int(os.environ.get("TS_RETRY_COUNT", "3")))
_TS_RETRY_BASE_SLEEP = float(os.environ.get("TS_RETRY_BASE_SLEEP", "0.8"))

_PRO = None
_PRO_LOCK = threading.Lock()


def _get_pro():
    global _PRO
    with _PRO_LOCK:
        if _PRO is not None:
            return _PRO
        if not _TUSHARE_TOKEN:
            raise RuntimeError("TUSHARE_TOKEN is not set")
        ts.set_token(_TUSHARE_TOKEN)
        _PRO = ts.pro_api()
        return _PRO


def _normalize_to_ts_code(raw_symbol: str) -> str:
    """
    Normalize user `symbol` input to Tushare `ts_code`.

    Accepted examples:
      - "603777"     -> "603777.SH"
      - "300065"     -> "300065.SZ"
      - "430047"     -> "430047.BJ"
      - "000063.SZ"  -> "000063.SZ"
    """
    symbol = raw_symbol.strip().upper()
    if "." in symbol:
        code, market = symbol.split(".", 1)
        if len(code) != 6 or not code.isdigit() or market not in {"SH", "SZ", "BJ"}:
            raise ValueError(
                f"Invalid symbol {raw_symbol!r}; expected e.g. '603777' or '000063.SZ'."
            )
        return f"{code}.{market}"

    if len(symbol) != 6 or not symbol.isdigit():
        raise ValueError(
            f"Invalid symbol {raw_symbol!r}; expected 6 digits or code with suffix."
        )

    if symbol.startswith("6"):
        market = "SH"
    elif symbol.startswith(("0", "3")):
        market = "SZ"
    elif symbol.startswith(("4", "8")):
        market = "BJ"
    else:
        raise ValueError(f"Cannot infer exchange suffix for symbol={raw_symbol!r}")

    return f"{symbol}.{market}"


def _retry_call(fn, *, call_name: str):
    last_exc: Exception | None = None
    for attempt in range(1, _TS_RETRY_COUNT + 1):
        try:
            return fn()
        except Exception as exc:  # pragma: no cover - external IO
            last_exc = exc
            if attempt >= _TS_RETRY_COUNT:
                break
            time.sleep(_TS_RETRY_BASE_SLEEP * (2 ** (attempt - 1)))
    raise RuntimeError(
        f"Tushare call failed after {_TS_RETRY_COUNT} attempts ({call_name}): {last_exc}"
    ) from last_exc


def fetch_stock_data(
    symbol: str, lookback: int
) -> tuple[pd.DataFrame, pd.Series, str]:
    """
    Returns:
        x_df         : DataFrame with columns [open, high, low, close, volume, amount]
        x_timestamp  : pd.Series[datetime], aligned to x_df
        last_trade_date: str "YYYYMMDD", the most recent bar date
    """
    ts_code = _normalize_to_ts_code(symbol)
    end_date = datetime.today().strftime("%Y%m%d")
    # 2x buffer to account for weekends/holidays.
    start_date = (datetime.today() - timedelta(days=lookback * 2)).strftime("%Y%m%d")

    pro = _get_pro()
    df = _retry_call(
        lambda: ts.pro_bar(
            ts_code=ts_code,
            adj="qfq",
            start_date=start_date,
            end_date=end_date,
            asset="E",
        ),
        call_name=f"pro_bar(ts_code={ts_code})",
    )

    if df is None or df.empty:
        raise ValueError(f"No data returned for symbol={symbol!r} (ts_code={ts_code})")

    df = df.sort_values("trade_date").reset_index(drop=True)
    df = df.rename(columns={"vol": "volume"})
    df["timestamps"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")

    # Keep the most recent `lookback` bars
    df = df.tail(lookback).reset_index(drop=True)

    x_df = df[["open", "high", "low", "close", "volume", "amount"]].copy()
    x_timestamp = df["timestamps"].copy()
    last_trade_date = str(df["trade_date"].iloc[-1])

    return x_df, x_timestamp, last_trade_date


def get_future_trading_dates(last_trade_date: str, pred_len: int) -> pd.Series:
    """
    Return a pd.Series of `pred_len` future SSE trading dates (datetime) that
    follow `last_trade_date` (format: YYYYMMDD).
    """
    last_dt = datetime.strptime(last_trade_date, "%Y%m%d")
    # 3x buffer so we always have enough dates even over a long holiday
    end_dt = last_dt + timedelta(days=pred_len * 3)

    pro = _get_pro()
    cal = _retry_call(
        lambda: pro.trade_cal(
            exchange="SSE",
            start_date=(last_dt + timedelta(days=1)).strftime("%Y%m%d"),
            end_date=end_dt.strftime("%Y%m%d"),
            is_open="1",
        ),
        call_name="trade_cal(exchange=SSE)",
    )

    cal = cal.sort_values("cal_date")
    dates = pd.to_datetime(cal["cal_date"].values[:pred_len], format="%Y%m%d")

    if len(dates) < pred_len:
        raise ValueError(
            f"Could only obtain {len(dates)} future trading dates; "
            f"increase buffer or check Tushare calendar coverage."
        )

    return pd.Series(dates)