Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import ast | |
| import os | |
| import json | |
| from typing import Iterator, Tuple | |
| from datasets import load_dataset | |
| from data_processor import AlphaDataProcessor | |
| import gc | |
| class StreamingDataLoader: | |
| """ | |
| Streams training data directly from HuggingFace Datasets without downloading. | |
| Buffers chunks to enable rolling window operations. | |
| """ | |
| def __init__(self, | |
| repo_id: str = "gionuibk/hyperliquid-data", | |
| model_type: str = "deeplob", | |
| batch_size: int = 32, | |
| chunk_size: int = 500, # Reduced to ensure frequent yields | |
| buffer_size: int = 200, # Reduced buffer | |
| coin: str = "ETH"): # Filter by Symbol (CRITICAL FIX) | |
| """ | |
| Args: | |
| repo_id: HF Dataset ID | |
| model_type: 'deeplob' or 'trm' | |
| batch_size: Training batch size | |
| chunk_size: Rows per processing chunk | |
| buffer_size: Overlap size to maintain rolling stats continuity | |
| coin: Symbol to filter (e.g. "ETH", "BTC") | |
| """ | |
| self.repo_id = repo_id | |
| self.model_type = model_type | |
| self.batch_size = batch_size | |
| self.chunk_size = chunk_size | |
| self.buffer_size = buffer_size | |
| self.coin = coin | |
| self.processor = AlphaDataProcessor() | |
| def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]: | |
| """ | |
| Yields batches of (X, y) tensors from the stream. | |
| """ | |
| print(f"📡 Connecting to HF Dataset Stream: {self.repo_id} (Filter: {self.coin})") | |
| token = os.environ.get("HF_TOKEN") | |
| try: | |
| # MANUAL LOADING Mode (Bypassing datasets library due to Arrow/Parquet errors) | |
| from huggingface_hub import HfApi, hf_hub_download | |
| api = HfApi(token=token) | |
| # 1. List files | |
| # 1. List files with Retry Logic | |
| print("🔍 Listing files...") | |
| def list_files_with_retry(retries=3): | |
| import time | |
| for i in range(retries): | |
| try: | |
| return api.list_repo_files(repo_id=self.repo_id, repo_type="dataset") | |
| except Exception as e: | |
| if i == retries - 1: raise e | |
| print(f"⚠️ List files failed (Attempt {i+1}/{retries}). Retrying in 2s... Error: {e}") | |
| time.sleep(2) | |
| files = list_files_with_retry() | |
| if self.model_type == "lstm": | |
| # Use Bar Data for LSTM (Support both v1 'data/bar/' and v2 'data/candles/') | |
| target_files = [ | |
| f for f in files | |
| if (f.startswith("data/bar/") or f.startswith("data/candles/")) | |
| and f.endswith(".parquet") | |
| and self.coin in f # STRICT FILTER | |
| ] | |
| print(f"📂 Found {len(target_files)} Bar/Candle files for LSTM (Symbol: {self.coin}).") | |
| else: | |
| # Use L2 Snapshots for DeepLOB/TRM (Support both v1 'order_book_snapshot' and v2 'l2book') | |
| target_files = [ | |
| f for f in files | |
| if ("order_book_snapshot" in f or "l2book" in f) | |
| and f.endswith(".parquet") | |
| and self.coin in f # STRICT FILTER | |
| ] | |
| print(f"📂 Found {len(target_files)} Snapshot/L2Book files for {self.model_type} (Symbol: {self.coin}).") | |
| if not target_files: | |
| raise RuntimeError(f"No valid training files found for {self.model_type} in {self.repo_id}") | |
| # Buffer for rolling operations | |
| buffer_df = pd.DataFrame() | |
| chunk_rows = [] | |
| total_loaded_rows = 0 | |
| for file_path in target_files: | |
| try: | |
| print(f"⬇️ Downloading {file_path}...") | |
| # Download to temp dir to avoid cache filling | |
| temp_dir = "./temp_data" | |
| os.makedirs(temp_dir, exist_ok=True) | |
| local_path = hf_hub_download( | |
| repo_id=self.repo_id, | |
| filename=file_path, | |
| repo_type="dataset", | |
| token=token, | |
| local_dir=temp_dir, | |
| local_dir_use_symlinks=False, | |
| force_download=True # Ensure we have a fresh copy to delete later | |
| ) | |
| print(f"📖 Reading {file_path}...") | |
| # Read parquet directly using pandas (robust) | |
| try: | |
| df = pd.read_parquet(local_path) | |
| except BaseException as e: # Catch EVERYTHING including OSError | |
| print(f"⚠️ Parquet Read Failed for {file_path}: {e}") | |
| continue | |
| rows_in_file = len(df) | |
| print(f"✅ Loaded {rows_in_file} rows from {file_path}") | |
| total_loaded_rows += rows_in_file | |
| total_loaded_rows += rows_in_file | |
| # Iterate rows in the dataframe | |
| for i, row in df.iterrows(): | |
| # Parse L2 columns (Support both nested lists and flat columns) | |
| if 'bids' in row and isinstance(row['bids'], str): | |
| try: row['bids'] = ast.literal_eval(row['bids']) | |
| except: pass | |
| if 'asks' in row and isinstance(row['asks'], str): | |
| try: row['asks'] = ast.literal_eval(row['asks']) | |
| except: pass | |
| # Handle Flat Format (bid_px_1, bid_sz_1, ...) | |
| if 'bids' not in row and 'bid_px_1' in row: | |
| bids = [] | |
| asks = [] | |
| for level in range(1, 21): # Support up to 20 levels | |
| if f'bid_px_{level}' in row: | |
| bids.append([row[f'bid_px_{level}'], row[f'bid_sz_{level}']]) | |
| if f'ask_px_{level}' in row: | |
| asks.append([row[f'ask_px_{level}'], row[f'ask_sz_{level}']]) | |
| row['bids'] = bids | |
| row['asks'] = asks | |
| # Pandas iterrows returns (index, Series), we want the Series/dict | |
| # Append as dict for processing | |
| chunk_rows.append(row.to_dict()) | |
| if len(chunk_rows) >= self.chunk_size: | |
| # Process and yield chunk | |
| yield from self._process_chunk(chunk_rows, buffer_df) | |
| # Update Buffer from new chunk | |
| new_df = pd.DataFrame(chunk_rows) | |
| buffer_df = new_df.tail(self.buffer_size) | |
| chunk_rows = [] | |
| gc.collect() | |
| except Exception as e: | |
| print(f"⚠️ Failed to process file {file_path}: {e}") | |
| finally: | |
| # CRITICAL: Clean up file immediately to save disk space | |
| if 'local_path' in locals() and os.path.exists(local_path): | |
| try: | |
| # Verify it's a file before removing (safety) | |
| if os.path.isfile(local_path): | |
| os.remove(local_path) | |
| except: pass | |
| # Process remaining rows after all files | |
| if len(chunk_rows) > 0: | |
| print(f"🧹 Processing final residual chunk ({len(chunk_rows)} rows)...") | |
| yield from self._process_chunk(chunk_rows, buffer_df) | |
| except Exception as e: | |
| print(f"⚠️ Manual Loading Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| def _process_chunk(self, chunk_rows, buffer_df): | |
| # Helper to process a chunk and yield batches | |
| new_df = pd.DataFrame(chunk_rows) | |
| # Merge with buffer (previous context) | |
| if not buffer_df.empty: | |
| combined_df = pd.concat([buffer_df, new_df]) | |
| else: | |
| combined_df = new_df | |
| # Process | |
| if self.model_type == "deeplob": | |
| X, y = self.processor.get_deeplob_tensors_from_df(combined_df) | |
| elif self.model_type == "trm": | |
| X, y = self.processor.get_trm_tensors_from_df(combined_df) | |
| elif self.model_type == "lstm": | |
| X, y = self.processor.get_lstm_tensors_from_df(combined_df) | |
| else: | |
| raise ValueError(f"Unknown model type: {self.model_type}") | |
| # Yield batches | |
| if len(X) > 0: | |
| dataset_size = len(X) | |
| indices = torch.randperm(dataset_size) | |
| X = X[indices] | |
| y = y[indices] | |
| for k in range(0, dataset_size, self.batch_size): | |
| batch_X = X[k:k+self.batch_size] | |
| batch_y = y[k:k+self.batch_size] | |
| if len(batch_X) == self.batch_size: | |
| yield batch_X, batch_y | |
| def get_sample_batch(self) -> Tuple[torch.Tensor, torch.Tensor]: | |
| for batch_X, batch_y in self: | |
| return batch_X, batch_y | |
| raise RuntimeError("Stream empty or failed") | |