Spaces:
Runtime error
Runtime error
| import torch | |
| from .protein import constants | |
| def find_cdrs(structure): | |
| cdrs = [] | |
| if structure['heavy'] is not None: | |
| flag = structure['heavy']['cdr_flag'] | |
| if int(constants.CDR.H1) in flag: | |
| cdrs.append('H_CDR1') | |
| if int(constants.CDR.H2) in flag: | |
| cdrs.append('H_CDR2') | |
| if int(constants.CDR.H3) in flag: | |
| cdrs.append('H_CDR3') | |
| if structure['light'] is not None: | |
| flag = structure['light']['cdr_flag'] | |
| if int(constants.CDR.L1) in flag: | |
| cdrs.append('L_CDR1') | |
| if int(constants.CDR.L2) in flag: | |
| cdrs.append('L_CDR2') | |
| if int(constants.CDR.L3) in flag: | |
| cdrs.append('L_CDR3') | |
| return cdrs | |
| def get_residue_first_last(data): | |
| loop_flag = data['generate_flag'] | |
| loop_idx = torch.arange(loop_flag.size(0))[loop_flag] | |
| idx_first, idx_last = loop_idx.min().item(), loop_idx.max().item() | |
| residue_first = (data['chain_id'][idx_first], data['resseq'][idx_first].item(), data['icode'][idx_first]) | |
| residue_last = (data['chain_id'][idx_last], data['resseq'][idx_last].item(), data['icode'][idx_last]) | |
| return residue_first, residue_last | |
| class RemoveNative(object): | |
| def __init__(self, remove_structure, remove_sequence): | |
| super().__init__() | |
| self.remove_structure = remove_structure | |
| self.remove_sequence = remove_sequence | |
| def __call__(self, data): | |
| generate_flag = data['generate_flag'].clone() | |
| if self.remove_sequence: | |
| data['aa'] = torch.where( | |
| generate_flag, | |
| torch.full_like(data['aa'], fill_value=int(constants.AA.UNK)), # Is loop | |
| data['aa'] | |
| ) | |
| if self.remove_structure: | |
| data['pos_heavyatom'] = torch.where( | |
| generate_flag[:, None, None].expand(data['pos_heavyatom'].shape), | |
| torch.randn_like(data['pos_heavyatom']) * 10, | |
| data['pos_heavyatom'] | |
| ) | |
| return data |