Spaces:
Runtime error
Runtime error
| import torch | |
| import random | |
| from typing import List, Optional | |
| from ..protein import constants | |
| from ._base import register_transform | |
| def random_shrink_extend(flag, min_length=5, shrink_limit=1, extend_limit=2): | |
| first, last = continuous_flag_to_range(flag) | |
| length = flag.sum().item() | |
| if (length - 2*shrink_limit) < min_length: | |
| shrink_limit = 0 | |
| first_ext = max(0, first-random.randint(-shrink_limit, extend_limit)) | |
| last_ext = min(last+random.randint(-shrink_limit, extend_limit), flag.size(0)-1) | |
| flag_ext = flag.clone() | |
| flag_ext[first_ext : last_ext+1] = True | |
| return flag_ext | |
| def continuous_flag_to_range(flag): | |
| first = (torch.arange(0, flag.size(0))[flag]).min().item() | |
| last = (torch.arange(0, flag.size(0))[flag]).max().item() | |
| return first, last | |
| class MaskSingleCDR(object): | |
| def __init__(self, selection=None, augmentation=True): | |
| super().__init__() | |
| cdr_str_to_enum = { | |
| 'H1': constants.CDR.H1, | |
| 'H2': constants.CDR.H2, | |
| 'H3': constants.CDR.H3, | |
| 'L1': constants.CDR.L1, | |
| 'L2': constants.CDR.L2, | |
| 'L3': constants.CDR.L3, | |
| 'H_CDR1': constants.CDR.H1, | |
| 'H_CDR2': constants.CDR.H2, | |
| 'H_CDR3': constants.CDR.H3, | |
| 'L_CDR1': constants.CDR.L1, | |
| 'L_CDR2': constants.CDR.L2, | |
| 'L_CDR3': constants.CDR.L3, | |
| 'CDR3': 'CDR3', # H3 first, then fallback to L3 | |
| } | |
| assert selection is None or selection in cdr_str_to_enum | |
| self.selection = cdr_str_to_enum.get(selection, None) | |
| self.augmentation = augmentation | |
| def perform_masking_(self, data, selection=None): | |
| cdr_flag = data['cdr_flag'] | |
| if selection is None: | |
| cdr_all = cdr_flag[cdr_flag > 0].unique().tolist() | |
| cdr_to_mask = random.choice(cdr_all) | |
| else: | |
| cdr_to_mask = selection | |
| cdr_to_mask_flag = (cdr_flag == cdr_to_mask) | |
| if self.augmentation: | |
| cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag) | |
| cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag) | |
| left_idx = max(0, cdr_first-1) | |
| right_idx = min(data['aa'].size(0)-1, cdr_last+1) | |
| anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool) | |
| anchor_flag[left_idx] = True | |
| anchor_flag[right_idx] = True | |
| data['generate_flag'] = cdr_to_mask_flag | |
| data['anchor_flag'] = anchor_flag | |
| def __call__(self, structure): | |
| if self.selection is None: | |
| ab_data = [] | |
| if structure['heavy'] is not None: | |
| ab_data.append(structure['heavy']) | |
| if structure['light'] is not None: | |
| ab_data.append(structure['light']) | |
| data_to_mask = random.choice(ab_data) | |
| sel = None | |
| elif self.selection in (constants.CDR.H1, constants.CDR.H2, constants.CDR.H3, ): | |
| data_to_mask = structure['heavy'] | |
| sel = int(self.selection) | |
| elif self.selection in (constants.CDR.L1, constants.CDR.L2, constants.CDR.L3, ): | |
| data_to_mask = structure['light'] | |
| sel = int(self.selection) | |
| elif self.selection == 'CDR3': | |
| if structure['heavy'] is not None: | |
| data_to_mask = structure['heavy'] | |
| sel = constants.CDR.H3 | |
| else: | |
| data_to_mask = structure['light'] | |
| sel = constants.CDR.L3 | |
| self.perform_masking_(data_to_mask, selection=sel) | |
| return structure | |
| class MaskMultipleCDRs(object): | |
| def __init__(self, selection: Optional[List[str]]=None, augmentation=True): | |
| super().__init__() | |
| cdr_str_to_enum = { | |
| 'H1': constants.CDR.H1, | |
| 'H2': constants.CDR.H2, | |
| 'H3': constants.CDR.H3, | |
| 'L1': constants.CDR.L1, | |
| 'L2': constants.CDR.L2, | |
| 'L3': constants.CDR.L3, | |
| 'H_CDR1': constants.CDR.H1, | |
| 'H_CDR2': constants.CDR.H2, | |
| 'H_CDR3': constants.CDR.H3, | |
| 'L_CDR1': constants.CDR.L1, | |
| 'L_CDR2': constants.CDR.L2, | |
| 'L_CDR3': constants.CDR.L3, | |
| } | |
| if selection is not None: | |
| self.selection = [cdr_str_to_enum[s] for s in selection] | |
| else: | |
| self.selection = None | |
| self.augmentation = augmentation | |
| def mask_one_cdr_(self, data, cdr_to_mask): | |
| cdr_flag = data['cdr_flag'] | |
| cdr_to_mask_flag = (cdr_flag == cdr_to_mask) | |
| if self.augmentation: | |
| cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag) | |
| cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag) | |
| left_idx = max(0, cdr_first-1) | |
| right_idx = min(data['aa'].size(0)-1, cdr_last+1) | |
| anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool) | |
| anchor_flag[left_idx] = True | |
| anchor_flag[right_idx] = True | |
| if 'generate_flag' not in data: | |
| data['generate_flag'] = cdr_to_mask_flag | |
| data['anchor_flag'] = anchor_flag | |
| else: | |
| data['generate_flag'] |= cdr_to_mask_flag | |
| data['anchor_flag'] |= anchor_flag | |
| def mask_for_one_chain_(self, data): | |
| cdr_flag = data['cdr_flag'] | |
| cdr_all = cdr_flag[cdr_flag > 0].unique().tolist() | |
| num_cdrs_to_mask = random.randint(1, len(cdr_all)) | |
| if self.selection is not None: | |
| cdrs_to_mask = list(set(cdr_all).intersection(self.selection)) | |
| else: | |
| random.shuffle(cdr_all) | |
| cdrs_to_mask = cdr_all[:num_cdrs_to_mask] | |
| for cdr_to_mask in cdrs_to_mask: | |
| self.mask_one_cdr_(data, cdr_to_mask) | |
| def __call__(self, structure): | |
| if structure['heavy'] is not None: | |
| self.mask_for_one_chain_(structure['heavy']) | |
| if structure['light'] is not None: | |
| self.mask_for_one_chain_(structure['light']) | |
| return structure | |
| class MaskAntibody(object): | |
| def mask_ab_chain_(self, data): | |
| data['generate_flag'] = torch.ones(data['aa'].shape, dtype=torch.bool) | |
| def __call__(self, structure): | |
| pos_ab_alpha = [] | |
| if structure['heavy'] is not None: | |
| self.mask_ab_chain_(structure['heavy']) | |
| pos_ab_alpha.append( | |
| structure['heavy']['pos_heavyatom'][:, constants.BBHeavyAtom.CA] | |
| ) | |
| if structure['light'] is not None: | |
| self.mask_ab_chain_(structure['light']) | |
| pos_ab_alpha.append( | |
| structure['light']['pos_heavyatom'][:, constants.BBHeavyAtom.CA] | |
| ) | |
| pos_ab_alpha = torch.cat(pos_ab_alpha, dim=0) # (L_Ab, 3) | |
| if structure['antigen'] is not None: | |
| pos_ag_alpha = structure['antigen']['pos_heavyatom'][:, constants.BBHeavyAtom.CA] | |
| ag_ab_dist = torch.cdist(pos_ag_alpha, pos_ab_alpha) # (L_Ag, L_Ab) | |
| nn_ab_dist = ag_ab_dist.min(dim=1)[0] # (L_Ag) | |
| contact_flag = (nn_ab_dist <= 6.0) # (L_Ag) | |
| if contact_flag.sum().item() == 0: | |
| contact_flag[nn_ab_dist.argmin()] = True | |
| anchor_idx = torch.multinomial(contact_flag.float(), num_samples=1).item() | |
| anchor_flag = torch.zeros(structure['antigen']['aa'].shape, dtype=torch.bool) | |
| anchor_flag[anchor_idx] = True | |
| structure['antigen']['anchor_flag'] = anchor_flag | |
| structure['antigen']['contact_flag'] = contact_flag | |
| return structure | |
| class RemoveAntigen: | |
| def __call__(self, structure): | |
| structure['antigen'] = None | |
| structure['antigen_seqmap'] = None | |
| return structure | |