TeacherPuffy commited on
Commit
49da7f3
·
verified ·
1 Parent(s): b1002cc

Update train_mlp.py

Browse files
Files changed (1) hide show
  1. train_mlp.py +12 -3
train_mlp.py CHANGED
@@ -23,7 +23,7 @@ class MLP(nn.Module):
23
  return self.model(x)
24
 
25
  # Train the model
26
- def train_model(model, train_dataset, val_dataset, epochs=10, lr=0.001, save_loss_path=None):
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
  model.to(device)
29
 
@@ -80,6 +80,11 @@ def train_model(model, train_dataset, val_dataset, epochs=10, lr=0.001, save_los
80
  val_losses.append(avg_val_loss)
81
  print(f'Validation Loss: {avg_val_loss}, Accuracy: {100 * correct / total}%')
82
 
 
 
 
 
 
83
  if save_loss_path:
84
  with open(save_loss_path, 'w') as f:
85
  for epoch, (train_loss, val_loss) in enumerate(zip(train_losses, val_losses)):
@@ -92,6 +97,7 @@ def main():
92
  parser = argparse.ArgumentParser(description='Train an MLP on the zh-plus/tiny-imagenet dataset.')
93
  parser.add_argument('--layer_count', type=int, default=2, help='Number of hidden layers (default: 2)')
94
  parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
 
95
  args = parser.parse_args()
96
 
97
  # Load the zh-plus/tiny-imagenet dataset
@@ -114,9 +120,12 @@ def main():
114
 
115
  model = MLP(input_size, hidden_sizes, output_size)
116
 
 
 
 
117
  # Train the model and get the final loss
118
  save_loss_path = 'losses.txt'
119
- final_loss = train_model(model, train_dataset, val_dataset, save_loss_path=save_loss_path)
120
 
121
  # Calculate the number of parameters
122
  param_count = sum(p.numel() for p in model.parameters())
@@ -125,7 +134,7 @@ def main():
125
  model_folder = f'mlp_model_l{args.layer_count}w{args.width}'
126
  os.makedirs(model_folder, exist_ok=True)
127
 
128
- # Save the model
129
  model_path = os.path.join(model_folder, 'model.pth')
130
  torch.save(model.state_dict(), model_path)
131
 
 
23
  return self.model(x)
24
 
25
  # Train the model
26
+ def train_model(model, train_dataset, val_dataset, epochs=10, lr=0.001, save_loss_path=None, save_model_dir=None):
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
  model.to(device)
29
 
 
80
  val_losses.append(avg_val_loss)
81
  print(f'Validation Loss: {avg_val_loss}, Accuracy: {100 * correct / total}%')
82
 
83
+ # Save the model after each epoch
84
+ if save_model_dir:
85
+ model_path = os.path.join(save_model_dir, f'model_epoch_{epoch+1}.pth')
86
+ torch.save(model.state_dict(), model_path)
87
+
88
  if save_loss_path:
89
  with open(save_loss_path, 'w') as f:
90
  for epoch, (train_loss, val_loss) in enumerate(zip(train_losses, val_losses)):
 
97
  parser = argparse.ArgumentParser(description='Train an MLP on the zh-plus/tiny-imagenet dataset.')
98
  parser.add_argument('--layer_count', type=int, default=2, help='Number of hidden layers (default: 2)')
99
  parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
100
+ parser.add_argument('--save_model_dir', type=str, default='saved_models', help='Directory to save model checkpoints (default: saved_models)')
101
  args = parser.parse_args()
102
 
103
  # Load the zh-plus/tiny-imagenet dataset
 
120
 
121
  model = MLP(input_size, hidden_sizes, output_size)
122
 
123
+ # Create the directory to save models
124
+ os.makedirs(args.save_model_dir, exist_ok=True)
125
+
126
  # Train the model and get the final loss
127
  save_loss_path = 'losses.txt'
128
+ final_loss = train_model(model, train_dataset, val_dataset, save_loss_path=save_loss_path, save_model_dir=args.save_model_dir)
129
 
130
  # Calculate the number of parameters
131
  param_count = sum(p.numel() for p in model.parameters())
 
134
  model_folder = f'mlp_model_l{args.layer_count}w{args.width}'
135
  os.makedirs(model_folder, exist_ok=True)
136
 
137
+ # Save the final model
138
  model_path = os.path.join(model_folder, 'model.pth')
139
  torch.save(model.state_dict(), model_path)
140