commanderzee commited on
Commit
8be5dd2
·
verified ·
1 Parent(s): 3df0dd9

stream orderbook row-groups via pyarrow (fix OOM, peak 5MB/slug)

Browse files
Files changed (2) hide show
  1. data_loader.py +96 -4
  2. train.py +45 -50
data_loader.py CHANGED
@@ -199,10 +199,9 @@ def iter_orderbook_batches(
199
  slugs: Iterable[str],
200
  batch_size: int = 500,
201
  ):
202
- """Yield polars DataFrames, each containing orderbook rows for up to
203
- `batch_size` slugs. Relies on parquet row-group pushdown of the slug
204
- filter keeps peak memory at O(batch_size * per_slug_bytes) instead of
205
- loading the full asset parquet (~37 GB for BTC) into RAM.
206
  """
207
  asset = asset.lower()
208
  local = _orderbook_local_path(asset, hf_token, cache_dir)
@@ -216,6 +215,99 @@ def iter_orderbook_batches(
216
  yield df, batch
217
 
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  def load_orderbook_filtered(
220
  asset: str,
221
  hf_token: str,
 
199
  slugs: Iterable[str],
200
  batch_size: int = 500,
201
  ):
202
+ """DEPRECATED: polars scan-filter-collect reads the full 37 GB parquet even
203
+ when filtering to a small slug list (is_in doesn't do row-group pushdown).
204
+ Kept for backwards-compat callers; use `iter_orderbook_slug_pairs` instead.
 
205
  """
206
  asset = asset.lower()
207
  local = _orderbook_local_path(asset, hf_token, cache_dir)
 
215
  yield df, batch
216
 
217
 
218
+ def _arrow_rg_to_polars(tbl) -> "pl.DataFrame":
219
+ """Convert an arrow row-group Table to a polars DataFrame with the right
220
+ dtypes: prices → Float32, sizes → Float64 (strings in storage)."""
221
+ df = pl.from_arrow(tbl)
222
+ casts = []
223
+ for c in _OB_PX_COLS:
224
+ if c in df.columns:
225
+ casts.append(pl.col(c).cast(pl.Float32, strict=False).alias(c))
226
+ for c in _OB_SZ_COLS:
227
+ if c in df.columns:
228
+ casts.append(pl.col(c).cast(pl.Float64, strict=False).alias(c))
229
+ if casts:
230
+ df = df.with_columns(casts)
231
+ return df
232
+
233
+
234
+ def iter_orderbook_slug_pairs(
235
+ asset: str,
236
+ hf_token: str,
237
+ cache_dir: Path,
238
+ wanted_slugs: Iterable[str],
239
+ ):
240
+ """Stream (slug, ob_up, ob_dn) tuples directly from parquet row groups.
241
+
242
+ The seeder wrote each (slug, outcome) intermediate via a single
243
+ `ParquetWriter.write_table()` call → each row group in the final parquet
244
+ contains exactly one (slug, outcome) pair. We iterate row groups in file
245
+ order, grouping Down+Up pairs per slug, and yield only slugs in
246
+ `wanted_slugs`.
247
+
248
+ Peak memory: ~2 row groups (~5 MB for BTC) regardless of asset size.
249
+ Works for the BTC 37 GB parquet on a 32 GB Space.
250
+ """
251
+ import pyarrow.parquet as pq
252
+
253
+ asset = asset.lower()
254
+ local = _orderbook_local_path(asset, hf_token, cache_dir)
255
+ wanted = set(wanted_slugs)
256
+ if not wanted:
257
+ return
258
+
259
+ pf = pq.ParquetFile(str(local))
260
+ avail_cols = pf.schema.names
261
+ cols = [c for c in _OB_BASE_COLS + _OB_PX_COLS + _OB_SZ_COLS if c in avail_cols]
262
+
263
+ current_slug: Optional[str] = None
264
+ ob_up_tbls: list = []
265
+ ob_dn_tbls: list = []
266
+
267
+ def _emit(slug, up_tbls, dn_tbls):
268
+ if slug not in wanted:
269
+ return None
270
+ if up_tbls:
271
+ up_tbl = up_tbls[0] if len(up_tbls) == 1 else __import__("pyarrow").concat_tables(up_tbls)
272
+ ob_up = _arrow_rg_to_polars(up_tbl).sort("timestamp_us")
273
+ else:
274
+ ob_up = pl.DataFrame()
275
+ if dn_tbls:
276
+ dn_tbl = dn_tbls[0] if len(dn_tbls) == 1 else __import__("pyarrow").concat_tables(dn_tbls)
277
+ ob_dn = _arrow_rg_to_polars(dn_tbl).sort("timestamp_us")
278
+ else:
279
+ ob_dn = pl.DataFrame()
280
+ return slug, ob_up, ob_dn
281
+
282
+ for rg_idx in range(pf.num_row_groups):
283
+ rg_tbl = pf.read_row_group(rg_idx, columns=cols)
284
+ if rg_tbl.num_rows == 0:
285
+ continue
286
+ slug_val = rg_tbl.column("slug")[0].as_py()
287
+ outcome_val = rg_tbl.column("outcome")[0].as_py()
288
+
289
+ if current_slug is None:
290
+ current_slug = slug_val
291
+
292
+ if slug_val != current_slug:
293
+ res = _emit(current_slug, ob_up_tbls, ob_dn_tbls)
294
+ if res is not None:
295
+ yield res
296
+ ob_up_tbls = []
297
+ ob_dn_tbls = []
298
+ current_slug = slug_val
299
+
300
+ if outcome_val == "Up":
301
+ ob_up_tbls.append(rg_tbl)
302
+ elif outcome_val == "Down":
303
+ ob_dn_tbls.append(rg_tbl)
304
+
305
+ if current_slug is not None:
306
+ res = _emit(current_slug, ob_up_tbls, ob_dn_tbls)
307
+ if res is not None:
308
+ yield res
309
+
310
+
311
  def load_orderbook_filtered(
312
  asset: str,
313
  hf_token: str,
train.py CHANGED
@@ -123,7 +123,7 @@ def _build_training_dataset(
123
  import gc
124
  import polars as pl # local import to keep module import-light
125
 
126
- from data_loader import iter_orderbook_batches
127
 
128
  log(f"[data] loading markets_index for {asset}")
129
  markets = load_markets_index(asset, hf_token, cache_dir)
@@ -135,64 +135,59 @@ def _build_training_dataset(
135
  slugs = markets["slug"].to_list()
136
  slug_ts_list = markets["slug_ts"].to_list()
137
  slug_ts_map = dict(zip(slugs, [int(t) for t in slug_ts_list]))
 
138
 
139
- log(f"[data] streaming book_snapshot_5 in batches of {ob_batch_size} "
140
- f"(~{(len(slugs) + ob_batch_size - 1) // ob_batch_size} batches)")
141
 
142
  rows: List[Dict] = []
143
  built = 0
144
  skipped = 0
 
145
 
146
- for batch_idx, (ob_batch, batch_slugs) in enumerate(
147
- iter_orderbook_batches(asset, hf_token, cache_dir, slugs, batch_size=ob_batch_size)
148
  ):
149
- ob_by_slug_up: Dict[str, pl.DataFrame] = {}
150
- ob_by_slug_dn: Dict[str, pl.DataFrame] = {}
151
- if len(ob_batch) > 0:
152
- for (slug_val, outcome_val), sub in ob_batch.group_by(["slug", "outcome"]):
153
- sub_sorted = sub.sort("timestamp_us")
154
- if outcome_val == "Up":
155
- ob_by_slug_up[slug_val] = sub_sorted
156
- elif outcome_val == "Down":
157
- ob_by_slug_dn[slug_val] = sub_sorted
158
-
159
- for slug in batch_slugs:
160
- slug_ts = slug_ts_map[slug]
161
- try:
162
- spot = get_window_label(slug_ts, ohlcv)
163
- if spot is None:
164
- skipped += 1
165
- continue
166
- ob_up = ob_by_slug_up.get(slug, pl.DataFrame())
167
- ob_dn = ob_by_slug_dn.get(slug, pl.DataFrame())
168
- wf = build_window_frame(slug, slug_ts, ob_up, ob_dn, ohlcv)
169
- feats = _extract_all(wf, at_tick=120)
170
- arb = compute_optimal_arb(wf, spot_label=spot)
171
- row = {
172
- "slug": slug,
173
- "slug_ts": slug_ts,
174
- "spot_label": spot,
175
- "optimal_pnl": arb["optimal_pnl"],
176
- "binary_label": arb["binary_label"],
177
- }
178
- for name, v in zip(ALL_FEATURES, feats):
179
- row[name] = float(v)
180
- rows.append(row)
181
- built += 1
182
- except Exception as e: # noqa: BLE001
183
  skipped += 1
184
- if skipped <= 3:
185
- import traceback
186
- log(f"[data] window error slug={slug}: {e!r}\n"
187
- f"{traceback.format_exc()}")
188
- else:
189
- log(f"[data] window error slug={slug}: {e!r}")
190
  continue
191
-
192
- # free polars memory before advancing to next batch
193
- del ob_batch, ob_by_slug_up, ob_by_slug_dn
194
- gc.collect()
195
- log(f"[data] batch {batch_idx + 1} done; built={built} skipped={skipped}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  df = pd.DataFrame(rows)
198
  if len(df) == 0:
 
123
  import gc
124
  import polars as pl # local import to keep module import-light
125
 
126
+ from data_loader import iter_orderbook_slug_pairs
127
 
128
  log(f"[data] loading markets_index for {asset}")
129
  markets = load_markets_index(asset, hf_token, cache_dir)
 
135
  slugs = markets["slug"].to_list()
136
  slug_ts_list = markets["slug_ts"].to_list()
137
  slug_ts_map = dict(zip(slugs, [int(t) for t in slug_ts_list]))
138
+ wanted = set(slugs)
139
 
140
+ log(f"[data] streaming book_snapshot_5 row-groups (~{len(slugs)} slugs, "
141
+ f"peak ~5 MB per slug)")
142
 
143
  rows: List[Dict] = []
144
  built = 0
145
  skipped = 0
146
+ processed = 0
147
 
148
+ for slug, ob_up, ob_dn in iter_orderbook_slug_pairs(
149
+ asset, hf_token, cache_dir, wanted
150
  ):
151
+ processed += 1
152
+ slug_ts = slug_ts_map.get(slug)
153
+ if slug_ts is None:
154
+ continue
155
+ try:
156
+ spot = get_window_label(slug_ts, ohlcv)
157
+ if spot is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  skipped += 1
 
 
 
 
 
 
159
  continue
160
+ wf = build_window_frame(slug, slug_ts, ob_up, ob_dn, ohlcv)
161
+ feats = _extract_all(wf, at_tick=120)
162
+ arb = compute_optimal_arb(wf, spot_label=spot)
163
+ row = {
164
+ "slug": slug,
165
+ "slug_ts": slug_ts,
166
+ "spot_label": spot,
167
+ "optimal_pnl": arb["optimal_pnl"],
168
+ "binary_label": arb["binary_label"],
169
+ }
170
+ for name, v in zip(ALL_FEATURES, feats):
171
+ row[name] = float(v)
172
+ rows.append(row)
173
+ built += 1
174
+ except Exception as e: # noqa: BLE001
175
+ skipped += 1
176
+ if skipped <= 3:
177
+ import traceback
178
+ log(f"[data] window error slug={slug}: {e!r}\n"
179
+ f"{traceback.format_exc()}")
180
+ else:
181
+ log(f"[data] window error slug={slug}: {e!r}")
182
+
183
+ # free the row-group tables before the next one
184
+ del ob_up, ob_dn
185
+ if processed % 1000 == 0:
186
+ gc.collect()
187
+ log(f"[data] processed={processed} built={built} skipped={skipped}")
188
+
189
+ gc.collect()
190
+ log(f"[data] done streaming. processed={processed} built={built} skipped={skipped}")
191
 
192
  df = pd.DataFrame(rows)
193
  if len(df) == 0: