SignLanguage / src /tools /features.py
thienphuc12339's picture
Add all source code
9f83ce9
raw
history blame contribute delete
946 Bytes
import torch
from configs import DataConfig
from features import BaseDataset, VSL98Dataset, VSL400Dataset
def load_dataset(data_config: DataConfig) -> BaseDataset:
'''
'''
datasets = {
'vsl_98': VSL98Dataset,
"vsl_400": VSL400Dataset,
}
return datasets[data_config.dataset](data_config)
def rgb_collate_fn(examples) -> dict:
# permute to (num_frames, num_channels, height, width)
pixel_values = torch.stack(
[example["video"].permute(1, 0, 2, 3) for example in examples]
)
labels = torch.tensor([example["label"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
def pose_collate_fn(examples) -> dict:
# permute to (num_frames, num_channels, height, width)
poses = torch.stack([example["pose"] for example in examples])
labels = torch.tensor([example["label"] for example in examples])
return {"poses": poses, "labels": labels}