Fill-Mask
Transformers
Safetensors
esm
File size: 6,056 Bytes
3efa812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn.functional as F
import numpy as np
import os
import pandas as pd
import pickle
from transformers import AutoTokenizer
from fuson_plm.utils.visualizing import set_font
import fuson_plm.benchmarking.mutation_prediction.discovery.config as config

def get_x_tick_labels(start, end):
  # Define start and end index which we actually use to index the sequence
  start_index = start - 1
  end_index = end
  
  # Define domain length
  domain_len = end - start
  if 500 > domain_len > 100:
    step_size = 50
  elif 500 <= domain_len:
    step_size = 100
  elif domain_len < 10:
    step_size = 1
  else:
    step_size = 10
    
  # Define x tick positions based on step size
  x_tick_positions = np.arange(start_index, end_index, step_size)
  x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
  
  return x_tick_positions, x_tick_labels

    
def plot_conservation_heatmap(mutation_results, fusion_name="Fusion Oncoprotein", save_path="conservation_heatmap.png"):
    start = mutation_results['start']
    end = mutation_results['end']
    originals_logits = mutation_results['originals_logits']
    conservation_likelihoods = mutation_results['conservation_likelihoods']
    logits = mutation_results['logits']
    logits_for_each_AA = mutation_results['logits_for_each_AA']
    filtered_indices = mutation_results['filtered_indices']
    top_n_mutations = mutation_results['top_n_mutations']
    
    # Get start index and end index
    start_index = start - 1
    end_index = end

    # Make conservation likelihoods array for plotting 
    all_logits_array = np.vstack(originals_logits)
    transposed_logits_array = all_logits_array.T
    conservation_likelihoods_array = np.array(list(conservation_likelihoods.values())).reshape(1, -1)
    # combine to make a 2D heatmap
    combined_array = np.vstack((transposed_logits_array, conservation_likelihoods_array))

    # Get ticks
    x_tick_positions, x_tick_labels = get_x_tick_labels(start, end)
    
    # Plot!
    set_font()
    # Adjust the figure size: constant height (e.g., 3) and width proportional to sequence length
    sequence_length = end_index - start_index
    fig = plt.figure(figsize=(min(15, sequence_length / 10), 3))  # Adjust width dynamically, keep height constant

    #plt.rcParams.update({'font.size': 16.5})  # make font size bigger 
    ax = sns.heatmap(
      combined_array,
      cmap='viridis',
      xticklabels=x_tick_labels,
      yticklabels=['Original Logits', 'Conserved'],
      cbar=True,
      cbar_kws={'aspect': 2, 
                'pad': 0.02,
                'shrink': 1.0,  # Adjust the overall size of the color bar
      }
    )
    # Access the color bar
    cbar = ax.collections[0].colorbar

    # Change the font size of the tick labels on the color bar
    cbar.ax.tick_params(labelsize=20)  # Adjust the font size of tick labels

    plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=90, fontsize=20)
    plt.yticks(fontsize=20, rotation=0)
    plt.title(f'{fusion_name} Residues {start}-{end}', fontsize=30)
    plt.xlabel('Residue Index', fontsize=30)
    plt.tight_layout()
    plt.show()
    
    # save the figure
    plt.savefig(save_path, format='png', dpi=300)
    
# plotting heatmap 1
def plot_full_heatmap(mutation_results, tokenizer, fusion_name="Fusion Oncoprotein", save_path="full_heatmap.png"):
  start = mutation_results['start']
  end = mutation_results['end']
  logits = mutation_results['logits']
  logits_for_each_AA = mutation_results['logits_for_each_AA']
  filtered_indices = mutation_results['filtered_indices']
  
  # get start and end index
  start_index = start - 1
  end_index = end
  
  # prepare data for plotting
  token_indices = torch.arange(logits.size(-1))
  tokens = [tokenizer.decode([idx]) for idx in token_indices]
  filtered_tokens = [tokens[i] for i in filtered_indices]
  all_logits_array = np.vstack(logits_for_each_AA)
  normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
  transposed_logits_array = normalized_logits_array.T

  # get x tick labels
  x_tick_positions, x_tick_labels = get_x_tick_labels(start, end)
  
  # make plot
  set_font()
  fig = plt.figure(figsize=(15, 8))
  plt.rcParams.update({'font.size': 16.5})
  sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
  plt.title(f'{fusion_name} Residues {start}-{end}: Token Probability')
  plt.ylabel('Amino Acid')
  plt.xlabel('Residue Index')
  plt.yticks(rotation=0)
  plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0)
  plt.tight_layout()
  plt.savefig(save_path, format='png', dpi = 300)
 
def plot_color_bar():
    """
    Create a Viridis color bar ranging from 0 to 1.
    """
    # Create a gradient from 0 to 1
    gradient = np.linspace(0, 1, 256).reshape(1, -1)
    
    # Plot the gradient as a color bar
    fig, ax = plt.subplots(figsize=(12, 3))
    ax.imshow(gradient, aspect="auto", cmap="viridis")
    ax.set_xticks([0, 255])
    ax.set_xticklabels(["0\nmost likely\nto mutate", "1\nleast likely\nto mutate"], fontsize=40)
    ax.set_yticks([])
    ax.set_title("Original Residue Logits", fontsize=40)
    
    # Save the figure
    plt.tight_layout()
    plt.show()
    plt.savefig("viridis_color_bar.png", dpi=300)

def main():
    # Call the function to create and display the color bar
    plot_color_bar()
    
    results_dir = "results/final"
    subfolders = os.listdir(results_dir)
    for subfolder in subfolders:
        full_path = f"{results_dir}/{subfolder}"
        if os.path.isdir(full_path):
            with open(f"{full_path}/raw_mutation_results.pkl", "rb") as f:
                mutation_results = pickle.load(f)
                plot_conservation_heatmap(mutation_results, 
                                          fusion_name=subfolder, save_path=f"{full_path}/conservation_heatmap.png")
                    
                
         
if __name__ == "__main__":
    main()