forensics-grpo / code /libs /datasets /datasets.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
920 Bytes
import os
import torch
from .data_utils import trivial_batch_collator, worker_init_reset_seed
datasets = {}
def register_dataset(name):
def decorator(cls):
datasets[name] = cls
return cls
return decorator
def make_dataset(name, is_training, split, **kwargs):
"""
A simple dataset builder
"""
dataset = datasets[name](is_training, split, **kwargs)
return dataset
def make_data_loader(dataset, is_training, generator, batch_size, num_workers):
"""
A simple dataloder builder
"""
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
collate_fn=trivial_batch_collator,
worker_init_fn=(worker_init_reset_seed if is_training else None),
shuffle=is_training,
drop_last=is_training,
generator=generator,
persistent_workers=True
)
return loader