Spaces:
Running
Running
dbg
Browse files
new_test_saved_finetuned_model.py
CHANGED
|
@@ -162,6 +162,9 @@ class BERTFineTuneTrainer:
|
|
| 162 |
logits = self.model.forward(data["input"], data["segment_label"], data["feat"])
|
| 163 |
|
| 164 |
logits = logits.cpu()
|
|
|
|
|
|
|
|
|
|
| 165 |
loss = self.criterion(logits, data["label"])
|
| 166 |
# if torch.cuda.device_count() > 1:
|
| 167 |
# loss = loss.mean()
|
|
|
|
| 162 |
logits = self.model.forward(data["input"], data["segment_label"], data["feat"])
|
| 163 |
|
| 164 |
logits = logits.cpu()
|
| 165 |
+
devic = logits.device # or self.model.device if available
|
| 166 |
+
labels = data["label"].to(devic)
|
| 167 |
+
|
| 168 |
loss = self.criterion(logits, data["label"])
|
| 169 |
# if torch.cuda.device_count() > 1:
|
| 170 |
# loss = loss.mean()
|