Spaces:
Running on Zero
Running on Zero
File size: 3,511 Bytes
454ecdd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | from __future__ import annotations
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
from typing import Optional, Dict, Any, List, Union
import pyarrow.parquet as pq
from manifold.data.generator import SyntheticDataGenerator, PlayerSession
class MANIFOLDDataset(Dataset):
"""
PyTorch Dataset for MANIFOLD training data.
Supports loading from Parquet files and on-the-fly generation.
Handles sequence padding/truncation for consistent tensor shapes.
"""
def __init__(
self,
data: Optional[np.ndarray] = None,
labels: Optional[np.ndarray] = None,
sequence_length: int = 128,
pad_value: float = 0.0,
):
self.data = data
self.labels = labels
self.sequence_length = sequence_length
self.pad_value = pad_value
def __len__(self) -> int:
return len(self.data) if self.data is not None else 0
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
features = self.data[idx]
label = self.labels[idx]
seq_len = features.shape[0]
if seq_len < self.sequence_length:
padding = np.full(
(self.sequence_length - seq_len, features.shape[1]),
self.pad_value
)
features = np.concatenate([features, padding], axis=0)
mask = np.concatenate([
np.ones(seq_len),
np.zeros(self.sequence_length - seq_len)
])
else:
features = features[:self.sequence_length]
mask = np.ones(self.sequence_length)
return {
"features": torch.tensor(features, dtype=torch.float32),
"labels": torch.tensor(label, dtype=torch.long),
"mask": torch.tensor(mask, dtype=torch.float32),
}
@classmethod
def from_parquet(cls, path: Union[str, Path], **kwargs) -> "MANIFOLDDataset":
table = pq.read_table(path)
df = table.to_pandas()
data = np.array(df["features"].tolist())
labels = np.array(df["label"].tolist())
return cls(data=data, labels=labels, **kwargs)
@classmethod
def from_generator(
cls,
num_samples: int,
cheater_ratio: float = 0.3,
seed: Optional[int] = None,
**kwargs
) -> "MANIFOLDDataset":
gen = SyntheticDataGenerator(seed=seed)
num_cheaters = int(num_samples * cheater_ratio)
num_legit = num_samples - num_cheaters
sessions = gen.generate_batch(num_legit, num_cheaters)
data = np.array([s.to_tensor() for s in sessions])
labels = np.array([2 if s.is_cheater else 0 for s in sessions])
return cls(data=data, labels=labels, **kwargs)
def create_dataloader(
dataset: MANIFOLDDataset,
batch_size: int = 32,
shuffle: bool = True,
num_workers: int = 4,
pin_memory: bool = True,
) -> DataLoader:
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=True,
)
def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
return {
"features": torch.stack([b["features"] for b in batch]),
"labels": torch.stack([b["labels"] for b in batch]),
"mask": torch.stack([b["mask"] for b in batch]),
}
|