Update file trainer.py
Browse files- trainer.py +2 -2
trainer.py
CHANGED
|
@@ -24,13 +24,13 @@ class Trainer:
|
|
| 24 |
def train(self, batches):
|
| 25 |
#pass
|
| 26 |
|
| 27 |
-
model.unfreeze()
|
| 28 |
|
| 29 |
for self.epoch in range(self.num_epochs):
|
| 30 |
for self.batch in range(self.num_batches):
|
| 31 |
ids = batches[batch]
|
| 32 |
|
| 33 |
-
loss = model.compute_loss(ids)
|
| 34 |
|
| 35 |
self.optimizer.zero_grad()
|
| 36 |
loss.backward()
|
|
|
|
| 24 |
def train(self, batches):
|
| 25 |
#pass
|
| 26 |
|
| 27 |
+
self.model.unfreeze()
|
| 28 |
|
| 29 |
for self.epoch in range(self.num_epochs):
|
| 30 |
for self.batch in range(self.num_batches):
|
| 31 |
ids = batches[batch]
|
| 32 |
|
| 33 |
+
loss = self.model.compute_loss(ids)
|
| 34 |
|
| 35 |
self.optimizer.zero_grad()
|
| 36 |
loss.backward()
|