| | import torch
|
| | import numpy as np
|
| |
|
| | def lengths_to_mask(lengths):
|
| | max_len = max(lengths)
|
| | mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
|
| | return mask
|
| |
|
| |
|
| | def collate_tensors(batch):
|
| | dims = batch[0].dim()
|
| | max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
|
| | size = (len(batch),) + tuple(max_size)
|
| | canvas = batch[0].new_zeros(size=size)
|
| | for i, b in enumerate(batch):
|
| | sub_tensor = canvas[i]
|
| | for d in range(dims):
|
| | sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
|
| | sub_tensor.add_(b)
|
| | return canvas
|
| |
|
| |
|
| | def collate(batch):
|
| | notnone_batches = [b for b in batch if b is not None]
|
| | if len(notnone_batches) == 0:
|
| | out_batch = {"x": [], "y": [],
|
| | "mask": [], "lengths": [],
|
| | "clip_image": [], "clip_text": [],
|
| | "clip_path": [], "clip_images_emb": []
|
| | }
|
| | return out_batch
|
| | databatch = [b['inp'] for b in notnone_batches]
|
| | labelbatch = [b['target'] for b in notnone_batches]
|
| | lenbatch = [len(b['inp'][0][0]) for b in notnone_batches]
|
| |
|
| |
|
| | databatchTensor = collate_tensors(databatch)
|
| | labelbatchTensor = torch.as_tensor(labelbatch)
|
| | lenbatchTensor = torch.as_tensor(lenbatch)
|
| | maskbatchTensor = lengths_to_mask(lenbatchTensor)
|
| |
|
| |
|
| | out_batch = {"x": databatchTensor, "y": labelbatchTensor,
|
| | "mask": maskbatchTensor, "lengths": lenbatchTensor}
|
| |
|
| | if 'clip_image' in notnone_batches[0]:
|
| | clip_image_batch = [torch.as_tensor(b['clip_image']) for b in notnone_batches]
|
| | out_batch.update({'clip_images': collate_tensors(clip_image_batch)})
|
| |
|
| | if 'clip_text' in notnone_batches[0]:
|
| | textbatch = [b['clip_text'] for b in notnone_batches]
|
| | out_batch.update({'clip_text': textbatch})
|
| |
|
| | if 'clip_path' in notnone_batches[0]:
|
| | textbatch = [b['clip_path'] for b in notnone_batches]
|
| | out_batch.update({'clip_path': textbatch})
|
| |
|
| | if 'all_categories' in notnone_batches[0]:
|
| | textbatch = [b['all_categories'] for b in notnone_batches]
|
| | out_batch.update({'all_categories': textbatch})
|
| |
|
| | return out_batch
|
| |
|