ablang2 / extra_utils.py
hemantn's picture
Integrate utility files into main repository - make self-contained
712d350
import string, re
import numpy as np
def res_to_list(logits, seq):
return logits[:len(seq)]
def res_to_seq(a, mode='mean'):
"""
Function for how we go from n_values for each amino acid to n_values for each sequence.
We leave out padding tokens.
"""
if mode=='sum':
return a[0:(int(a[-1]))].sum()
elif mode=='mean':
return a[0:(int(a[-1]))].mean()
elif mode=='restore':
return a[0][0:(int(a[-1]))]
def get_number_alignment(numbered_seqs):
"""
Creates a number alignment from the anarci results.
"""
import pandas as pd
alist = [pd.DataFrame(aligned_seq, columns = [0,1,'resi']) for aligned_seq in numbered_seqs]
unsorted_alignment = pd.concat(alist).drop_duplicates(subset=0)
max_alignment = get_max_alignment()
return max_alignment.merge(unsorted_alignment.query("resi!='-'"), left_on=0, right_on=0)[[0,1]]
def get_max_alignment():
"""
Create maximum possible alignment for sorting
"""
import pandas as pd
sortlist = [[("<", "")]]
for num in range(1, 128+1):
if num in [33,61,112]:
for char in string.ascii_uppercase[::-1]:
sortlist.append([(num, char)])
sortlist.append([(num,' ')])
else:
sortlist.append([(num,' ')])
for char in string.ascii_uppercase:
sortlist.append([(num, char)])
return pd.DataFrame(sortlist + [[(">", "")]])
def paired_msa_numbering(ab_seqs, fragmented = False, n_jobs = 10):
import pandas as pd
tmp_seqs = [pairs.replace(">", "").replace("<", "").split("|") for pairs in ab_seqs]
numbered_seqs_heavy, seqs_heavy, number_alignment_heavy = unpaired_msa_numbering(
[i[0] for i in tmp_seqs], 'H', fragmented = fragmented, n_jobs = n_jobs
)
numbered_seqs_light, seqs_light, number_alignment_light = unpaired_msa_numbering(
[i[1] for i in tmp_seqs], 'L', fragmented = fragmented, n_jobs = n_jobs
)
number_alignment = pd.concat([
number_alignment_heavy,
pd.DataFrame([[("|",""), "|"]]),
number_alignment_light]
).reset_index(drop=True)
seqs = [f"{heavy}|{light}" for heavy, light in zip(seqs_heavy, seqs_light)]
numbered_seqs = [
heavy + [(("|",""), "|", "|")] + light for heavy, light in zip(numbered_seqs_heavy, numbered_seqs_light)
]
return numbered_seqs, seqs, number_alignment
def unpaired_msa_numbering(seqs, chain = 'H', fragmented = False, n_jobs = 10):
numbered_seqs = number_with_anarci(seqs, chain = chain, fragmented = fragmented, n_jobs = n_jobs)
number_alignment = get_number_alignment(numbered_seqs)
number_alignment[1] = chain
seqs = [''.join([i[2] for i in numbered_seq]).replace('-','') for numbered_seq in numbered_seqs]
return numbered_seqs, seqs, number_alignment
def number_with_anarci(seqs, chain = 'H', fragmented = False, n_jobs = 1):
import anarci
import pandas as pd
anarci_out = anarci.run_anarci(
pd.DataFrame(seqs).reset_index().values.tolist(),
ncpu=n_jobs,
scheme='imgt',
allowed_species=['human', 'mouse'],
)
numbered_seqs = []
for onarci in anarci_out[1]:
numbered_seq = []
for i in onarci[0][0]:
if i[1] != '-':
numbered_seq.append((i[0], chain, i[1]))
if fragmented:
numbered_seqs.append(numbered_seq)
else:
numbered_seqs.append([(("<",""), chain, "<")] + numbered_seq + [((">",""), chain, ">")])
return numbered_seqs
def create_alignment(res_embeds, numbered_seqs, seq, number_alignment):
import pandas as pd
datadf = pd.DataFrame(numbered_seqs)
sequence_alignment = number_alignment.merge(datadf, how='left', on=[0, 1]).fillna('-')[2]
idxs = np.where(sequence_alignment.values == '-')[0]
idxs = [idx-num for num, idx in enumerate(idxs)]
aligned_embeds = pd.DataFrame(np.insert(res_embeds[:len(seq)], idxs , 0, axis=0))
return pd.concat([aligned_embeds, sequence_alignment], axis=1).values
def get_spread_sequences(seq, spread, start_position):
"""
Test sequences which are 8 positions shorter (position 10 + max CDR1 gap of 7) up to 2 positions longer (possible insertions).
"""
spread_sequences = []
for diff in range(start_position-8, start_position+2+1):
spread_sequences.append('*'*diff+seq)
return np.array(spread_sequences)
def get_sequences_from_anarci(out_anarci, max_position, spread):
"""
Ensures correct masking on each side of sequence
"""
if out_anarci == 'ANARCI_error':
return np.array(['ANARCI-ERR']*spread)
end_position = int(re.search(r'\d+', out_anarci[::-1]).group()[::-1])
# Fixes ANARCI error of poor numbering of the CDR1 region
start_position = int(re.search(r'\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+',
out_anarci).group().split(',')[0]) - 1
sequence = "".join(re.findall(r"(?i)[A-Z*]", "".join(re.findall(r'\),\s\'[A-Z*]', out_anarci))))
sequence_j = ''.join(sequence).replace('-','').replace('X','*') + '*'*(max_position-int(end_position))
return get_spread_sequences(sequence_j, spread, start_position)