flpelerin commited on
Commit
0049910
·
1 Parent(s): 1582608

Update 2 files

Browse files

- /model.py
- /trainer.py

Files changed (2) hide show
  1. model.py +0 -1
  2. trainer.py +14 -1
model.py CHANGED
@@ -34,7 +34,6 @@ class Model:
34
  shift_logits = lm_logits[:, :-1, :].contiguous()
35
  labels = labels[:, 1:].contiguous()
36
 
37
-
38
  loss_fct = criterion or torch.nn.CrossEntropyLoss()
39
  lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
40
 
 
34
  shift_logits = lm_logits[:, :-1, :].contiguous()
35
  labels = labels[:, 1:].contiguous()
36
 
 
37
  loss_fct = criterion or torch.nn.CrossEntropyLoss()
38
  lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
39
 
trainer.py CHANGED
@@ -3,6 +3,17 @@ import torch
3
  from util import Config
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
6
  class Trainer:
7
  def __init__(self, config: Config):
8
  self.__dict__ = dict(config.__dict__)
@@ -21,7 +32,9 @@ class Trainer:
21
 
22
 
23
  def train(self, batches):
24
- self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
 
 
25
  self.model.unfreeze()
26
 
27
  for self.epoch in range(self.num_epochs):
 
3
  from util import Config
4
 
5
 
6
+
7
+ class NoOptimizer(torch.optim.Optimizer):
8
+ def __init__(self, params, lr=0):
9
+ defaults = dict(lr=lr)
10
+ super(NoOptimizer, self).__init__(params, defaults)
11
+
12
+ def step(self, closure=None):
13
+ pass
14
+
15
+
16
+
17
  class Trainer:
18
  def __init__(self, config: Config):
19
  self.__dict__ = dict(config.__dict__)
 
32
 
33
 
34
  def train(self, batches):
35
+ #self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
36
+ self.optimizer = NoOptimizer(self.model.parameters(), lr=self.learning_rate)
37
+
38
  self.model.unfreeze()
39
 
40
  for self.epoch in range(self.num_epochs):