Saves model checkpoint every epoch
Browse filesChanges model checkpoint saving frequency to every epoch,
allowing for more frequent progress tracking and easier
recovery of potentially better model states.
- fast_pointnet_class.py +9 -10
fast_pointnet_class.py
CHANGED
|
@@ -316,16 +316,15 @@ def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, ba
|
|
| 316 |
|
| 317 |
scheduler.step()
|
| 318 |
|
| 319 |
-
# Save model checkpoint every
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
}, checkpoint_path)
|
| 329 |
|
| 330 |
# Save the trained model
|
| 331 |
torch.save({
|
|
|
|
| 316 |
|
| 317 |
scheduler.step()
|
| 318 |
|
| 319 |
+
# Save model checkpoint every epoch
|
| 320 |
+
checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
|
| 321 |
+
torch.save({
|
| 322 |
+
'model_state_dict': model.state_dict(),
|
| 323 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 324 |
+
'epoch': epoch + 1,
|
| 325 |
+
'loss': avg_loss,
|
| 326 |
+
'accuracy': accuracy,
|
| 327 |
+
}, checkpoint_path)
|
|
|
|
| 328 |
|
| 329 |
# Save the trained model
|
| 330 |
torch.save({
|