oucgc1996 commited on
Commit
ca16e8e
·
verified ·
1 Parent(s): 70d95e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -121
app.py CHANGED
@@ -4,138 +4,54 @@ from utils import create_vocab, setup_seed
4
  from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
5
  setup_seed(4)
6
 
7
- # def CTXGen(X1,X2,X3,model_name):
8
- # device = torch.device("cpu")
9
- # vocab_mlm = create_vocab()
10
- # vocab_mlm = add_tokens_to_vocab(vocab_mlm)
11
- # save_path = model_name
12
- # model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu'))
13
- # model = model.to(device)
14
-
15
- # predicted_token_probability_all = []
16
- # model.eval()
17
- # topk = []
18
- # with torch.no_grad():
19
- # new_seq = None
20
- # seq = [f"{X1}|{X2}|{X3}|||"]
21
- # vocab_mlm.token_to_idx["X"] = 4
22
- # padded_seq, _, idx_msa, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
23
- # idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
24
- # mask_positions = [i for i, token in enumerate(padded_seq) if token == "X"]
25
- # #if not mask_positions:
26
- # #raise ValueError("Nothing found in the sequence to predict.")
27
-
28
- # for mask_position in mask_positions:
29
- # padded_seq[mask_position] = "[MASK]"
30
- # input_ids = vocab_mlm.__getitem__(padded_seq)
31
- # input_ids = torch.tensor([input_ids]).to(device)
32
- # logits = model(input_ids, idx_msa)
33
- # mask_logits = logits[0, mask_position, :]
34
- # predicted_token_probability, predicted_token_id = torch.topk((torch.softmax(mask_logits, dim=-1)), k=5)
35
- # topk.append(predicted_token_id)
36
- # predicted_token = vocab_mlm.idx_to_token[predicted_token_id[0].item()]
37
- # predicted_token_probability_all.append(predicted_token_probability[0].item())
38
- # padded_seq[mask_position] = predicted_token
39
-
40
- # cls_pos = vocab_mlm.to_tokens(list(topk[0]))
41
- # if X1 != "X":
42
- # Topk = X1
43
- # Subtype = X1
44
- # Potency = padded_seq[2],predicted_token_probability_all[0]
45
- # elif X2 != "X":
46
- # Topk = cls_pos
47
- # Subtype = padded_seq[1],predicted_token_probability_all[0]
48
- # Potency = X2
49
- # else:
50
- # Topk = cls_pos
51
- # Subtype = padded_seq[1],predicted_token_probability_all[0]
52
- # Potency = padded_seq[2],predicted_token_probability_all[1]
53
- # return Subtype, Potency, Topk
54
-
55
- def CTXGen(X1, X2, X3, model_name):
56
  device = torch.device("cpu")
57
  vocab_mlm = create_vocab()
58
  vocab_mlm = add_tokens_to_vocab(vocab_mlm)
59
-
60
- model = torch.load(model_name, weights_only=False, map_location=device)
61
  model = model.to(device)
62
- model.eval()
63
-
64
- seq = [f"{X1}|{X2}|{X3}|||"]
65
- vocab_mlm.token_to_idx["X"] = 4
66
-
67
- padded_seq, _, idx_msa, _ = get_paded_token_idx_gen(vocab_mlm, seq, None)
68
- idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
69
-
70
- mask_positions = [i for i, token in enumerate(padded_seq) if token == "X"]
71
-
72
  predicted_token_probability_all = []
 
73
  topk = []
74
-
75
  with torch.no_grad():
76
- if mask_positions:
77
- for mask_position in mask_positions:
78
- padded_seq[mask_position] = "[MASK]"
79
- input_ids = torch.tensor(
80
- [vocab_mlm.__getitem__(padded_seq)]
81
- ).to(device)
82
-
83
- logits = model(input_ids, idx_msa)
84
- mask_logits = logits[0, mask_position, :]
85
- probs = torch.softmax(mask_logits, dim=-1)
86
-
87
- prob, token_id = torch.topk(probs, k=5)
88
- topk.append(token_id)
89
-
90
- predicted_token = vocab_mlm.idx_to_token[token_id[0].item()]
91
- predicted_token_probability_all.append(prob[0].item())
92
-
93
- padded_seq[mask_position] = predicted_token
94
-
95
- cls_pos = vocab_mlm.to_tokens(list(topk[0]))
96
-
97
- if X1 != "X":
98
- Subtype = X1
99
- Potency = (padded_seq[2], predicted_token_probability_all[0])
100
- elif X2 != "X":
101
- Subtype = (padded_seq[1], predicted_token_probability_all[0])
102
- Potency = X2
103
- else:
104
- Subtype = (padded_seq[1], predicted_token_probability_all[0])
105
- Potency = (padded_seq[2], predicted_token_probability_all[1])
106
-
107
  Topk = cls_pos
