Update app.py
Browse files
app.py
CHANGED
|
@@ -2,6 +2,7 @@ import gradio as gr
|
|
| 2 |
import pandas as pd
|
| 3 |
import torch
|
| 4 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
|
|
|
| 5 |
import logging
|
| 6 |
import numpy as np
|
| 7 |
import matplotlib.pyplot as plt
|
|
@@ -52,15 +53,15 @@ def process_sequence(sequence, domain_bounds, n):
|
|
| 52 |
filtered_tokens = [tokens[i] for i in filtered_indices]
|
| 53 |
|
| 54 |
all_logits_array = np.vstack(all_logits)
|
| 55 |
-
normalized_logits_array = (
|
| 56 |
transposed_logits_array = normalized_logits_array.T
|
| 57 |
|
| 58 |
# Plotting the heatmap
|
| 59 |
-
|
| 60 |
-
|
| 61 |
|
| 62 |
plt.figure(figsize=(15, 8))
|
| 63 |
-
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=
|
| 64 |
plt.title('Logits for masked per residue tokens')
|
| 65 |
plt.ylabel('Token')
|
| 66 |
plt.xlabel('Residue Index')
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
import torch
|
| 4 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
import logging
|
| 7 |
import numpy as np
|
| 8 |
import matplotlib.pyplot as plt
|
|
|
|
| 53 |
filtered_tokens = [tokens[i] for i in filtered_indices]
|
| 54 |
|
| 55 |
all_logits_array = np.vstack(all_logits)
|
| 56 |
+
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
|
| 57 |
transposed_logits_array = normalized_logits_array.T
|
| 58 |
|
| 59 |
# Plotting the heatmap
|
| 60 |
+
x_tick_positions = np.arange(start_index, end_index)
|
| 61 |
+
x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
|
| 62 |
|
| 63 |
plt.figure(figsize=(15, 8))
|
| 64 |
+
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
|
| 65 |
plt.title('Logits for masked per residue tokens')
|
| 66 |
plt.ylabel('Token')
|
| 67 |
plt.xlabel('Residue Index')
|