flpelerin commited on
Commit
dcd6392
·
1 Parent(s): 0049910

Update file trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +10 -11
trainer.py CHANGED
@@ -4,13 +4,13 @@ from util import Config
4
 
5
 
6
 
7
- class NoOptimizer(torch.optim.Optimizer):
8
- def __init__(self, params, lr=0):
9
- defaults = dict(lr=lr)
10
- super(NoOptimizer, self).__init__(params, defaults)
11
-
12
- def step(self, closure=None):
13
- pass
14
 
15
 
16
 
@@ -19,7 +19,6 @@ class Trainer:
19
  self.__dict__ = dict(config.__dict__)
20
 
21
 
22
-
23
  def log(self, loss: float):
24
  print(f"Epoch: {self.epoch} / {self.num_epochs}\t\tBatch: {self.batch} / {self.num_batches}\t\tLoss: {round(loss, 4)}")
25
 
@@ -32,9 +31,9 @@ class Trainer:
32
 
33
 
34
  def train(self, batches):
35
- #self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
36
- self.optimizer = NoOptimizer(self.model.parameters(), lr=self.learning_rate)
37
-
38
  self.model.unfreeze()
39
 
40
  for self.epoch in range(self.num_epochs):
 
4
 
5
 
6
 
7
+ #class NoOptimizer(torch.optim.Optimizer):
8
+ # def __init__(self, params, lr=0):
9
+ # defaults = dict(lr=lr)
10
+ # super(NoOptimizer, self).__init__(params, defaults)
11
+ #
12
+ # def step(self, closure=None):
13
+ # pass
14
 
15
 
16
 
 
19
  self.__dict__ = dict(config.__dict__)
20
 
21
 
 
22
  def log(self, loss: float):
23
  print(f"Epoch: {self.epoch} / {self.num_epochs}\t\tBatch: {self.batch} / {self.num_batches}\t\tLoss: {round(loss, 4)}")
24
 
 
31
 
32
 
33
  def train(self, batches):
34
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
35
+ #self.optimizer = NoOptimizer(self.model.parameters(), lr=self.learning_rate)
36
+
37
  self.model.unfreeze()
38
 
39
  for self.epoch in range(self.num_epochs):