##### Discover mutations in new sequences. A tool import fuson_plm.benchmarking.mutation_prediction.discovery.config as config import os import pickle os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES import pandas as pd import numpy as np import transformers from transformers import AutoTokenizer, AutoModelForMaskedLM import logging import torch import matplotlib.pyplot as plt import seaborn as sns import torch.nn.functional as F from fuson_plm.utils.logging import open_logfile, log_update, get_local_time, print_configpy from fuson_plm.utils.embedding import load_esm2_type from fuson_plm.utils.visualizing import set_font from fuson_plm.benchmarking.mutation_prediction.recovery.recover import check_env_variables, predict_positionwise_mutations from fuson_plm.benchmarking.mutation_prediction.discovery.plot import plot_conservation_heatmap, plot_full_heatmap def check_seq_inputs(sequence, AAs_tokens): # checking sequence inputs for validity if not sequence.strip(): raise Exception("Error: The sequence input is empty. Please enter a valid protein sequence.") return None, None, None if any(char not in AAs_tokens for char in sequence): raise Exception("Error: The sequence input contains non-amino acid characters. Please enter a valid protein sequence.") return None, None, None def check_domain_bounds(domain_bounds): try: start = int(domain_bounds['start']) end = int(domain_bounds['end']) return start, end except ValueError: raise Exception("Error: Start and end indices must be integers.") return None, None if start >= end: raise Exception("Start index must be smaller than end index.") return None, None if start == 0 and end != 0: raise Exception("Indexing starts at 1. Please enter valid domain bounds.") return None, None if start <= 0 or end <= 0: raise Exception("Domain bounds must be positive integers. Please enter valid domain bounds.") return None, None if start > len(sequence) or end > len(sequence): raise Exception("Domain bounds exceed sequence length.") return None, None def check_n_input(n): if n < 1: raise Exception("Choose N>=1") return None, None, None def predict_positionwise_mutations(sequence, domain_bounds, n, model, tokenizer, device): # Define amino acids and their token indices AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C'] AAs_tokens_indices = {'L' : 4, 'A' : 5, 'G' : 6, 'V': 7, 'S' : 8, 'E' : 9, 'R' : 10, 'T' : 11, 'I': 12, 'D' : 13, 'P' : 14, 'K' : 15, 'Q' : 16, 'N' : 17, 'F' : 18, 'Y' : 19, 'M' : 20, 'H' : 21, 'W' : 22, 'C' : 23} # checking all inputs for validity log_update("\nChecking validity of sequence input, domain bounds, and N mutations") check_seq_inputs(sequence, AAs_tokens) start, end = check_domain_bounds(domain_bounds) check_n_input(n) # define start_index as start - 1 (because residues are 1-indexed, while Python is 0-indexed). end is same start_index = start - 1 end_index = end # place to store top n mutations and all logits top_n_mutations = {} top_n_advantage_mutations = {} top_n_disadvantage_mutations = {} logits_for_each_AA = [] llrs_for_each_AA = [] # storage for the conservation heatmap originals_logits = [] conservation_likelihoods = {} log_update("\nCalculating mutations. Printing currently masked position and mutation results.") for i in range(len(sequence)): # only iterate through the residues inside the domain if start_index <= i <= (end_index - 1): # isolate original residue and its index original_residue = sequence[i] original_residue_index = AAs_tokens_indices[original_residue] masked_seq = sequence[:i] + '' + sequence[i+1:] # prepare log masked_seq_list = list(sequence[:i]) + [''] + list(sequence[i+1:]) log_starti = i-min(5, i) log_endi = i+5 log_update(f"\t{i+1}: residue = {original_residue}, masked sequence preview (pos {log_starti+1}-{log_endi}) = {''.join(masked_seq_list[log_starti:log_endi])}") # prepare inputs inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=len(masked_seq)+2) inputs = {k: v.to(device) for k, v in inputs.items()} # forward pass with torch.no_grad(): logits = model(**inputs).logits # Find masked positions and extract their logits mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] mask_token_logits = logits[0, mask_token_index, :] # shape: [1, vocab_size] == [1, 33]. logits for each vocab word at this position in the sequence # Collect logits for the full heamtap logits_array = mask_token_logits.cpu().numpy() # shape: [1, 33] # filter out non-amino acid tokens filtered_indices = list(range(4, 23 + 1)) # filtered indices are indices of amino acids filtered_logits = logits_array[:, filtered_indices] # shape: [1, 20] only the 20 amino acids logits_for_each_AA.append(filtered_logits) # get logits for each amino acid # Collect LLRs for the LLR heatmap log_probabilities = F.log_softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).squeeze(0) # take log softmax of the [33] dimension log_prob_og = log_probabilities[original_residue_index] # get the log probability of the TRUE residue underneath the mask llrs = torch.tensor([(x-log_prob_og) for x in log_probabilities]) # calculate the LLR #print(original_residue_index, llrs) filtered_llrs = llrs[filtered_indices].numpy() # filter so it's [20], just the amino acids; only save this filtered_llrs_array = np.array([filtered_llrs]) llrs_for_each_AA.append(filtered_llrs_array) ######### Top tokens # Get top tokens based on LOGITS all_tokens_logits = mask_token_logits.squeeze(0) # shape: [vocab_size] == [33] top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True) # sort the logits mutation = [] # make sure we don't include non-AA tokens for token_index in top_tokens_indices: decoded_token = tokenizer.decode([token_index.item()]) # decoded all tokens, pick the top n amino acid ones if decoded_token in AAs_tokens: mutation.append(decoded_token) if len(mutation) == n: break top_n_mutations[(sequence[i], i)] = mutation log_update(f"\t\ttop {n} predicted AAs: {','.join(mutation)}") # Get top tokens based on LLR top_advantage_tokens_indices = torch.argsort(llrs, dim=0, descending=True) # sort the LLRs advantage_mutation = [] # make sure we don't include non-AA tokens for token_index in top_advantage_tokens_indices: decoded_token = tokenizer.decode([token_index.item()]) # decoded all tokens, pick the top n amino acid ones if decoded_token in AAs_tokens: advantage_mutation.append(decoded_token) if len(advantage_mutation) == n: break top_n_advantage_mutations[(sequence[i], i)] = advantage_mutation log_update(f"\t\ttop {n} predicted advantageous mutations: {','.join(advantage_mutation)}") # Get top tokens based on LLR top_disadvantage_tokens_indices = torch.argsort(llrs, dim=0, descending=False) # sort the LLRs disadvantage_mutation = [] # make sure we don't include non-AA tokens for token_index in top_disadvantage_tokens_indices: decoded_token = tokenizer.decode([token_index.item()]) # decoded all tokens, pick the top n amino acid ones if decoded_token in AAs_tokens: disadvantage_mutation.append(decoded_token) if len(disadvantage_mutation) == n: break top_n_disadvantage_mutations[(sequence[i], i)] = disadvantage_mutation log_update(f"\t\ttop {n} predicted disadvantageous mutations: {','.join(disadvantage_mutation)}") # fill in the logits and conservation likelihoods for the second array normalized_mask_token_logits = F.softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).numpy() normalized_mask_token_logits = np.squeeze(normalized_mask_token_logits) originals_logit = normalized_mask_token_logits[original_residue_index] originals_logits.append(originals_logit) # a region is conserved if the probability of the amino acid is > 0.7; AKA, probability of a mutation is <= 0.3 if originals_logit > 0.7: conservation_likelihoods[(original_residue, i)] = 1 log_update("\t\tConserved position") else: conservation_likelihoods[(original_residue, i)] = 0 log_update("\t\tNot conserved position") # return a dictionary of all the things we need for the next part return { 'start': start, 'end': end, 'originals_logits': originals_logits, 'conservation_likelihoods': conservation_likelihoods, 'logits': logits, 'filtered_indices': filtered_indices, 'top_n_mutations': top_n_mutations, 'top_n_advantage_mutations': top_n_advantage_mutations, 'top_n_disadvantage_mutations': top_n_disadvantage_mutations, 'logits_for_each_AA': logits_for_each_AA, 'llrs_for_each_AA': llrs_for_each_AA } def find_top_3(d): temp = pd.DataFrame.from_dict(d, orient='index').reset_index() temp = temp.sort_values(by=0,ascending=False).reset_index(drop=True) temp = temp.iloc[0:3,:] return_d = dict(zip(temp['index'],temp[0])) return return_d def make_full_results_df(mutation_results, tokenizer, original_sequence): # Unpack mutation results logits = mutation_results['logits'] logits_for_each_AA = mutation_results['logits_for_each_AA'] filtered_indices = mutation_results['filtered_indices'] llrs_for_each_AA = mutation_results['llrs_for_each_AA'] token_indices = torch.arange(logits.size(-1)) tokens = [tokenizer.decode([idx]) for idx in token_indices] filtered_tokens = [tokens[i] for i in filtered_indices] all_logits_array = np.vstack(logits_for_each_AA) normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy() transposed_logits_array = normalized_logits_array.T df = pd.DataFrame(transposed_logits_array.T) df.columns = filtered_tokens df.index = list(range(1, len(df)+1)) df['all_logits'] = df[filtered_tokens].to_dict(orient='index') df['top_3_mutations'] = df['all_logits'].apply(lambda x: find_top_3(x)) df['original_residue'] = list(original_sequence) df['original_residue_logit'] = df.apply(lambda row: row['all_logits'][row['original_residue']],axis=1) df = df[['original_residue','original_residue_logit','all_logits','top_3_mutations']] df = df.reset_index().rename(columns={'index':'Residue'}) return df def make_small_results_df(mutation_results): conservation_likelihoods = mutation_results['conservation_likelihoods'] top_n_mutations = mutation_results['top_n_mutations'] # store the predicted mutations in a dataframe original_residues = [] mutations = [] positions = [] conserved = [] for key, value in top_n_mutations.items(): original_residue, position = key original_residues.append(original_residue) mutations.append(','.join(value)) positions.append(position + 1) for i, (key, value) in enumerate(conservation_likelihoods.items()): original_residue, position = key if original_residues[i]==original_residue: # it should, otherwise something is wrong conserved.append(value) df = pd.DataFrame({ 'Original Residue': original_residues, 'Predicted Residues': mutations, 'Conserved': conserved, 'Position': positions }) return df def main(): # Make results directory os.makedirs('results',exist_ok=True) output_dir = f'results/{get_local_time()}' os.makedirs(output_dir,exist_ok=True) # Predict mutations, writing results to a log inside of the output directory with open_logfile(f"{output_dir}/mutation_discovery_log.txt"): print_configpy(config) # Make sure environment variables are set correctly check_env_variables() # Get device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") log_update(f"Using device: {device}") # Load fuson as AutoModelForMaskedLM fuson_ckpt_path = config.FUSON_PLM_CKPT if fuson_ckpt_path=="FusOn-pLM": fuson_ckpt_path="../../../.." model_name = "fuson_plm" model_epoch = "best" model_str = f"fuson_plm/best" else: model_name = list(fuson_ckpt_path.keys())[0] epoch = list(fuson_ckpt_path.values())[0] fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}' model_name, model_epoch = fuson_ckpt_path.split('/')[-2::] model_epoch = model_epoch.split('checkpoint_')[-1] model_str = f"{model_name}/{model_epoch}" log_update(f"\nLoading FusOn-pLM model from {fuson_ckpt_path}") fuson_tokenizer = AutoTokenizer.from_pretrained(fuson_ckpt_path) fuson_model = AutoModelForMaskedLM.from_pretrained(fuson_ckpt_path) fuson_model.to(device) fuson_model.eval() if (config.PATH_TO_INPUT_FILE is not None) and os.path.exists(config.PATH_TO_INPUT_FILE): input_file = pd.read_csv(config.PATH_TO_INPUT_FILE) else: input_file = pd.DataFrame( data={ 'fusion_name': [config.FUSION_NAME], 'full_fusion_sequence': [config.FULL_FUSION_SEQUENCE], 'start_residue_index': [config.START_RESIDUE_INDEX], 'end_residue_index': [config.END_RESIDUE_INDEX], 'n': [config.N] } ) log_update(f"\nThere are {len(input_file)} total sequences on which to perform mutation discovery. Fusion Genes:") log_update("\t" + "\n\t".join(input_file['fusion_name'])) # Loop through each input and make a subfolder with its data for i in range(len(input_file)): row = input_file.loc[i,:] fusion_name = row['fusion_name'] full_fusion_sequence = row['full_fusion_sequence'] start_residue_index = row['start_residue_index'] end_residue_index = row['end_residue_index'] n = row['n'] sub_output_dir = f"{output_dir}/{fusion_name}" os.makedirs(sub_output_dir,exist_ok=True) # Predict postionwise mutations, plot the results domain_bounds = {'start': start_residue_index, 'end': end_residue_index} mutation_results = predict_positionwise_mutations(full_fusion_sequence, domain_bounds, n, fuson_model, fuson_tokenizer, device) # Save mutation results with open(f"{sub_output_dir}/raw_mutation_results.pkl", "wb") as f: pickle.dump(mutation_results, f) # Plot the heatmaps plot_full_heatmap(mutation_results, fuson_tokenizer, fusion_name=fusion_name, save_path=f"{sub_output_dir}/full_heatmap.png") plot_conservation_heatmap(mutation_results, fusion_name=fusion_name, save_path=f"{sub_output_dir}/conservation_heatmap.png") # Make results dataframe small_mutation_results_df = make_small_results_df(mutation_results) small_mutation_results_df.to_csv(f"{sub_output_dir}/predicted_tokens.csv",index=False) full_mutation_results_df = make_full_results_df(mutation_results, fuson_tokenizer, full_fusion_sequence) full_mutation_results_df.to_csv(f"{sub_output_dir}/full_results_with_logits.csv",index=False) if __name__ == "__main__": main()