HemanM commited on
Commit
2e57b59
·
verified ·
1 Parent(s): 3d8c93e

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +21 -0
inference.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from init_model import get_tokenizer, get_base_model
3
+ from model import EvoTransformer
4
+
5
+ tokenizer = get_tokenizer()
6
+ bert = get_base_model()
7
+ model = EvoTransformer()
8
+ model.eval()
9
+
10
+ def evo_suggest(question, option1, option2):
11
+ inputs = [question + " " + option1, question + " " + option2]
12
+ scores = []
13
+
14
+ for text in inputs:
15
+ encoded = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
16
+ with torch.no_grad():
17
+ outputs = bert(**encoded).last_hidden_state[:, 0, :] # [CLS] token
18
+ logits = model(outputs)
19
+ scores.append(logits[0][1].item()) # Confidence for class 1
20
+
21
+ return option1 if scores[0] > scores[1] else option2