Fill-Mask
Transformers
Safetensors
esm
File size: 16,571 Bytes
3efa812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
##### Discover mutations in new sequences. A tool
import fuson_plm.benchmarking.mutation_prediction.discovery.config as config
import os
import pickle
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES

import pandas as pd
import numpy as np
import transformers
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn.functional as F

from fuson_plm.utils.logging import open_logfile, log_update, get_local_time, print_configpy
from fuson_plm.utils.embedding import load_esm2_type
from fuson_plm.utils.visualizing import set_font
from fuson_plm.benchmarking.mutation_prediction.recovery.recover import check_env_variables, predict_positionwise_mutations
from fuson_plm.benchmarking.mutation_prediction.discovery.plot import plot_conservation_heatmap, plot_full_heatmap 

def check_seq_inputs(sequence, AAs_tokens):
  # checking sequence inputs for validity
  if not sequence.strip():
    raise Exception("Error: The sequence input is empty. Please enter a valid protein sequence.")
    return None, None, None
  if any(char not in AAs_tokens for char in sequence):
    raise Exception("Error: The sequence input contains non-amino acid characters. Please enter a valid protein sequence.")
    return None, None, None

def check_domain_bounds(domain_bounds):
  try:
    start = int(domain_bounds['start'])
    end = int(domain_bounds['end'])
    return start, end
  except ValueError:
    raise Exception("Error: Start and end indices must be integers.")
    return None, None
  if start >= end:
    raise Exception("Start index must be smaller than end index.")
    return None, None
  if start == 0 and end != 0:
    raise Exception("Indexing starts at 1. Please enter valid domain bounds.")
    return None, None
  if start <= 0 or end <= 0:
    raise Exception("Domain bounds must be positive integers. Please enter valid domain bounds.")
    return None, None
  if start > len(sequence) or end > len(sequence):
    raise Exception("Domain bounds exceed sequence length.")
    return None, None

def check_n_input(n):
  if n < 1:
    raise Exception("Choose N>=1")
    return None, None, None
            
