import pandas as pd import numpy as np import torch import ast import os import json from typing import Iterator, Tuple from datasets import load_dataset from data_processor import AlphaDataProcessor import gc class StreamingDataLoader: """ Streams training data directly from HuggingFace Datasets without downloading. Buffers chunks to enable rolling window operations. """ def __init__(self, repo_id: str = "gionuibk/hyperliquid-data", model_type: str = "deeplob", batch_size: int = 32, chunk_size: int = 500, # Reduced to ensure frequent yields buffer_size: int = 200, # Reduced buffer coin: str = "ETH"): # Filter by Symbol (CRITICAL FIX) """ Args: repo_id: HF Dataset ID model_type: 'deeplob' or 'trm' batch_size: Training batch size chunk_size: Rows per processing chunk buffer_size: Overlap size to maintain rolling stats continuity coin: Symbol to filter (e.g. "ETH", "BTC") """ self.repo_id = repo_id self.model_type = model_type self.batch_size = batch_size self.chunk_size = chunk_size self.buffer_size = buffer_size self.coin = coin self.processor = AlphaDataProcessor() def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]: """ Yields batches of (X, y) tensors from the stream. """ print(f"๐Ÿ“ก Connecting to HF Dataset Stream: {self.repo_id} (Filter: {self.coin})") token = os.environ.get("HF_TOKEN") try: # MANUAL LOADING Mode (Bypassing datasets library due to Arrow/Parquet errors) from huggingface_hub import HfApi, hf_hub_download api = HfApi(token=token) # 1. List files # 1. List files with Retry Logic print("๐Ÿ” Listing files...") def list_files_with_retry(retries=3): import time for i in range(retries): try: return api.list_repo_files(repo_id=self.repo_id, repo_type="dataset") except Exception as e: if i == retries - 1: raise e print(f"โš ๏ธ List files failed (Attempt {i+1}/{retries}). Retrying in 2s... Error: {e}") time.sleep(2) files = list_files_with_retry() if self.model_type == "lstm": # Use Bar Data for LSTM (Support both v1 'data/bar/' and v2 'data/candles/') target_files = [ f for f in files if (f.startswith("data/bar/") or f.startswith("data/candles/")) and f.endswith(".parquet") and self.coin in f # STRICT FILTER ] print(f"๐Ÿ“‚ Found {len(target_files)} Bar/Candle files for LSTM (Symbol: {self.coin}).") else: # Use L2 Snapshots for DeepLOB/TRM (Support both v1 'order_book_snapshot' and v2 'l2book') target_files = [ f for f in files if ("order_book_snapshot" in f or "l2book" in f) and f.endswith(".parquet") and self.coin in f # STRICT FILTER ] print(f"๐Ÿ“‚ Found {len(target_files)} Snapshot/L2Book files for {self.model_type} (Symbol: {self.coin}).") if not target_files: raise RuntimeError(f"No valid training files found for {self.model_type} in {self.repo_id}") # Buffer for rolling operations buffer_df = pd.DataFrame() chunk_rows = [] total_loaded_rows = 0 for file_path in target_files: try: print(f"โฌ‡๏ธ Downloading {file_path}...") # Download to temp dir to avoid cache filling temp_dir = "./temp_data" os.makedirs(temp_dir, exist_ok=True) local_path = hf_hub_download( repo_id=self.repo_id, filename=file_path, repo_type="dataset", token=token, local_dir=temp_dir, local_dir_use_symlinks=False, force_download=True # Ensure we have a fresh copy to delete later ) print(f"๐Ÿ“– Reading {file_path}...") # Read parquet directly using pandas (robust) try: df = pd.read_parquet(local_path) except BaseException as e: # Catch EVERYTHING including OSError print(f"โš ๏ธ Parquet Read Failed for {file_path}: {e}") continue rows_in_file = len(df) print(f"โœ… Loaded {rows_in_file} rows from {file_path}") total_loaded_rows += rows_in_file total_loaded_rows += rows_in_file # Iterate rows in the dataframe for i, row in df.iterrows(): # Parse L2 columns (Support both nested lists and flat columns) if 'bids' in row and isinstance(row['bids'], str): try: row['bids'] = ast.literal_eval(row['bids']) except: pass if 'asks' in row and isinstance(row['asks'], str): try: row['asks'] = ast.literal_eval(row['asks']) except: pass # Handle Flat Format (bid_px_1, bid_sz_1, ...) if 'bids' not in row and 'bid_px_1' in row: bids = [] asks = [] for level in range(1, 21): # Support up to 20 levels if f'bid_px_{level}' in row: bids.append([row[f'bid_px_{level}'], row[f'bid_sz_{level}']]) if f'ask_px_{level}' in row: asks.append([row[f'ask_px_{level}'], row[f'ask_sz_{level}']]) row['bids'] = bids row['asks'] = asks # Pandas iterrows returns (index, Series), we want the Series/dict # Append as dict for processing chunk_rows.append(row.to_dict()) if len(chunk_rows) >= self.chunk_size: # Process and yield chunk yield from self._process_chunk(chunk_rows, buffer_df) # Update Buffer from new chunk new_df = pd.DataFrame(chunk_rows) buffer_df = new_df.tail(self.buffer_size) chunk_rows = [] gc.collect() except Exception as e: print(f"โš ๏ธ Failed to process file {file_path}: {e}") finally: # CRITICAL: Clean up file immediately to save disk space if 'local_path' in locals() and os.path.exists(local_path): try: # Verify it's a file before removing (safety) if os.path.isfile(local_path): os.remove(local_path) except: pass # Process remaining rows after all files if len(chunk_rows) > 0: print(f"๐Ÿงน Processing final residual chunk ({len(chunk_rows)} rows)...") yield from self._process_chunk(chunk_rows, buffer_df) except Exception as e: print(f"โš ๏ธ Manual Loading Error: {e}") import traceback traceback.print_exc() def _process_chunk(self, chunk_rows, buffer_df): # Helper to process a chunk and yield batches new_df = pd.DataFrame(chunk_rows) # Merge with buffer (previous context) if not buffer_df.empty: combined_df = pd.concat([buffer_df, new_df]) else: combined_df = new_df # Process if self.model_type == "deeplob": X, y = self.processor.get_deeplob_tensors_from_df(combined_df) elif self.model_type == "trm": X, y = self.processor.get_trm_tensors_from_df(combined_df) elif self.model_type == "lstm": X, y = self.processor.get_lstm_tensors_from_df(combined_df) else: raise ValueError(f"Unknown model type: {self.model_type}") # Yield batches if len(X) > 0: dataset_size = len(X) indices = torch.randperm(dataset_size) X = X[indices] y = y[indices] for k in range(0, dataset_size, self.batch_size): batch_X = X[k:k+self.batch_size] batch_y = y[k:k+self.batch_size] if len(batch_X) == self.batch_size: yield batch_X, batch_y def get_sample_batch(self) -> Tuple[torch.Tensor, torch.Tensor]: for batch_X, batch_y in self: return batch_X, batch_y raise RuntimeError("Stream empty or failed")