Update tasks/text.py
Browse files- tasks/text.py +5 -5
tasks/text.py
CHANGED
|
@@ -82,11 +82,11 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 82 |
|
| 83 |
predictions = []
|
| 84 |
with torch.no_grad():
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
|
| 91 |
true_labels = test_dataset["label"]
|
| 92 |
|
|
|
|
| 82 |
|
| 83 |
predictions = []
|
| 84 |
with torch.no_grad():
|
| 85 |
+
for batch in test_loader:
|
| 86 |
+
input_ids, attention_mask, labels = [x.to(device) for x in batch]
|
| 87 |
+
outputs = model(input_ids, attention_mask=attention_mask)
|
| 88 |
+
preds = torch.argmax(outputs.logits, dim=1)
|
| 89 |
+
predictions.extend(preds.cpu().numpy())
|
| 90 |
|
| 91 |
true_labels = test_dataset["label"]
|
| 92 |
|