Spaces:
Running
Running
fix: torch 2.6, chronos inputs=, timesfm fallback, yfinance curl_cffi+retry
Browse files- Dockerfile +2 -1
- app.py +54 -15
Dockerfile
CHANGED
|
@@ -14,7 +14,7 @@ WORKDIR /home/user/app
|
|
| 14 |
|
| 15 |
# CPU-only torch from PyTorch index (skip CUDA wheels, ~200MB vs ~2GB)
|
| 16 |
RUN pip install --user --no-cache-dir --index-url https://download.pytorch.org/whl/cpu \
|
| 17 |
-
torch==2.
|
| 18 |
|
| 19 |
# App deps — all pinned to known-compatible versions
|
| 20 |
RUN pip install --user --no-cache-dir \
|
|
@@ -23,6 +23,7 @@ RUN pip install --user --no-cache-dir \
|
|
| 23 |
"numpy>=1.26,<2.3" \
|
| 24 |
"pandas>=2.1" \
|
| 25 |
"yfinance>=0.2.50" \
|
|
|
|
| 26 |
"websockets>=13.0" \
|
| 27 |
"einops>=0.7" \
|
| 28 |
"safetensors>=0.4" \
|
|
|
|
| 14 |
|
| 15 |
# CPU-only torch from PyTorch index (skip CUDA wheels, ~200MB vs ~2GB)
|
| 16 |
RUN pip install --user --no-cache-dir --index-url https://download.pytorch.org/whl/cpu \
|
| 17 |
+
torch==2.6.0
|
| 18 |
|
| 19 |
# App deps — all pinned to known-compatible versions
|
| 20 |
RUN pip install --user --no-cache-dir \
|
|
|
|
| 23 |
"numpy>=1.26,<2.3" \
|
| 24 |
"pandas>=2.1" \
|
| 25 |
"yfinance>=0.2.50" \
|
| 26 |
+
"curl_cffi>=0.7" \
|
| 27 |
"websockets>=13.0" \
|
| 28 |
"einops>=0.7" \
|
| 29 |
"safetensors>=0.4" \
|
app.py
CHANGED
|
@@ -22,20 +22,50 @@ import gradio as gr
|
|
| 22 |
# -----------------------------------------------------------------------------
|
| 23 |
|
| 24 |
def _load_ohlc(symbol: str, lookback_days: int) -> pd.DataFrame:
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
df
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
return pd.DataFrame()
|
| 35 |
-
df["timestamps"] = pd.to_datetime(df["date"])
|
| 36 |
-
keep = ["timestamps", "open", "high", "low", "close", "volume"]
|
| 37 |
-
df = df[[c for c in keep if c in df.columns]].dropna().tail(lookback_days).reset_index(drop=True)
|
| 38 |
-
return df
|
| 39 |
|
| 40 |
|
| 41 |
def _direction(pct: float) -> int:
|
|
@@ -130,7 +160,7 @@ def forecast_chronos(symbol: str, lookback_days: int = 180, pred_days: int = 30)
|
|
| 130 |
return {"status": "error", "error": f"insufficient data for {symbol}", "n_lookback": len(kdf)}
|
| 131 |
context = torch.tensor(kdf["close"].values, dtype=torch.float32)
|
| 132 |
quantiles, mean = _get_chronos().predict_quantiles(
|
| 133 |
-
|
| 134 |
)
|
| 135 |
median = quantiles[0, :, 1].cpu().numpy()
|
| 136 |
low = quantiles[0, :, 0].cpu().numpy()
|
|
@@ -163,7 +193,16 @@ def _get_timesfm():
|
|
| 163 |
global _timesfm
|
| 164 |
if _timesfm is None:
|
| 165 |
import timesfm
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
return _timesfm
|
| 168 |
|
| 169 |
|
|
|
|
| 22 |
# -----------------------------------------------------------------------------
|
| 23 |
|
| 24 |
def _load_ohlc(symbol: str, lookback_days: int) -> pd.DataFrame:
|
| 25 |
+
def _tidy(df):
|
| 26 |
+
if df is None or df.empty:
|
| 27 |
+
return pd.DataFrame()
|
| 28 |
+
if isinstance(df.columns, pd.MultiIndex):
|
| 29 |
+
df.columns = df.columns.get_level_values(0)
|
| 30 |
+
df = df.reset_index()
|
| 31 |
+
df.columns = [str(c).lower() for c in df.columns]
|
| 32 |
+
if "date" not in df.columns and "datetime" in df.columns:
|
| 33 |
+
df["date"] = df["datetime"]
|
| 34 |
+
if "date" not in df.columns:
|
| 35 |
+
return pd.DataFrame()
|
| 36 |
+
df["timestamps"] = pd.to_datetime(df["date"])
|
| 37 |
+
keep = ["timestamps", "open", "high", "low", "close", "volume"]
|
| 38 |
+
df = df[[c for c in keep if c in df.columns]].dropna().tail(lookback_days).reset_index(drop=True)
|
| 39 |
+
return df
|
| 40 |
+
|
| 41 |
+
# Primary: yf.download. Fallback: yf.Ticker().history(). Uses curl_cffi chrome impersonation if available.
|
| 42 |
+
session = None
|
| 43 |
+
try:
|
| 44 |
+
from curl_cffi import requests as cureq
|
| 45 |
+
session = cureq.Session(impersonate="chrome")
|
| 46 |
+
except Exception:
|
| 47 |
+
session = None
|
| 48 |
+
|
| 49 |
+
period = f"{lookback_days + 10}d"
|
| 50 |
+
try:
|
| 51 |
+
df = yf.download(symbol, period=period, interval="1d",
|
| 52 |
+
progress=False, auto_adjust=False,
|
| 53 |
+
session=session) if session else \
|
| 54 |
+
yf.download(symbol, period=period, interval="1d",
|
| 55 |
+
progress=False, auto_adjust=False)
|
| 56 |
+
out = _tidy(df)
|
| 57 |
+
if len(out) >= 32:
|
| 58 |
+
return out
|
| 59 |
+
except Exception:
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
# Fallback path
|
| 63 |
+
try:
|
| 64 |
+
t = yf.Ticker(symbol, session=session) if session else yf.Ticker(symbol)
|
| 65 |
+
df = t.history(period=period, interval="1d", auto_adjust=False)
|
| 66 |
+
return _tidy(df)
|
| 67 |
+
except Exception:
|
| 68 |
return pd.DataFrame()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
def _direction(pct: float) -> int:
|
|
|
|
| 160 |
return {"status": "error", "error": f"insufficient data for {symbol}", "n_lookback": len(kdf)}
|
| 161 |
context = torch.tensor(kdf["close"].values, dtype=torch.float32)
|
| 162 |
quantiles, mean = _get_chronos().predict_quantiles(
|
| 163 |
+
inputs=context, prediction_length=pred_days, quantile_levels=[0.1, 0.5, 0.9],
|
| 164 |
)
|
| 165 |
median = quantiles[0, :, 1].cpu().numpy()
|
| 166 |
low = quantiles[0, :, 0].cpu().numpy()
|
|
|
|
| 193 |
global _timesfm
|
| 194 |
if _timesfm is None:
|
| 195 |
import timesfm
|
| 196 |
+
# TimesFM package API varies across versions. Try the 2.5-specific class first,
|
| 197 |
+
# then fall back to the generic TimesFm(hparams, checkpoint) constructor.
|
| 198 |
+
cls = getattr(timesfm, "TimesFm_2p5_200M_torch", None)
|
| 199 |
+
if cls is not None:
|
| 200 |
+
_timesfm = cls.from_pretrained(TIMESFM_MODEL_ID)
|
| 201 |
+
else:
|
| 202 |
+
_timesfm = timesfm.TimesFm(
|
| 203 |
+
hparams=timesfm.TimesFmHparams(backend="torch"),
|
| 204 |
+
checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=TIMESFM_MODEL_ID),
|
| 205 |
+
)
|
| 206 |
return _timesfm
|
| 207 |
|
| 208 |
|