File size: 267 Bytes
c1596ac
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

def collate_caption(batch):
    images = []
    tokens = []

    for image, token in batch:
        images.append(image)
        tokens.append(token)

    images = torch.cat(images, dim=0)
    tokens = torch.cat(tokens, dim=0)

    return images, tokens