vmjn commited on
Commit
af666bf
·
verified ·
1 Parent(s): 3e07fe0

fix: torch 2.6, chronos inputs=, timesfm fallback, yfinance curl_cffi+retry

Browse files
Files changed (2) hide show
  1. Dockerfile +2 -1
  2. app.py +54 -15
Dockerfile CHANGED
@@ -14,7 +14,7 @@ WORKDIR /home/user/app
14
 
15
  # CPU-only torch from PyTorch index (skip CUDA wheels, ~200MB vs ~2GB)
16
  RUN pip install --user --no-cache-dir --index-url https://download.pytorch.org/whl/cpu \
17
- torch==2.4.1
18
 
19
  # App deps — all pinned to known-compatible versions
20
  RUN pip install --user --no-cache-dir \
@@ -23,6 +23,7 @@ RUN pip install --user --no-cache-dir \
23
  "numpy>=1.26,<2.3" \
24
  "pandas>=2.1" \
25
  "yfinance>=0.2.50" \
 
26
  "websockets>=13.0" \
27
  "einops>=0.7" \
28
  "safetensors>=0.4" \
 
14
 
15
  # CPU-only torch from PyTorch index (skip CUDA wheels, ~200MB vs ~2GB)
16
  RUN pip install --user --no-cache-dir --index-url https://download.pytorch.org/whl/cpu \
17
+ torch==2.6.0
18
 
19
  # App deps — all pinned to known-compatible versions
20
  RUN pip install --user --no-cache-dir \
 
23
  "numpy>=1.26,<2.3" \
24
  "pandas>=2.1" \
25
  "yfinance>=0.2.50" \
26
+ "curl_cffi>=0.7" \
27
  "websockets>=13.0" \
28
  "einops>=0.7" \
29
  "safetensors>=0.4" \
app.py CHANGED
@@ -22,20 +22,50 @@ import gradio as gr
22
  # -----------------------------------------------------------------------------
23
 
24
  def _load_ohlc(symbol: str, lookback_days: int) -> pd.DataFrame:
25
- df = yf.download(symbol, period=f"{lookback_days + 10}d",
26
- interval="1d", progress=False, auto_adjust=False)
27
- if df is None or df.empty:
28
- return pd.DataFrame()
29
- if isinstance(df.columns, pd.MultiIndex):
30
- df.columns = df.columns.get_level_values(0)
31
- df = df.reset_index()
32
- df.columns = [str(c).lower() for c in df.columns]
33
- if "date" not in df.columns:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  return pd.DataFrame()
35
- df["timestamps"] = pd.to_datetime(df["date"])
36
- keep = ["timestamps", "open", "high", "low", "close", "volume"]
37
- df = df[[c for c in keep if c in df.columns]].dropna().tail(lookback_days).reset_index(drop=True)
38
- return df
39
 
40
 
41
  def _direction(pct: float) -> int:
@@ -130,7 +160,7 @@ def forecast_chronos(symbol: str, lookback_days: int = 180, pred_days: int = 30)
130
  return {"status": "error", "error": f"insufficient data for {symbol}", "n_lookback": len(kdf)}
131
  context = torch.tensor(kdf["close"].values, dtype=torch.float32)
132
  quantiles, mean = _get_chronos().predict_quantiles(
133
- context=context, prediction_length=pred_days, quantile_levels=[0.1, 0.5, 0.9],
134
  )
135
  median = quantiles[0, :, 1].cpu().numpy()
136
  low = quantiles[0, :, 0].cpu().numpy()
@@ -163,7 +193,16 @@ def _get_timesfm():
163
  global _timesfm
164
  if _timesfm is None:
165
  import timesfm
166
- _timesfm = timesfm.TimesFm_2p5_200M_torch.from_pretrained(TIMESFM_MODEL_ID)
 
 
 
 
 
 
 
 
 
167
  return _timesfm
168
 
169
 
 
22
  # -----------------------------------------------------------------------------
23
 
24
  def _load_ohlc(symbol: str, lookback_days: int) -> pd.DataFrame:
25
+ def _tidy(df):
26
+ if df is None or df.empty:
27
+ return pd.DataFrame()
28
+ if isinstance(df.columns, pd.MultiIndex):
29
+ df.columns = df.columns.get_level_values(0)
30
+ df = df.reset_index()
31
+ df.columns = [str(c).lower() for c in df.columns]
32
+ if "date" not in df.columns and "datetime" in df.columns:
33
+ df["date"] = df["datetime"]
34
+ if "date" not in df.columns:
35
+ return pd.DataFrame()
36
+ df["timestamps"] = pd.to_datetime(df["date"])
37
+ keep = ["timestamps", "open", "high", "low", "close", "volume"]
38
+ df = df[[c for c in keep if c in df.columns]].dropna().tail(lookback_days).reset_index(drop=True)
39
+ return df
40
+
41
+ # Primary: yf.download. Fallback: yf.Ticker().history(). Uses curl_cffi chrome impersonation if available.
42
+ session = None
43
+ try:
44
+ from curl_cffi import requests as cureq
45
+ session = cureq.Session(impersonate="chrome")
46
+ except Exception:
47
+ session = None
48
+
49
+ period = f"{lookback_days + 10}d"
50
+ try:
51
+ df = yf.download(symbol, period=period, interval="1d",
52
+ progress=False, auto_adjust=False,
53
+ session=session) if session else \
54
+ yf.download(symbol, period=period, interval="1d",
55
+ progress=False, auto_adjust=False)
56
+ out = _tidy(df)
57
+ if len(out) >= 32:
58
+ return out
59
+ except Exception:
60
+ pass
61
+
62
+ # Fallback path
63
+ try:
64
+ t = yf.Ticker(symbol, session=session) if session else yf.Ticker(symbol)
65
+ df = t.history(period=period, interval="1d", auto_adjust=False)
66
+ return _tidy(df)
67
+ except Exception:
68
  return pd.DataFrame()
 
 
 
 
69
 
70
 
71
  def _direction(pct: float) -> int:
 
160
  return {"status": "error", "error": f"insufficient data for {symbol}", "n_lookback": len(kdf)}
161
  context = torch.tensor(kdf["close"].values, dtype=torch.float32)
162
  quantiles, mean = _get_chronos().predict_quantiles(
163
+ inputs=context, prediction_length=pred_days, quantile_levels=[0.1, 0.5, 0.9],
164
  )
165
  median = quantiles[0, :, 1].cpu().numpy()
166
  low = quantiles[0, :, 0].cpu().numpy()
 
193
  global _timesfm
194
  if _timesfm is None:
195
  import timesfm
196
+ # TimesFM package API varies across versions. Try the 2.5-specific class first,
197
+ # then fall back to the generic TimesFm(hparams, checkpoint) constructor.
198
+ cls = getattr(timesfm, "TimesFm_2p5_200M_torch", None)
199
+ if cls is not None:
200
+ _timesfm = cls.from_pretrained(TIMESFM_MODEL_ID)
201
+ else:
202
+ _timesfm = timesfm.TimesFm(
203
+ hparams=timesfm.TimesFmHparams(backend="torch"),
204
+ checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=TIMESFM_MODEL_ID),
205
+ )
206
  return _timesfm
207
 
208