Spaces:
Runtime error
Runtime error
File size: 946 Bytes
9f83ce9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | import torch
from configs import DataConfig
from features import BaseDataset, VSL98Dataset, VSL400Dataset
def load_dataset(data_config: DataConfig) -> BaseDataset:
'''
'''
datasets = {
'vsl_98': VSL98Dataset,
"vsl_400": VSL400Dataset,
}
return datasets[data_config.dataset](data_config)
def rgb_collate_fn(examples) -> dict:
# permute to (num_frames, num_channels, height, width)
pixel_values = torch.stack(
[example["video"].permute(1, 0, 2, 3) for example in examples]
)
labels = torch.tensor([example["label"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
def pose_collate_fn(examples) -> dict:
# permute to (num_frames, num_channels, height, width)
poses = torch.stack([example["pose"] for example in examples])
labels = torch.tensor([example["label"] for example in examples])
return {"poses": poses, "labels": labels}
|