Spaces:
Sleeping
Sleeping
stream orderbook row-groups via pyarrow (fix OOM, peak 5MB/slug)
Browse files- data_loader.py +96 -4
- train.py +45 -50
data_loader.py
CHANGED
|
@@ -199,10 +199,9 @@ def iter_orderbook_batches(
|
|
| 199 |
slugs: Iterable[str],
|
| 200 |
batch_size: int = 500,
|
| 201 |
):
|
| 202 |
-
"""
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
loading the full asset parquet (~37 GB for BTC) into RAM.
|
| 206 |
"""
|
| 207 |
asset = asset.lower()
|
| 208 |
local = _orderbook_local_path(asset, hf_token, cache_dir)
|
|
@@ -216,6 +215,99 @@ def iter_orderbook_batches(
|
|
| 216 |
yield df, batch
|
| 217 |
|
| 218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
def load_orderbook_filtered(
|
| 220 |
asset: str,
|
| 221 |
hf_token: str,
|
|
|
|
| 199 |
slugs: Iterable[str],
|
| 200 |
batch_size: int = 500,
|
| 201 |
):
|
| 202 |
+
"""DEPRECATED: polars scan-filter-collect reads the full 37 GB parquet even
|
| 203 |
+
when filtering to a small slug list (is_in doesn't do row-group pushdown).
|
| 204 |
+
Kept for backwards-compat callers; use `iter_orderbook_slug_pairs` instead.
|
|
|
|
| 205 |
"""
|
| 206 |
asset = asset.lower()
|
| 207 |
local = _orderbook_local_path(asset, hf_token, cache_dir)
|
|
|
|
| 215 |
yield df, batch
|
| 216 |
|
| 217 |
|
| 218 |
+
def _arrow_rg_to_polars(tbl) -> "pl.DataFrame":
|
| 219 |
+
"""Convert an arrow row-group Table to a polars DataFrame with the right
|
| 220 |
+
dtypes: prices → Float32, sizes → Float64 (strings in storage)."""
|
| 221 |
+
df = pl.from_arrow(tbl)
|
| 222 |
+
casts = []
|
| 223 |
+
for c in _OB_PX_COLS:
|
| 224 |
+
if c in df.columns:
|
| 225 |
+
casts.append(pl.col(c).cast(pl.Float32, strict=False).alias(c))
|
| 226 |
+
for c in _OB_SZ_COLS:
|
| 227 |
+
if c in df.columns:
|
| 228 |
+
casts.append(pl.col(c).cast(pl.Float64, strict=False).alias(c))
|
| 229 |
+
if casts:
|
| 230 |
+
df = df.with_columns(casts)
|
| 231 |
+
return df
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def iter_orderbook_slug_pairs(
|
| 235 |
+
asset: str,
|
| 236 |
+
hf_token: str,
|
| 237 |
+
cache_dir: Path,
|
| 238 |
+
wanted_slugs: Iterable[str],
|
| 239 |
+
):
|
| 240 |
+
"""Stream (slug, ob_up, ob_dn) tuples directly from parquet row groups.
|
| 241 |
+
|
| 242 |
+
The seeder wrote each (slug, outcome) intermediate via a single
|
| 243 |
+
`ParquetWriter.write_table()` call → each row group in the final parquet
|
| 244 |
+
contains exactly one (slug, outcome) pair. We iterate row groups in file
|
| 245 |
+
order, grouping Down+Up pairs per slug, and yield only slugs in
|
| 246 |
+
`wanted_slugs`.
|
| 247 |
+
|
| 248 |
+
Peak memory: ~2 row groups (~5 MB for BTC) regardless of asset size.
|
| 249 |
+
Works for the BTC 37 GB parquet on a 32 GB Space.
|
| 250 |
+
"""
|
| 251 |
+
import pyarrow.parquet as pq
|
| 252 |
+
|
| 253 |
+
asset = asset.lower()
|
| 254 |
+
local = _orderbook_local_path(asset, hf_token, cache_dir)
|
| 255 |
+
wanted = set(wanted_slugs)
|
| 256 |
+
if not wanted:
|
| 257 |
+
return
|
| 258 |
+
|
| 259 |
+
pf = pq.ParquetFile(str(local))
|
| 260 |
+
avail_cols = pf.schema.names
|
| 261 |
+
cols = [c for c in _OB_BASE_COLS + _OB_PX_COLS + _OB_SZ_COLS if c in avail_cols]
|
| 262 |
+
|
| 263 |
+
current_slug: Optional[str] = None
|
| 264 |
+
ob_up_tbls: list = []
|
| 265 |
+
ob_dn_tbls: list = []
|
| 266 |
+
|
| 267 |
+
def _emit(slug, up_tbls, dn_tbls):
|
| 268 |
+
if slug not in wanted:
|
| 269 |
+
return None
|
| 270 |
+
if up_tbls:
|
| 271 |
+
up_tbl = up_tbls[0] if len(up_tbls) == 1 else __import__("pyarrow").concat_tables(up_tbls)
|
| 272 |
+
ob_up = _arrow_rg_to_polars(up_tbl).sort("timestamp_us")
|
| 273 |
+
else:
|
| 274 |
+
ob_up = pl.DataFrame()
|
| 275 |
+
if dn_tbls:
|
| 276 |
+
dn_tbl = dn_tbls[0] if len(dn_tbls) == 1 else __import__("pyarrow").concat_tables(dn_tbls)
|
| 277 |
+
ob_dn = _arrow_rg_to_polars(dn_tbl).sort("timestamp_us")
|
| 278 |
+
else:
|
| 279 |
+
ob_dn = pl.DataFrame()
|
| 280 |
+
return slug, ob_up, ob_dn
|
| 281 |
+
|
| 282 |
+
for rg_idx in range(pf.num_row_groups):
|
| 283 |
+
rg_tbl = pf.read_row_group(rg_idx, columns=cols)
|
| 284 |
+
if rg_tbl.num_rows == 0:
|
| 285 |
+
continue
|
| 286 |
+
slug_val = rg_tbl.column("slug")[0].as_py()
|
| 287 |
+
outcome_val = rg_tbl.column("outcome")[0].as_py()
|
| 288 |
+
|
| 289 |
+
if current_slug is None:
|
| 290 |
+
current_slug = slug_val
|
| 291 |
+
|
| 292 |
+
if slug_val != current_slug:
|
| 293 |
+
res = _emit(current_slug, ob_up_tbls, ob_dn_tbls)
|
| 294 |
+
if res is not None:
|
| 295 |
+
yield res
|
| 296 |
+
ob_up_tbls = []
|
| 297 |
+
ob_dn_tbls = []
|
| 298 |
+
current_slug = slug_val
|
| 299 |
+
|
| 300 |
+
if outcome_val == "Up":
|
| 301 |
+
ob_up_tbls.append(rg_tbl)
|
| 302 |
+
elif outcome_val == "Down":
|
| 303 |
+
ob_dn_tbls.append(rg_tbl)
|
| 304 |
+
|
| 305 |
+
if current_slug is not None:
|
| 306 |
+
res = _emit(current_slug, ob_up_tbls, ob_dn_tbls)
|
| 307 |
+
if res is not None:
|
| 308 |
+
yield res
|
| 309 |
+
|
| 310 |
+
|
| 311 |
def load_orderbook_filtered(
|
| 312 |
asset: str,
|
| 313 |
hf_token: str,
|
train.py
CHANGED
|
@@ -123,7 +123,7 @@ def _build_training_dataset(
|
|
| 123 |
import gc
|
| 124 |
import polars as pl # local import to keep module import-light
|
| 125 |
|
| 126 |
-
from data_loader import
|
| 127 |
|
| 128 |
log(f"[data] loading markets_index for {asset}")
|
| 129 |
markets = load_markets_index(asset, hf_token, cache_dir)
|
|
@@ -135,64 +135,59 @@ def _build_training_dataset(
|
|
| 135 |
slugs = markets["slug"].to_list()
|
| 136 |
slug_ts_list = markets["slug_ts"].to_list()
|
| 137 |
slug_ts_map = dict(zip(slugs, [int(t) for t in slug_ts_list]))
|
|
|
|
| 138 |
|
| 139 |
-
log(f"[data] streaming book_snapshot_5
|
| 140 |
-
f"
|
| 141 |
|
| 142 |
rows: List[Dict] = []
|
| 143 |
built = 0
|
| 144 |
skipped = 0
|
|
|
|
| 145 |
|
| 146 |
-
for
|
| 147 |
-
|
| 148 |
):
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
if
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
elif outcome_val == "Down":
|
| 157 |
-
ob_by_slug_dn[slug_val] = sub_sorted
|
| 158 |
-
|
| 159 |
-
for slug in batch_slugs:
|
| 160 |
-
slug_ts = slug_ts_map[slug]
|
| 161 |
-
try:
|
| 162 |
-
spot = get_window_label(slug_ts, ohlcv)
|
| 163 |
-
if spot is None:
|
| 164 |
-
skipped += 1
|
| 165 |
-
continue
|
| 166 |
-
ob_up = ob_by_slug_up.get(slug, pl.DataFrame())
|
| 167 |
-
ob_dn = ob_by_slug_dn.get(slug, pl.DataFrame())
|
| 168 |
-
wf = build_window_frame(slug, slug_ts, ob_up, ob_dn, ohlcv)
|
| 169 |
-
feats = _extract_all(wf, at_tick=120)
|
| 170 |
-
arb = compute_optimal_arb(wf, spot_label=spot)
|
| 171 |
-
row = {
|
| 172 |
-
"slug": slug,
|
| 173 |
-
"slug_ts": slug_ts,
|
| 174 |
-
"spot_label": spot,
|
| 175 |
-
"optimal_pnl": arb["optimal_pnl"],
|
| 176 |
-
"binary_label": arb["binary_label"],
|
| 177 |
-
}
|
| 178 |
-
for name, v in zip(ALL_FEATURES, feats):
|
| 179 |
-
row[name] = float(v)
|
| 180 |
-
rows.append(row)
|
| 181 |
-
built += 1
|
| 182 |
-
except Exception as e: # noqa: BLE001
|
| 183 |
skipped += 1
|
| 184 |
-
if skipped <= 3:
|
| 185 |
-
import traceback
|
| 186 |
-
log(f"[data] window error slug={slug}: {e!r}\n"
|
| 187 |
-
f"{traceback.format_exc()}")
|
| 188 |
-
else:
|
| 189 |
-
log(f"[data] window error slug={slug}: {e!r}")
|
| 190 |
continue
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
df = pd.DataFrame(rows)
|
| 198 |
if len(df) == 0:
|
|
|
|
| 123 |
import gc
|
| 124 |
import polars as pl # local import to keep module import-light
|
| 125 |
|
| 126 |
+
from data_loader import iter_orderbook_slug_pairs
|
| 127 |
|
| 128 |
log(f"[data] loading markets_index for {asset}")
|
| 129 |
markets = load_markets_index(asset, hf_token, cache_dir)
|
|
|
|
| 135 |
slugs = markets["slug"].to_list()
|
| 136 |
slug_ts_list = markets["slug_ts"].to_list()
|
| 137 |
slug_ts_map = dict(zip(slugs, [int(t) for t in slug_ts_list]))
|
| 138 |
+
wanted = set(slugs)
|
| 139 |
|
| 140 |
+
log(f"[data] streaming book_snapshot_5 row-groups (~{len(slugs)} slugs, "
|
| 141 |
+
f"peak ~5 MB per slug)")
|
| 142 |
|
| 143 |
rows: List[Dict] = []
|
| 144 |
built = 0
|
| 145 |
skipped = 0
|
| 146 |
+
processed = 0
|
| 147 |
|
| 148 |
+
for slug, ob_up, ob_dn in iter_orderbook_slug_pairs(
|
| 149 |
+
asset, hf_token, cache_dir, wanted
|
| 150 |
):
|
| 151 |
+
processed += 1
|
| 152 |
+
slug_ts = slug_ts_map.get(slug)
|
| 153 |
+
if slug_ts is None:
|
| 154 |
+
continue
|
| 155 |
+
try:
|
| 156 |
+
spot = get_window_label(slug_ts, ohlcv)
|
| 157 |
+
if spot is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
skipped += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
continue
|
| 160 |
+
wf = build_window_frame(slug, slug_ts, ob_up, ob_dn, ohlcv)
|
| 161 |
+
feats = _extract_all(wf, at_tick=120)
|
| 162 |
+
arb = compute_optimal_arb(wf, spot_label=spot)
|
| 163 |
+
row = {
|
| 164 |
+
"slug": slug,
|
| 165 |
+
"slug_ts": slug_ts,
|
| 166 |
+
"spot_label": spot,
|
| 167 |
+
"optimal_pnl": arb["optimal_pnl"],
|
| 168 |
+
"binary_label": arb["binary_label"],
|
| 169 |
+
}
|
| 170 |
+
for name, v in zip(ALL_FEATURES, feats):
|
| 171 |
+
row[name] = float(v)
|
| 172 |
+
rows.append(row)
|
| 173 |
+
built += 1
|
| 174 |
+
except Exception as e: # noqa: BLE001
|
| 175 |
+
skipped += 1
|
| 176 |
+
if skipped <= 3:
|
| 177 |
+
import traceback
|
| 178 |
+
log(f"[data] window error slug={slug}: {e!r}\n"
|
| 179 |
+
f"{traceback.format_exc()}")
|
| 180 |
+
else:
|
| 181 |
+
log(f"[data] window error slug={slug}: {e!r}")
|
| 182 |
+
|
| 183 |
+
# free the row-group tables before the next one
|
| 184 |
+
del ob_up, ob_dn
|
| 185 |
+
if processed % 1000 == 0:
|
| 186 |
+
gc.collect()
|
| 187 |
+
log(f"[data] processed={processed} built={built} skipped={skipped}")
|
| 188 |
+
|
| 189 |
+
gc.collect()
|
| 190 |
+
log(f"[data] done streaming. processed={processed} built={built} skipped={skipped}")
|
| 191 |
|
| 192 |
df = pd.DataFrame(rows)
|
| 193 |
if len(df) == 0:
|