Fill-Mask
Transformers
Safetensors
esm
Sophia Vincoff
mutation prediction discovery and recovery
3efa812
##### Discover mutations in new sequences. A tool
import fuson_plm.benchmarking.mutation_prediction.discovery.config as config
import os
import pickle
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
import pandas as pd
import numpy as np
import transformers
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn.functional as F
from fuson_plm.utils.logging import open_logfile, log_update, get_local_time, print_configpy
from fuson_plm.utils.embedding import load_esm2_type
from fuson_plm.utils.visualizing import set_font
from fuson_plm.benchmarking.mutation_prediction.recovery.recover import check_env_variables, predict_positionwise_mutations
from fuson_plm.benchmarking.mutation_prediction.discovery.plot import plot_conservation_heatmap, plot_full_heatmap
def check_seq_inputs(sequence, AAs_tokens):
# checking sequence inputs for validity
if not sequence.strip():
raise Exception("Error: The sequence input is empty. Please enter a valid protein sequence.")
return None, None, None
if any(char not in AAs_tokens for char in sequence):
raise Exception("Error: The sequence input contains non-amino acid characters. Please enter a valid protein sequence.")
return None, None, None
def check_domain_bounds(domain_bounds):
try:
start = int(domain_bounds['start'])
end = int(domain_bounds['end'])
return start, end
except ValueError:
raise Exception("Error: Start and end indices must be integers.")
return None, None
if start >= end:
raise Exception("Start index must be smaller than end index.")
return None, None
if start == 0 and end != 0:
raise Exception("Indexing starts at 1. Please enter valid domain bounds.")
return None, None
if start <= 0 or end <= 0:
raise Exception("Domain bounds must be positive integers. Please enter valid domain bounds.")
return None, None
if start > len(sequence) or end > len(sequence):
raise Exception("Domain bounds exceed sequence length.")
return None, None
def check_n_input(n):
if n < 1:
raise Exception("Choose N>=1")
return None, None, None
def predict_positionwise_mutations(sequence, domain_bounds, n, model, tokenizer, device):
# Define amino acids and their token indices
AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
AAs_tokens_indices = {'L' : 4, 'A' : 5, 'G' : 6, 'V': 7, 'S' : 8, 'E' : 9, 'R' : 10, 'T' : 11, 'I': 12, 'D' : 13, 'P' : 14,
'K' : 15, 'Q' : 16, 'N' : 17, 'F' : 18, 'Y' : 19, 'M' : 20, 'H' : 21, 'W' : 22, 'C' : 23}
# checking all inputs for validity
log_update("\nChecking validity of sequence input, domain bounds, and N mutations")
check_seq_inputs(sequence, AAs_tokens)
start, end = check_domain_bounds(domain_bounds)
check_n_input(n)
# define start_index as start - 1 (because residues are 1-indexed, while Python is 0-indexed). end is same
start_index = start - 1
end_index = end
# place to store top n mutations and all logits
top_n_mutations = {}
top_n_advantage_mutations = {}
top_n_disadvantage_mutations = {}
logits_for_each_AA = []
llrs_for_each_AA = []
# storage for the conservation heatmap
originals_logits = []
conservation_likelihoods = {}
log_update("\nCalculating mutations. Printing currently masked position and mutation results.")
for i in range(len(sequence)):
# only iterate through the residues inside the domain
if start_index <= i <= (end_index - 1):
# isolate original residue and its index
original_residue = sequence[i]
original_residue_index = AAs_tokens_indices[original_residue]
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
# prepare log
masked_seq_list = list(sequence[:i]) + ['<mask>'] + list(sequence[i+1:])
log_starti = i-min(5, i)
log_endi = i+5
log_update(f"\t{i+1}: residue = {original_residue}, masked sequence preview (pos {log_starti+1}-{log_endi}) = {''.join(masked_seq_list[log_starti:log_endi])}")
# prepare inputs
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=len(masked_seq)+2)
inputs = {k: v.to(device) for k, v in inputs.items()}
# forward pass
with torch.no_grad():
logits = model(**inputs).logits
# Find masked positions and extract their logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = logits[0, mask_token_index, :] # shape: [1, vocab_size] == [1, 33]. logits for each vocab word at this position in the sequence
# Collect logits for the full heamtap
logits_array = mask_token_logits.cpu().numpy() # shape: [1, 33]
# filter out non-amino acid tokens
filtered_indices = list(range(4, 23 + 1)) # filtered indices are indices of amino acids
filtered_logits = logits_array[:, filtered_indices] # shape: [1, 20] only the 20 amino acids
logits_for_each_AA.append(filtered_logits) # get logits for each amino acid
# Collect LLRs for the LLR heatmap
log_probabilities = F.log_softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).squeeze(0) # take log softmax of the [33] dimension
log_prob_og = log_probabilities[original_residue_index] # get the log probability of the TRUE residue underneath the mask
llrs = torch.tensor([(x-log_prob_og) for x in log_probabilities]) # calculate the LLR
#print(original_residue_index, llrs)
filtered_llrs = llrs[filtered_indices].numpy() # filter so it's [20], just the amino acids; only save this
filtered_llrs_array = np.array([filtered_llrs])
llrs_for_each_AA.append(filtered_llrs_array)
######### Top tokens
# Get top tokens based on LOGITS
all_tokens_logits = mask_token_logits.squeeze(0) # shape: [vocab_size] == [33]
top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True) # sort the logits
mutation = []
# make sure we don't include non-AA tokens
for token_index in top_tokens_indices:
decoded_token = tokenizer.decode([token_index.item()])
# decoded all tokens, pick the top n amino acid ones
if decoded_token in AAs_tokens:
mutation.append(decoded_token)
if len(mutation) == n:
break
top_n_mutations[(sequence[i], i)] = mutation
log_update(f"\t\ttop {n} predicted AAs: {','.join(mutation)}")
# Get top tokens based on LLR
top_advantage_tokens_indices = torch.argsort(llrs, dim=0, descending=True) # sort the LLRs
advantage_mutation = []
# make sure we don't include non-AA tokens
for token_index in top_advantage_tokens_indices:
decoded_token = tokenizer.decode([token_index.item()])
# decoded all tokens, pick the top n amino acid ones
if decoded_token in AAs_tokens:
advantage_mutation.append(decoded_token)
if len(advantage_mutation) == n:
break
top_n_advantage_mutations[(sequence[i], i)] = advantage_mutation
log_update(f"\t\ttop {n} predicted advantageous mutations: {','.join(advantage_mutation)}")
# Get top tokens based on LLR
top_disadvantage_tokens_indices = torch.argsort(llrs, dim=0, descending=False) # sort the LLRs
disadvantage_mutation = []
# make sure we don't include non-AA tokens
for token_index in top_disadvantage_tokens_indices:
decoded_token = tokenizer.decode([token_index.item()])
# decoded all tokens, pick the top n amino acid ones
if decoded_token in AAs_tokens:
disadvantage_mutation.append(decoded_token)
if len(disadvantage_mutation) == n:
break
top_n_disadvantage_mutations[(sequence[i], i)] = disadvantage_mutation
log_update(f"\t\ttop {n} predicted disadvantageous mutations: {','.join(disadvantage_mutation)}")
# fill in the logits and conservation likelihoods for the second array
normalized_mask_token_logits = F.softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).numpy()
normalized_mask_token_logits = np.squeeze(normalized_mask_token_logits)
originals_logit = normalized_mask_token_logits[original_residue_index]
originals_logits.append(originals_logit)
# a region is conserved if the probability of the amino acid is > 0.7; AKA, probability of a mutation is <= 0.3
if originals_logit > 0.7:
conservation_likelihoods[(original_residue, i)] = 1
log_update("\t\tConserved position")
else:
conservation_likelihoods[(original_residue, i)] = 0
log_update("\t\tNot conserved position")
# return a dictionary of all the things we need for the next part
return {
'start': start,
'end': end,
'originals_logits': originals_logits,
'conservation_likelihoods': conservation_likelihoods,
'logits': logits,
'filtered_indices': filtered_indices,
'top_n_mutations': top_n_mutations,
'top_n_advantage_mutations': top_n_advantage_mutations,
'top_n_disadvantage_mutations': top_n_disadvantage_mutations,
'logits_for_each_AA': logits_for_each_AA,
'llrs_for_each_AA': llrs_for_each_AA
}
def find_top_3(d):
temp = pd.DataFrame.from_dict(d, orient='index').reset_index()
temp = temp.sort_values(by=0,ascending=False).reset_index(drop=True)
temp = temp.iloc[0:3,:]
return_d = dict(zip(temp['index'],temp[0]))
return return_d
def make_full_results_df(mutation_results, tokenizer, original_sequence):
# Unpack mutation results
logits = mutation_results['logits']
logits_for_each_AA = mutation_results['logits_for_each_AA']
filtered_indices = mutation_results['filtered_indices']
llrs_for_each_AA = mutation_results['llrs_for_each_AA']
token_indices = torch.arange(logits.size(-1))
tokens = [tokenizer.decode([idx]) for idx in token_indices]
filtered_tokens = [tokens[i] for i in filtered_indices]
all_logits_array = np.vstack(logits_for_each_AA)
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
transposed_logits_array = normalized_logits_array.T
df = pd.DataFrame(transposed_logits_array.T)
df.columns = filtered_tokens
df.index = list(range(1, len(df)+1))
df['all_logits'] = df[filtered_tokens].to_dict(orient='index')
df['top_3_mutations'] = df['all_logits'].apply(lambda x: find_top_3(x))
df['original_residue'] = list(original_sequence)
df['original_residue_logit'] = df.apply(lambda row: row['all_logits'][row['original_residue']],axis=1)
df = df[['original_residue','original_residue_logit','all_logits','top_3_mutations']]
df = df.reset_index().rename(columns={'index':'Residue'})
return df
def make_small_results_df(mutation_results):
conservation_likelihoods = mutation_results['conservation_likelihoods']
top_n_mutations = mutation_results['top_n_mutations']
# store the predicted mutations in a dataframe
original_residues = []
mutations = []
positions = []
conserved = []
for key, value in top_n_mutations.items():
original_residue, position = key
original_residues.append(original_residue)
mutations.append(','.join(value))
positions.append(position + 1)
for i, (key, value) in enumerate(conservation_likelihoods.items()):
original_residue, position = key
if original_residues[i]==original_residue: # it should, otherwise something is wrong
conserved.append(value)
df = pd.DataFrame({
'Original Residue': original_residues,
'Predicted Residues': mutations,
'Conserved': conserved,
'Position': positions
})
return df
def main():
# Make results directory
os.makedirs('results',exist_ok=True)
output_dir = f'results/{get_local_time()}'
os.makedirs(output_dir,exist_ok=True)
# Predict mutations, writing results to a log inside of the output directory
with open_logfile(f"{output_dir}/mutation_discovery_log.txt"):
print_configpy(config)
# Make sure environment variables are set correctly
check_env_variables()
# Get device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log_update(f"Using device: {device}")
# Load fuson as AutoModelForMaskedLM
fuson_ckpt_path = config.FUSON_PLM_CKPT
if fuson_ckpt_path=="FusOn-pLM":
fuson_ckpt_path="../../../.."
model_name = "fuson_plm"
model_epoch = "best"
model_str = f"fuson_plm/best"
else:
model_name = list(fuson_ckpt_path.keys())[0]
epoch = list(fuson_ckpt_path.values())[0]
fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}'
model_name, model_epoch = fuson_ckpt_path.split('/')[-2::]
model_epoch = model_epoch.split('checkpoint_')[-1]
model_str = f"{model_name}/{model_epoch}"
log_update(f"\nLoading FusOn-pLM model from {fuson_ckpt_path}")
fuson_tokenizer = AutoTokenizer.from_pretrained(fuson_ckpt_path)
fuson_model = AutoModelForMaskedLM.from_pretrained(fuson_ckpt_path)
fuson_model.to(device)
fuson_model.eval()
if (config.PATH_TO_INPUT_FILE is not None) and os.path.exists(config.PATH_TO_INPUT_FILE):
input_file = pd.read_csv(config.PATH_TO_INPUT_FILE)
else:
input_file = pd.DataFrame(
data={
'fusion_name': [config.FUSION_NAME],
'full_fusion_sequence': [config.FULL_FUSION_SEQUENCE],
'start_residue_index': [config.START_RESIDUE_INDEX],
'end_residue_index': [config.END_RESIDUE_INDEX],
'n': [config.N]
}
)
log_update(f"\nThere are {len(input_file)} total sequences on which to perform mutation discovery. Fusion Genes:")
log_update("\t" + "\n\t".join(input_file['fusion_name']))
# Loop through each input and make a subfolder with its data
for i in range(len(input_file)):
row = input_file.loc[i,:]
fusion_name = row['fusion_name']
full_fusion_sequence = row['full_fusion_sequence']
start_residue_index = row['start_residue_index']
end_residue_index = row['end_residue_index']
n = row['n']
sub_output_dir = f"{output_dir}/{fusion_name}"
os.makedirs(sub_output_dir,exist_ok=True)
# Predict postionwise mutations, plot the results
domain_bounds = {'start': start_residue_index, 'end': end_residue_index}
mutation_results = predict_positionwise_mutations(full_fusion_sequence, domain_bounds, n,
fuson_model, fuson_tokenizer, device)
# Save mutation results
with open(f"{sub_output_dir}/raw_mutation_results.pkl", "wb") as f:
pickle.dump(mutation_results, f)
# Plot the heatmaps
plot_full_heatmap(mutation_results, fuson_tokenizer,
fusion_name=fusion_name, save_path=f"{sub_output_dir}/full_heatmap.png")
plot_conservation_heatmap(mutation_results,
fusion_name=fusion_name, save_path=f"{sub_output_dir}/conservation_heatmap.png")
# Make results dataframe
small_mutation_results_df = make_small_results_df(mutation_results)
small_mutation_results_df.to_csv(f"{sub_output_dir}/predicted_tokens.csv",index=False)
full_mutation_results_df = make_full_results_df(mutation_results, fuson_tokenizer, full_fusion_sequence)
full_mutation_results_df.to_csv(f"{sub_output_dir}/full_results_with_logits.csv",index=False)
if __name__ == "__main__":
main()