NautilusTrainer / streaming_loader.py
gionuibk's picture
Upload streaming_loader.py with huggingface_hub
4b94e4d verified
import pandas as pd
import numpy as np
import torch
import ast
import os
import json
from typing import Iterator, Tuple
from datasets import load_dataset
from data_processor import AlphaDataProcessor
import gc
class StreamingDataLoader:
"""
Streams training data directly from HuggingFace Datasets without downloading.
Buffers chunks to enable rolling window operations.
"""
def __init__(self,
repo_id: str = "gionuibk/hyperliquid-data",
model_type: str = "deeplob",
batch_size: int = 32,
chunk_size: int = 500, # Reduced to ensure frequent yields
buffer_size: int = 200, # Reduced buffer
coin: str = "ETH"): # Filter by Symbol (CRITICAL FIX)
"""
Args:
repo_id: HF Dataset ID
model_type: 'deeplob' or 'trm'
batch_size: Training batch size
chunk_size: Rows per processing chunk
buffer_size: Overlap size to maintain rolling stats continuity
coin: Symbol to filter (e.g. "ETH", "BTC")
"""
self.repo_id = repo_id
self.model_type = model_type
self.batch_size = batch_size
self.chunk_size = chunk_size
self.buffer_size = buffer_size
self.coin = coin
self.processor = AlphaDataProcessor()
def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
"""
Yields batches of (X, y) tensors from the stream.
"""
print(f"📡 Connecting to HF Dataset Stream: {self.repo_id} (Filter: {self.coin})")
token = os.environ.get("HF_TOKEN")
try:
# MANUAL LOADING Mode (Bypassing datasets library due to Arrow/Parquet errors)
from huggingface_hub import HfApi, hf_hub_download
api = HfApi(token=token)
# 1. List files
# 1. List files with Retry Logic
print("🔍 Listing files...")
def list_files_with_retry(retries=3):
import time
for i in range(retries):
try:
return api.list_repo_files(repo_id=self.repo_id, repo_type="dataset")
except Exception as e:
if i == retries - 1: raise e
print(f"⚠️ List files failed (Attempt {i+1}/{retries}). Retrying in 2s... Error: {e}")
time.sleep(2)
files = list_files_with_retry()
if self.model_type == "lstm":
# Use Bar Data for LSTM (Support both v1 'data/bar/' and v2 'data/candles/')
target_files = [
f for f in files
if (f.startswith("data/bar/") or f.startswith("data/candles/"))
and f.endswith(".parquet")
and self.coin in f # STRICT FILTER
]
print(f"📂 Found {len(target_files)} Bar/Candle files for LSTM (Symbol: {self.coin}).")
else:
# Use L2 Snapshots for DeepLOB/TRM (Support both v1 'order_book_snapshot' and v2 'l2book')
target_files = [
f for f in files
if ("order_book_snapshot" in f or "l2book" in f)
and f.endswith(".parquet")
and self.coin in f # STRICT FILTER
]
print(f"📂 Found {len(target_files)} Snapshot/L2Book files for {self.model_type} (Symbol: {self.coin}).")
if not target_files:
raise RuntimeError(f"No valid training files found for {self.model_type} in {self.repo_id}")
# Buffer for rolling operations
buffer_df = pd.DataFrame()
chunk_rows = []
total_loaded_rows = 0
for file_path in target_files:
try:
print(f"⬇️ Downloading {file_path}...")
# Download to temp dir to avoid cache filling
temp_dir = "./temp_data"
os.makedirs(temp_dir, exist_ok=True)
local_path = hf_hub_download(
repo_id=self.repo_id,
filename=file_path,
repo_type="dataset",
token=token,
local_dir=temp_dir,
local_dir_use_symlinks=False,
force_download=True # Ensure we have a fresh copy to delete later
)
print(f"📖 Reading {file_path}...")
# Read parquet directly using pandas (robust)
try:
df = pd.read_parquet(local_path)
except BaseException as e: # Catch EVERYTHING including OSError
print(f"⚠️ Parquet Read Failed for {file_path}: {e}")
continue
rows_in_file = len(df)
print(f"✅ Loaded {rows_in_file} rows from {file_path}")
total_loaded_rows += rows_in_file
total_loaded_rows += rows_in_file
# Iterate rows in the dataframe
for i, row in df.iterrows():
# Parse L2 columns (Support both nested lists and flat columns)
if 'bids' in row and isinstance(row['bids'], str):
try: row['bids'] = ast.literal_eval(row['bids'])
except: pass
if 'asks' in row and isinstance(row['asks'], str):
try: row['asks'] = ast.literal_eval(row['asks'])
except: pass
# Handle Flat Format (bid_px_1, bid_sz_1, ...)
if 'bids' not in row and 'bid_px_1' in row:
bids = []
asks = []
for level in range(1, 21): # Support up to 20 levels
if f'bid_px_{level}' in row:
bids.append([row[f'bid_px_{level}'], row[f'bid_sz_{level}']])
if f'ask_px_{level}' in row:
asks.append([row[f'ask_px_{level}'], row[f'ask_sz_{level}']])
row['bids'] = bids
row['asks'] = asks
# Pandas iterrows returns (index, Series), we want the Series/dict
# Append as dict for processing
chunk_rows.append(row.to_dict())
if len(chunk_rows) >= self.chunk_size:
# Process and yield chunk
yield from self._process_chunk(chunk_rows, buffer_df)
# Update Buffer from new chunk
new_df = pd.DataFrame(chunk_rows)
buffer_df = new_df.tail(self.buffer_size)
chunk_rows = []
gc.collect()
except Exception as e:
print(f"⚠️ Failed to process file {file_path}: {e}")
finally:
# CRITICAL: Clean up file immediately to save disk space
if 'local_path' in locals() and os.path.exists(local_path):
try:
# Verify it's a file before removing (safety)
if os.path.isfile(local_path):
os.remove(local_path)
except: pass
# Process remaining rows after all files
if len(chunk_rows) > 0:
print(f"🧹 Processing final residual chunk ({len(chunk_rows)} rows)...")
yield from self._process_chunk(chunk_rows, buffer_df)
except Exception as e:
print(f"⚠️ Manual Loading Error: {e}")
import traceback
traceback.print_exc()
def _process_chunk(self, chunk_rows, buffer_df):
# Helper to process a chunk and yield batches
new_df = pd.DataFrame(chunk_rows)
# Merge with buffer (previous context)
if not buffer_df.empty:
combined_df = pd.concat([buffer_df, new_df])
else:
combined_df = new_df
# Process
if self.model_type == "deeplob":
X, y = self.processor.get_deeplob_tensors_from_df(combined_df)
elif self.model_type == "trm":
X, y = self.processor.get_trm_tensors_from_df(combined_df)
elif self.model_type == "lstm":
X, y = self.processor.get_lstm_tensors_from_df(combined_df)
else:
raise ValueError(f"Unknown model type: {self.model_type}")
# Yield batches
if len(X) > 0:
dataset_size = len(X)
indices = torch.randperm(dataset_size)
X = X[indices]
y = y[indices]
for k in range(0, dataset_size, self.batch_size):
batch_X = X[k:k+self.batch_size]
batch_y = y[k:k+self.batch_size]
if len(batch_X) == self.batch_size:
yield batch_X, batch_y
def get_sample_batch(self) -> Tuple[torch.Tensor, torch.Tensor]:
for batch_X, batch_y in self:
return batch_X, batch_y
raise RuntimeError("Stream empty or failed")