flpelerin commited on
Commit
76bd6b1
·
1 Parent(s): db3e9ba

Update 2 files

Browse files

- /trainer.cli.py
- /trainer.py

Files changed (2) hide show
  1. trainer.cli.py +8 -2
  2. trainer.py +2 -8
trainer.cli.py CHANGED
@@ -40,5 +40,11 @@ if __name__ == '__main__':
40
  print(f"batches: {num_batches}")
41
 
42
 
43
- trainer = Trainer(config)
44
- trainer.train(dataset)
 
 
 
 
 
 
 
40
  print(f"batches: {num_batches}")
41
 
42
 
43
+ model = Model(config.model)
44
+ self.wandb = Wandb(config.wandb)
45
+
46
+ config.trainer.model = model
47
+ config.trainer.wandb = wandb
48
+
49
+ trainer = Trainer(config.trainer)
50
+ trainer.train(batches)
trainer.py CHANGED
@@ -8,12 +8,7 @@ from model import Model
8
  class Trainer:
9
  def __init__(self, config: Config):
10
  self.__dict__ = dict(config.trainer.__dict__)
11
-
12
-
13
- #self.wandb = Wandb(config.wandb)
14
- self.model = Model(config.model)
15
 
16
- self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
17
 
18
 
19
  def log(self, loss: float):
@@ -21,9 +16,8 @@ class Trainer:
21
 
22
 
23
 
24
- def train(self, batches):
25
- #pass
26
-
27
  self.model.unfreeze()
28
 
29
  for self.epoch in range(self.num_epochs):
 
8
  class Trainer:
9
  def __init__(self, config: Config):
10
  self.__dict__ = dict(config.trainer.__dict__)
 
 
 
 
11
 
 
12
 
13
 
14
  def log(self, loss: float):
 
16
 
17
 
18
 
19
+ def train(self, batches):
20
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
 
21
  self.model.unfreeze()
22
 
23
  for self.epoch in range(self.num_epochs):