jskvrna commited on
Commit
25d87ae
·
1 Parent(s): 708ba65

Saves model checkpoint every epoch

Browse files

Changes model checkpoint saving frequency to every epoch,
allowing for more frequent progress tracking and easier
recovery of potentially better model states.

Files changed (1) hide show
  1. fast_pointnet_class.py +9 -10
fast_pointnet_class.py CHANGED
@@ -316,16 +316,15 @@ def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, ba
316
 
317
  scheduler.step()
318
 
319
- # Save model checkpoint every 10 epochs
320
- if (epoch + 1) % 10 == 0:
321
- checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
322
- torch.save({
323
- 'model_state_dict': model.state_dict(),
324
- 'optimizer_state_dict': optimizer.state_dict(),
325
- 'epoch': epoch + 1,
326
- 'loss': avg_loss,
327
- 'accuracy': accuracy,
328
- }, checkpoint_path)
329
 
330
  # Save the trained model
331
  torch.save({
 
316
 
317
  scheduler.step()
318
 
319
+ # Save model checkpoint every epoch
320
+ checkpoint_path = model_save_path.replace('.pth', f'_epoch_{epoch+1}.pth')
321
+ torch.save({
322
+ 'model_state_dict': model.state_dict(),
323
+ 'optimizer_state_dict': optimizer.state_dict(),
324
+ 'epoch': epoch + 1,
325
+ 'loss': avg_loss,
326
+ 'accuracy': accuracy,
327
+ }, checkpoint_path)
 
328
 
329
  # Save the trained model
330
  torch.save({