Spaces:
Sleeping
Sleeping
| import pdbx | |
| from pdbx.reader.PdbxReader import PdbxReader | |
| from pdbx.reader.PdbxContainers import DataCategory | |
| import gzip | |
| import numpy as np | |
| import torch | |
| import os,sys | |
| import glob | |
| import re | |
| from scipy.spatial import KDTree | |
| from itertools import combinations,permutations | |
| import tempfile | |
| import subprocess | |
| RES_NAMES = [ | |
| 'ALA','ARG','ASN','ASP','CYS', | |
| 'GLN','GLU','GLY','HIS','ILE', | |
| 'LEU','LYS','MET','PHE','PRO', | |
| 'SER','THR','TRP','TYR','VAL' | |
| ] | |
| RES_NAMES_1 = 'ARNDCQEGHILKMFPSTWYV' | |
| to1letter = {aaa:a for a,aaa in zip(RES_NAMES_1,RES_NAMES)} | |
| to3letter = {a:aaa for a,aaa in zip(RES_NAMES_1,RES_NAMES)} | |
| ATOM_NAMES = [ | |
| ("N", "CA", "C", "O", "CB"), # ala | |
| ("N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"), # arg | |
| ("N", "CA", "C", "O", "CB", "CG", "OD1", "ND2"), # asn | |
| ("N", "CA", "C", "O", "CB", "CG", "OD1", "OD2"), # asp | |
| ("N", "CA", "C", "O", "CB", "SG"), # cys | |
| ("N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2"), # gln | |
| ("N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2"), # glu | |
| ("N", "CA", "C", "O"), # gly | |
| ("N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2"), # his | |
| ("N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1"), # ile | |
| ("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2"), # leu | |
| ("N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ"), # lys | |
| ("N", "CA", "C", "O", "CB", "CG", "SD", "CE"), # met | |
| ("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"), # phe | |
| ("N", "CA", "C", "O", "CB", "CG", "CD"), # pro | |
| ("N", "CA", "C", "O", "CB", "OG"), # ser | |
| ("N", "CA", "C", "O", "CB", "OG1", "CG2"), # thr | |
| ("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE2", "CE3", "NE1", "CZ2", "CZ3", "CH2"), # trp | |
| ("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"), # tyr | |
| ("N", "CA", "C", "O", "CB", "CG1", "CG2") # val | |
| ] | |
| idx2ra = {(RES_NAMES_1[i],j):(RES_NAMES[i],a) for i in range(20) for j,a in enumerate(ATOM_NAMES[i])} | |
| aa2idx = {(r,a):i for r,atoms in zip(RES_NAMES,ATOM_NAMES) | |
| for i,a in enumerate(atoms)} | |
| aa2idx.update({(r,'OXT'):3 for r in RES_NAMES}) | |
| def writepdb(f, xyz, seq, bfac=None): | |
| #f = open(filename,"w") | |
| f.seek(0) | |
| ctr = 1 | |
| seq = str(seq) | |
| L = len(seq) | |
| if bfac is None: | |
| bfac = np.zeros((L)) | |
| idx = [] | |
| for i in range(L): | |
| for j,xyz_ij in enumerate(xyz[i]): | |
| key = (seq[i],j) | |
| if key not in idx2ra.keys(): | |
| continue | |
| if np.isnan(xyz_ij).sum()>0: | |
| continue | |
| r,a = idx2ra[key] | |
| f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( | |
| "ATOM", ctr, a, r, | |
| "A", i+1, xyz_ij[0], xyz_ij[1], xyz_ij[2], | |
| 1.0, bfac[i,j] ) ) | |
| if a == 'CA': | |
| idx.append(i) | |
| ctr += 1 | |
| #f.close() | |
| f.flush() | |
| return np.array(idx) | |
| def TMalign(chainA, chainB): | |
| # temp files to save the two input protein chains | |
| # and TMalign transformation | |
| fA = tempfile.NamedTemporaryFile(mode='w+t', dir='/dev/shm') | |
| fB = tempfile.NamedTemporaryFile(mode='w+t', dir='/dev/shm') | |
| mtx = tempfile.NamedTemporaryFile(mode='w+t', dir='/dev/shm') | |
| # create temp PDB files keep track of residue indices which were saved | |
| idxA = writepdb(fA, chainA['xyz'], chainA['seq'], bfac=chainA['bfac']) | |
| idxB = writepdb(fB, chainB['xyz'], chainB['seq'], bfac=chainB['bfac']) | |
| # run TMalign | |
| tm = subprocess.Popen('/home/aivan/prog/TMalign %s %s -m %s'%(fA.name, fB.name, mtx.name), | |
| shell=True, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| encoding='utf-8') | |
| stdout,stderr = tm.communicate() | |
| lines = stdout.split('\n') | |
| # if TMalign failed | |
| if len(stderr) > 0: | |
| return None,None | |
| # parse transformation | |
| mtx.seek(0) | |
| tu = np.fromstring(''.join(l[2:] for l in mtx.readlines()[2:5]), | |
| dtype=float, sep=' ').reshape((3,4)) | |
| t = tu[:,0] | |
| u = tu[:,1:] | |
| # parse rmsd, sequence identity, and two TM-scores | |
| rmsd = float(lines[16].split()[4][:-1]) | |
| seqid = float(lines[16].split()[-1]) | |
| tm1 = float(lines[17].split()[1]) | |
| tm2 = float(lines[18].split()[1]) | |
| # parse alignment | |
| seq1 = lines[-5] | |
| seq2 = lines[-3] | |
| ss1 = np.array(list(seq1.strip()))!='-' | |
| ss2 = np.array(list(seq2.strip()))!='-' | |
| #print(ss1) | |
| #print(ss2) | |
| mask = np.logical_and(ss1, ss2) | |
| alnAB = np.stack((idxA[(np.cumsum(ss1)-1)[mask]], | |
| idxB[(np.cumsum(ss2)-1)[mask]])) | |
| alnBA = np.stack((alnAB[1],alnAB[0])) | |
| # clean up | |
| fA.close() | |
| fB.close() | |
| mtx.close() | |
| resAB = {'rmsd':rmsd, 'seqid':seqid, 'tm':tm1, 'aln':alnAB, 't':t, 'u':u} | |
| resBA = {'rmsd':rmsd, 'seqid':seqid, 'tm':tm2, 'aln':alnBA, 't':-u.T@t, 'u':u.T} | |
| return resAB,resBA | |
| def get_tm_pairs(chains): | |
| """run TM-align for all pairs of chains""" | |
| tm_pairs = {} | |
| for A,B in combinations(chains.keys(),r=2): | |
| resAB,resBA = TMalign(chains[A],chains[B]) | |
| #if resAB is None: | |
| # continue | |
| tm_pairs.update({(A,B):resAB}) | |
| tm_pairs.update({(B,A):resBA}) | |
| # add self-alignments | |
| for A in chains.keys(): | |
| L = chains[A]['xyz'].shape[0] | |
| aln = np.arange(L)[chains[A]['mask'][:,1]] | |
| aln = np.stack((aln,aln)) | |
| tm_pairs.update({(A,A):{'rmsd':0.0, 'seqid':1.0, 'tm':1.0, 'aln':aln}}) | |
| return tm_pairs | |
| def parseOperationExpression(expression) : | |
| expression = expression.strip('() ') | |
| operations = [] | |
| for e in expression.split(','): | |
| e = e.strip() | |
| pos = e.find('-') | |
| if pos>0: | |
| start = int(e[0:pos]) | |
| stop = int(e[pos+1:]) | |
| operations.extend([str(i) for i in range(start,stop+1)]) | |
| else: | |
| operations.append(e) | |
| return operations | |
| def parseAssemblies(data,chids): | |
| xforms = {'asmb_chains' : None, | |
| 'asmb_details' : None, | |
| 'asmb_method' : None, | |
| 'asmb_ids' : None} | |
| assembly_data = data.getObj("pdbx_struct_assembly") | |
| assembly_gen = data.getObj("pdbx_struct_assembly_gen") | |
| oper_list = data.getObj("pdbx_struct_oper_list") | |
| if (assembly_data is None) or (assembly_gen is None) or (oper_list is None): | |
| return xforms | |
| # save all basic transformations in a dictionary | |
| opers = {} | |
| for k in range(oper_list.getRowCount()): | |
| key = oper_list.getValue("id", k) | |
| val = np.eye(4) | |
| for i in range(3): | |
| val[i,3] = float(oper_list.getValue("vector[%d]"%(i+1), k)) | |
| for j in range(3): | |
| val[i,j] = float(oper_list.getValue("matrix[%d][%d]"%(i+1,j+1), k)) | |
| opers.update({key:val}) | |
| chains,details,method,ids = [],[],[],[] | |
| for index in range(assembly_gen.getRowCount()): | |
| # Retrieve the assembly_id attribute value for this assembly | |
| assemblyId = assembly_gen.getValue("assembly_id", index) | |
| ids.append(assemblyId) | |
| # Retrieve the operation expression for this assembly from the oper_expression attribute | |
| oper_expression = assembly_gen.getValue("oper_expression", index) | |
| oper_list = [parseOperationExpression(expression) | |
| for expression in re.split('\(|\)', oper_expression) if expression] | |
| # chain IDs which the transform should be applied to | |
| chains.append(assembly_gen.getValue("asym_id_list", index)) | |
| index_asmb = min(index,assembly_data.getRowCount()-1) | |
| details.append(assembly_data.getValue("details", index_asmb)) | |
| method.append(assembly_data.getValue("method_details", index_asmb)) | |
| # | |
| if len(oper_list)==1: | |
| xform = np.stack([opers[o] for o in oper_list[0]]) | |
| elif len(oper_list)==2: | |
| xform = np.stack([opers[o1]@opers[o2] | |
| for o1 in oper_list[0] | |
| for o2 in oper_list[1]]) | |
| else: | |
| print('Error in processing assembly') | |
| return xforms | |
| xforms.update({'asmb_xform%d'%(index):xform}) | |
| xforms['asmb_chains'] = chains | |
| xforms['asmb_details'] = details | |
| xforms['asmb_method'] = method | |
| xforms['asmb_ids'] = ids | |
| return xforms | |
| def parse_mmcif(filename): | |
| #print(filename) | |
| chains = {} # 'chain_id' -> chain_strucure | |
| # read a gzipped .cif file | |
| data = [] | |
| with gzip.open(filename,'rt') as cif: | |
| reader = PdbxReader(cif) | |
| reader.read(data) | |
| data = data[0] | |
| # | |
| # get sequences | |
| # | |
| # map chain entity to chain ID | |
| entity_poly = data.getObj('entity_poly') | |
| if entity_poly is None: | |
| return {},{} | |
| pdbx_poly_seq_scheme = data.getObj('pdbx_poly_seq_scheme') | |
| pdb2asym = dict({ | |
| (r[pdbx_poly_seq_scheme.getIndex('pdb_strand_id')], | |
| r[pdbx_poly_seq_scheme.getIndex('asym_id')]) | |
| for r in data.getObj('pdbx_poly_seq_scheme').getRowList() | |
| }) | |
| chs2num = {pdb2asym[ch]:r[entity_poly.getIndex('entity_id')] | |
| for r in entity_poly.getRowList() | |
| for ch in r[entity_poly.getIndex('pdbx_strand_id')].split(',') | |
| if r[entity_poly.getIndex('type')]=='polypeptide(L)'} | |
| # get canonical sequences for polypeptide chains | |
| num2seq = {r[entity_poly.getIndex('entity_id')]:r[entity_poly.getIndex('pdbx_seq_one_letter_code_can')].replace('\n','') | |
| for r in entity_poly.getRowList() | |
| if r[entity_poly.getIndex('type')]=='polypeptide(L)'} | |
| # map chain entity to amino acid sequence | |
| #entity_poly_seq = data.getObj('entity_poly_seq') | |
| #num2seq = dict.fromkeys(set(chs2num.values()), "") | |
| #for row in entity_poly_seq.getRowList(): | |
| # num = row[entity_poly_seq.getIndex('entity_id')] | |
| # res = row[entity_poly_seq.getIndex('mon_id')] | |
| # if num not in num2seq.keys(): | |
| # continue | |
| # num2seq[num] += (to1letter[res] if res in to1letter.keys() else 'X') | |
| # modified residues | |
| pdbx_struct_mod_residue = data.getObj('pdbx_struct_mod_residue') | |
| if pdbx_struct_mod_residue is None: | |
| modres = {} | |
| else: | |
| modres = dict({(r[pdbx_struct_mod_residue.getIndex('label_comp_id')], | |
| r[pdbx_struct_mod_residue.getIndex('parent_comp_id')]) | |
| for r in pdbx_struct_mod_residue.getRowList()}) | |
| for k,v in modres.items(): | |
| print("# non-standard residue: %s %s"%(k,v)) | |
| # initialize dict of chains | |
| for c,n in chs2num.items(): | |
| seq = num2seq[n] | |
| L = len(seq) | |
| chains.update({c : {'seq' : seq, | |
| 'xyz' : np.full((L,14,3),np.nan,dtype=np.float32), | |
| 'mask' : np.zeros((L,14),dtype=bool), | |
| 'bfac' : np.full((L,14),np.nan,dtype=np.float32), | |
| 'occ' : np.zeros((L,14),dtype=np.float32) }}) | |
| # | |
| # populate structures | |
| # | |
| # get indices of fields of interest | |
| atom_site = data.getObj('atom_site') | |
| i = {k:atom_site.getIndex(val) for k,val in [('atm', 'label_atom_id'), # atom name | |
| ('atype', 'type_symbol'), # atom chemical type | |
| ('res', 'label_comp_id'), # residue name (3-letter) | |
| #('chid', 'auth_asym_id'), # chain ID | |
| ('chid', 'label_asym_id'), # chain ID | |
| ('num', 'label_seq_id'), # sequence number | |
| ('alt', 'label_alt_id'), # alternative location ID | |
| ('x', 'Cartn_x'), # xyz coords | |
| ('y', 'Cartn_y'), | |
| ('z', 'Cartn_z'), | |
| ('occ', 'occupancy'), # occupancy | |
| ('bfac', 'B_iso_or_equiv'), # B-factors | |
| ('model', 'pdbx_PDB_model_num') # model number (for multi-model PDBs, e.g. NMR) | |
| ]} | |
| for a in atom_site.getRowList(): | |
| # skip HETATM | |
| #if a[0] != 'ATOM': | |
| # continue | |
| # skip hydrogens | |
| if a[i['atype']] == 'H': | |
| continue | |
| # skip if not a polypeptide | |
| if a[i['chid']] not in chains.keys(): | |
| continue | |
| # parse atom | |
| atm, res, chid, num, alt, x, y, z, occ, Bfac, model = \ | |
| (t(a[i[k]]) for k,t in (('atm',str), ('res',str), ('chid',str), | |
| ('num',int), ('alt',str), | |
| ('x',float), ('y',float), ('z',float), | |
| ('occ',float), ('bfac',float), ('model',int))) | |
| #print(atm, res, chid, num, alt, x, y, z, occ, Bfac, model) | |
| c = chains[chid] | |
| # remap residue to canonical | |
| a = c['seq'][num-1] | |
| if a in to3letter.keys(): | |
| res = to3letter[a] | |
| else: | |
| if res in modres.keys() and modres[res] in to1letter.keys(): | |
| res = modres[res] | |
| c['seq'] = c['seq'][:num-1] + to1letter[res] + c['seq'][num:] | |
| else: | |
| res = 'GLY' | |
| # skip if not a standard residue/atom | |
| if (res,atm) not in aa2idx.keys(): | |
| continue | |
| # skip everything except model #1 | |
| if model > 1: | |
| continue | |
| # populate chians using max occup atoms | |
| idx = (num-1, aa2idx[(res,atm)]) | |
| if occ > c['occ'][idx]: | |
| c['xyz'][idx] = [x,y,z] | |
| c['mask'][idx] = True | |
| c['occ'][idx] = occ | |
| c['bfac'][idx] = Bfac | |
| # | |
| # metadata | |
| # | |
| #if data.getObj('reflns') is not None: | |
| # res = data.getObj('reflns').getValue('d_resolution_high',0) | |
| res = None | |
| if data.getObj('refine') is not None: | |
| try: | |
| res = float(data.getObj('refine').getValue('ls_d_res_high',0)) | |
| except: | |
| res = None | |
| if (data.getObj('em_3d_reconstruction') is not None) and (res is None): | |
| try: | |
| res = float(data.getObj('em_3d_reconstruction').getValue('resolution',0)) | |
| except: | |
| res = None | |
| chids = list(chains.keys()) | |
| seq = [] | |
| for ch in chids: | |
| mask = chains[ch]['mask'][:,:3].sum(1)==3 | |
| ref_seq = chains[ch]['seq'] | |
| atom_seq = ''.join([a if m else '-' for a,m in zip(ref_seq,mask)]) | |
| seq.append([ref_seq,atom_seq]) | |
| metadata = { | |
| 'method' : data.getObj('exptl').getValue('method',0).replace(' ','_'), | |
| 'date' : data.getObj('pdbx_database_status').getValue('recvd_initial_deposition_date',0), | |
| 'resolution' : res, | |
| 'chains' : chids, | |
| 'seq' : seq, | |
| 'id' : data.getObj('entry').getValue('id',0) | |
| } | |
| # | |
| # assemblies | |
| # | |
| asmbs = parseAssemblies(data,chains) | |
| metadata.update(asmbs) | |
| return chains, metadata | |
| IN = sys.argv[1] | |
| OUT = sys.argv[2] | |
| chains,metadata = parse_mmcif(IN) | |
| ID = metadata['id'] | |
| tm_pairs = get_tm_pairs(chains) | |
| if 'chains' in metadata.keys() and len(metadata['chains'])>0: | |
| chids = metadata['chains'] | |
| tm = [] | |
| for a in chids: | |
| tm_a = [] | |
| for b in chids: | |
| tm_ab = tm_pairs[(a,b)] | |
| if tm_ab is None: | |
| tm_a.append([0.0,0.0,999.9]) | |
| else: | |
| tm_a.append([tm_ab[k] for k in ['tm','seqid','rmsd']]) | |
| tm.append(tm_a) | |
| metadata.update({'tm':tm}) | |
| for k,v in chains.items(): | |
| nres = (v['mask'][:,:3].sum(1)==3).sum() | |
| print(">%s_%s %s %s %s %d %d\n%s"%(ID,k,metadata['date'],metadata['method'], | |
| metadata['resolution'],len(v['seq']),nres,v['seq'])) | |
| torch.save({kc:torch.Tensor(vc) if kc!='seq' else str(vc) | |
| for kc,vc in v.items()}, f"{OUT}_{k}.pt") | |
| meta_pt = {} | |
| for k,v in metadata.items(): | |
| if "asmb_xform" in k or k=="tm": | |
| v = torch.Tensor(v) | |
| meta_pt.update({k:v}) | |
| torch.save(meta_pt, f"{OUT}.pt") | |