File size: 6,056 Bytes
3efa812 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn.functional as F
import numpy as np
import os
import pandas as pd
import pickle
from transformers import AutoTokenizer
from fuson_plm.utils.visualizing import set_font
import fuson_plm.benchmarking.mutation_prediction.discovery.config as config
def get_x_tick_labels(start, end):
# Define start and end index which we actually use to index the sequence
start_index = start - 1
end_index = end
# Define domain length
domain_len = end - start
if 500 > domain_len > 100:
step_size = 50
elif 500 <= domain_len:
step_size = 100
elif domain_len < 10:
step_size = 1
else:
step_size = 10
# Define x tick positions based on step size
x_tick_positions = np.arange(start_index, end_index, step_size)
x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
return x_tick_positions, x_tick_labels
def plot_conservation_heatmap(mutation_results, fusion_name="Fusion Oncoprotein", save_path="conservation_heatmap.png"):
start = mutation_results['start']
end = mutation_results['end']
originals_logits = mutation_results['originals_logits']
conservation_likelihoods = mutation_results['conservation_likelihoods']
logits = mutation_results['logits']
logits_for_each_AA = mutation_results['logits_for_each_AA']
filtered_indices = mutation_results['filtered_indices']
top_n_mutations = mutation_results['top_n_mutations']
# Get start index and end index
start_index = start - 1
end_index = end
# Make conservation likelihoods array for plotting
all_logits_array = np.vstack(originals_logits)
transposed_logits_array = all_logits_array.T
conservation_likelihoods_array = np.array(list(conservation_likelihoods.values())).reshape(1, -1)
# combine to make a 2D heatmap
combined_array = np.vstack((transposed_logits_array, conservation_likelihoods_array))
# Get ticks
x_tick_positions, x_tick_labels = get_x_tick_labels(start, end)
# Plot!
set_font()
# Adjust the figure size: constant height (e.g., 3) and width proportional to sequence length
sequence_length = end_index - start_index
fig = plt.figure(figsize=(min(15, sequence_length / 10), 3)) # Adjust width dynamically, keep height constant
#plt.rcParams.update({'font.size': 16.5}) # make font size bigger
ax = sns.heatmap(
combined_array,
cmap='viridis',
xticklabels=x_tick_labels,
yticklabels=['Original Logits', 'Conserved'],
cbar=True,
cbar_kws={'aspect': 2,
'pad': 0.02,
'shrink': 1.0, # Adjust the overall size of the color bar
}
)
# Access the color bar
cbar = ax.collections[0].colorbar
# Change the font size of the tick labels on the color bar
cbar.ax.tick_params(labelsize=20) # Adjust the font size of tick labels
plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=90, fontsize=20)
plt.yticks(fontsize=20, rotation=0)
plt.title(f'{fusion_name} Residues {start}-{end}', fontsize=30)
plt.xlabel('Residue Index', fontsize=30)
plt.tight_layout()
plt.show()
# save the figure
plt.savefig(save_path, format='png', dpi=300)
# plotting heatmap 1
def plot_full_heatmap(mutation_results, tokenizer, fusion_name="Fusion Oncoprotein", save_path="full_heatmap.png"):
start = mutation_results['start']
end = mutation_results['end']
logits = mutation_results['logits']
logits_for_each_AA = mutation_results['logits_for_each_AA']
filtered_indices = mutation_results['filtered_indices']
# get start and end index
start_index = start - 1
end_index = end
# prepare data for plotting
token_indices = torch.arange(logits.size(-1))
tokens = [tokenizer.decode([idx]) for idx in token_indices]
filtered_tokens = [tokens[i] for i in filtered_indices]
all_logits_array = np.vstack(logits_for_each_AA)
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
transposed_logits_array = normalized_logits_array.T
# get x tick labels
x_tick_positions, x_tick_labels = get_x_tick_labels(start, end)
# make plot
set_font()
fig = plt.figure(figsize=(15, 8))
plt.rcParams.update({'font.size': 16.5})
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
plt.title(f'{fusion_name} Residues {start}-{end}: Token Probability')
plt.ylabel('Amino Acid')
plt.xlabel('Residue Index')
plt.yticks(rotation=0)
plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0)
plt.tight_layout()
plt.savefig(save_path, format='png', dpi = 300)
def plot_color_bar():
"""
Create a Viridis color bar ranging from 0 to 1.
"""
# Create a gradient from 0 to 1
gradient = np.linspace(0, 1, 256).reshape(1, -1)
# Plot the gradient as a color bar
fig, ax = plt.subplots(figsize=(12, 3))
ax.imshow(gradient, aspect="auto", cmap="viridis")
ax.set_xticks([0, 255])
ax.set_xticklabels(["0\nmost likely\nto mutate", "1\nleast likely\nto mutate"], fontsize=40)
ax.set_yticks([])
ax.set_title("Original Residue Logits", fontsize=40)
# Save the figure
plt.tight_layout()
plt.show()
plt.savefig("viridis_color_bar.png", dpi=300)
def main():
# Call the function to create and display the color bar
plot_color_bar()
results_dir = "results/final"
subfolders = os.listdir(results_dir)
for subfolder in subfolders:
full_path = f"{results_dir}/{subfolder}"
if os.path.isdir(full_path):
with open(f"{full_path}/raw_mutation_results.pkl", "rb") as f:
mutation_results = pickle.load(f)
plot_conservation_heatmap(mutation_results,
fusion_name=subfolder, save_path=f"{full_path}/conservation_heatmap.png")
if __name__ == "__main__":
main() |