Update app.py
Browse files
app.py
CHANGED
|
@@ -29,26 +29,37 @@ def process_sequence(sequence, domain_bounds, n):
|
|
| 29 |
all_logits = []
|
| 30 |
|
| 31 |
for i in range(len(sequence)):
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
# filter out non-amino acid tokens
|
| 42 |
-
filtered_indices = list(range(4, 23 + 1))
|
| 43 |
-
filtered_logits = mask_token_logits[:, filtered_indices]
|
| 44 |
-
# Decode top n tokens
|
| 45 |
-
top_n_tokens = torch.topk(filtered_logits, n, dim=1).indices[0].tolist()
|
| 46 |
-
mutation = [tokenizer.decode([token]) for token in top_n_tokens]
|
| 47 |
-
top_n_mutations[(sequence[i], i)] = mutation
|
| 48 |
-
|
| 49 |
-
logits_array = mask_token_logits.cpu().numpy()
|
| 50 |
-
filtered_logits = logits_array[:, filtered_indices]
|
| 51 |
-
all_logits.append(filtered_logits)
|
| 52 |
|
| 53 |
token_indices = torch.arange(logits.size(-1))
|
| 54 |
tokens = [tokenizer.decode([idx]) for idx in token_indices]
|
|
@@ -63,7 +74,7 @@ def process_sequence(sequence, domain_bounds, n):
|
|
| 63 |
x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
|
| 64 |
|
| 65 |
plt.figure(figsize=(15, 8))
|
| 66 |
-
plt.rcParams.update({'font.size':
|
| 67 |
|
| 68 |
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
|
| 69 |
plt.title('Token Probability Heatmap')
|
|
|
|
| 29 |
all_logits = []
|
| 30 |
|
| 31 |
for i in range(len(sequence)):
|
| 32 |
+
if start_index <= i <= (end_index - 1):
|
| 33 |
+
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
|
| 34 |
+
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
|
| 35 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
logits = model(**inputs).logits
|
| 38 |
+
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
|
| 39 |
+
mask_token_logits = logits[0, mask_token_index, :]
|
| 40 |
+
|
| 41 |
+
# Define amino acid tokens
|
| 42 |
+
AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
|
| 43 |
+
all_tokens_logits = mask_token_logits.squeeze(0)
|
| 44 |
+
top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True)
|
| 45 |
+
top_tokens_logits = all_tokens_logits[top_tokens_indices]
|
| 46 |
+
mutation = []
|
| 47 |
+
# make sure we don't include non-AA tokens
|
| 48 |
+
for token_index in top_tokens_indices:
|
| 49 |
+
decoded_token = tokenizer.decode([token_index.item()])
|
| 50 |
+
if decoded_token in AAs_tokens:
|
| 51 |
+
mutation.append(decoded_token)
|
| 52 |
+
if len(mutation) == n:
|
| 53 |
+
break
|
| 54 |
+
top_n_mutations[(sequence[i], i)] = mutation
|
| 55 |
+
|
| 56 |
+
# collecting logits for the heatmap
|
| 57 |
+
logits_array = mask_token_logits.cpu().numpy()
|
| 58 |
+
# filter out non-amino acid tokens
|
| 59 |
+
filtered_indices = list(range(4, 23 + 1))
|
| 60 |
+
filtered_logits = logits_array[:, filtered_indices]
|
| 61 |
+
all_logits.append(filtered_logits)
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
token_indices = torch.arange(logits.size(-1))
|
| 65 |
tokens = [tokenizer.decode([idx]) for idx in token_indices]
|
|
|
|
| 74 |
x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
|
| 75 |
|
| 76 |
plt.figure(figsize=(15, 8))
|
| 77 |
+
plt.rcParams.update({'font.size': 18})
|
| 78 |
|
| 79 |
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
|
| 80 |
plt.title('Token Probability Heatmap')
|