Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,138 +4,54 @@ from utils import create_vocab, setup_seed
|
|
| 4 |
from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
|
| 5 |
setup_seed(4)
|
| 6 |
|
| 7 |
-
|
| 8 |
-
# device = torch.device("cpu")
|
| 9 |
-
# vocab_mlm = create_vocab()
|
| 10 |
-
# vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
| 11 |
-
# save_path = model_name
|
| 12 |
-
# model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu'))
|
| 13 |
-
# model = model.to(device)
|
| 14 |
-
|
| 15 |
-
# predicted_token_probability_all = []
|
| 16 |
-
# model.eval()
|
| 17 |
-
# topk = []
|
| 18 |
-
# with torch.no_grad():
|
| 19 |
-
# new_seq = None
|
| 20 |
-
# seq = [f"{X1}|{X2}|{X3}|||"]
|
| 21 |
-
# vocab_mlm.token_to_idx["X"] = 4
|
| 22 |
-
# padded_seq, _, idx_msa, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
| 23 |
-
# idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
| 24 |
-
# mask_positions = [i for i, token in enumerate(padded_seq) if token == "X"]
|
| 25 |
-
# #if not mask_positions:
|
| 26 |
-
# #raise ValueError("Nothing found in the sequence to predict.")
|
| 27 |
-
|
| 28 |
-
# for mask_position in mask_positions:
|
| 29 |
-
# padded_seq[mask_position] = "[MASK]"
|
| 30 |
-
# input_ids = vocab_mlm.__getitem__(padded_seq)
|
| 31 |
-
# input_ids = torch.tensor([input_ids]).to(device)
|
| 32 |
-
# logits = model(input_ids, idx_msa)
|
| 33 |
-
# mask_logits = logits[0, mask_position, :]
|
| 34 |
-
# predicted_token_probability, predicted_token_id = torch.topk((torch.softmax(mask_logits, dim=-1)), k=5)
|
| 35 |
-
# topk.append(predicted_token_id)
|
| 36 |
-
# predicted_token = vocab_mlm.idx_to_token[predicted_token_id[0].item()]
|
| 37 |
-
# predicted_token_probability_all.append(predicted_token_probability[0].item())
|
| 38 |
-
# padded_seq[mask_position] = predicted_token
|
| 39 |
-
|
| 40 |
-
# cls_pos = vocab_mlm.to_tokens(list(topk[0]))
|
| 41 |
-
# if X1 != "X":
|
| 42 |
-
# Topk = X1
|
| 43 |
-
# Subtype = X1
|
| 44 |
-
# Potency = padded_seq[2],predicted_token_probability_all[0]
|
| 45 |
-
# elif X2 != "X":
|
| 46 |
-
# Topk = cls_pos
|
| 47 |
-
# Subtype = padded_seq[1],predicted_token_probability_all[0]
|
| 48 |
-
# Potency = X2
|
| 49 |
-
# else:
|
| 50 |
-
# Topk = cls_pos
|
| 51 |
-
# Subtype = padded_seq[1],predicted_token_probability_all[0]
|
| 52 |
-
# Potency = padded_seq[2],predicted_token_probability_all[1]
|
| 53 |
-
# return Subtype, Potency, Topk
|
| 54 |
-
|
| 55 |
-
def CTXGen(X1, X2, X3, model_name):
|
| 56 |
device = torch.device("cpu")
|
| 57 |
vocab_mlm = create_vocab()
|
| 58 |
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
| 59 |
-
|
| 60 |
-
model = torch.load(
|
| 61 |
model = model.to(device)
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
seq = [f"{X1}|{X2}|{X3}|||"]
|
| 65 |
-
vocab_mlm.token_to_idx["X"] = 4
|
| 66 |
-
|
| 67 |
-
padded_seq, _, idx_msa, _ = get_paded_token_idx_gen(vocab_mlm, seq, None)
|
| 68 |
-
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
| 69 |
-
|
| 70 |
-
mask_positions = [i for i, token in enumerate(padded_seq) if token == "X"]
|
| 71 |
-
|
| 72 |
predicted_token_probability_all = []
|
|
|
|
| 73 |
topk = []
|
| 74 |
-
|
| 75 |
with torch.no_grad():
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
else:
|
| 104 |
-
Subtype = (padded_seq[1], predicted_token_probability_all[0])
|
| 105 |
-
Potency = (padded_seq[2], predicted_token_probability_all[1])
|
| 106 |
-
|
| 107 |
Topk = cls_pos
|
| 108 |
-
|
|
|
|
| 109 |
else:
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
if token in ["|", "[PAD]"]:
|
| 114 |
-
continue
|
| 115 |
-
|
| 116 |
-
original_token = token
|
| 117 |
-
padded_seq[pos] = "[MASK]"
|
| 118 |
-
|
| 119 |
-
input_ids = torch.tensor(
|
| 120 |
-
[vocab_mlm.__getitem__(padded_seq)]
|
| 121 |
-
).to(device)
|
| 122 |
-
|
| 123 |
-
logits = model(input_ids, idx_msa)
|
| 124 |
-
mask_logits = logits[0, pos, :]
|
| 125 |
-
probs = torch.softmax(mask_logits, dim=-1)
|
| 126 |
-
|
| 127 |
-
token_id = vocab_mlm.token_to_idx[original_token]
|
| 128 |
-
probs_known[pos] = probs[token_id].item()
|
| 129 |
-
|
| 130 |
-
padded_seq[pos] = original_token
|
| 131 |
-
|
| 132 |
-
Subtype = (X1, probs_known.get(0, None))
|
| 133 |
-
Potency = (X2, probs_known.get(2, None))
|
| 134 |
-
Topk = "All known (no prediction)"
|
| 135 |
-
|
| 136 |
return Subtype, Potency, Topk
|
| 137 |
|
| 138 |
-
|
| 139 |
iface = gr.Interface(
|
| 140 |
fn=CTXGen,
|
| 141 |
inputs=[
|
|
|
|
| 4 |
from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
|
| 5 |
setup_seed(4)
|
| 6 |
|
| 7 |
+
def CTXGen(X1,X2,X3,model_name):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
device = torch.device("cpu")
|
| 9 |
vocab_mlm = create_vocab()
|
| 10 |
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
| 11 |
+
save_path = model_name
|
| 12 |
+
model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu'))
|
| 13 |
model = model.to(device)
|
| 14 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
predicted_token_probability_all = []
|
| 16 |
+
model.eval()
|
| 17 |
topk = []
|
|
|
|
| 18 |
with torch.no_grad():
|
| 19 |
+
new_seq = None
|
| 20 |
+
seq = [f"{X1}|{X2}|{X3}|||"]
|
| 21 |
+
vocab_mlm.token_to_idx["X"] = 4
|
| 22 |
+
padded_seq, _, idx_msa, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
| 23 |
+
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
| 24 |
+
mask_positions = [i for i, token in enumerate(padded_seq) if token == "X"]
|
| 25 |
+
if not mask_positions:
|
| 26 |
+
raise ValueError("Nothing found in the sequence to predict.")
|
| 27 |
+
|
| 28 |
+
for mask_position in mask_positions:
|
| 29 |
+
padded_seq[mask_position] = "[MASK]"
|
| 30 |
+
input_ids = vocab_mlm.__getitem__(padded_seq)
|
| 31 |
+
input_ids = torch.tensor([input_ids]).to(device)
|
| 32 |
+
logits = model(input_ids, idx_msa)
|
| 33 |
+
mask_logits = logits[0, mask_position, :]
|
| 34 |
+
predicted_token_probability, predicted_token_id = torch.topk((torch.softmax(mask_logits, dim=-1)), k=5)
|
| 35 |
+
topk.append(predicted_token_id)
|
| 36 |
+
predicted_token = vocab_mlm.idx_to_token[predicted_token_id[0].item()]
|
| 37 |
+
predicted_token_probability_all.append(predicted_token_probability[0].item())
|
| 38 |
+
padded_seq[mask_position] = predicted_token
|
| 39 |
+
|
| 40 |
+
cls_pos = vocab_mlm.to_tokens(list(topk[0]))
|
| 41 |
+
if X1 != "X":
|
| 42 |
+
Topk = X1
|
| 43 |
+
Subtype = X1
|
| 44 |
+
Potency = padded_seq[2],predicted_token_probability_all[0]
|
| 45 |
+
elif X2 != "X":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
Topk = cls_pos
|
| 47 |
+
Subtype = padded_seq[1],predicted_token_probability_all[0]
|
| 48 |
+
Potency = X2
|
| 49 |
else:
|
| 50 |
+
Topk = cls_pos
|
| 51 |
+
Subtype = padded_seq[1],predicted_token_probability_all[0]
|
| 52 |
+
Potency = padded_seq[2],predicted_token_probability_all[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
return Subtype, Potency, Topk
|
| 54 |
|
|
|
|
| 55 |
iface = gr.Interface(
|
| 56 |
fn=CTXGen,
|
| 57 |
inputs=[
|