| import torch | |
| from torch.utils.data.dataloader import default_collate | |
| def collate_video(batch): | |
| ''' | |
| Our video is (temporal_crops, C, T, H, W) where temporal_crops differes from clip to clip | |
| We can't use standard collate function. | |
| Instead of stacking, let's do cat | |
| Keep in mind that this will also need list of frame length in order to restore each videos later. | |
| ''' | |
| elem = batch[0] | |
| assert isinstance(elem,dict) | |
| output = {key: default_collate([d[key] for d in batch]) for key in elem if key!='input'} | |
| output["input"] = torch.cat([d["input"] for d in batch]) | |
| return output | |