Update train_mlp.py
Browse files- train_mlp.py +12 -3
train_mlp.py
CHANGED
|
@@ -23,7 +23,7 @@ class MLP(nn.Module):
|
|
| 23 |
return self.model(x)
|
| 24 |
|
| 25 |
# Train the model
|
| 26 |
-
def train_model(model, train_dataset, val_dataset, epochs=10, lr=0.001, save_loss_path=None):
|
| 27 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 28 |
model.to(device)
|
| 29 |
|
|
@@ -80,6 +80,11 @@ def train_model(model, train_dataset, val_dataset, epochs=10, lr=0.001, save_los
|
|
| 80 |
val_losses.append(avg_val_loss)
|
| 81 |
print(f'Validation Loss: {avg_val_loss}, Accuracy: {100 * correct / total}%')
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
if save_loss_path:
|
| 84 |
with open(save_loss_path, 'w') as f:
|
| 85 |
for epoch, (train_loss, val_loss) in enumerate(zip(train_losses, val_losses)):
|
|
@@ -92,6 +97,7 @@ def main():
|
|
| 92 |
parser = argparse.ArgumentParser(description='Train an MLP on the zh-plus/tiny-imagenet dataset.')
|
| 93 |
parser.add_argument('--layer_count', type=int, default=2, help='Number of hidden layers (default: 2)')
|
| 94 |
parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
|
|
|
|
| 95 |
args = parser.parse_args()
|
| 96 |
|
| 97 |
# Load the zh-plus/tiny-imagenet dataset
|
|
@@ -114,9 +120,12 @@ def main():
|
|
| 114 |
|
| 115 |
model = MLP(input_size, hidden_sizes, output_size)
|
| 116 |
|
|
|
|
|
|
|
|
|
|
| 117 |
# Train the model and get the final loss
|
| 118 |
save_loss_path = 'losses.txt'
|
| 119 |
-
final_loss = train_model(model, train_dataset, val_dataset, save_loss_path=save_loss_path)
|
| 120 |
|
| 121 |
# Calculate the number of parameters
|
| 122 |
param_count = sum(p.numel() for p in model.parameters())
|
|
@@ -125,7 +134,7 @@ def main():
|
|
| 125 |
model_folder = f'mlp_model_l{args.layer_count}w{args.width}'
|
| 126 |
os.makedirs(model_folder, exist_ok=True)
|
| 127 |
|
| 128 |
-
# Save the model
|
| 129 |
model_path = os.path.join(model_folder, 'model.pth')
|
| 130 |
torch.save(model.state_dict(), model_path)
|
| 131 |
|
|
|
|
| 23 |
return self.model(x)
|
| 24 |
|
| 25 |
# Train the model
|
| 26 |
+
def train_model(model, train_dataset, val_dataset, epochs=10, lr=0.001, save_loss_path=None, save_model_dir=None):
|
| 27 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 28 |
model.to(device)
|
| 29 |
|
|
|
|
| 80 |
val_losses.append(avg_val_loss)
|
| 81 |
print(f'Validation Loss: {avg_val_loss}, Accuracy: {100 * correct / total}%')
|
| 82 |
|
| 83 |
+
# Save the model after each epoch
|
| 84 |
+
if save_model_dir:
|
| 85 |
+
model_path = os.path.join(save_model_dir, f'model_epoch_{epoch+1}.pth')
|
| 86 |
+
torch.save(model.state_dict(), model_path)
|
| 87 |
+
|
| 88 |
if save_loss_path:
|
| 89 |
with open(save_loss_path, 'w') as f:
|
| 90 |
for epoch, (train_loss, val_loss) in enumerate(zip(train_losses, val_losses)):
|
|
|
|
| 97 |
parser = argparse.ArgumentParser(description='Train an MLP on the zh-plus/tiny-imagenet dataset.')
|
| 98 |
parser.add_argument('--layer_count', type=int, default=2, help='Number of hidden layers (default: 2)')
|
| 99 |
parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
|
| 100 |
+
parser.add_argument('--save_model_dir', type=str, default='saved_models', help='Directory to save model checkpoints (default: saved_models)')
|
| 101 |
args = parser.parse_args()
|
| 102 |
|
| 103 |
# Load the zh-plus/tiny-imagenet dataset
|
|
|
|
| 120 |
|
| 121 |
model = MLP(input_size, hidden_sizes, output_size)
|
| 122 |
|
| 123 |
+
# Create the directory to save models
|
| 124 |
+
os.makedirs(args.save_model_dir, exist_ok=True)
|
| 125 |
+
|
| 126 |
# Train the model and get the final loss
|
| 127 |
save_loss_path = 'losses.txt'
|
| 128 |
+
final_loss = train_model(model, train_dataset, val_dataset, save_loss_path=save_loss_path, save_model_dir=args.save_model_dir)
|
| 129 |
|
| 130 |
# Calculate the number of parameters
|
| 131 |
param_count = sum(p.numel() for p in model.parameters())
|
|
|
|
| 134 |
model_folder = f'mlp_model_l{args.layer_count}w{args.width}'
|
| 135 |
os.makedirs(model_folder, exist_ok=True)
|
| 136 |
|
| 137 |
+
# Save the final model
|
| 138 |
model_path = os.path.join(model_folder, 'model.pth')
|
| 139 |
torch.save(model.state_dict(), model_path)
|
| 140 |
|