import torch from torch.nn.utils.rnn import pad_sequence from mmengine.registry import FUNCTIONS @FUNCTIONS.register_module() def motion_collate_fn(batch): motion = [item['motion'] for item in batch] # [motion: T, C] motion_length = [item['motion_length'] for item in batch] # caption = [item['caption'] for item in batch] motion = pad_sequence(motion, batch_first=True) mean = torch.stack([item['mean'] for item in batch], dim=0) std = torch.stack([item['std'] for item in batch], dim=0) return dict( std=std, mean=mean, motion=motion, motion_length=torch.tensor(motion_length), # caption=caption ) @FUNCTIONS.register_module() def motion_collate_fn_no_translation(batch): motion = [item['motion'] for item in batch] # [motion: T, C] motion_length = [item['motion_length'] for item in batch] # caption = [item['caption'] for item in batch] motion = pad_sequence(motion, batch_first=True) mean = torch.stack([item['mean'] for item in batch], dim=0) std = torch.stack([item['std'] for item in batch], dim=0) return dict( std=std, mean=mean, motion=motion[..., 3:], motion_length=torch.tensor(motion_length), # caption=caption )