File size: 946 Bytes
9f83ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from configs import DataConfig
from features import BaseDataset, VSL98Dataset, VSL400Dataset


def load_dataset(data_config: DataConfig) -> BaseDataset:
    '''
    '''
    datasets = {
        'vsl_98': VSL98Dataset,
        "vsl_400": VSL400Dataset,
    }
    return datasets[data_config.dataset](data_config)


def rgb_collate_fn(examples) -> dict:
    # permute to (num_frames, num_channels, height, width)
    pixel_values = torch.stack(
        [example["video"].permute(1, 0, 2, 3) for example in examples]
    )
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}


def pose_collate_fn(examples) -> dict:
    # permute to (num_frames, num_channels, height, width)
    poses = torch.stack([example["pose"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"poses": poses, "labels": labels}