File size: 1,246 Bytes
45950ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from torch.nn.utils.rnn import pad_sequence
from mmengine.registry import FUNCTIONS

@FUNCTIONS.register_module()
def motion_collate_fn(batch):
    motion = [item['motion'] for item in batch] # [motion: T, C]
    motion_length = [item['motion_length'] for item in batch]
    # caption = [item['caption'] for item in batch]

    motion = pad_sequence(motion, batch_first=True)
    mean = torch.stack([item['mean'] for item in batch], dim=0)
    std = torch.stack([item['std'] for item in batch], dim=0)

    return dict(
        std=std, mean=mean, motion=motion, 
        motion_length=torch.tensor(motion_length),
        # caption=caption
    )

@FUNCTIONS.register_module()
def motion_collate_fn_no_translation(batch):
    motion = [item['motion'] for item in batch] # [motion: T, C]
    motion_length = [item['motion_length'] for item in batch]
    # caption = [item['caption'] for item in batch]

    motion = pad_sequence(motion, batch_first=True)
    mean = torch.stack([item['mean'] for item in batch], dim=0)
    std = torch.stack([item['std'] for item in batch], dim=0)

    return dict(
        std=std, mean=mean, motion=motion[..., 3:], 
        motion_length=torch.tensor(motion_length),
        # caption=caption
    )