Update tasks/text.py
Browse files- tasks/text.py +2 -2
tasks/text.py
CHANGED
|
@@ -100,8 +100,8 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 100 |
with torch.no_grad():
|
| 101 |
text_input_ids = text_encoding["input_ids"].to(device)
|
| 102 |
text_attention_mask = text_encoding["attention_mask"].to(device)
|
| 103 |
-
|
| 104 |
-
predictions = torch.argmax(
|
| 105 |
|
| 106 |
#--------------------------------------------------------------------------------------------
|
| 107 |
# YOUR MODEL INFERENCE STOPS HERE
|
|
|
|
| 100 |
with torch.no_grad():
|
| 101 |
text_input_ids = text_encoding["input_ids"].to(device)
|
| 102 |
text_attention_mask = text_encoding["attention_mask"].to(device)
|
| 103 |
+
logits = model(text_input_ids, text_attention_mask)
|
| 104 |
+
predictions = torch.argmax(logits, dim=1).cpu().numpy()
|
| 105 |
|
| 106 |
#--------------------------------------------------------------------------------------------
|
| 107 |
# YOUR MODEL INFERENCE STOPS HERE
|