fix: reset lr on load ckpt
Browse files- detector/model.py +5 -2
- train.py +1 -0
detector/model.py
CHANGED
|
@@ -130,6 +130,7 @@ class FontDetector(ptl.LightningModule):
|
|
| 130 |
betas: Tuple[float, float],
|
| 131 |
num_warmup_iters: int,
|
| 132 |
num_iters: int,
|
|
|
|
| 133 |
):
|
| 134 |
super().__init__()
|
| 135 |
self.model = model
|
|
@@ -156,6 +157,7 @@ class FontDetector(ptl.LightningModule):
|
|
| 156 |
self.betas = betas
|
| 157 |
self.num_warmup_iters = num_warmup_iters
|
| 158 |
self.num_iters = num_iters
|
|
|
|
| 159 |
self.load_step = 0
|
| 160 |
|
| 161 |
def forward(self, x):
|
|
@@ -240,7 +242,8 @@ class FontDetector(ptl.LightningModule):
|
|
| 240 |
self.scheduler = CosineWarmupScheduler(
|
| 241 |
optimizer, self.num_warmup_iters, self.num_iters
|
| 242 |
)
|
| 243 |
-
|
|
|
|
| 244 |
self.scheduler.step()
|
| 245 |
print("Current learning rate set to:", self.scheduler.get_last_lr())
|
| 246 |
return optimizer
|
|
@@ -261,4 +264,4 @@ class FontDetector(ptl.LightningModule):
|
|
| 261 |
self.scheduler.step()
|
| 262 |
|
| 263 |
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
| 264 |
-
self.
|
|
|
|
| 130 |
betas: Tuple[float, float],
|
| 131 |
num_warmup_iters: int,
|
| 132 |
num_iters: int,
|
| 133 |
+
num_epochs: int,
|
| 134 |
):
|
| 135 |
super().__init__()
|
| 136 |
self.model = model
|
|
|
|
| 157 |
self.betas = betas
|
| 158 |
self.num_warmup_iters = num_warmup_iters
|
| 159 |
self.num_iters = num_iters
|
| 160 |
+
self.num_epochs = num_epochs
|
| 161 |
self.load_step = 0
|
| 162 |
|
| 163 |
def forward(self, x):
|
|
|
|
| 242 |
self.scheduler = CosineWarmupScheduler(
|
| 243 |
optimizer, self.num_warmup_iters, self.num_iters
|
| 244 |
)
|
| 245 |
+
print("Load epoch:", self.load_epoch)
|
| 246 |
+
for _ in range(self.num_iters * self.load_epoch // self.num_epochs):
|
| 247 |
self.scheduler.step()
|
| 248 |
print("Current learning rate set to:", self.scheduler.get_last_lr())
|
| 249 |
return optimizer
|
|
|
|
| 264 |
self.scheduler.step()
|
| 265 |
|
| 266 |
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
| 267 |
+
self.load_epoch = checkpoint["epoch"]
|
train.py
CHANGED
|
@@ -87,6 +87,7 @@ detector = FontDetector(
|
|
| 87 |
betas=(b1, b2),
|
| 88 |
num_warmup_iters=num_warmup_iter,
|
| 89 |
num_iters=num_iters,
|
|
|
|
| 90 |
)
|
| 91 |
|
| 92 |
trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
|
|
|
|
| 87 |
betas=(b1, b2),
|
| 88 |
num_warmup_iters=num_warmup_iter,
|
| 89 |
num_iters=num_iters,
|
| 90 |
+
num_epochs=num_epochs,
|
| 91 |
)
|
| 92 |
|
| 93 |
trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
|