Spaces:
Sleeping
Sleeping
| """ | |
| data_loader.py — yfinance market data download with batching, caching, | |
| and granular progress callbacks for the UI progress bar. | |
| """ | |
| import gc | |
| import time | |
| import logging | |
| import hashlib | |
| import pickle | |
| from datetime import datetime, timedelta | |
| from pathlib import Path | |
| from typing import Callable, Optional | |
| import numpy as np | |
| import pandas as pd | |
| logger = logging.getLogger("SniperData") | |
| DATA_CACHE_DIR = Path("/tmp/sniper_data_cache") | |
| DATA_CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| BATCH_SIZE = 50 | |
| MAX_RETRIES = 3 | |
| RETRY_DELAY = 5 | |
| # --------------------------------------------------------------------------- | |
| # Cache helpers | |
| # --------------------------------------------------------------------------- | |
| def _cache_key(tickers: list, start: str, end: str) -> str: | |
| key = f"{'|'.join(sorted(tickers))}|{start}|{end}" | |
| return hashlib.md5(key.encode()).hexdigest()[:16] | |
| def _cache_path(ticker: str, start: str, end: str) -> Path: | |
| safe = ticker.replace("^", "IDX_") | |
| key = hashlib.md5(f"{ticker}{start}{end}".encode()).hexdigest()[:8] | |
| return DATA_CACHE_DIR / f"{safe}_{key}.pkl" | |
| def _load_cached(ticker: str, start: str, end: str) -> Optional[pd.DataFrame]: | |
| p = _cache_path(ticker, start, end) | |
| if p.exists(): | |
| try: | |
| with open(p, "rb") as fh: | |
| return pickle.load(fh) | |
| except Exception: | |
| p.unlink(missing_ok=True) | |
| return None | |
| def _save_cached(ticker: str, start: str, end: str, df: pd.DataFrame): | |
| p = _cache_path(ticker, start, end) | |
| try: | |
| with open(p, "wb") as fh: | |
| pickle.dump(df, fh) | |
| except Exception as e: | |
| logger.warning(f"Cache write failed for {ticker}: {e}") | |
| # --------------------------------------------------------------------------- | |
| # Download | |
| # --------------------------------------------------------------------------- | |
| def download_ticker_batch( | |
| tickers: list[str], | |
| start: str, | |
| end: str, | |
| progress_cb: Callable = None, | |
| ) -> dict[str, pd.DataFrame]: | |
| """ | |
| Download OHLCV data for a list of tickers (+ ^VIX, ^GSPC) in batches of 50. | |
| Returns dict {ticker: DataFrame}. | |
| Calls progress_cb(message, fraction_done) at each batch. | |
| """ | |
| import yfinance as yf | |
| all_tickers = list(dict.fromkeys(tickers + ["^VIX", "^GSPC"])) | |
| total = len(all_tickers) | |
| results: dict[str, pd.DataFrame] = {} | |
| to_download = [] | |
| # Check cache first | |
| for ticker in all_tickers: | |
| cached = _load_cached(ticker, start, end) | |
| if cached is not None: | |
| results[ticker] = cached | |
| else: | |
| to_download.append(ticker) | |
| cached_count = total - len(to_download) | |
| if cached_count > 0 and progress_cb: | |
| progress_cb(f"Loaded {cached_count} tickers from cache. Downloading {len(to_download)} new...", 0.02) | |
| n_batches = max(1, (len(to_download) + BATCH_SIZE - 1) // BATCH_SIZE) | |
| for batch_i, batch_start in enumerate(range(0, len(to_download), BATCH_SIZE)): | |
| batch = to_download[batch_start: batch_start + BATCH_SIZE] | |
| frac = batch_start / max(1, len(to_download)) | |
| if progress_cb: | |
| progress_cb( | |
| f"Downloading batch {batch_i + 1}/{n_batches} ({len(batch)} tickers)...", | |
| 0.05 + 0.30 * frac, | |
| ) | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| raw = yf.download( | |
| batch, | |
| start=start, | |
| end=end, | |
| group_by="ticker", | |
| auto_adjust=True, | |
| threads=True, | |
| progress=False, | |
| ) | |
| for ticker in batch: | |
| try: | |
| df = (raw[ticker] if len(batch) > 1 else raw).copy() | |
| df = df.dropna(subset=["Close"]) | |
| if len(df) > 0: | |
| df.index = pd.to_datetime(df.index) | |
| results[ticker] = df | |
| _save_cached(ticker, start, end, df) | |
| except Exception: | |
| pass | |
| break | |
| except Exception as e: | |
| logger.warning(f"Batch {batch_i+1} attempt {attempt+1} failed: {e}") | |
| if attempt < MAX_RETRIES - 1: | |
| time.sleep(RETRY_DELAY * (attempt + 1)) | |
| if progress_cb: | |
| progress_cb(f"Download complete. {len(results)} tickers loaded.", 0.35) | |
| gc.collect() | |
| return results | |
| # --------------------------------------------------------------------------- | |
| # Quality filter | |
| # --------------------------------------------------------------------------- | |
| def filter_ticker_data( | |
| ticker_data: dict[str, pd.DataFrame], | |
| min_history_days: int = 252, | |
| min_dollar_volume: float = 2_000_000, | |
| progress_cb: Callable = None, | |
| ) -> dict[str, pd.DataFrame]: | |
| """Remove tickers with insufficient history or low liquidity.""" | |
| valid = {} | |
| skipped = [] | |
| for ticker, df in ticker_data.items(): | |
| if ticker.startswith("^"): | |
| valid[ticker] = df | |
| continue | |
| if len(df) < min_history_days: | |
| skipped.append((ticker, f"only {len(df)} days")) | |
| continue | |
| try: | |
| avg_dv = (df["Close"] * df["Volume"]).rolling(20).mean().dropna() | |
| if len(avg_dv) == 0 or avg_dv.iloc[-1] < min_dollar_volume: | |
| skipped.append((ticker, "low dollar volume")) | |
| continue | |
| except Exception: | |
| skipped.append((ticker, "volume check failed")) | |
| continue | |
| valid[ticker] = df | |
| if progress_cb and skipped: | |
| progress_cb(f"Filtered: {len(valid)} tickers kept, {len(skipped)} skipped.", 0.37) | |
| return valid | |
| # --------------------------------------------------------------------------- | |
| # Extract VIX / SPX convenience helpers | |
| # --------------------------------------------------------------------------- | |
| def extract_market_series(ticker_data: dict) -> tuple: | |
| """Returns (vix_close_series, sp500_close_series) or (None, None).""" | |
| vix = ticker_data.get("^VIX") | |
| spx = ticker_data.get("^GSPC") | |
| vix_s = vix["Close"] if vix is not None and not vix.empty else None | |
| spx_s = spx["Close"] if spx is not None and not spx.empty else None | |
| return vix_s, spx_s | |
| # --------------------------------------------------------------------------- | |
| # Regime detection (for live routing in backtester) | |
| # --------------------------------------------------------------------------- | |
| def get_current_regime(date, sp500_data: pd.Series, vix_data: pd.Series, | |
| sma_period: int = 200, vix_threshold: float = 20.0) -> tuple[int, int]: | |
| """ | |
| Returns (market_regime, vix_regime) as (0/1, 0/1). | |
| market_regime: 1=bull (above 200 SMA), 0=bear | |
| vix_regime: 1=high VIX, 0=low VIX | |
| """ | |
| try: | |
| sp_sma = sp500_data.rolling(sma_period).mean() | |
| sp_idx = sp500_data.index.get_indexer([date], method="ffill")[0] | |
| mkt = 1 if sp_idx >= 0 and sp500_data.iloc[sp_idx] > sp_sma.iloc[sp_idx] else 0 | |
| except Exception: | |
| mkt = 1 | |
| try: | |
| vix_idx = vix_data.index.get_indexer([date], method="ffill")[0] | |
| vix = 1 if vix_idx >= 0 and vix_data.iloc[vix_idx] > vix_threshold else 0 | |
| except Exception: | |
| vix = 0 | |
| return mkt, vix |