""" Data loading utilities for the ARB-MAX 15-minute trainer Space. All downloads are idempotent via huggingface_hub.hf_hub_download which caches under `cache_dir`. Nothing is uploaded or deleted from here. """ from __future__ import annotations import time from pathlib import Path from typing import Iterable, List, Optional import numpy as np import pandas as pd import polars as pl from huggingface_hub import hf_hub_download from huggingface_hub.utils import HfHubHTTPError DATASET_REPO_ID = "commanderzee/15m-crypto" # --------------------------------------------------------------------------- # Idempotent download helper with retry/backoff # --------------------------------------------------------------------------- def _download_once( repo_id: str, path_in_repo: str, repo_type: str, hf_token: str, cache_dir: Path, max_attempts: int = 6, ) -> Path: """Download a file (idempotent). Returns local path. Uses hf_hub_download's internal cache — subsequent calls are near-free. Retries with exponential backoff on 5xx / 429. """ cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) attempt = 0 last_err: Optional[Exception] = None while attempt < max_attempts: try: local = hf_hub_download( repo_id=repo_id, filename=path_in_repo, repo_type=repo_type, token=hf_token, cache_dir=str(cache_dir), ) return Path(local) except HfHubHTTPError as e: # type: ignore[attr-defined] status = getattr(getattr(e, "response", None), "status_code", None) if status is not None and (status == 429 or 500 <= status < 600): last_err = e sleep_s = min(60.0, 2.0 ** attempt) time.sleep(sleep_s) attempt += 1 continue raise except Exception as e: # noqa: BLE001 last_err = e sleep_s = min(60.0, 2.0 ** attempt) time.sleep(sleep_s) attempt += 1 raise RuntimeError( f"Failed to download {repo_id}:{path_in_repo} after {max_attempts} attempts" ) from last_err # --------------------------------------------------------------------------- # Markets index # --------------------------------------------------------------------------- def load_markets_index(asset: str, hf_token: str, cache_dir: Path) -> pl.DataFrame: """Load and filter the markets_index.parquet for a single asset. Filters: - slug startswith f"{asset}-updown-15m-" - book_snapshot_5_from is non-empty / non-null - status == "resolved" Also: - Extracts slug_ts = int(slug.rsplit('-', 1)[1]) as Int64 seconds. - Asserts the invariant (end_date_us // 1e6 - slug_ts) == 900 for every row. - Sorts by slug_ts ascending. """ asset = asset.lower() local = _download_once( DATASET_REPO_ID, "markets_index.parquet", "dataset", hf_token, cache_dir ) df = pl.read_parquet(str(local)) prefix = f"{asset}-updown-15m-" df = df.filter(pl.col("slug").str.starts_with(prefix)) if "status" in df.columns: df = df.filter(pl.col("status") == "resolved") if "book_snapshot_5_from" in df.columns: df = df.filter( pl.col("book_snapshot_5_from").is_not_null() & (pl.col("book_snapshot_5_from") != "") ) df = df.with_columns( pl.col("slug") .str.split("-") .list.last() .cast(pl.Int64) .alias("slug_ts") ) # ---- CRITICAL INVARIANT: slug_ts is WINDOW START, window length = 900s ---- end_s_minus_start = (pl.col("end_date_us") // 1_000_000) - pl.col("slug_ts") check = df.select((end_s_minus_start == 900).all().alias("ok")).item() assert check, ( "Schema drift: end_date_us//1e6 - slug_ts != 900 for some rows. " "slug_ts is supposed to be the window START (seconds, UTC); window is " "[slug_ts, slug_ts+900). Do NOT join [slug_ts-900, slug_ts) — this " "would shift everything by 15 minutes and silently train on garbage." ) df = df.sort("slug_ts") return df # --------------------------------------------------------------------------- # OHLCV # --------------------------------------------------------------------------- _OHLCV_KEEP = [ "open_time", "open", "high", "low", "close", "volume", "trades", "quote_volume", "taker_buy_base", "taker_buy_quote", ] def load_ohlcv(asset: str, hf_token: str, cache_dir: Path) -> pl.DataFrame: """Load Binance 1s OHLCV klines for the asset.""" asset = asset.lower() local = _download_once( DATASET_REPO_ID, f"ohlcv_1s/{asset}.parquet", "dataset", hf_token, cache_dir ) df = pl.read_parquet(str(local)) keep = [c for c in _OHLCV_KEEP if c in df.columns] df = df.select(keep) df = df.sort("open_time") return df # --------------------------------------------------------------------------- # Orderbook (filtered to the slugs actually in the markets frame) # --------------------------------------------------------------------------- _OB_BASE_COLS = ["timestamp_us", "slug", "outcome"] _OB_PX_COLS = [f"bid_price_{i}" for i in range(5)] + [f"ask_price_{i}" for i in range(5)] _OB_SZ_COLS = [f"bid_size_{i}" for i in range(5)] + [f"ask_size_{i}" for i in range(5)] def _orderbook_local_path(asset: str, hf_token: str, cache_dir: Path) -> Path: return _download_once( DATASET_REPO_ID, f"book_snapshot_5/{asset.lower()}.parquet", "dataset", hf_token, cache_dir, ) def _orderbook_lazy(local_path: Path, slug_list: list) -> "pl.LazyFrame": cols = _OB_BASE_COLS + _OB_PX_COLS + _OB_SZ_COLS lf = pl.scan_parquet(str(local_path)) avail_cols = lf.collect_schema().names() cols = [c for c in cols if c in avail_cols] lf = lf.select(cols).filter(pl.col("slug").is_in(slug_list)) casts = [] for c in _OB_PX_COLS: if c in cols: casts.append(pl.col(c).cast(pl.Float32, strict=False).alias(c)) for c in _OB_SZ_COLS: if c in cols: casts.append(pl.col(c).cast(pl.Float64, strict=False).alias(c)) if casts: lf = lf.with_columns(casts) return lf def iter_orderbook_batches( asset: str, hf_token: str, cache_dir: Path, slugs: Iterable[str], batch_size: int = 500, ): """DEPRECATED: polars scan-filter-collect reads the full 37 GB parquet even when filtering to a small slug list (is_in doesn't do row-group pushdown). Kept for backwards-compat callers; use `iter_orderbook_slug_pairs` instead. """ asset = asset.lower() local = _orderbook_local_path(asset, hf_token, cache_dir) slug_list = list(slugs) for start in range(0, len(slug_list), batch_size): batch = slug_list[start : start + batch_size] lf = _orderbook_lazy(local, batch) df = lf.collect() if len(df) > 0: df = df.sort(["slug", "outcome", "timestamp_us"]) yield df, batch def _arrow_rg_to_polars(tbl) -> "pl.DataFrame": """Convert an arrow row-group Table to a polars DataFrame with the right dtypes: prices → Float32, sizes → Float64 (strings in storage).""" df = pl.from_arrow(tbl) casts = [] for c in _OB_PX_COLS: if c in df.columns: casts.append(pl.col(c).cast(pl.Float32, strict=False).alias(c)) for c in _OB_SZ_COLS: if c in df.columns: casts.append(pl.col(c).cast(pl.Float64, strict=False).alias(c)) if casts: df = df.with_columns(casts) return df def iter_orderbook_slug_pairs( asset: str, hf_token: str, cache_dir: Path, wanted_slugs: Iterable[str], ): """Stream (slug, ob_up, ob_dn) tuples directly from parquet row groups. The seeder wrote each (slug, outcome) intermediate via a single `ParquetWriter.write_table()` call → each row group in the final parquet contains exactly one (slug, outcome) pair. We iterate row groups in file order, grouping Down+Up pairs per slug, and yield only slugs in `wanted_slugs`. Peak memory: ~2 row groups (~5 MB for BTC) regardless of asset size. Works for the BTC 37 GB parquet on a 32 GB Space. """ import pyarrow as pa import pyarrow.parquet as pq asset = asset.lower() local = _orderbook_local_path(asset, hf_token, cache_dir) wanted = set(wanted_slugs) if not wanted: return pf = pq.ParquetFile(str(local)) avail_cols = pf.schema.names cols = [c for c in _OB_BASE_COLS + _OB_PX_COLS + _OB_SZ_COLS if c in avail_cols] current_slug: Optional[str] = None ob_up_tbls: list = [] ob_dn_tbls: list = [] def _emit(slug, up_tbls, dn_tbls): if slug not in wanted: return None if up_tbls: up_tbl = up_tbls[0] if len(up_tbls) == 1 else pa.concat_tables(up_tbls) ob_up = _arrow_rg_to_polars(up_tbl).sort("timestamp_us") else: ob_up = pl.DataFrame() if dn_tbls: dn_tbl = dn_tbls[0] if len(dn_tbls) == 1 else pa.concat_tables(dn_tbls) ob_dn = _arrow_rg_to_polars(dn_tbl).sort("timestamp_us") else: ob_dn = pl.DataFrame() return slug, ob_up, ob_dn for rg_idx in range(pf.num_row_groups): rg_tbl = pf.read_row_group(rg_idx, columns=cols) if rg_tbl.num_rows == 0: continue slug_val = rg_tbl.column("slug")[0].as_py() outcome_val = rg_tbl.column("outcome")[0].as_py() if current_slug is None: current_slug = slug_val if slug_val != current_slug: res = _emit(current_slug, ob_up_tbls, ob_dn_tbls) if res is not None: yield res ob_up_tbls = [] ob_dn_tbls = [] current_slug = slug_val if outcome_val == "Up": ob_up_tbls.append(rg_tbl) elif outcome_val == "Down": ob_dn_tbls.append(rg_tbl) if current_slug is not None: res = _emit(current_slug, ob_up_tbls, ob_dn_tbls) if res is not None: yield res def load_orderbook_filtered( asset: str, hf_token: str, cache_dir: Path, slugs: Iterable[str], ) -> pl.DataFrame: """Materialize the entire filtered orderbook into memory. WARNING: for large assets (BTC ~37 GB parquet), this will OOM a 32 GB Space. Prefer `iter_orderbook_batches` for any production use. Kept as a thin convenience wrapper for small asset tests / partial runs. """ local = _orderbook_local_path(asset, hf_token, cache_dir) slug_list = list(slugs) lf = _orderbook_lazy(local, slug_list) df = lf.collect() if len(df) > 0: df = df.sort(["slug", "outcome", "timestamp_us"]) return df # --------------------------------------------------------------------------- # Window frame builder: 900 rows per window # --------------------------------------------------------------------------- _OB_RENAME_MAP = { # 0->1, 1->2, ... 4->5 ; spec wants 1..5 **{f"bid_price_{i}": f"bid_px_{i+1}" for i in range(5)}, **{f"bid_size_{i}": f"bid_sz_{i+1}" for i in range(5)}, **{f"ask_price_{i}": f"ask_px_{i+1}" for i in range(5)}, **{f"ask_size_{i}": f"ask_sz_{i+1}" for i in range(5)}, } def _forward_fill_side_to_900( side_df: pl.DataFrame, slug_ts: int, prefix: str ) -> pd.DataFrame: """Forward-fill a single outcome's snapshots onto a 900-row per-second grid. Rule: for each tick t in 0..899, pick the latest snapshot with `timestamp_us <= (slug_ts + t + 1) * 1e6 - 1` (snapshot as of end of second). If no snapshot exists before tick 0, reuse the earliest available snapshot. If none at all, rows are NaN for this side. """ snap_cols = list(_OB_RENAME_MAP.values()) out_cols = [f"{prefix}_{c}" for c in snap_cols] grid = pd.DataFrame({"tick": np.arange(900, dtype=np.int64)}) for c in out_cols: grid[c] = np.nan if side_df.is_empty(): return grid pdf = side_df.to_pandas() # normalize column names to px_1..5 / sz_1..5 pdf = pdf.rename(columns=_OB_RENAME_MAP) # boundary timestamps: end-of-second in μs = (slug_ts + t + 1)*1e6 - 1 boundaries = ((slug_ts + np.arange(900, dtype=np.int64) + 1) * 1_000_000) - 1 ts = pdf["timestamp_us"].to_numpy() # idx = number of snapshots with ts <= boundary (i.e. last valid index = idx-1) idx = np.searchsorted(ts, boundaries, side="right") - 1 # If nothing before the first boundary, fall back to the earliest snapshot. if len(ts) > 0: idx = np.where(idx < 0, 0, idx) values = pdf[snap_cols].to_numpy() if len(ts) > 0: picked = values[idx] for j, c in enumerate(out_cols): grid[c] = picked[:, j] return grid def build_window_frame( slug: str, slug_ts: int, ob_up: pl.DataFrame, ob_dn: pl.DataFrame, ohlcv: pl.DataFrame, ) -> pd.DataFrame: """Build a 900-row pandas DF for a single 15-minute window. Columns produced: tick (0..899) open, high, low, close, volume, trades pm_up_bid_px_{1..5}, pm_up_bid_sz_{1..5}, pm_up_ask_px_{1..5}, pm_up_ask_sz_{1..5} pm_dn_bid_px_{1..5}, pm_dn_bid_sz_{1..5}, pm_dn_ask_px_{1..5}, pm_dn_ask_sz_{1..5} """ # OHLCV slice for this window — join on open_time_ms == (slug_ts+t)*1000 win_ms_start = slug_ts * 1000 win_ms_end = (slug_ts + 900) * 1000 # exclusive oh = ohlcv.filter( (pl.col("open_time") >= win_ms_start) & (pl.col("open_time") < win_ms_end) ).to_pandas() grid = pd.DataFrame({"tick": np.arange(900, dtype=np.int64)}) grid["open_time"] = (slug_ts + grid["tick"].to_numpy()) * 1000 oh_cols_wanted = ["open", "high", "low", "close", "volume", "trades"] if not oh.empty: oh_small = oh[["open_time"] + [c for c in oh_cols_wanted if c in oh.columns]] grid = grid.merge(oh_small, on="open_time", how="left") for c in oh_cols_wanted: if c not in grid.columns: grid[c] = np.nan # Orderbook — forward-filled per side up_grid = _forward_fill_side_to_900(ob_up, slug_ts, "pm_up") dn_grid = _forward_fill_side_to_900(ob_dn, slug_ts, "pm_dn") grid = grid.merge(up_grid, on="tick", how="left") grid = grid.merge(dn_grid, on="tick", how="left") grid = grid.drop(columns=["open_time"]) grid["slug"] = slug grid["slug_ts"] = slug_ts return grid # --------------------------------------------------------------------------- # Label (spot outcome, not arb pnl) # --------------------------------------------------------------------------- def get_window_label(slug_ts: int, ohlcv: pl.DataFrame) -> Optional[int]: """Return 1 if close_last > open_first over this 15-min window, else 0. Returns None if either the first- or last-second kline is missing. """ open_ms = slug_ts * 1000 close_ms = (slug_ts + 899) * 1000 first = ohlcv.filter(pl.col("open_time") == open_ms).select("open").to_series() last = ohlcv.filter(pl.col("open_time") == close_ms).select("close").to_series() if len(first) == 0 or len(last) == 0: return None o = float(first[0]) c = float(last[0]) return int(c > o) __all__ = [ "DATASET_REPO_ID", "_download_once", "load_markets_index", "load_ohlcv", "load_orderbook_filtered", "build_window_frame", "get_window_label", ]