Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| import numpy as np | |
| from pathlib import Path | |
| from typing import Optional, Dict, Any, List, Union | |
| import pyarrow.parquet as pq | |
| from manifold.data.generator import SyntheticDataGenerator, PlayerSession | |
| class MANIFOLDDataset(Dataset): | |
| """ | |
| PyTorch Dataset for MANIFOLD training data. | |
| Supports loading from Parquet files and on-the-fly generation. | |
| Handles sequence padding/truncation for consistent tensor shapes. | |
| """ | |
| def __init__( | |
| self, | |
| data: Optional[np.ndarray] = None, | |
| labels: Optional[np.ndarray] = None, | |
| sequence_length: int = 128, | |
| pad_value: float = 0.0, | |
| ): | |
| self.data = data | |
| self.labels = labels | |
| self.sequence_length = sequence_length | |
| self.pad_value = pad_value | |
| def __len__(self) -> int: | |
| return len(self.data) if self.data is not None else 0 | |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
| features = self.data[idx] | |
| label = self.labels[idx] | |
| seq_len = features.shape[0] | |
| if seq_len < self.sequence_length: | |
| padding = np.full( | |
| (self.sequence_length - seq_len, features.shape[1]), | |
| self.pad_value | |
| ) | |
| features = np.concatenate([features, padding], axis=0) | |
| mask = np.concatenate([ | |
| np.ones(seq_len), | |
| np.zeros(self.sequence_length - seq_len) | |
| ]) | |
| else: | |
| features = features[:self.sequence_length] | |
| mask = np.ones(self.sequence_length) | |
| return { | |
| "features": torch.tensor(features, dtype=torch.float32), | |
| "labels": torch.tensor(label, dtype=torch.long), | |
| "mask": torch.tensor(mask, dtype=torch.float32), | |
| } | |
| def from_parquet(cls, path: Union[str, Path], **kwargs) -> "MANIFOLDDataset": | |
| table = pq.read_table(path) | |
| df = table.to_pandas() | |
| data = np.array(df["features"].tolist()) | |
| labels = np.array(df["label"].tolist()) | |
| return cls(data=data, labels=labels, **kwargs) | |
| def from_generator( | |
| cls, | |
| num_samples: int, | |
| cheater_ratio: float = 0.3, | |
| seed: Optional[int] = None, | |
| **kwargs | |
| ) -> "MANIFOLDDataset": | |
| gen = SyntheticDataGenerator(seed=seed) | |
| num_cheaters = int(num_samples * cheater_ratio) | |
| num_legit = num_samples - num_cheaters | |
| sessions = gen.generate_batch(num_legit, num_cheaters) | |
| data = np.array([s.to_tensor() for s in sessions]) | |
| labels = np.array([2 if s.is_cheater else 0 for s in sessions]) | |
| return cls(data=data, labels=labels, **kwargs) | |
| def create_dataloader( | |
| dataset: MANIFOLDDataset, | |
| batch_size: int = 32, | |
| shuffle: bool = True, | |
| num_workers: int = 4, | |
| pin_memory: bool = True, | |
| ) -> DataLoader: | |
| return DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=shuffle, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| drop_last=True, | |
| ) | |
| def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: | |
| return { | |
| "features": torch.stack([b["features"] for b in batch]), | |
| "labels": torch.stack([b["labels"] for b in batch]), | |
| "mask": torch.stack([b["mask"] for b in batch]), | |
| } | |