Update app.py
Browse files
app.py
CHANGED
|
@@ -56,8 +56,8 @@ def process_sequence(sequence, domain_bounds, n):
|
|
| 56 |
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
|
| 57 |
transposed_logits_array = normalized_logits_array.T
|
| 58 |
|
| 59 |
-
|
| 60 |
-
x_tick_positions = np.arange(start_index, end_index)
|
| 61 |
x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
|
| 62 |
|
| 63 |
plt.figure(figsize=(15, 8))
|
|
@@ -66,7 +66,7 @@ def process_sequence(sequence, domain_bounds, n):
|
|
| 66 |
plt.ylabel('Token')
|
| 67 |
plt.xlabel('Residue Index')
|
| 68 |
plt.yticks(rotation=0)
|
| 69 |
-
plt.xticks(x_tick_positions, x_tick_labels, rotation
|
| 70 |
|
| 71 |
# Save the figure to a BytesIO object
|
| 72 |
buf = BytesIO()
|
|
|
|
| 56 |
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
|
| 57 |
transposed_logits_array = normalized_logits_array.T
|
| 58 |
|
| 59 |
+
# Plotting the heatmap
|
| 60 |
+
x_tick_positions = np.arange(start_index, end_index, 10)
|
| 61 |
x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
|
| 62 |
|
| 63 |
plt.figure(figsize=(15, 8))
|
|
|
|
| 66 |
plt.ylabel('Token')
|
| 67 |
plt.xlabel('Residue Index')
|
| 68 |
plt.yticks(rotation=0)
|
| 69 |
+
plt.xticks(x_tick_positions - start_index, x_tick_labels, rotation=0)
|
| 70 |
|
| 71 |
# Save the figure to a BytesIO object
|
| 72 |
buf = BytesIO()
|