Spaces:
Sleeping
Sleeping
| """ | |
| Data loading utilities for the ARB-MAX 15-minute trainer Space. | |
| All downloads are idempotent via huggingface_hub.hf_hub_download which | |
| caches under `cache_dir`. Nothing is uploaded or deleted from here. | |
| """ | |
| from __future__ import annotations | |
| import time | |
| from pathlib import Path | |
| from typing import Iterable, List, Optional | |
| import numpy as np | |
| import pandas as pd | |
| import polars as pl | |
| from huggingface_hub import hf_hub_download | |
| from huggingface_hub.utils import HfHubHTTPError | |
| DATASET_REPO_ID = "commanderzee/15m-crypto" | |
| # --------------------------------------------------------------------------- | |
| # Idempotent download helper with retry/backoff | |
| # --------------------------------------------------------------------------- | |
| def _download_once( | |
| repo_id: str, | |
| path_in_repo: str, | |
| repo_type: str, | |
| hf_token: str, | |
| cache_dir: Path, | |
| max_attempts: int = 6, | |
| ) -> Path: | |
| """Download a file (idempotent). Returns local path. | |
| Uses hf_hub_download's internal cache — subsequent calls are near-free. | |
| Retries with exponential backoff on 5xx / 429. | |
| """ | |
| cache_dir = Path(cache_dir) | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| attempt = 0 | |
| last_err: Optional[Exception] = None | |
| while attempt < max_attempts: | |
| try: | |
| local = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=path_in_repo, | |
| repo_type=repo_type, | |
| token=hf_token, | |
| cache_dir=str(cache_dir), | |
| ) | |
| return Path(local) | |
| except HfHubHTTPError as e: # type: ignore[attr-defined] | |
| status = getattr(getattr(e, "response", None), "status_code", None) | |
| if status is not None and (status == 429 or 500 <= status < 600): | |
| last_err = e | |
| sleep_s = min(60.0, 2.0 ** attempt) | |
| time.sleep(sleep_s) | |
| attempt += 1 | |
| continue | |
| raise | |
| except Exception as e: # noqa: BLE001 | |
| last_err = e | |
| sleep_s = min(60.0, 2.0 ** attempt) | |
| time.sleep(sleep_s) | |
| attempt += 1 | |
| raise RuntimeError( | |
| f"Failed to download {repo_id}:{path_in_repo} after {max_attempts} attempts" | |
| ) from last_err | |
| # --------------------------------------------------------------------------- | |
| # Markets index | |
| # --------------------------------------------------------------------------- | |
| def load_markets_index(asset: str, hf_token: str, cache_dir: Path) -> pl.DataFrame: | |
| """Load and filter the markets_index.parquet for a single asset. | |
| Filters: | |
| - slug startswith f"{asset}-updown-15m-" | |
| - book_snapshot_5_from is non-empty / non-null | |
| - status == "resolved" | |
| Also: | |
| - Extracts slug_ts = int(slug.rsplit('-', 1)[1]) as Int64 seconds. | |
| - Asserts the invariant (end_date_us // 1e6 - slug_ts) == 900 for every row. | |
| - Sorts by slug_ts ascending. | |
| """ | |
| asset = asset.lower() | |
| local = _download_once( | |
| DATASET_REPO_ID, "markets_index.parquet", "dataset", hf_token, cache_dir | |
| ) | |
| df = pl.read_parquet(str(local)) | |
| prefix = f"{asset}-updown-15m-" | |
| df = df.filter(pl.col("slug").str.starts_with(prefix)) | |
| if "status" in df.columns: | |
| df = df.filter(pl.col("status") == "resolved") | |
| if "book_snapshot_5_from" in df.columns: | |
| df = df.filter( | |
| pl.col("book_snapshot_5_from").is_not_null() | |
| & (pl.col("book_snapshot_5_from") != "") | |
| ) | |
| df = df.with_columns( | |
| pl.col("slug") | |
| .str.split("-") | |
| .list.last() | |
| .cast(pl.Int64) | |
| .alias("slug_ts") | |
| ) | |
| # ---- CRITICAL INVARIANT: slug_ts is WINDOW START, window length = 900s ---- | |
| end_s_minus_start = (pl.col("end_date_us") // 1_000_000) - pl.col("slug_ts") | |
| check = df.select((end_s_minus_start == 900).all().alias("ok")).item() | |
| assert check, ( | |
| "Schema drift: end_date_us//1e6 - slug_ts != 900 for some rows. " | |
| "slug_ts is supposed to be the window START (seconds, UTC); window is " | |
| "[slug_ts, slug_ts+900). Do NOT join [slug_ts-900, slug_ts) — this " | |
| "would shift everything by 15 minutes and silently train on garbage." | |
| ) | |
| df = df.sort("slug_ts") | |
| return df | |
| # --------------------------------------------------------------------------- | |
| # OHLCV | |
| # --------------------------------------------------------------------------- | |
| _OHLCV_KEEP = [ | |
| "open_time", | |
| "open", | |
| "high", | |
| "low", | |
| "close", | |
| "volume", | |
| "trades", | |
| "quote_volume", | |
| "taker_buy_base", | |
| "taker_buy_quote", | |
| ] | |
| def load_ohlcv(asset: str, hf_token: str, cache_dir: Path) -> pl.DataFrame: | |
| """Load Binance 1s OHLCV klines for the asset.""" | |
| asset = asset.lower() | |
| local = _download_once( | |
| DATASET_REPO_ID, f"ohlcv_1s/{asset}.parquet", "dataset", hf_token, cache_dir | |
| ) | |
| df = pl.read_parquet(str(local)) | |
| keep = [c for c in _OHLCV_KEEP if c in df.columns] | |
| df = df.select(keep) | |
| df = df.sort("open_time") | |
| return df | |
| # --------------------------------------------------------------------------- | |
| # Orderbook (filtered to the slugs actually in the markets frame) | |
| # --------------------------------------------------------------------------- | |
| _OB_BASE_COLS = ["timestamp_us", "slug", "outcome"] | |
| _OB_PX_COLS = [f"bid_price_{i}" for i in range(5)] + [f"ask_price_{i}" for i in range(5)] | |
| _OB_SZ_COLS = [f"bid_size_{i}" for i in range(5)] + [f"ask_size_{i}" for i in range(5)] | |
| def _orderbook_local_path(asset: str, hf_token: str, cache_dir: Path) -> Path: | |
| return _download_once( | |
| DATASET_REPO_ID, | |
| f"book_snapshot_5/{asset.lower()}.parquet", | |
| "dataset", | |
| hf_token, | |
| cache_dir, | |
| ) | |
| def _orderbook_lazy(local_path: Path, slug_list: list) -> "pl.LazyFrame": | |
| cols = _OB_BASE_COLS + _OB_PX_COLS + _OB_SZ_COLS | |
| lf = pl.scan_parquet(str(local_path)) | |
| avail_cols = lf.collect_schema().names() | |
| cols = [c for c in cols if c in avail_cols] | |
| lf = lf.select(cols).filter(pl.col("slug").is_in(slug_list)) | |
| casts = [] | |
| for c in _OB_PX_COLS: | |
| if c in cols: | |
| casts.append(pl.col(c).cast(pl.Float32, strict=False).alias(c)) | |
| for c in _OB_SZ_COLS: | |
| if c in cols: | |
| casts.append(pl.col(c).cast(pl.Float64, strict=False).alias(c)) | |
| if casts: | |
| lf = lf.with_columns(casts) | |
| return lf | |
| def iter_orderbook_batches( | |
| asset: str, | |
| hf_token: str, | |
| cache_dir: Path, | |
| slugs: Iterable[str], | |
| batch_size: int = 500, | |
| ): | |
| """DEPRECATED: polars scan-filter-collect reads the full 37 GB parquet even | |
| when filtering to a small slug list (is_in doesn't do row-group pushdown). | |
| Kept for backwards-compat callers; use `iter_orderbook_slug_pairs` instead. | |
| """ | |
| asset = asset.lower() | |
| local = _orderbook_local_path(asset, hf_token, cache_dir) | |
| slug_list = list(slugs) | |
| for start in range(0, len(slug_list), batch_size): | |
| batch = slug_list[start : start + batch_size] | |
| lf = _orderbook_lazy(local, batch) | |
| df = lf.collect() | |
| if len(df) > 0: | |
| df = df.sort(["slug", "outcome", "timestamp_us"]) | |
| yield df, batch | |
| def _arrow_rg_to_polars(tbl) -> "pl.DataFrame": | |
| """Convert an arrow row-group Table to a polars DataFrame with the right | |
| dtypes: prices → Float32, sizes → Float64 (strings in storage).""" | |
| df = pl.from_arrow(tbl) | |
| casts = [] | |
| for c in _OB_PX_COLS: | |
| if c in df.columns: | |
| casts.append(pl.col(c).cast(pl.Float32, strict=False).alias(c)) | |
| for c in _OB_SZ_COLS: | |
| if c in df.columns: | |
| casts.append(pl.col(c).cast(pl.Float64, strict=False).alias(c)) | |
| if casts: | |
| df = df.with_columns(casts) | |
| return df | |
| def iter_orderbook_slug_pairs( | |
| asset: str, | |
| hf_token: str, | |
| cache_dir: Path, | |
| wanted_slugs: Iterable[str], | |
| ): | |
| """Stream (slug, ob_up, ob_dn) tuples directly from parquet row groups. | |
| The seeder wrote each (slug, outcome) intermediate via a single | |
| `ParquetWriter.write_table()` call → each row group in the final parquet | |
| contains exactly one (slug, outcome) pair. We iterate row groups in file | |
| order, grouping Down+Up pairs per slug, and yield only slugs in | |
| `wanted_slugs`. | |
| Peak memory: ~2 row groups (~5 MB for BTC) regardless of asset size. | |
| Works for the BTC 37 GB parquet on a 32 GB Space. | |
| """ | |
| import pyarrow as pa | |
| import pyarrow.parquet as pq | |
| asset = asset.lower() | |
| local = _orderbook_local_path(asset, hf_token, cache_dir) | |
| wanted = set(wanted_slugs) | |
| if not wanted: | |
| return | |
| pf = pq.ParquetFile(str(local)) | |
| avail_cols = pf.schema.names | |
| cols = [c for c in _OB_BASE_COLS + _OB_PX_COLS + _OB_SZ_COLS if c in avail_cols] | |
| current_slug: Optional[str] = None | |
| ob_up_tbls: list = [] | |
| ob_dn_tbls: list = [] | |
| def _emit(slug, up_tbls, dn_tbls): | |
| if slug not in wanted: | |
| return None | |
| if up_tbls: | |
| up_tbl = up_tbls[0] if len(up_tbls) == 1 else pa.concat_tables(up_tbls) | |
| ob_up = _arrow_rg_to_polars(up_tbl).sort("timestamp_us") | |
| else: | |
| ob_up = pl.DataFrame() | |
| if dn_tbls: | |
| dn_tbl = dn_tbls[0] if len(dn_tbls) == 1 else pa.concat_tables(dn_tbls) | |
| ob_dn = _arrow_rg_to_polars(dn_tbl).sort("timestamp_us") | |
| else: | |
| ob_dn = pl.DataFrame() | |
| return slug, ob_up, ob_dn | |
| for rg_idx in range(pf.num_row_groups): | |
| rg_tbl = pf.read_row_group(rg_idx, columns=cols) | |
| if rg_tbl.num_rows == 0: | |
| continue | |
| slug_val = rg_tbl.column("slug")[0].as_py() | |
| outcome_val = rg_tbl.column("outcome")[0].as_py() | |
| if current_slug is None: | |
| current_slug = slug_val | |
| if slug_val != current_slug: | |
| res = _emit(current_slug, ob_up_tbls, ob_dn_tbls) | |
| if res is not None: | |
| yield res | |
| ob_up_tbls = [] | |
| ob_dn_tbls = [] | |
| current_slug = slug_val | |
| if outcome_val == "Up": | |
| ob_up_tbls.append(rg_tbl) | |
| elif outcome_val == "Down": | |
| ob_dn_tbls.append(rg_tbl) | |
| if current_slug is not None: | |
| res = _emit(current_slug, ob_up_tbls, ob_dn_tbls) | |
| if res is not None: | |
| yield res | |
| def load_orderbook_filtered( | |
| asset: str, | |
| hf_token: str, | |
| cache_dir: Path, | |
| slugs: Iterable[str], | |
| ) -> pl.DataFrame: | |
| """Materialize the entire filtered orderbook into memory. | |
| WARNING: for large assets (BTC ~37 GB parquet), this will OOM a 32 GB | |
| Space. Prefer `iter_orderbook_batches` for any production use. Kept as | |
| a thin convenience wrapper for small asset tests / partial runs. | |
| """ | |
| local = _orderbook_local_path(asset, hf_token, cache_dir) | |
| slug_list = list(slugs) | |
| lf = _orderbook_lazy(local, slug_list) | |
| df = lf.collect() | |
| if len(df) > 0: | |
| df = df.sort(["slug", "outcome", "timestamp_us"]) | |
| return df | |
| # --------------------------------------------------------------------------- | |
| # Window frame builder: 900 rows per window | |
| # --------------------------------------------------------------------------- | |
| _OB_RENAME_MAP = { | |
| # 0->1, 1->2, ... 4->5 ; spec wants 1..5 | |
| **{f"bid_price_{i}": f"bid_px_{i+1}" for i in range(5)}, | |
| **{f"bid_size_{i}": f"bid_sz_{i+1}" for i in range(5)}, | |
| **{f"ask_price_{i}": f"ask_px_{i+1}" for i in range(5)}, | |
| **{f"ask_size_{i}": f"ask_sz_{i+1}" for i in range(5)}, | |
| } | |
| def _forward_fill_side_to_900( | |
| side_df: pl.DataFrame, slug_ts: int, prefix: str | |
| ) -> pd.DataFrame: | |
| """Forward-fill a single outcome's snapshots onto a 900-row per-second grid. | |
| Rule: for each tick t in 0..899, pick the latest snapshot with | |
| `timestamp_us <= (slug_ts + t + 1) * 1e6 - 1` (snapshot as of end of second). | |
| If no snapshot exists before tick 0, reuse the earliest available snapshot. | |
| If none at all, rows are NaN for this side. | |
| """ | |
| snap_cols = list(_OB_RENAME_MAP.values()) | |
| out_cols = [f"{prefix}_{c}" for c in snap_cols] | |
| grid = pd.DataFrame({"tick": np.arange(900, dtype=np.int64)}) | |
| for c in out_cols: | |
| grid[c] = np.nan | |
| if side_df.is_empty(): | |
| return grid | |
| pdf = side_df.to_pandas() | |
| # normalize column names to px_1..5 / sz_1..5 | |
| pdf = pdf.rename(columns=_OB_RENAME_MAP) | |
| # boundary timestamps: end-of-second in μs = (slug_ts + t + 1)*1e6 - 1 | |
| boundaries = ((slug_ts + np.arange(900, dtype=np.int64) + 1) * 1_000_000) - 1 | |
| ts = pdf["timestamp_us"].to_numpy() | |
| # idx = number of snapshots with ts <= boundary (i.e. last valid index = idx-1) | |
| idx = np.searchsorted(ts, boundaries, side="right") - 1 | |
| # If nothing before the first boundary, fall back to the earliest snapshot. | |
| if len(ts) > 0: | |
| idx = np.where(idx < 0, 0, idx) | |
| values = pdf[snap_cols].to_numpy() | |
| if len(ts) > 0: | |
| picked = values[idx] | |
| for j, c in enumerate(out_cols): | |
| grid[c] = picked[:, j] | |
| return grid | |
| def build_window_frame( | |
| slug: str, | |
| slug_ts: int, | |
| ob_up: pl.DataFrame, | |
| ob_dn: pl.DataFrame, | |
| ohlcv: pl.DataFrame, | |
| ) -> pd.DataFrame: | |
| """Build a 900-row pandas DF for a single 15-minute window. | |
| Columns produced: | |
| tick (0..899) | |
| open, high, low, close, volume, trades | |
| pm_up_bid_px_{1..5}, pm_up_bid_sz_{1..5}, | |
| pm_up_ask_px_{1..5}, pm_up_ask_sz_{1..5} | |
| pm_dn_bid_px_{1..5}, pm_dn_bid_sz_{1..5}, | |
| pm_dn_ask_px_{1..5}, pm_dn_ask_sz_{1..5} | |
| """ | |
| # OHLCV slice for this window — join on open_time_ms == (slug_ts+t)*1000 | |
| win_ms_start = slug_ts * 1000 | |
| win_ms_end = (slug_ts + 900) * 1000 # exclusive | |
| oh = ohlcv.filter( | |
| (pl.col("open_time") >= win_ms_start) & (pl.col("open_time") < win_ms_end) | |
| ).to_pandas() | |
| grid = pd.DataFrame({"tick": np.arange(900, dtype=np.int64)}) | |
| grid["open_time"] = (slug_ts + grid["tick"].to_numpy()) * 1000 | |
| oh_cols_wanted = ["open", "high", "low", "close", "volume", "trades"] | |
| if not oh.empty: | |
| oh_small = oh[["open_time"] + [c for c in oh_cols_wanted if c in oh.columns]] | |
| grid = grid.merge(oh_small, on="open_time", how="left") | |
| for c in oh_cols_wanted: | |
| if c not in grid.columns: | |
| grid[c] = np.nan | |
| # Orderbook — forward-filled per side | |
| up_grid = _forward_fill_side_to_900(ob_up, slug_ts, "pm_up") | |
| dn_grid = _forward_fill_side_to_900(ob_dn, slug_ts, "pm_dn") | |
| grid = grid.merge(up_grid, on="tick", how="left") | |
| grid = grid.merge(dn_grid, on="tick", how="left") | |
| grid = grid.drop(columns=["open_time"]) | |
| grid["slug"] = slug | |
| grid["slug_ts"] = slug_ts | |
| return grid | |
| # --------------------------------------------------------------------------- | |
| # Label (spot outcome, not arb pnl) | |
| # --------------------------------------------------------------------------- | |
| def get_window_label(slug_ts: int, ohlcv: pl.DataFrame) -> Optional[int]: | |
| """Return 1 if close_last > open_first over this 15-min window, else 0. | |
| Returns None if either the first- or last-second kline is missing. | |
| """ | |
| open_ms = slug_ts * 1000 | |
| close_ms = (slug_ts + 899) * 1000 | |
| first = ohlcv.filter(pl.col("open_time") == open_ms).select("open").to_series() | |
| last = ohlcv.filter(pl.col("open_time") == close_ms).select("close").to_series() | |
| if len(first) == 0 or len(last) == 0: | |
| return None | |
| o = float(first[0]) | |
| c = float(last[0]) | |
| return int(c > o) | |
| __all__ = [ | |
| "DATASET_REPO_ID", | |
| "_download_once", | |
| "load_markets_index", | |
| "load_ohlcv", | |
| "load_orderbook_filtered", | |
| "build_window_frame", | |
| "get_window_label", | |
| ] | |