LimmeDev's picture
Initial MANIFOLD upload - CS2 cheat detection training
454ecdd verified
from __future__ import annotations
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
from typing import Optional, Dict, Any, List, Union
import pyarrow.parquet as pq
from manifold.data.generator import SyntheticDataGenerator, PlayerSession
class MANIFOLDDataset(Dataset):
"""
PyTorch Dataset for MANIFOLD training data.
Supports loading from Parquet files and on-the-fly generation.
Handles sequence padding/truncation for consistent tensor shapes.
"""
def __init__(
self,
data: Optional[np.ndarray] = None,
labels: Optional[np.ndarray] = None,
sequence_length: int = 128,
pad_value: float = 0.0,
):
self.data = data
self.labels = labels
self.sequence_length = sequence_length
self.pad_value = pad_value
def __len__(self) -> int:
return len(self.data) if self.data is not None else 0
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
features = self.data[idx]
label = self.labels[idx]
seq_len = features.shape[0]
if seq_len < self.sequence_length:
padding = np.full(
(self.sequence_length - seq_len, features.shape[1]),
self.pad_value
)
features = np.concatenate([features, padding], axis=0)
mask = np.concatenate([
np.ones(seq_len),
np.zeros(self.sequence_length - seq_len)
])
else:
features = features[:self.sequence_length]
mask = np.ones(self.sequence_length)
return {
"features": torch.tensor(features, dtype=torch.float32),
"labels": torch.tensor(label, dtype=torch.long),
"mask": torch.tensor(mask, dtype=torch.float32),
}
@classmethod
def from_parquet(cls, path: Union[str, Path], **kwargs) -> "MANIFOLDDataset":
table = pq.read_table(path)
df = table.to_pandas()
data = np.array(df["features"].tolist())
labels = np.array(df["label"].tolist())
return cls(data=data, labels=labels, **kwargs)
@classmethod
def from_generator(
cls,
num_samples: int,
cheater_ratio: float = 0.3,
seed: Optional[int] = None,
**kwargs
) -> "MANIFOLDDataset":
gen = SyntheticDataGenerator(seed=seed)
num_cheaters = int(num_samples * cheater_ratio)
num_legit = num_samples - num_cheaters
sessions = gen.generate_batch(num_legit, num_cheaters)
data = np.array([s.to_tensor() for s in sessions])
labels = np.array([2 if s.is_cheater else 0 for s in sessions])
return cls(data=data, labels=labels, **kwargs)
def create_dataloader(
dataset: MANIFOLDDataset,
batch_size: int = 32,
shuffle: bool = True,
num_workers: int = 4,
pin_memory: bool = True,
) -> DataLoader:
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=True,
)
def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
return {
"features": torch.stack([b["features"] for b in batch]),
"labels": torch.stack([b["labels"] for b in batch]),
"mask": torch.stack([b["mask"] for b in batch]),
}