Spaces:
Runtime error
Runtime error
| import torch | |
| from configs import DataConfig | |
| from features import BaseDataset, VSL98Dataset, VSL400Dataset | |
| def load_dataset(data_config: DataConfig) -> BaseDataset: | |
| ''' | |
| ''' | |
| datasets = { | |
| 'vsl_98': VSL98Dataset, | |
| "vsl_400": VSL400Dataset, | |
| } | |
| return datasets[data_config.dataset](data_config) | |
| def rgb_collate_fn(examples) -> dict: | |
| # permute to (num_frames, num_channels, height, width) | |
| pixel_values = torch.stack( | |
| [example["video"].permute(1, 0, 2, 3) for example in examples] | |
| ) | |
| labels = torch.tensor([example["label"] for example in examples]) | |
| return {"pixel_values": pixel_values, "labels": labels} | |
| def pose_collate_fn(examples) -> dict: | |
| # permute to (num_frames, num_channels, height, width) | |
| poses = torch.stack([example["pose"] for example in examples]) | |
| labels = torch.tensor([example["label"] for example in examples]) | |
| return {"poses": poses, "labels": labels} | |