bnb-arb-trainer / data_loader.py
commanderzee's picture
clean import in iter_orderbook_slug_pairs
7f66696 verified
"""
Data loading utilities for the ARB-MAX 15-minute trainer Space.
All downloads are idempotent via huggingface_hub.hf_hub_download which
caches under `cache_dir`. Nothing is uploaded or deleted from here.
"""
from __future__ import annotations
import time
from pathlib import Path
from typing import Iterable, List, Optional
import numpy as np
import pandas as pd
import polars as pl
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
DATASET_REPO_ID = "commanderzee/15m-crypto"
# ---------------------------------------------------------------------------
# Idempotent download helper with retry/backoff
# ---------------------------------------------------------------------------
def _download_once(
repo_id: str,
path_in_repo: str,
repo_type: str,
hf_token: str,
cache_dir: Path,
max_attempts: int = 6,
) -> Path:
"""Download a file (idempotent). Returns local path.
Uses hf_hub_download's internal cache — subsequent calls are near-free.
Retries with exponential backoff on 5xx / 429.
"""
cache_dir = Path(cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
attempt = 0
last_err: Optional[Exception] = None
while attempt < max_attempts:
try:
local = hf_hub_download(
repo_id=repo_id,
filename=path_in_repo,
repo_type=repo_type,
token=hf_token,
cache_dir=str(cache_dir),
)
return Path(local)
except HfHubHTTPError as e: # type: ignore[attr-defined]
status = getattr(getattr(e, "response", None), "status_code", None)
if status is not None and (status == 429 or 500 <= status < 600):
last_err = e
sleep_s = min(60.0, 2.0 ** attempt)
time.sleep(sleep_s)
attempt += 1
continue
raise
except Exception as e: # noqa: BLE001
last_err = e
sleep_s = min(60.0, 2.0 ** attempt)
time.sleep(sleep_s)
attempt += 1
raise RuntimeError(
f"Failed to download {repo_id}:{path_in_repo} after {max_attempts} attempts"
) from last_err
# ---------------------------------------------------------------------------
# Markets index
# ---------------------------------------------------------------------------
def load_markets_index(asset: str, hf_token: str, cache_dir: Path) -> pl.DataFrame:
"""Load and filter the markets_index.parquet for a single asset.
Filters:
- slug startswith f"{asset}-updown-15m-"
- book_snapshot_5_from is non-empty / non-null
- status == "resolved"
Also:
- Extracts slug_ts = int(slug.rsplit('-', 1)[1]) as Int64 seconds.
- Asserts the invariant (end_date_us // 1e6 - slug_ts) == 900 for every row.
- Sorts by slug_ts ascending.
"""
asset = asset.lower()
local = _download_once(
DATASET_REPO_ID, "markets_index.parquet", "dataset", hf_token, cache_dir
)
df = pl.read_parquet(str(local))
prefix = f"{asset}-updown-15m-"
df = df.filter(pl.col("slug").str.starts_with(prefix))
if "status" in df.columns:
df = df.filter(pl.col("status") == "resolved")
if "book_snapshot_5_from" in df.columns:
df = df.filter(
pl.col("book_snapshot_5_from").is_not_null()
& (pl.col("book_snapshot_5_from") != "")
)
df = df.with_columns(
pl.col("slug")
.str.split("-")
.list.last()
.cast(pl.Int64)
.alias("slug_ts")
)
# ---- CRITICAL INVARIANT: slug_ts is WINDOW START, window length = 900s ----
end_s_minus_start = (pl.col("end_date_us") // 1_000_000) - pl.col("slug_ts")
check = df.select((end_s_minus_start == 900).all().alias("ok")).item()
assert check, (
"Schema drift: end_date_us//1e6 - slug_ts != 900 for some rows. "
"slug_ts is supposed to be the window START (seconds, UTC); window is "
"[slug_ts, slug_ts+900). Do NOT join [slug_ts-900, slug_ts) — this "
"would shift everything by 15 minutes and silently train on garbage."
)
df = df.sort("slug_ts")
return df
# ---------------------------------------------------------------------------
# OHLCV
# ---------------------------------------------------------------------------
_OHLCV_KEEP = [
"open_time",
"open",
"high",
"low",
"close",
"volume",
"trades",
"quote_volume",
"taker_buy_base",
"taker_buy_quote",
]
def load_ohlcv(asset: str, hf_token: str, cache_dir: Path) -> pl.DataFrame:
"""Load Binance 1s OHLCV klines for the asset."""
asset = asset.lower()
local = _download_once(
DATASET_REPO_ID, f"ohlcv_1s/{asset}.parquet", "dataset", hf_token, cache_dir
)
df = pl.read_parquet(str(local))
keep = [c for c in _OHLCV_KEEP if c in df.columns]
df = df.select(keep)
df = df.sort("open_time")
return df
# ---------------------------------------------------------------------------
# Orderbook (filtered to the slugs actually in the markets frame)
# ---------------------------------------------------------------------------
_OB_BASE_COLS = ["timestamp_us", "slug", "outcome"]
_OB_PX_COLS = [f"bid_price_{i}" for i in range(5)] + [f"ask_price_{i}" for i in range(5)]
_OB_SZ_COLS = [f"bid_size_{i}" for i in range(5)] + [f"ask_size_{i}" for i in range(5)]
def _orderbook_local_path(asset: str, hf_token: str, cache_dir: Path) -> Path:
return _download_once(
DATASET_REPO_ID,
f"book_snapshot_5/{asset.lower()}.parquet",
"dataset",
hf_token,
cache_dir,
)
def _orderbook_lazy(local_path: Path, slug_list: list) -> "pl.LazyFrame":
cols = _OB_BASE_COLS + _OB_PX_COLS + _OB_SZ_COLS
lf = pl.scan_parquet(str(local_path))
avail_cols = lf.collect_schema().names()
cols = [c for c in cols if c in avail_cols]
lf = lf.select(cols).filter(pl.col("slug").is_in(slug_list))
casts = []
for c in _OB_PX_COLS:
if c in cols:
casts.append(pl.col(c).cast(pl.Float32, strict=False).alias(c))
for c in _OB_SZ_COLS:
if c in cols:
casts.append(pl.col(c).cast(pl.Float64, strict=False).alias(c))
if casts:
lf = lf.with_columns(casts)
return lf
def iter_orderbook_batches(
asset: str,
hf_token: str,
cache_dir: Path,
slugs: Iterable[str],
batch_size: int = 500,
):
"""DEPRECATED: polars scan-filter-collect reads the full 37 GB parquet even
when filtering to a small slug list (is_in doesn't do row-group pushdown).
Kept for backwards-compat callers; use `iter_orderbook_slug_pairs` instead.
"""
asset = asset.lower()
local = _orderbook_local_path(asset, hf_token, cache_dir)
slug_list = list(slugs)
for start in range(0, len(slug_list), batch_size):
batch = slug_list[start : start + batch_size]
lf = _orderbook_lazy(local, batch)
df = lf.collect()
if len(df) > 0:
df = df.sort(["slug", "outcome", "timestamp_us"])
yield df, batch
def _arrow_rg_to_polars(tbl) -> "pl.DataFrame":
"""Convert an arrow row-group Table to a polars DataFrame with the right
dtypes: prices → Float32, sizes → Float64 (strings in storage)."""
df = pl.from_arrow(tbl)
casts = []
for c in _OB_PX_COLS:
if c in df.columns:
casts.append(pl.col(c).cast(pl.Float32, strict=False).alias(c))
for c in _OB_SZ_COLS:
if c in df.columns:
casts.append(pl.col(c).cast(pl.Float64, strict=False).alias(c))
if casts:
df = df.with_columns(casts)
return df
def iter_orderbook_slug_pairs(
asset: str,
hf_token: str,
cache_dir: Path,
wanted_slugs: Iterable[str],
):
"""Stream (slug, ob_up, ob_dn) tuples directly from parquet row groups.
The seeder wrote each (slug, outcome) intermediate via a single
`ParquetWriter.write_table()` call → each row group in the final parquet
contains exactly one (slug, outcome) pair. We iterate row groups in file
order, grouping Down+Up pairs per slug, and yield only slugs in
`wanted_slugs`.
Peak memory: ~2 row groups (~5 MB for BTC) regardless of asset size.
Works for the BTC 37 GB parquet on a 32 GB Space.
"""
import pyarrow as pa
import pyarrow.parquet as pq
asset = asset.lower()
local = _orderbook_local_path(asset, hf_token, cache_dir)
wanted = set(wanted_slugs)
if not wanted:
return
pf = pq.ParquetFile(str(local))
avail_cols = pf.schema.names
cols = [c for c in _OB_BASE_COLS + _OB_PX_COLS + _OB_SZ_COLS if c in avail_cols]
current_slug: Optional[str] = None
ob_up_tbls: list = []
ob_dn_tbls: list = []
def _emit(slug, up_tbls, dn_tbls):
if slug not in wanted:
return None
if up_tbls:
up_tbl = up_tbls[0] if len(up_tbls) == 1 else pa.concat_tables(up_tbls)
ob_up = _arrow_rg_to_polars(up_tbl).sort("timestamp_us")
else:
ob_up = pl.DataFrame()
if dn_tbls:
dn_tbl = dn_tbls[0] if len(dn_tbls) == 1 else pa.concat_tables(dn_tbls)
ob_dn = _arrow_rg_to_polars(dn_tbl).sort("timestamp_us")
else:
ob_dn = pl.DataFrame()
return slug, ob_up, ob_dn
for rg_idx in range(pf.num_row_groups):
rg_tbl = pf.read_row_group(rg_idx, columns=cols)
if rg_tbl.num_rows == 0:
continue
slug_val = rg_tbl.column("slug")[0].as_py()
outcome_val = rg_tbl.column("outcome")[0].as_py()
if current_slug is None:
current_slug = slug_val
if slug_val != current_slug:
res = _emit(current_slug, ob_up_tbls, ob_dn_tbls)
if res is not None:
yield res
ob_up_tbls = []
ob_dn_tbls = []
current_slug = slug_val
if outcome_val == "Up":
ob_up_tbls.append(rg_tbl)
elif outcome_val == "Down":
ob_dn_tbls.append(rg_tbl)
if current_slug is not None:
res = _emit(current_slug, ob_up_tbls, ob_dn_tbls)
if res is not None:
yield res
def load_orderbook_filtered(
asset: str,
hf_token: str,
cache_dir: Path,
slugs: Iterable[str],
) -> pl.DataFrame:
"""Materialize the entire filtered orderbook into memory.
WARNING: for large assets (BTC ~37 GB parquet), this will OOM a 32 GB
Space. Prefer `iter_orderbook_batches` for any production use. Kept as
a thin convenience wrapper for small asset tests / partial runs.
"""
local = _orderbook_local_path(asset, hf_token, cache_dir)
slug_list = list(slugs)
lf = _orderbook_lazy(local, slug_list)
df = lf.collect()
if len(df) > 0:
df = df.sort(["slug", "outcome", "timestamp_us"])
return df
# ---------------------------------------------------------------------------
# Window frame builder: 900 rows per window
# ---------------------------------------------------------------------------
_OB_RENAME_MAP = {
# 0->1, 1->2, ... 4->5 ; spec wants 1..5
**{f"bid_price_{i}": f"bid_px_{i+1}" for i in range(5)},
**{f"bid_size_{i}": f"bid_sz_{i+1}" for i in range(5)},
**{f"ask_price_{i}": f"ask_px_{i+1}" for i in range(5)},
**{f"ask_size_{i}": f"ask_sz_{i+1}" for i in range(5)},
}
def _forward_fill_side_to_900(
side_df: pl.DataFrame, slug_ts: int, prefix: str
) -> pd.DataFrame:
"""Forward-fill a single outcome's snapshots onto a 900-row per-second grid.
Rule: for each tick t in 0..899, pick the latest snapshot with
`timestamp_us <= (slug_ts + t + 1) * 1e6 - 1` (snapshot as of end of second).
If no snapshot exists before tick 0, reuse the earliest available snapshot.
If none at all, rows are NaN for this side.
"""
snap_cols = list(_OB_RENAME_MAP.values())
out_cols = [f"{prefix}_{c}" for c in snap_cols]
grid = pd.DataFrame({"tick": np.arange(900, dtype=np.int64)})
for c in out_cols:
grid[c] = np.nan
if side_df.is_empty():
return grid
pdf = side_df.to_pandas()
# normalize column names to px_1..5 / sz_1..5
pdf = pdf.rename(columns=_OB_RENAME_MAP)
# boundary timestamps: end-of-second in μs = (slug_ts + t + 1)*1e6 - 1
boundaries = ((slug_ts + np.arange(900, dtype=np.int64) + 1) * 1_000_000) - 1
ts = pdf["timestamp_us"].to_numpy()
# idx = number of snapshots with ts <= boundary (i.e. last valid index = idx-1)
idx = np.searchsorted(ts, boundaries, side="right") - 1
# If nothing before the first boundary, fall back to the earliest snapshot.
if len(ts) > 0:
idx = np.where(idx < 0, 0, idx)
values = pdf[snap_cols].to_numpy()
if len(ts) > 0:
picked = values[idx]
for j, c in enumerate(out_cols):
grid[c] = picked[:, j]
return grid
def build_window_frame(
slug: str,
slug_ts: int,
ob_up: pl.DataFrame,
ob_dn: pl.DataFrame,
ohlcv: pl.DataFrame,
) -> pd.DataFrame:
"""Build a 900-row pandas DF for a single 15-minute window.
Columns produced:
tick (0..899)
open, high, low, close, volume, trades
pm_up_bid_px_{1..5}, pm_up_bid_sz_{1..5},
pm_up_ask_px_{1..5}, pm_up_ask_sz_{1..5}
pm_dn_bid_px_{1..5}, pm_dn_bid_sz_{1..5},
pm_dn_ask_px_{1..5}, pm_dn_ask_sz_{1..5}
"""
# OHLCV slice for this window — join on open_time_ms == (slug_ts+t)*1000
win_ms_start = slug_ts * 1000
win_ms_end = (slug_ts + 900) * 1000 # exclusive
oh = ohlcv.filter(
(pl.col("open_time") >= win_ms_start) & (pl.col("open_time") < win_ms_end)
).to_pandas()
grid = pd.DataFrame({"tick": np.arange(900, dtype=np.int64)})
grid["open_time"] = (slug_ts + grid["tick"].to_numpy()) * 1000
oh_cols_wanted = ["open", "high", "low", "close", "volume", "trades"]
if not oh.empty:
oh_small = oh[["open_time"] + [c for c in oh_cols_wanted if c in oh.columns]]
grid = grid.merge(oh_small, on="open_time", how="left")
for c in oh_cols_wanted:
if c not in grid.columns:
grid[c] = np.nan
# Orderbook — forward-filled per side
up_grid = _forward_fill_side_to_900(ob_up, slug_ts, "pm_up")
dn_grid = _forward_fill_side_to_900(ob_dn, slug_ts, "pm_dn")
grid = grid.merge(up_grid, on="tick", how="left")
grid = grid.merge(dn_grid, on="tick", how="left")
grid = grid.drop(columns=["open_time"])
grid["slug"] = slug
grid["slug_ts"] = slug_ts
return grid
# ---------------------------------------------------------------------------
# Label (spot outcome, not arb pnl)
# ---------------------------------------------------------------------------
def get_window_label(slug_ts: int, ohlcv: pl.DataFrame) -> Optional[int]:
"""Return 1 if close_last > open_first over this 15-min window, else 0.
Returns None if either the first- or last-second kline is missing.
"""
open_ms = slug_ts * 1000
close_ms = (slug_ts + 899) * 1000
first = ohlcv.filter(pl.col("open_time") == open_ms).select("open").to_series()
last = ohlcv.filter(pl.col("open_time") == close_ms).select("close").to_series()
if len(first) == 0 or len(last) == 0:
return None
o = float(first[0])
c = float(last[0])
return int(c > o)
__all__ = [
"DATASET_REPO_ID",
"_download_once",
"load_markets_index",
"load_ohlcv",
"load_orderbook_filtered",
"build_window_frame",
"get_window_label",
]