|
|
|
|
|
import fuson_plm.benchmarking.mutation_prediction.discovery.config as config |
|
|
import os |
|
|
import pickle |
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES |
|
|
|
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import transformers |
|
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
|
import logging |
|
|
import torch |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from fuson_plm.utils.logging import open_logfile, log_update, get_local_time, print_configpy |
|
|
from fuson_plm.utils.embedding import load_esm2_type |
|
|
from fuson_plm.utils.visualizing import set_font |
|
|
from fuson_plm.benchmarking.mutation_prediction.recovery.recover import check_env_variables, predict_positionwise_mutations |
|
|
from fuson_plm.benchmarking.mutation_prediction.discovery.plot import plot_conservation_heatmap, plot_full_heatmap |
|
|
|
|
|
def check_seq_inputs(sequence, AAs_tokens): |
|
|
|
|
|
if not sequence.strip(): |
|
|
raise Exception("Error: The sequence input is empty. Please enter a valid protein sequence.") |
|
|
return None, None, None |
|
|
if any(char not in AAs_tokens for char in sequence): |
|
|
raise Exception("Error: The sequence input contains non-amino acid characters. Please enter a valid protein sequence.") |
|
|
return None, None, None |
|
|
|
|
|
def check_domain_bounds(domain_bounds): |
|
|
try: |
|
|
start = int(domain_bounds['start']) |
|
|
end = int(domain_bounds['end']) |
|
|
return start, end |
|
|
except ValueError: |
|
|
raise Exception("Error: Start and end indices must be integers.") |
|
|
return None, None |
|
|
if start >= end: |
|
|
raise Exception("Start index must be smaller than end index.") |
|
|
return None, None |
|
|
if start == 0 and end != 0: |
|
|
raise Exception("Indexing starts at 1. Please enter valid domain bounds.") |
|
|
return None, None |
|
|
if start <= 0 or end <= 0: |
|
|
raise Exception("Domain bounds must be positive integers. Please enter valid domain bounds.") |
|
|
return None, None |
|
|
if start > len(sequence) or end > len(sequence): |
|
|
raise Exception("Domain bounds exceed sequence length.") |
|
|
return None, None |
|
|
|
|
|
def check_n_input(n): |
|
|
if n < 1: |
|
|
raise Exception("Choose N>=1") |
|
|
return None, None, None |
|
|
|
|
|
def predict_positionwise_mutations(sequence, domain_bounds, n, model, tokenizer, device): |
|
|
|
|
|
AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C'] |
|
|
AAs_tokens_indices = {'L' : 4, 'A' : 5, 'G' : 6, 'V': 7, 'S' : 8, 'E' : 9, 'R' : 10, 'T' : 11, 'I': 12, 'D' : 13, 'P' : 14, |
|
|
'K' : 15, 'Q' : 16, 'N' : 17, 'F' : 18, 'Y' : 19, 'M' : 20, 'H' : 21, 'W' : 22, 'C' : 23} |
|
|
|
|
|
|
|
|
log_update("\nChecking validity of sequence input, domain bounds, and N mutations") |
|
|
check_seq_inputs(sequence, AAs_tokens) |
|
|
start, end = check_domain_bounds(domain_bounds) |
|
|
check_n_input(n) |
|
|
|
|
|
|
|
|
start_index = start - 1 |
|
|
end_index = end |
|
|
|
|
|
|
|
|
top_n_mutations = {} |
|
|
top_n_advantage_mutations = {} |
|
|
top_n_disadvantage_mutations = {} |
|
|
logits_for_each_AA = [] |
|
|
llrs_for_each_AA = [] |
|
|
|
|
|
|
|
|
originals_logits = [] |
|
|
conservation_likelihoods = {} |
|
|
|
|
|
log_update("\nCalculating mutations. Printing currently masked position and mutation results.") |
|
|
for i in range(len(sequence)): |
|
|
|
|
|
if start_index <= i <= (end_index - 1): |
|
|
|
|
|
original_residue = sequence[i] |
|
|
original_residue_index = AAs_tokens_indices[original_residue] |
|
|
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:] |
|
|
|
|
|
|
|
|
masked_seq_list = list(sequence[:i]) + ['<mask>'] + list(sequence[i+1:]) |
|
|
log_starti = i-min(5, i) |
|
|
log_endi = i+5 |
|
|
log_update(f"\t{i+1}: residue = {original_residue}, masked sequence preview (pos {log_starti+1}-{log_endi}) = {''.join(masked_seq_list[log_starti:log_endi])}") |
|
|
|
|
|
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=len(masked_seq)+2) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs).logits |
|
|
|
|
|
|
|
|
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] |
|
|
mask_token_logits = logits[0, mask_token_index, :] |
|
|
|
|
|
|
|
|
logits_array = mask_token_logits.cpu().numpy() |
|
|
|
|
|
filtered_indices = list(range(4, 23 + 1)) |
|
|
filtered_logits = logits_array[:, filtered_indices] |
|
|
logits_for_each_AA.append(filtered_logits) |
|
|
|
|
|
|
|
|
log_probabilities = F.log_softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).squeeze(0) |
|
|
log_prob_og = log_probabilities[original_residue_index] |
|
|
llrs = torch.tensor([(x-log_prob_og) for x in log_probabilities]) |
|
|
|
|
|
filtered_llrs = llrs[filtered_indices].numpy() |
|
|
filtered_llrs_array = np.array([filtered_llrs]) |
|
|
llrs_for_each_AA.append(filtered_llrs_array) |
|
|
|
|
|
|
|
|
|
|
|
all_tokens_logits = mask_token_logits.squeeze(0) |
|
|
top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True) |
|
|
mutation = [] |
|
|
|
|
|
for token_index in top_tokens_indices: |
|
|
decoded_token = tokenizer.decode([token_index.item()]) |
|
|
|
|
|
if decoded_token in AAs_tokens: |
|
|
mutation.append(decoded_token) |
|
|
if len(mutation) == n: |
|
|
break |
|
|
top_n_mutations[(sequence[i], i)] = mutation |
|
|
log_update(f"\t\ttop {n} predicted AAs: {','.join(mutation)}") |
|
|
|
|
|
|
|
|
top_advantage_tokens_indices = torch.argsort(llrs, dim=0, descending=True) |
|
|
advantage_mutation = [] |
|
|
|
|
|
for token_index in top_advantage_tokens_indices: |
|
|
decoded_token = tokenizer.decode([token_index.item()]) |
|
|
|
|
|
if decoded_token in AAs_tokens: |
|
|
advantage_mutation.append(decoded_token) |
|
|
if len(advantage_mutation) == n: |
|
|
break |
|
|
top_n_advantage_mutations[(sequence[i], i)] = advantage_mutation |
|
|
log_update(f"\t\ttop {n} predicted advantageous mutations: {','.join(advantage_mutation)}") |
|
|
|
|
|
|
|
|
top_disadvantage_tokens_indices = torch.argsort(llrs, dim=0, descending=False) |
|
|
disadvantage_mutation = [] |
|
|
|
|
|
for token_index in top_disadvantage_tokens_indices: |
|
|
decoded_token = tokenizer.decode([token_index.item()]) |
|
|
|
|
|
if decoded_token in AAs_tokens: |
|
|
disadvantage_mutation.append(decoded_token) |
|
|
if len(disadvantage_mutation) == n: |
|
|
break |
|
|
top_n_disadvantage_mutations[(sequence[i], i)] = disadvantage_mutation |
|
|
log_update(f"\t\ttop {n} predicted disadvantageous mutations: {','.join(disadvantage_mutation)}") |
|
|
|
|
|
|
|
|
normalized_mask_token_logits = F.softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).numpy() |
|
|
normalized_mask_token_logits = np.squeeze(normalized_mask_token_logits) |
|
|
originals_logit = normalized_mask_token_logits[original_residue_index] |
|
|
originals_logits.append(originals_logit) |
|
|
|
|
|
|
|
|
if originals_logit > 0.7: |
|
|
conservation_likelihoods[(original_residue, i)] = 1 |
|
|
log_update("\t\tConserved position") |
|
|
else: |
|
|
conservation_likelihoods[(original_residue, i)] = 0 |
|
|
log_update("\t\tNot conserved position") |
|
|
|
|
|
|
|
|
return { |
|
|
'start': start, |
|
|
'end': end, |
|
|
'originals_logits': originals_logits, |
|
|
'conservation_likelihoods': conservation_likelihoods, |
|
|
'logits': logits, |
|
|
'filtered_indices': filtered_indices, |
|
|
'top_n_mutations': top_n_mutations, |
|
|
'top_n_advantage_mutations': top_n_advantage_mutations, |
|
|
'top_n_disadvantage_mutations': top_n_disadvantage_mutations, |
|
|
'logits_for_each_AA': logits_for_each_AA, |
|
|
'llrs_for_each_AA': llrs_for_each_AA |
|
|
} |
|
|
|
|
|
def find_top_3(d): |
|
|
temp = pd.DataFrame.from_dict(d, orient='index').reset_index() |
|
|
temp = temp.sort_values(by=0,ascending=False).reset_index(drop=True) |
|
|
temp = temp.iloc[0:3,:] |
|
|
return_d = dict(zip(temp['index'],temp[0])) |
|
|
return return_d |
|
|
|
|
|
def make_full_results_df(mutation_results, tokenizer, original_sequence): |
|
|
|
|
|
logits = mutation_results['logits'] |
|
|
logits_for_each_AA = mutation_results['logits_for_each_AA'] |
|
|
filtered_indices = mutation_results['filtered_indices'] |
|
|
llrs_for_each_AA = mutation_results['llrs_for_each_AA'] |
|
|
|
|
|
token_indices = torch.arange(logits.size(-1)) |
|
|
tokens = [tokenizer.decode([idx]) for idx in token_indices] |
|
|
filtered_tokens = [tokens[i] for i in filtered_indices] |
|
|
all_logits_array = np.vstack(logits_for_each_AA) |
|
|
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy() |
|
|
transposed_logits_array = normalized_logits_array.T |
|
|
|
|
|
df = pd.DataFrame(transposed_logits_array.T) |
|
|
df.columns = filtered_tokens |
|
|
df.index = list(range(1, len(df)+1)) |
|
|
df['all_logits'] = df[filtered_tokens].to_dict(orient='index') |
|
|
df['top_3_mutations'] = df['all_logits'].apply(lambda x: find_top_3(x)) |
|
|
df['original_residue'] = list(original_sequence) |
|
|
df['original_residue_logit'] = df.apply(lambda row: row['all_logits'][row['original_residue']],axis=1) |
|
|
df = df[['original_residue','original_residue_logit','all_logits','top_3_mutations']] |
|
|
df = df.reset_index().rename(columns={'index':'Residue'}) |
|
|
return df |
|
|
|
|
|
def make_small_results_df(mutation_results): |
|
|
conservation_likelihoods = mutation_results['conservation_likelihoods'] |
|
|
top_n_mutations = mutation_results['top_n_mutations'] |
|
|
|
|
|
|
|
|
original_residues = [] |
|
|
mutations = [] |
|
|
positions = [] |
|
|
conserved = [] |
|
|
|
|
|
for key, value in top_n_mutations.items(): |
|
|
original_residue, position = key |
|
|
original_residues.append(original_residue) |
|
|
mutations.append(','.join(value)) |
|
|
positions.append(position + 1) |
|
|
|
|
|
for i, (key, value) in enumerate(conservation_likelihoods.items()): |
|
|
original_residue, position = key |
|
|
if original_residues[i]==original_residue: |
|
|
conserved.append(value) |
|
|
|
|
|
df = pd.DataFrame({ |
|
|
'Original Residue': original_residues, |
|
|
'Predicted Residues': mutations, |
|
|
'Conserved': conserved, |
|
|
'Position': positions |
|
|
}) |
|
|
return df |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
os.makedirs('results',exist_ok=True) |
|
|
output_dir = f'results/{get_local_time()}' |
|
|
os.makedirs(output_dir,exist_ok=True) |
|
|
|
|
|
|
|
|
with open_logfile(f"{output_dir}/mutation_discovery_log.txt"): |
|
|
print_configpy(config) |
|
|
|
|
|
check_env_variables() |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
log_update(f"Using device: {device}") |
|
|
|
|
|
|
|
|
fuson_ckpt_path = config.FUSON_PLM_CKPT |
|
|
if fuson_ckpt_path=="FusOn-pLM": |
|
|
fuson_ckpt_path="../../../.." |
|
|
model_name = "fuson_plm" |
|
|
model_epoch = "best" |
|
|
model_str = f"fuson_plm/best" |
|
|
else: |
|
|
model_name = list(fuson_ckpt_path.keys())[0] |
|
|
epoch = list(fuson_ckpt_path.values())[0] |
|
|
fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}' |
|
|
model_name, model_epoch = fuson_ckpt_path.split('/')[-2::] |
|
|
model_epoch = model_epoch.split('checkpoint_')[-1] |
|
|
model_str = f"{model_name}/{model_epoch}" |
|
|
|
|
|
log_update(f"\nLoading FusOn-pLM model from {fuson_ckpt_path}") |
|
|
fuson_tokenizer = AutoTokenizer.from_pretrained(fuson_ckpt_path) |
|
|
fuson_model = AutoModelForMaskedLM.from_pretrained(fuson_ckpt_path) |
|
|
fuson_model.to(device) |
|
|
fuson_model.eval() |
|
|
|
|
|
if (config.PATH_TO_INPUT_FILE is not None) and os.path.exists(config.PATH_TO_INPUT_FILE): |
|
|
input_file = pd.read_csv(config.PATH_TO_INPUT_FILE) |
|
|
else: |
|
|
input_file = pd.DataFrame( |
|
|
data={ |
|
|
'fusion_name': [config.FUSION_NAME], |
|
|
'full_fusion_sequence': [config.FULL_FUSION_SEQUENCE], |
|
|
'start_residue_index': [config.START_RESIDUE_INDEX], |
|
|
'end_residue_index': [config.END_RESIDUE_INDEX], |
|
|
'n': [config.N] |
|
|
} |
|
|
) |
|
|
|
|
|
log_update(f"\nThere are {len(input_file)} total sequences on which to perform mutation discovery. Fusion Genes:") |
|
|
log_update("\t" + "\n\t".join(input_file['fusion_name'])) |
|
|
|
|
|
for i in range(len(input_file)): |
|
|
row = input_file.loc[i,:] |
|
|
fusion_name = row['fusion_name'] |
|
|
full_fusion_sequence = row['full_fusion_sequence'] |
|
|
start_residue_index = row['start_residue_index'] |
|
|
end_residue_index = row['end_residue_index'] |
|
|
n = row['n'] |
|
|
sub_output_dir = f"{output_dir}/{fusion_name}" |
|
|
os.makedirs(sub_output_dir,exist_ok=True) |
|
|
|
|
|
|
|
|
domain_bounds = {'start': start_residue_index, 'end': end_residue_index} |
|
|
mutation_results = predict_positionwise_mutations(full_fusion_sequence, domain_bounds, n, |
|
|
fuson_model, fuson_tokenizer, device) |
|
|
|
|
|
|
|
|
with open(f"{sub_output_dir}/raw_mutation_results.pkl", "wb") as f: |
|
|
pickle.dump(mutation_results, f) |
|
|
|
|
|
|
|
|
plot_full_heatmap(mutation_results, fuson_tokenizer, |
|
|
fusion_name=fusion_name, save_path=f"{sub_output_dir}/full_heatmap.png") |
|
|
plot_conservation_heatmap(mutation_results, |
|
|
fusion_name=fusion_name, save_path=f"{sub_output_dir}/conservation_heatmap.png") |
|
|
|
|
|
|
|
|
small_mutation_results_df = make_small_results_df(mutation_results) |
|
|
small_mutation_results_df.to_csv(f"{sub_output_dir}/predicted_tokens.csv",index=False) |
|
|
full_mutation_results_df = make_full_results_df(mutation_results, fuson_tokenizer, full_fusion_sequence) |
|
|
full_mutation_results_df.to_csv(f"{sub_output_dir}/full_results_with_logits.csv",index=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |