Update inference.py
Browse files- inference.py +5 -1
inference.py
CHANGED
|
@@ -22,11 +22,15 @@ class DebertaEvaluator(nn.Module):
|
|
| 22 |
|
| 23 |
return linear_output
|
| 24 |
|
| 25 |
-
def inference():
|
| 26 |
saved_model_path = './'
|
| 27 |
model = torch.load(saved_model_path + 'fine-tuned-model.pt', map_location=torch.device('cpu'))
|
| 28 |
tokenizer = torch.load(saved_model_path + 'fine-tuned-tokenizer.pt', map_location=torch.device('cpu'))
|
| 29 |
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
if __name__ == "__main__":
|
| 32 |
inference()
|
|
|
|
| 22 |
|
| 23 |
return linear_output
|
| 24 |
|
| 25 |
+
def inference(input_text):
|
| 26 |
saved_model_path = './'
|
| 27 |
model = torch.load(saved_model_path + 'fine-tuned-model.pt', map_location=torch.device('cpu'))
|
| 28 |
tokenizer = torch.load(saved_model_path + 'fine-tuned-tokenizer.pt', map_location=torch.device('cpu'))
|
| 29 |
model.eval()
|
| 30 |
+
input = tokenizer(input_text)
|
| 31 |
+
output = model(input_data['input_ids'].squeeze(1), input_data['attention_mask'])
|
| 32 |
+
|
| 33 |
+
return output.tolist()
|
| 34 |
|
| 35 |
if __name__ == "__main__":
|
| 36 |
inference()
|