def predict_positionwise_mutations(sequence, domain_bounds, n, model, tokenizer, device):
  # Define amino acids and their token indices 
  AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
  AAs_tokens_indices = {'L' : 4, 'A' : 5, 'G' : 6, 'V': 7, 'S' : 8, 'E' : 9, 'R' : 10, 'T' : 11, 'I': 12, 'D' : 13, 'P' : 14,
                        'K' : 15, 'Q' : 16, 'N' : 17, 'F' : 18, 'Y' : 19, 'M' : 20, 'H' : 21, 'W' : 22, 'C' : 23}
  
  # checking all inputs for validity
  log_update("\nChecking validity of sequence input, domain bounds, and N mutations")
  check_seq_inputs(sequence, AAs_tokens)
  start, end = check_domain_bounds(domain_bounds)
  check_n_input(n)
  
  # define start_index as start - 1 (because residues are 1-indexed, while Python is 0-indexed). end is same
  start_index = start - 1
  end_index = end

  # place to store top n mutations and all logits
  top_n_mutations = {}
  top_n_advantage_mutations = {}
  top_n_disadvantage_mutations = {}
  logits_for_each_AA = []
  llrs_for_each_AA = []

  # storage for the conservation heatmap
  originals_logits = []
  conservation_likelihoods = {}

  log_update("\nCalculating mutations. Printing currently masked position and mutation results.")
  for i in range(len(sequence)):
    # only iterate through the residues inside the domain
        if start_index <= i <= (end_index - 1):
            # isolate original residue and its index
            original_residue = sequence[i]
            original_residue_index = AAs_tokens_indices[original_residue]
            masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
            
            # prepare log
            masked_seq_list = list(sequence[:i]) + ['<mask>'] + list(sequence[i+1:])
            log_starti = i-min(5, i)
            log_endi = i+5
            log_update(f"\t{i+1}: residue = {original_residue}, masked sequence preview (pos {log_starti+1}-{log_endi}) = {''.join(masked_seq_list[log_starti:log_endi])}")
            # prepare inputs
            inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=len(masked_seq)+2)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            # forward pass 
            with torch.no_grad():
                logits = model(**inputs).logits
                
            # Find masked positions and extract their logits 
            mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
            mask_token_logits = logits[0, mask_token_index, :]    # shape: [1, vocab_size] == [1, 33]. logits for each vocab word at this position in the sequence

            # Collect logits for the full heamtap
            logits_array = mask_token_logits.cpu().numpy()  # shape: [1, 33]
            # filter out non-amino acid tokens
            filtered_indices = list(range(4, 23 + 1))   # filtered indices are indices of amino acids
            filtered_logits = logits_array[:, filtered_indices] # shape: [1, 20]  only the 20 amino acids
            logits_for_each_AA.append(filtered_logits)        # get logits for each amino acid
            
            # Collect LLRs for the LLR heatmap
            log_probabilities = F.log_softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).squeeze(0)  # take log softmax of the [33] dimension
            log_prob_og = log_probabilities[original_residue_index] # get the log probability of the TRUE residue underneath the mask
            llrs = torch.tensor([(x-log_prob_og) for x in log_probabilities]) # calculate the LLR 
            #print(original_residue_index, llrs)
            filtered_llrs = llrs[filtered_indices].numpy()   # filter so it's [20], just the amino acids; only save this
            filtered_llrs_array = np.array([filtered_llrs])
            llrs_for_each_AA.append(filtered_llrs_array)
            
            ######### Top tokens
            # Get top tokens based on LOGITS 
            all_tokens_logits = mask_token_logits.squeeze(0) # shape: [vocab_size] == [33]
            top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True) # sort the logits
            mutation = []
            # make sure we don't include non-AA tokens
            for token_index in top_tokens_indices:
                decoded_token = tokenizer.decode([token_index.item()])
                # decoded all tokens, pick the top n amino acid ones
                if decoded_token in AAs_tokens:
                    mutation.append(decoded_token)
                    if len(mutation) == n:
                        break
            top_n_mutations[(sequence[i], i)] = mutation
            log_update(f"\t\ttop {n} predicted AAs: {','.join(mutation)}")
            
            # Get top tokens based on LLR 
            top_advantage_tokens_indices = torch.argsort(llrs, dim=0, descending=True) # sort the LLRs
            advantage_mutation = []
            # make sure we don't include non-AA tokens
            for token_index in top_advantage_tokens_indices:
                decoded_token = tokenizer.decode([token_index.item()])
                # decoded all tokens, pick the top n amino acid ones
                if decoded_token in AAs_tokens:
                    advantage_mutation.append(decoded_token)
                    if len(advantage_mutation) == n:
                        break
            top_n_advantage_mutations[(sequence[i], i)] = advantage_mutation
            log_update(f"\t\ttop {n} predicted advantageous mutations: {','.join(advantage_mutation)}")   
            
            # Get top tokens based on LLR 
            top_disadvantage_tokens_indices = torch.argsort(llrs, dim=0, descending=False) # sort the LLRs
            disadvantage_mutation = []
            # make sure we don't include non-AA tokens
            for token_index in top_disadvantage_tokens_indices:
                decoded_token = tokenizer.decode([token_index.item()])
                # decoded all tokens, pick the top n amino acid ones
                if decoded_token in AAs_tokens:
                    disadvantage_mutation.append(decoded_token)
                    if len(disadvantage_mutation) == n:
                        break
            top_n_disadvantage_mutations[(sequence[i], i)] = disadvantage_mutation
            log_update(f"\t\ttop {n} predicted disadvantageous mutations: {','.join(disadvantage_mutation)}")       
            
            # fill in the logits and conservation likelihoods for the second array
            normalized_mask_token_logits = F.softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).numpy()
            normalized_mask_token_logits = np.squeeze(normalized_mask_token_logits)
            originals_logit = normalized_mask_token_logits[original_residue_index]
            originals_logits.append(originals_logit)

            # a region is conserved if the probability of the amino acid is > 0.7; AKA, probability of a mutation is <= 0.3
            if originals_logit > 0.7:
                conservation_likelihoods[(original_residue, i)] = 1
                log_update("\t\tConserved position")
            else:
                conservation_likelihoods[(original_residue, i)] = 0
                log_update("\t\tNot conserved position")
                
  # return a dictionary of all the things we need for the next part 
  return {
    'start': start,
    'end': end,
    'originals_logits': originals_logits,
    'conservation_likelihoods': conservation_likelihoods,
    'logits': logits,
    'filtered_indices': filtered_indices,
    'top_n_mutations': top_n_mutations,
    'top_n_advantage_mutations': top_n_advantage_mutations,
    'top_n_disadvantage_mutations': top_n_disadvantage_mutations,
    'logits_for_each_AA': logits_for_each_AA,
    'llrs_for_each_AA': llrs_for_each_AA
  }

def find_top_3(d):
     temp = pd.DataFrame.from_dict(d, orient='index').reset_index()
     temp = temp.sort_values(by=0,ascending=False).reset_index(drop=True)
     temp = temp.iloc[0:3,:]
     return_d = dict(zip(temp['index'],temp[0]))
     return return_d
    
def make_full_results_df(mutation_results, tokenizer, original_sequence):
  # Unpack mutation results
  logits = mutation_results['logits']
  logits_for_each_AA = mutation_results['logits_for_each_AA']
  filtered_indices = mutation_results['filtered_indices']
  llrs_for_each_AA = mutation_results['llrs_for_each_AA']
  
  token_indices = torch.arange(logits.size(-1))
  tokens = [tokenizer.decode([idx]) for idx in token_indices]
  filtered_tokens = [tokens[i] for i in filtered_indices]
  all_logits_array = np.vstack(logits_for_each_AA)
  normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
  transposed_logits_array = normalized_logits_array.T
  
  df = pd.DataFrame(transposed_logits_array.T)
  df.columns = filtered_tokens
  df.index = list(range(1, len(df)+1))
  df['all_logits'] = df[filtered_tokens].to_dict(orient='index')
  df['top_3_mutations'] = df['all_logits'].apply(lambda x: find_top_3(x))
  df['original_residue'] = list(original_sequence)
  df['original_residue_logit'] = df.apply(lambda row: row['all_logits'][row['original_residue']],axis=1)
  df = df[['original_residue','original_residue_logit','all_logits','top_3_mutations']]
  df = df.reset_index().rename(columns={'index':'Residue'})
  return df

