import torch import matplotlib.pyplot as plt import seaborn as sns import torch.nn.functional as F import numpy as np import os import pandas as pd import pickle from transformers import AutoTokenizer from fuson_plm.utils.visualizing import set_font import fuson_plm.benchmarking.mutation_prediction.discovery.config as config def get_x_tick_labels(start, end): # Define start and end index which we actually use to index the sequence start_index = start - 1 end_index = end # Define domain length domain_len = end - start if 500 > domain_len > 100: step_size = 50 elif 500 <= domain_len: step_size = 100 elif domain_len < 10: step_size = 1 else: step_size = 10 # Define x tick positions based on step size x_tick_positions = np.arange(start_index, end_index, step_size) x_tick_labels = [str(pos + 1) for pos in x_tick_positions] return x_tick_positions, x_tick_labels def plot_conservation_heatmap(mutation_results, fusion_name="Fusion Oncoprotein", save_path="conservation_heatmap.png"): start = mutation_results['start'] end = mutation_results['end'] originals_logits = mutation_results['originals_logits'] conservation_likelihoods = mutation_results['conservation_likelihoods'] logits = mutation_results['logits'] logits_for_each_AA = mutation_results['logits_for_each_AA'] filtered_indices = mutation_results['filtered_indices'] top_n_mutations = mutation_results['top_n_mutations'] # Get start index and end index start_index = start - 1 end_index = end # Make conservation likelihoods array for plotting all_logits_array = np.vstack(originals_logits) transposed_logits_array = all_logits_array.T conservation_likelihoods_array = np.array(list(conservation_likelihoods.values())).reshape(1, -1) # combine to make a 2D heatmap combined_array = np.vstack((transposed_logits_array, conservation_likelihoods_array)) # Get ticks x_tick_positions, x_tick_labels = get_x_tick_labels(start, end) # Plot! set_font() # Adjust the figure size: constant height (e.g., 3) and width proportional to sequence length sequence_length = end_index - start_index fig = plt.figure(figsize=(min(15, sequence_length / 10), 3)) # Adjust width dynamically, keep height constant #plt.rcParams.update({'font.size': 16.5}) # make font size bigger ax = sns.heatmap( combined_array, cmap='viridis', xticklabels=x_tick_labels, yticklabels=['Original Logits', 'Conserved'], cbar=True, cbar_kws={'aspect': 2, 'pad': 0.02, 'shrink': 1.0, # Adjust the overall size of the color bar } ) # Access the color bar cbar = ax.collections[0].colorbar # Change the font size of the tick labels on the color bar cbar.ax.tick_params(labelsize=20) # Adjust the font size of tick labels plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=90, fontsize=20) plt.yticks(fontsize=20, rotation=0) plt.title(f'{fusion_name} Residues {start}-{end}', fontsize=30) plt.xlabel('Residue Index', fontsize=30) plt.tight_layout() plt.show() # save the figure plt.savefig(save_path, format='png', dpi=300) # plotting heatmap 1 def plot_full_heatmap(mutation_results, tokenizer, fusion_name="Fusion Oncoprotein", save_path="full_heatmap.png"): start = mutation_results['start'] end = mutation_results['end'] logits = mutation_results['logits'] logits_for_each_AA = mutation_results['logits_for_each_AA'] filtered_indices = mutation_results['filtered_indices'] # get start and end index start_index = start - 1 end_index = end # prepare data for plotting token_indices = torch.arange(logits.size(-1)) tokens = [tokenizer.decode([idx]) for idx in token_indices] filtered_tokens = [tokens[i] for i in filtered_indices] all_logits_array = np.vstack(logits_for_each_AA) normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy() transposed_logits_array = normalized_logits_array.T # get x tick labels x_tick_positions, x_tick_labels = get_x_tick_labels(start, end) # make plot set_font() fig = plt.figure(figsize=(15, 8)) plt.rcParams.update({'font.size': 16.5}) sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens) plt.title(f'{fusion_name} Residues {start}-{end}: Token Probability') plt.ylabel('Amino Acid') plt.xlabel('Residue Index') plt.yticks(rotation=0) plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0) plt.tight_layout() plt.savefig(save_path, format='png', dpi = 300) def plot_color_bar(): """ Create a Viridis color bar ranging from 0 to 1. """ # Create a gradient from 0 to 1 gradient = np.linspace(0, 1, 256).reshape(1, -1) # Plot the gradient as a color bar fig, ax = plt.subplots(figsize=(12, 3)) ax.imshow(gradient, aspect="auto", cmap="viridis") ax.set_xticks([0, 255]) ax.set_xticklabels(["0\nmost likely\nto mutate", "1\nleast likely\nto mutate"], fontsize=40) ax.set_yticks([]) ax.set_title("Original Residue Logits", fontsize=40) # Save the figure plt.tight_layout() plt.show() plt.savefig("viridis_color_bar.png", dpi=300) def main(): # Call the function to create and display the color bar plot_color_bar() results_dir = "results/final" subfolders = os.listdir(results_dir) for subfolder in subfolders: full_path = f"{results_dir}/{subfolder}" if os.path.isdir(full_path): with open(f"{full_path}/raw_mutation_results.pkl", "rb") as f: mutation_results = pickle.load(f) plot_conservation_heatmap(mutation_results, fusion_name=subfolder, save_path=f"{full_path}/conservation_heatmap.png") if __name__ == "__main__": main()