Update 2 files
Browse files- /trainer.cli.py
- /trainer.py
- trainer.cli.py +8 -2
- trainer.py +2 -8
trainer.cli.py
CHANGED
|
@@ -40,5 +40,11 @@ if __name__ == '__main__':
|
|
| 40 |
print(f"batches: {num_batches}")
|
| 41 |
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
print(f"batches: {num_batches}")
|
| 41 |
|
| 42 |
|
| 43 |
+
model = Model(config.model)
|
| 44 |
+
self.wandb = Wandb(config.wandb)
|
| 45 |
+
|
| 46 |
+
config.trainer.model = model
|
| 47 |
+
config.trainer.wandb = wandb
|
| 48 |
+
|
| 49 |
+
trainer = Trainer(config.trainer)
|
| 50 |
+
trainer.train(batches)
|
trainer.py
CHANGED
|
@@ -8,12 +8,7 @@ from model import Model
|
|
| 8 |
class Trainer:
|
| 9 |
def __init__(self, config: Config):
|
| 10 |
self.__dict__ = dict(config.trainer.__dict__)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
#self.wandb = Wandb(config.wandb)
|
| 14 |
-
self.model = Model(config.model)
|
| 15 |
|
| 16 |
-
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
|
| 17 |
|
| 18 |
|
| 19 |
def log(self, loss: float):
|
|
@@ -21,9 +16,8 @@ class Trainer:
|
|
| 21 |
|
| 22 |
|
| 23 |
|
| 24 |
-
def train(self, batches):
|
| 25 |
-
|
| 26 |
-
|
| 27 |
self.model.unfreeze()
|
| 28 |
|
| 29 |
for self.epoch in range(self.num_epochs):
|
|
|
|
| 8 |
class Trainer:
|
| 9 |
def __init__(self, config: Config):
|
| 10 |
self.__dict__ = dict(config.trainer.__dict__)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def log(self, loss: float):
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
|
| 19 |
+
def train(self, batches):
|
| 20 |
+
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
|
|
|
|
| 21 |
self.model.unfreeze()
|
| 22 |
|
| 23 |
for self.epoch in range(self.num_epochs):
|