def make_small_results_df(mutation_results):
  conservation_likelihoods = mutation_results['conservation_likelihoods']
  top_n_mutations = mutation_results['top_n_mutations']
  
  # store the predicted mutations in a dataframe
  original_residues = []
  mutations = []
  positions = []
  conserved = []

  for key, value in top_n_mutations.items():
      original_residue, position = key
      original_residues.append(original_residue)
      mutations.append(','.join(value))
      positions.append(position + 1)
  
  for i, (key, value) in enumerate(conservation_likelihoods.items()):
      original_residue, position = key
      if original_residues[i]==original_residue:  # it should, otherwise something is wrong
        conserved.append(value)

  df = pd.DataFrame({
      'Original Residue': original_residues,
      'Predicted Residues': mutations,
      'Conserved': conserved,
      'Position': positions
  })
  return df
  

def main():
  # Make results directory
  os.makedirs('results',exist_ok=True)
  output_dir = f'results/{get_local_time()}'
  os.makedirs(output_dir,exist_ok=True)
  
  # Predict mutations, writing results to a log inside of the output directory
  with open_logfile(f"{output_dir}/mutation_discovery_log.txt"):
      print_configpy(config)
      # Make sure environment variables are set correctly
      check_env_variables()
  
      # Get device
      device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
      log_update(f"Using device: {device}")
      
      # Load fuson as AutoModelForMaskedLM
      fuson_ckpt_path = config.FUSON_PLM_CKPT
      if fuson_ckpt_path=="FusOn-pLM":
          fuson_ckpt_path="../../../.."
          model_name = "fuson_plm"
          model_epoch = "best"
          model_str = f"fuson_plm/best"
      else:
          model_name = list(fuson_ckpt_path.keys())[0]
          epoch = list(fuson_ckpt_path.values())[0]
          fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}'
          model_name, model_epoch = fuson_ckpt_path.split('/')[-2::]
          model_epoch = model_epoch.split('checkpoint_')[-1]
          model_str = f"{model_name}/{model_epoch}"
            
      log_update(f"\nLoading FusOn-pLM model from {fuson_ckpt_path}")
      fuson_tokenizer = AutoTokenizer.from_pretrained(fuson_ckpt_path)
      fuson_model = AutoModelForMaskedLM.from_pretrained(fuson_ckpt_path)
      fuson_model.to(device)
      fuson_model.eval()
      
      if (config.PATH_TO_INPUT_FILE is not None) and os.path.exists(config.PATH_TO_INPUT_FILE):
        input_file = pd.read_csv(config.PATH_TO_INPUT_FILE)
      else:
        input_file = pd.DataFrame(
          data={
            'fusion_name': [config.FUSION_NAME],
            'full_fusion_sequence': [config.FULL_FUSION_SEQUENCE],
            'start_residue_index': [config.START_RESIDUE_INDEX],
            'end_residue_index': [config.END_RESIDUE_INDEX],
            'n': [config.N]
          }
        )
        
      log_update(f"\nThere are {len(input_file)} total sequences on which to perform mutation discovery. Fusion Genes:")
      log_update("\t" + "\n\t".join(input_file['fusion_name']))
      # Loop through each input and make a subfolder with its data
      for i in range(len(input_file)):
        row = input_file.loc[i,:]
        fusion_name = row['fusion_name']
        full_fusion_sequence = row['full_fusion_sequence']
        start_residue_index = row['start_residue_index']
        end_residue_index = row['end_residue_index']
        n = row['n']
        sub_output_dir = f"{output_dir}/{fusion_name}"
        os.makedirs(sub_output_dir,exist_ok=True)
        
        # Predict postionwise mutations, plot the results
        domain_bounds = {'start': start_residue_index, 'end': end_residue_index}
        mutation_results = predict_positionwise_mutations(full_fusion_sequence, domain_bounds, n,
                                                          fuson_model, fuson_tokenizer, device)
        
        # Save mutation results
        with open(f"{sub_output_dir}/raw_mutation_results.pkl", "wb") as f:
          pickle.dump(mutation_results, f)
        
        # Plot the heatmaps
        plot_full_heatmap(mutation_results, fuson_tokenizer,
                          fusion_name=fusion_name, save_path=f"{sub_output_dir}/full_heatmap.png")
        plot_conservation_heatmap(mutation_results,
                                  fusion_name=fusion_name, save_path=f"{sub_output_dir}/conservation_heatmap.png")

        # Make results dataframe
        small_mutation_results_df = make_small_results_df(mutation_results)
        small_mutation_results_df.to_csv(f"{sub_output_dir}/predicted_tokens.csv",index=False)
        full_mutation_results_df = make_full_results_df(mutation_results, fuson_tokenizer, full_fusion_sequence)
        full_mutation_results_df.to_csv(f"{sub_output_dir}/full_results_with_logits.csv",index=False)
        
if __name__ == "__main__":
    main()