Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| from torch.utils.data._utils.collate import default_collate | |
| DEFAULT_PAD_VALUES = { | |
| 'aa': 21, | |
| 'chain_id': ' ', | |
| 'icode': ' ', | |
| } | |
| DEFAULT_NO_PADDING = { | |
| 'origin', | |
| } | |
| class PaddingCollate(object): | |
| def __init__(self, length_ref_key='aa', pad_values=DEFAULT_PAD_VALUES, no_padding=DEFAULT_NO_PADDING, eight=True): | |
| super().__init__() | |
| self.length_ref_key = length_ref_key | |
| self.pad_values = pad_values | |
| self.no_padding = no_padding | |
| self.eight = eight | |
| def _pad_last(x, n, value=0): | |
| if isinstance(x, torch.Tensor): | |
| assert x.size(0) <= n | |
| if x.size(0) == n: | |
| return x | |
| pad_size = [n - x.size(0)] + list(x.shape[1:]) | |
| pad = torch.full(pad_size, fill_value=value).to(x) | |
| return torch.cat([x, pad], dim=0) | |
| elif isinstance(x, list): | |
| pad = [value] * (n - len(x)) | |
| return x + pad | |
| else: | |
| return x | |
| def _get_pad_mask(l, n): | |
| return torch.cat([ | |
| torch.ones([l], dtype=torch.bool), | |
| torch.zeros([n-l], dtype=torch.bool) | |
| ], dim=0) | |
| def _get_common_keys(list_of_dict): | |
| keys = set(list_of_dict[0].keys()) | |
| for d in list_of_dict[1:]: | |
| keys = keys.intersection(d.keys()) | |
| return keys | |
| def _get_pad_value(self, key): | |
| if key not in self.pad_values: | |
| return 0 | |
| return self.pad_values[key] | |
| def __call__(self, data_list): | |
| max_length = max([data[self.length_ref_key].size(0) for data in data_list]) | |
| keys = self._get_common_keys(data_list) | |
| if self.eight: | |
| max_length = math.ceil(max_length / 8) * 8 | |
| data_list_padded = [] | |
| for data in data_list: | |
| data_padded = { | |
| k: self._pad_last(v, max_length, value=self._get_pad_value(k)) if k not in self.no_padding else v | |
| for k, v in data.items() | |
| if k in keys | |
| } | |
| data_padded['mask'] = self._get_pad_mask(data[self.length_ref_key].size(0), max_length) | |
| data_list_padded.append(data_padded) | |
| return default_collate(data_list_padded) | |
| def apply_patch_to_tensor(x_full, x_patch, patch_idx): | |
| """ | |
| Args: | |
| x_full: (N, ...) | |
| x_patch: (M, ...) | |
| patch_idx: (M, ) | |
| Returns: | |
| (N, ...) | |
| """ | |
| x_full = x_full.clone() | |
| x_full[patch_idx] = x_patch | |
| return x_full | |