Spaces:
Runtime error
Runtime error
| import torch | |
| from ..protein import constants | |
| from ._base import register_transform | |
| class MergeChains(object): | |
| def __init__(self): | |
| super().__init__() | |
| def assign_chain_number_(self, data_list): | |
| chains = set() | |
| for data in data_list: | |
| chains.update(data['chain_id']) | |
| chains = {c: i for i, c in enumerate(chains)} | |
| for data in data_list: | |
| data['chain_nb'] = torch.LongTensor([ | |
| chains[c] for c in data['chain_id'] | |
| ]) | |
| def _data_attr(self, data, name): | |
| if name in ('generate_flag', 'anchor_flag') and name not in data: | |
| return torch.zeros(data['aa'].shape, dtype=torch.bool) | |
| else: | |
| return data[name] | |
| def __call__(self, structure): | |
| data_list = [] | |
| if structure['heavy'] is not None: | |
| structure['heavy']['fragment_type'] = torch.full_like( | |
| structure['heavy']['aa'], | |
| fill_value = constants.Fragment.Heavy, | |
| ) | |
| data_list.append(structure['heavy']) | |
| if structure['light'] is not None: | |
| structure['light']['fragment_type'] = torch.full_like( | |
| structure['light']['aa'], | |
| fill_value = constants.Fragment.Light, | |
| ) | |
| data_list.append(structure['light']) | |
| if structure['antigen'] is not None: | |
| structure['antigen']['fragment_type'] = torch.full_like( | |
| structure['antigen']['aa'], | |
| fill_value = constants.Fragment.Antigen, | |
| ) | |
| structure['antigen']['cdr_flag'] = torch.zeros_like( | |
| structure['antigen']['aa'], | |
| ) | |
| data_list.append(structure['antigen']) | |
| self.assign_chain_number_(data_list) | |
| list_props = { | |
| 'chain_id': [], | |
| 'icode': [], | |
| } | |
| tensor_props = { | |
| 'chain_nb': [], | |
| 'resseq': [], | |
| 'res_nb': [], | |
| 'aa': [], | |
| 'pos_heavyatom': [], | |
| 'mask_heavyatom': [], | |
| 'generate_flag': [], | |
| 'cdr_flag': [], | |
| 'anchor_flag': [], | |
| 'fragment_type': [], | |
| } | |
| for data in data_list: | |
| for k in list_props.keys(): | |
| list_props[k].append(self._data_attr(data, k)) | |
| for k in tensor_props.keys(): | |
| tensor_props[k].append(self._data_attr(data, k)) | |
| list_props = {k: sum(v, start=[]) for k, v in list_props.items()} | |
| tensor_props = {k: torch.cat(v, dim=0) for k, v in tensor_props.items()} | |
| data_out = { | |
| **list_props, | |
| **tensor_props, | |
| } | |
| return data_out | |