Update file trainer.py
Browse files- trainer.py +5 -5
trainer.py
CHANGED
|
@@ -27,10 +27,10 @@ class Trainer:
|
|
| 27 |
for self.batch in range(self.num_batches):
|
| 28 |
ids = batches[self.batch]
|
| 29 |
|
| 30 |
-
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
|
| 36 |
-
|
|
|
|
| 27 |
for self.batch in range(self.num_batches):
|
| 28 |
ids = batches[self.batch]
|
| 29 |
|
| 30 |
+
loss = self.model.compute_loss(ids)
|
| 31 |
|
| 32 |
+
self.optimizer.zero_grad()
|
| 33 |
+
loss.backward()
|
| 34 |
+
self.optimizer.step()
|
| 35 |
|
| 36 |
+
self.log(loss.item())
|