Update file trainer.py
Browse files- trainer.py +10 -11
trainer.py
CHANGED
|
@@ -4,13 +4,13 @@ from util import Config
|
|
| 4 |
|
| 5 |
|
| 6 |
|
| 7 |
-
class NoOptimizer(torch.optim.Optimizer):
|
| 8 |
-
def __init__(self, params, lr=0):
|
| 9 |
-
defaults = dict(lr=lr)
|
| 10 |
-
super(NoOptimizer, self).__init__(params, defaults)
|
| 11 |
-
|
| 12 |
-
def step(self, closure=None):
|
| 13 |
-
pass
|
| 14 |
|
| 15 |
|
| 16 |
|
|
@@ -19,7 +19,6 @@ class Trainer:
|
|
| 19 |
self.__dict__ = dict(config.__dict__)
|
| 20 |
|
| 21 |
|
| 22 |
-
|
| 23 |
def log(self, loss: float):
|
| 24 |
print(f"Epoch: {self.epoch} / {self.num_epochs}\t\tBatch: {self.batch} / {self.num_batches}\t\tLoss: {round(loss, 4)}")
|
| 25 |
|
|
@@ -32,9 +31,9 @@ class Trainer:
|
|
| 32 |
|
| 33 |
|
| 34 |
def train(self, batches):
|
| 35 |
-
|
| 36 |
-
self.optimizer = NoOptimizer(self.model.parameters(), lr=self.learning_rate)
|
| 37 |
-
|
| 38 |
self.model.unfreeze()
|
| 39 |
|
| 40 |
for self.epoch in range(self.num_epochs):
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
|
| 7 |
+
#class NoOptimizer(torch.optim.Optimizer):
|
| 8 |
+
# def __init__(self, params, lr=0):
|
| 9 |
+
# defaults = dict(lr=lr)
|
| 10 |
+
# super(NoOptimizer, self).__init__(params, defaults)
|
| 11 |
+
#
|
| 12 |
+
# def step(self, closure=None):
|
| 13 |
+
# pass
|
| 14 |
|
| 15 |
|
| 16 |
|
|
|
|
| 19 |
self.__dict__ = dict(config.__dict__)
|
| 20 |
|
| 21 |
|
|
|
|
| 22 |
def log(self, loss: float):
|
| 23 |
print(f"Epoch: {self.epoch} / {self.num_epochs}\t\tBatch: {self.batch} / {self.num_batches}\t\tLoss: {round(loss, 4)}")
|
| 24 |
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
def train(self, batches):
|
| 34 |
+
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
|
| 35 |
+
#self.optimizer = NoOptimizer(self.model.parameters(), lr=self.learning_rate)
|
| 36 |
+
|
| 37 |
self.model.unfreeze()
|
| 38 |
|
| 39 |
for self.epoch in range(self.num_epochs):
|