File size: 16,571 Bytes
3efa812 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 |
##### Discover mutations in new sequences. A tool
import fuson_plm.benchmarking.mutation_prediction.discovery.config as config
import os
import pickle
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
import pandas as pd
import numpy as np
import transformers
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn.functional as F
from fuson_plm.utils.logging import open_logfile, log_update, get_local_time, print_configpy
from fuson_plm.utils.embedding import load_esm2_type
from fuson_plm.utils.visualizing import set_font
from fuson_plm.benchmarking.mutation_prediction.recovery.recover import check_env_variables, predict_positionwise_mutations
from fuson_plm.benchmarking.mutation_prediction.discovery.plot import plot_conservation_heatmap, plot_full_heatmap
def check_seq_inputs(sequence, AAs_tokens):
# checking sequence inputs for validity
if not sequence.strip():
raise Exception("Error: The sequence input is empty. Please enter a valid protein sequence.")
return None, None, None
if any(char not in AAs_tokens for char in sequence):
raise Exception("Error: The sequence input contains non-amino acid characters. Please enter a valid protein sequence.")
return None, None, None
def check_domain_bounds(domain_bounds):
try:
start = int(domain_bounds['start'])
end = int(domain_bounds['end'])
return start, end
except ValueError:
raise Exception("Error: Start and end indices must be integers.")
return None, None
if start >= end:
raise Exception("Start index must be smaller than end index.")
return None, None
if start == 0 and end != 0:
raise Exception("Indexing starts at 1. Please enter valid domain bounds.")
return None, None
if start <= 0 or end <= 0:
raise Exception("Domain bounds must be positive integers. Please enter valid domain bounds.")
return None, None
if start > len(sequence) or end > len(sequence):
raise Exception("Domain bounds exceed sequence length.")
return None, None
def check_n_input(n):
if n < 1:
raise Exception("Choose N>=1")
return None, None, None
def predict_positionwise_mutations(sequence, domain_bounds, n, model, tokenizer, device):
# Define amino acids and their token indices
AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
AAs_tokens_indices = {'L' : 4, 'A' : 5, 'G' : 6, 'V': 7, 'S' : 8, 'E' : 9, 'R' : 10, 'T' : 11, 'I': 12, 'D' : 13, 'P' : 14,
'K' : 15, 'Q' : 16, 'N' : 17, 'F' : 18, 'Y' : 19, 'M' : 20, 'H' : 21, 'W' : 22, 'C' : 23}
# checking all inputs for validity
log_update("\nChecking validity of sequence input, domain bounds, and N mutations")
check_seq_inputs(sequence, AAs_tokens)
start, end = check_domain_bounds(domain_bounds)
check_n_input(n)
# define start_index as start - 1 (because residues are 1-indexed, while Python is 0-indexed). end is same
start_index = start - 1
end_index = end
# place to store top n mutations and all logits
top_n_mutations = {}
top_n_advantage_mutations = {}
top_n_disadvantage_mutations = {}
logits_for_each_AA = []
llrs_for_each_AA = []
# storage for the conservation heatmap
originals_logits = []
conservation_likelihoods = {}
log_update("\nCalculating mutations. Printing currently masked position and mutation results.")
for i in range(len(sequence)):
# only iterate through the residues inside the domain
if start_index <= i <= (end_index - 1):
# isolate original residue and its index
original_residue = sequence[i]
original_residue_index = AAs_tokens_indices[original_residue]
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
# prepare log
masked_seq_list = list(sequence[:i]) + ['<mask>'] + list(sequence[i+1:])
log_starti = i-min(5, i)
log_endi = i+5
log_update(f"\t{i+1}: residue = {original_residue}, masked sequence preview (pos {log_starti+1}-{log_endi}) = {''.join(masked_seq_list[log_starti:log_endi])}")
# prepare inputs
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=len(masked_seq)+2)
inputs = {k: v.to(device) for k, v in inputs.items()}
# forward pass
with torch.no_grad():
logits = model(**inputs).logits
# Find masked positions and extract their logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = logits[0, mask_token_index, :] # shape: [1, vocab_size] == [1, 33]. logits for each vocab word at this position in the sequence
# Collect logits for the full heamtap
logits_array = mask_token_logits.cpu().numpy() # shape: [1, 33]
# filter out non-amino acid tokens
filtered_indices = list(range(4, 23 + 1)) # filtered indices are indices of amino acids
filtered_logits = logits_array[:, filtered_indices] # shape: [1, 20] only the 20 amino acids
logits_for_each_AA.append(filtered_logits) # get logits for each amino acid
# Collect LLRs for the LLR heatmap
log_probabilities = F.log_softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).squeeze(0) # take log softmax of the [33] dimension
log_prob_og = log_probabilities[original_residue_index] # get the log probability of the TRUE residue underneath the mask
llrs = torch.tensor([(x-log_prob_og) for x in log_probabilities]) # calculate the LLR
#print(original_residue_index, llrs)
filtered_llrs = llrs[filtered_indices].numpy() # filter so it's [20], just the amino acids; only save this
filtered_llrs_array = np.array([filtered_llrs])
llrs_for_each_AA.append(filtered_llrs_array)
######### Top tokens
# Get top tokens based on LOGITS
all_tokens_logits = mask_token_logits.squeeze(0) # shape: [vocab_size] == [33]
top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True) # sort the logits
mutation = []
# make sure we don't include non-AA tokens
for token_index in top_tokens_indices:
decoded_token = tokenizer.decode([token_index.item()])
# decoded all tokens, pick the top n amino acid ones
if decoded_token in AAs_tokens:
mutation.append(decoded_token)
if len(mutation) == n:
break
top_n_mutations[(sequence[i], i)] = mutation
log_update(f"\t\ttop {n} predicted AAs: {','.join(mutation)}")
# Get top tokens based on LLR
top_advantage_tokens_indices = torch.argsort(llrs, dim=0, descending=True) # sort the LLRs
advantage_mutation = []
# make sure we don't include non-AA tokens
for token_index in top_advantage_tokens_indices:
decoded_token = tokenizer.decode([token_index.item()])
# decoded all tokens, pick the top n amino acid ones
if decoded_token in AAs_tokens:
advantage_mutation.append(decoded_token)
if len(advantage_mutation) == n:
break
top_n_advantage_mutations[(sequence[i], i)] = advantage_mutation
log_update(f"\t\ttop {n} predicted advantageous mutations: {','.join(advantage_mutation)}")
# Get top tokens based on LLR
top_disadvantage_tokens_indices = torch.argsort(llrs, dim=0, descending=False) # sort the LLRs
disadvantage_mutation = []
# make sure we don't include non-AA tokens
for token_index in top_disadvantage_tokens_indices:
decoded_token = tokenizer.decode([token_index.item()])
# decoded all tokens, pick the top n amino acid ones
if decoded_token in AAs_tokens:
disadvantage_mutation.append(decoded_token)
if len(disadvantage_mutation) == n:
break
top_n_disadvantage_mutations[(sequence[i], i)] = disadvantage_mutation
log_update(f"\t\ttop {n} predicted disadvantageous mutations: {','.join(disadvantage_mutation)}")
# fill in the logits and conservation likelihoods for the second array
normalized_mask_token_logits = F.softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).numpy()
normalized_mask_token_logits = np.squeeze(normalized_mask_token_logits)
originals_logit = normalized_mask_token_logits[original_residue_index]
originals_logits.append(originals_logit)
# a region is conserved if the probability of the amino acid is > 0.7; AKA, probability of a mutation is <= 0.3
if originals_logit > 0.7:
conservation_likelihoods[(original_residue, i)] = 1
log_update("\t\tConserved position")
else:
conservation_likelihoods[(original_residue, i)] = 0
log_update("\t\tNot conserved position")
# return a dictionary of all the things we need for the next part
return {
'start': start,
'end': end,
'originals_logits': originals_logits,
'conservation_likelihoods': conservation_likelihoods,
'logits': logits,
'filtered_indices': filtered_indices,
'top_n_mutations': top_n_mutations,
'top_n_advantage_mutations': top_n_advantage_mutations,
'top_n_disadvantage_mutations': top_n_disadvantage_mutations,
'logits_for_each_AA': logits_for_each_AA,
'llrs_for_each_AA': llrs_for_each_AA
}
def find_top_3(d):
temp = pd.DataFrame.from_dict(d, orient='index').reset_index()
temp = temp.sort_values(by=0,ascending=False).reset_index(drop=True)
temp = temp.iloc[0:3,:]
return_d = dict(zip(temp['index'],temp[0]))
return return_d
def make_full_results_df(mutation_results, tokenizer, original_sequence):
# Unpack mutation results
logits = mutation_results['logits']
logits_for_each_AA = mutation_results['logits_for_each_AA']
filtered_indices = mutation_results['filtered_indices']
llrs_for_each_AA = mutation_results['llrs_for_each_AA']
token_indices = torch.arange(logits.size(-1))
tokens = [tokenizer.decode([idx]) for idx in token_indices]
filtered_tokens = [tokens[i] for i in filtered_indices]
all_logits_array = np.vstack(logits_for_each_AA)
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
transposed_logits_array = normalized_logits_array.T
df = pd.DataFrame(transposed_logits_array.T)
df.columns = filtered_tokens
df.index = list(range(1, len(df)+1))
df['all_logits'] = df[filtered_tokens].to_dict(orient='index')
df['top_3_mutations'] = df['all_logits'].apply(lambda x: find_top_3(x))
df['original_residue'] = list(original_sequence)
df['original_residue_logit'] = df.apply(lambda row: row['all_logits'][row['original_residue']],axis=1)
df = df[['original_residue','original_residue_logit','all_logits','top_3_mutations']]
df = df.reset_index().rename(columns={'index':'Residue'})
return df
def make_small_results_df(mutation_results):
conservation_likelihoods = mutation_results['conservation_likelihoods']
top_n_mutations = mutation_results['top_n_mutations']
# store the predicted mutations in a dataframe
original_residues = []
mutations = []
positions = []
conserved = []
for key, value in top_n_mutations.items():
original_residue, position = key
original_residues.append(original_residue)
mutations.append(','.join(value))
positions.append(position + 1)
for i, (key, value) in enumerate(conservation_likelihoods.items()):
original_residue, position = key
if original_residues[i]==original_residue: # it should, otherwise something is wrong
conserved.append(value)
df = pd.DataFrame({
'Original Residue': original_residues,
'Predicted Residues': mutations,
'Conserved': conserved,
'Position': positions
})
return df
def main():
# Make results directory
os.makedirs('results',exist_ok=True)
output_dir = f'results/{get_local_time()}'
os.makedirs(output_dir,exist_ok=True)
# Predict mutations, writing results to a log inside of the output directory
with open_logfile(f"{output_dir}/mutation_discovery_log.txt"):
print_configpy(config)
# Make sure environment variables are set correctly
check_env_variables()
# Get device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log_update(f"Using device: {device}")
# Load fuson as AutoModelForMaskedLM
fuson_ckpt_path = config.FUSON_PLM_CKPT
if fuson_ckpt_path=="FusOn-pLM":
fuson_ckpt_path="../../../.."
model_name = "fuson_plm"
model_epoch = "best"
model_str = f"fuson_plm/best"
else:
model_name = list(fuson_ckpt_path.keys())[0]
epoch = list(fuson_ckpt_path.values())[0]
fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}'
model_name, model_epoch = fuson_ckpt_path.split('/')[-2::]
model_epoch = model_epoch.split('checkpoint_')[-1]
model_str = f"{model_name}/{model_epoch}"
log_update(f"\nLoading FusOn-pLM model from {fuson_ckpt_path}")
fuson_tokenizer = AutoTokenizer.from_pretrained(fuson_ckpt_path)
fuson_model = AutoModelForMaskedLM.from_pretrained(fuson_ckpt_path)
fuson_model.to(device)
fuson_model.eval()
if (config.PATH_TO_INPUT_FILE is not None) and os.path.exists(config.PATH_TO_INPUT_FILE):
input_file = pd.read_csv(config.PATH_TO_INPUT_FILE)
else:
input_file = pd.DataFrame(
data={
'fusion_name': [config.FUSION_NAME],
'full_fusion_sequence': [config.FULL_FUSION_SEQUENCE],
'start_residue_index': [config.START_RESIDUE_INDEX],
'end_residue_index': [config.END_RESIDUE_INDEX],
'n': [config.N]
}
)
log_update(f"\nThere are {len(input_file)} total sequences on which to perform mutation discovery. Fusion Genes:")
log_update("\t" + "\n\t".join(input_file['fusion_name']))
# Loop through each input and make a subfolder with its data
for i in range(len(input_file)):
row = input_file.loc[i,:]
fusion_name = row['fusion_name']
full_fusion_sequence = row['full_fusion_sequence']
start_residue_index = row['start_residue_index']
end_residue_index = row['end_residue_index']
n = row['n']
sub_output_dir = f"{output_dir}/{fusion_name}"
os.makedirs(sub_output_dir,exist_ok=True)
# Predict postionwise mutations, plot the results
domain_bounds = {'start': start_residue_index, 'end': end_residue_index}
mutation_results = predict_positionwise_mutations(full_fusion_sequence, domain_bounds, n,
fuson_model, fuson_tokenizer, device)
# Save mutation results
with open(f"{sub_output_dir}/raw_mutation_results.pkl", "wb") as f:
pickle.dump(mutation_results, f)
# Plot the heatmaps
plot_full_heatmap(mutation_results, fuson_tokenizer,
fusion_name=fusion_name, save_path=f"{sub_output_dir}/full_heatmap.png")
plot_conservation_heatmap(mutation_results,
fusion_name=fusion_name, save_path=f"{sub_output_dir}/conservation_heatmap.png")
# Make results dataframe
small_mutation_results_df = make_small_results_df(mutation_results)
small_mutation_results_df.to_csv(f"{sub_output_dir}/predicted_tokens.csv",index=False)
full_mutation_results_df = make_full_results_df(mutation_results, fuson_tokenizer, full_fusion_sequence)
full_mutation_results_df.to_csv(f"{sub_output_dir}/full_results_with_logits.csv",index=False)
if __name__ == "__main__":
main() |