File size: 3,511 Bytes
454ecdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from __future__ import annotations
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
from typing import Optional, Dict, Any, List, Union
import pyarrow.parquet as pq

from manifold.data.generator import SyntheticDataGenerator, PlayerSession


class MANIFOLDDataset(Dataset):
    """
    PyTorch Dataset for MANIFOLD training data.
    
    Supports loading from Parquet files and on-the-fly generation.
    Handles sequence padding/truncation for consistent tensor shapes.
    """
    
    def __init__(
        self,
        data: Optional[np.ndarray] = None,
        labels: Optional[np.ndarray] = None,
        sequence_length: int = 128,
        pad_value: float = 0.0,
    ):
        self.data = data
        self.labels = labels
        self.sequence_length = sequence_length
        self.pad_value = pad_value
        
    def __len__(self) -> int:
        return len(self.data) if self.data is not None else 0
        
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        features = self.data[idx]
        label = self.labels[idx]
        
        seq_len = features.shape[0]
        if seq_len < self.sequence_length:
            padding = np.full(
                (self.sequence_length - seq_len, features.shape[1]),
                self.pad_value
            )
            features = np.concatenate([features, padding], axis=0)
            mask = np.concatenate([
                np.ones(seq_len),
                np.zeros(self.sequence_length - seq_len)
            ])
        else:
            features = features[:self.sequence_length]
            mask = np.ones(self.sequence_length)
            
        return {
            "features": torch.tensor(features, dtype=torch.float32),
            "labels": torch.tensor(label, dtype=torch.long),
            "mask": torch.tensor(mask, dtype=torch.float32),
        }
    
    @classmethod
    def from_parquet(cls, path: Union[str, Path], **kwargs) -> "MANIFOLDDataset":
        table = pq.read_table(path)
        df = table.to_pandas()
        data = np.array(df["features"].tolist())
        labels = np.array(df["label"].tolist())
        return cls(data=data, labels=labels, **kwargs)
    
    @classmethod
    def from_generator(
        cls,
        num_samples: int,
        cheater_ratio: float = 0.3,
        seed: Optional[int] = None,
        **kwargs
    ) -> "MANIFOLDDataset":
        gen = SyntheticDataGenerator(seed=seed)
        num_cheaters = int(num_samples * cheater_ratio)
        num_legit = num_samples - num_cheaters
        
        sessions = gen.generate_batch(num_legit, num_cheaters)
        
        data = np.array([s.to_tensor() for s in sessions])
        labels = np.array([2 if s.is_cheater else 0 for s in sessions])
        
        return cls(data=data, labels=labels, **kwargs)


def create_dataloader(
    dataset: MANIFOLDDataset,
    batch_size: int = 32,
    shuffle: bool = True,
    num_workers: int = 4,
    pin_memory: bool = True,
) -> DataLoader:
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=True,
    )


def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    return {
        "features": torch.stack([b["features"] for b in batch]),
        "labels": torch.stack([b["labels"] for b in batch]),
        "mask": torch.stack([b["mask"] for b in batch]),
    }