| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import List, Dict |
| | from torch import Tensor |
| |
|
| |
|
| | def collate_tensor_with_padding(batch: List[Tensor]) -> Tensor: |
| | dims = batch[0].dim() |
| | max_size = [max([b.size(i) for b in batch]) for i in range(dims)] |
| | size = (len(batch),) + tuple(max_size) |
| | canvas = batch[0].new_zeros(size=size) |
| | for i, b in enumerate(batch): |
| | sub_tensor = canvas[i] |
| | for d in range(dims): |
| | sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) |
| | sub_tensor.add_(b) |
| | return canvas |
| |
|
| |
|
| | def collate_datastruct_and_text(lst_elements: List) -> Dict: |
| | collate_datastruct = lst_elements[0]["datastruct"].transforms.collate |
| |
|
| | batch = { |
| | |
| | "datastruct": collate_datastruct([x["datastruct"] for x in lst_elements]), |
| | |
| | "length": [x["length"] for x in lst_elements], |
| | |
| | "text": [x["text"] for x in lst_elements]} |
| |
|
| | |
| | otherkeys = [x for x in lst_elements[0].keys() if x not in batch] |
| | for key in otherkeys: |
| | batch[key] = [x[key] for x in lst_elements] |
| |
|
| | return batch |
| |
|
| | def collate_length_and_text(lst_elements: List) -> Dict: |
| |
|
| | batch = { |
| | "length_0": [x["length_0"] for x in lst_elements], |
| | "length_1": [x["length_1"] for x in lst_elements], |
| | "length_transition": [x["length_transition"] for x in lst_elements], |
| | "length_1_with_transition": [x["length_1_with_transition"] for x in lst_elements], |
| | "text_0": [x["text_0"] for x in lst_elements], |
| | "text_1": [x["text_1"] for x in lst_elements] |
| | } |
| |
|
| | return batch |
| |
|
| | def collate_pairs_and_text(lst_elements: List, ) -> Dict: |
| | if 'features_0' not in lst_elements[0]: |
| | collate_datastruct = lst_elements[0]["datastruct"].transforms.collate |
| | batch = {"datastruct": collate_datastruct([x["datastruct"] for x in lst_elements]), |
| | "length_0": [x["length_0"] for x in lst_elements], |
| | "length_1": [x["length_1"] for x in lst_elements], |
| | "length_transition": [x["length_transition"] for x in lst_elements], |
| | "length_1_with_transition": [x["length_1_with_transition"] for x in lst_elements], |
| | "text_0": [x["text_0"] for x in lst_elements], |
| | "text_1": [x["text_1"] for x in lst_elements] |
| | } |
| |
|
| | else: |
| | batch = {"motion_feats_0": collate_tensor_with_padding([el["features_0"] for el in lst_elements]), |
| | "motion_feats_1": collate_tensor_with_padding([el["features_1"] for el in lst_elements]), |
| | "motion_feats_1_with_transition": collate_tensor_with_padding([el["features_1_with_transition"] for el in lst_elements]), |
| | "length_0": [x["length_0"] for x in lst_elements], |
| | "length_1": [x["length_1"] for x in lst_elements], |
| | "length_transition": [x["length_transition"] for x in lst_elements], |
| | "length_1_with_transition": [x["length_1_with_transition"] for x in lst_elements], |
| | "text_0": [x["text_0"] for x in lst_elements], |
| | "text_1": [x["text_1"] for x in lst_elements] |
| | } |
| | return batch |
| |
|
| |
|
| | def collate_text_and_length(lst_elements: Dict) -> Dict: |
| | batch = {"length": [x["length"] for x in lst_elements], |
| | "text": [x["text"] for x in lst_elements]} |
| |
|
| | |
| | otherkeys = [x for x in lst_elements[0].keys() if x not in batch and x != "datastruct"] |
| | for key in otherkeys: |
| | batch[key] = [x[key] for x in lst_elements] |
| | return batch |
| |
|