from __future__ import annotations import torch from torch.utils.data import Dataset, DataLoader import numpy as np from pathlib import Path from typing import Optional, Dict, Any, List, Union import pyarrow.parquet as pq from manifold.data.generator import SyntheticDataGenerator, PlayerSession class MANIFOLDDataset(Dataset): """ PyTorch Dataset for MANIFOLD training data. Supports loading from Parquet files and on-the-fly generation. Handles sequence padding/truncation for consistent tensor shapes. """ def __init__( self, data: Optional[np.ndarray] = None, labels: Optional[np.ndarray] = None, sequence_length: int = 128, pad_value: float = 0.0, ): self.data = data self.labels = labels self.sequence_length = sequence_length self.pad_value = pad_value def __len__(self) -> int: return len(self.data) if self.data is not None else 0 def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: features = self.data[idx] label = self.labels[idx] seq_len = features.shape[0] if seq_len < self.sequence_length: padding = np.full( (self.sequence_length - seq_len, features.shape[1]), self.pad_value ) features = np.concatenate([features, padding], axis=0) mask = np.concatenate([ np.ones(seq_len), np.zeros(self.sequence_length - seq_len) ]) else: features = features[:self.sequence_length] mask = np.ones(self.sequence_length) return { "features": torch.tensor(features, dtype=torch.float32), "labels": torch.tensor(label, dtype=torch.long), "mask": torch.tensor(mask, dtype=torch.float32), } @classmethod def from_parquet(cls, path: Union[str, Path], **kwargs) -> "MANIFOLDDataset": table = pq.read_table(path) df = table.to_pandas() data = np.array(df["features"].tolist()) labels = np.array(df["label"].tolist()) return cls(data=data, labels=labels, **kwargs) @classmethod def from_generator( cls, num_samples: int, cheater_ratio: float = 0.3, seed: Optional[int] = None, **kwargs ) -> "MANIFOLDDataset": gen = SyntheticDataGenerator(seed=seed) num_cheaters = int(num_samples * cheater_ratio) num_legit = num_samples - num_cheaters sessions = gen.generate_batch(num_legit, num_cheaters) data = np.array([s.to_tensor() for s in sessions]) labels = np.array([2 if s.is_cheater else 0 for s in sessions]) return cls(data=data, labels=labels, **kwargs) def create_dataloader( dataset: MANIFOLDDataset, batch_size: int = 32, shuffle: bool = True, num_workers: int = 4, pin_memory: bool = True, ) -> DataLoader: return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, drop_last=True, ) def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: return { "features": torch.stack([b["features"] for b in batch]), "labels": torch.stack([b["labels"] for b in batch]), "mask": torch.stack([b["mask"] for b in batch]), }