108
-
 
109
  else:
110
- probs_known = {}
111
-
112
- for pos, token in enumerate(padded_seq):
113
- if token in ["|", "[PAD]"]:
114
- continue
115
-
116
- original_token = token
117
- padded_seq[pos] = "[MASK]"
118
-
119
- input_ids = torch.tensor(
120
- [vocab_mlm.__getitem__(padded_seq)]
121
- ).to(device)
122
-
123
- logits = model(input_ids, idx_msa)
124
- mask_logits = logits[0, pos, :]
125
- probs = torch.softmax(mask_logits, dim=-1)
126
-
127
- token_id = vocab_mlm.token_to_idx[original_token]
128
- probs_known[pos] = probs[token_id].item()
129
-
130
- padded_seq[pos] = original_token
131
-
132
- Subtype = (X1, probs_known.get(0, None))
133
- Potency = (X2, probs_known.get(2, None))
134
- Topk = "All known (no prediction)"
135
-
136
  return Subtype, Potency, Topk
137
 
138
-
139
  iface = gr.Interface(
140
  fn=CTXGen,
141
  inputs=[
 
4
  from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
5
  setup_seed(4)
6
 
7
+ def CTXGen(X1,X2,X3,model_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  device = torch.device("cpu")
9
  vocab_mlm = create_vocab()
10
  vocab_mlm = add_tokens_to_vocab(vocab_mlm)
11
+ save_path = model_name
12
+ model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu'))
13
  model = model.to(device)
14
+
 
 
 
 
 
 
 
 
 
15
  predicted_token_probability_all = []
16
+ model.eval()
17
  topk = []
 
18
  with torch.no_grad():
19
+ new_seq = None
20
+ seq = [f"{X1}|{X2}|{X3}|||"]
21
+ vocab_mlm.token_to_idx["X"] = 4
22
+ padded_seq, _, idx_msa, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
23
+ idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
24
+ mask_positions = [i for i, token in enumerate(padded_seq) if token == "X"]
25
+ if not mask_positions:
26
+ raise ValueError("Nothing found in the sequence to predict.")
27
+
28
+ for mask_position in mask_positions:
29
+ padded_seq[mask_position] = "[MASK]"
30
+ input_ids = vocab_mlm.__getitem__(padded_seq)
31
+ input_ids = torch.tensor([input_ids]).to(device)
32
+ logits = model(input_ids, idx_msa)
33
+ mask_logits = logits[0, mask_position, :]
34
+ predicted_token_probability, predicted_token_id = torch.topk((torch.softmax(mask_logits, dim=-1)), k=5)
35
+ topk.append(predicted_token_id)
36
+ predicted_token = vocab_mlm.idx_to_token[predicted_token_id[0].item()]
37
+ predicted_token_probability_all.append(predicted_token_probability[0].item())
38
+ padded_seq[mask_position] = predicted_token
39
+
40
+ cls_pos = vocab_mlm.to_tokens(list(topk[0]))
41
+ if X1 != "X":
42
+ Topk = X1
43
+ Subtype = X1
44
+ Potency = padded_seq[2],predicted_token_probability_all[0]
45
+ elif X2 != "X":
 
 
 
 
46
  Topk = cls_pos
47
+ Subtype = padded_seq[1],predicted_token_probability_all[0]
48
+ Potency = X2
49
  else:
50
+ Topk = cls_pos
51
+ Subtype = padded_seq[1],predicted_token_probability_all[0]
52
+ Potency = padded_seq[2],predicted_token_probability_all[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  return Subtype, Potency, Topk
54
 
 
55
  iface = gr.Interface(
56
  fn=CTXGen,
57
  inputs=[