Update train_mlp.py
Browse files- train_mlp.py +22 -15
train_mlp.py
CHANGED
|
@@ -3,7 +3,8 @@ import os
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.optim as optim
|
| 6 |
-
from datasets import
|
|
|
|
| 7 |
|
| 8 |
# Define the MLP model
|
| 9 |
class MLP(nn.Module):
|
|
@@ -20,12 +21,6 @@ class MLP(nn.Module):
|
|
| 20 |
def forward(self, x):
|
| 21 |
return self.model(x)
|
| 22 |
|
| 23 |
-
# Custom collate function
|
| 24 |
-
def custom_collate(batch):
|
| 25 |
-
images = torch.stack([item['image'] for item in batch])
|
| 26 |
-
labels = torch.tensor([item['label'] for item in batch])
|
| 27 |
-
return {'image': images, 'label': labels}
|
| 28 |
-
|
| 29 |
# Train the model
|
| 30 |
def train_model(model, train_loader, val_loader, epochs=10, lr=0.001, save_loss_path=None):
|
| 31 |
criterion = nn.CrossEntropyLoss()
|
|
@@ -84,20 +79,23 @@ def train_model(model, train_loader, val_loader, epochs=10, lr=0.001, save_loss_
|
|
| 84 |
|
| 85 |
# Main function
|
| 86 |
def main():
|
| 87 |
-
parser = argparse.ArgumentParser(description='Train an MLP on
|
| 88 |
parser.add_argument('--layer_count', type=int, default=2, help='Number of hidden layers (default: 2)')
|
| 89 |
parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
|
| 90 |
args = parser.parse_args()
|
| 91 |
|
| 92 |
-
# Load the
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
# Determine the number of classes
|
| 97 |
num_classes = len(set(train_dataset['label']))
|
| 98 |
|
| 99 |
# Determine the fixed resolution of the images
|
| 100 |
-
image_size =
|
| 101 |
|
| 102 |
# Define the model
|
| 103 |
input_size = image_size * image_size * 3
|
|
@@ -106,9 +104,18 @@ def main():
|
|
| 106 |
|
| 107 |
model = MLP(input_size, hidden_sizes, output_size)
|
| 108 |
|
| 109 |
-
#
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# Train the model and get the final loss
|
| 114 |
save_loss_path = 'losses.txt'
|
|
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.optim as optim
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from torchvision import transforms
|
| 8 |
|
| 9 |
# Define the MLP model
|
| 10 |
class MLP(nn.Module):
|
|
|
|
| 21 |
def forward(self, x):
|
| 22 |
return self.model(x)
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
# Train the model
|
| 25 |
def train_model(model, train_loader, val_loader, epochs=10, lr=0.001, save_loss_path=None):
|
| 26 |
criterion = nn.CrossEntropyLoss()
|
|
|
|
| 79 |
|
| 80 |
# Main function
|
| 81 |
def main():
|
| 82 |
+
parser = argparse.ArgumentParser(description='Train an MLP on the zh-plus/tiny-imagenet dataset.')
|
| 83 |
parser.add_argument('--layer_count', type=int, default=2, help='Number of hidden layers (default: 2)')
|
| 84 |
parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
|
| 85 |
args = parser.parse_args()
|
| 86 |
|
| 87 |
+
# Load the zh-plus/tiny-imagenet dataset
|
| 88 |
+
dataset = load_dataset('zh-plus/tiny-imagenet')
|
| 89 |
+
|
| 90 |
+
# Split the dataset into train and validation sets
|
| 91 |
+
train_dataset = dataset['train']
|
| 92 |
+
val_dataset = dataset['valid']
|
| 93 |
|
| 94 |
# Determine the number of classes
|
| 95 |
num_classes = len(set(train_dataset['label']))
|
| 96 |
|
| 97 |
# Determine the fixed resolution of the images
|
| 98 |
+
image_size = 64 # Assuming the images are square
|
| 99 |
|
| 100 |
# Define the model
|
| 101 |
input_size = image_size * image_size * 3
|
|
|
|
| 104 |
|
| 105 |
model = MLP(input_size, hidden_sizes, output_size)
|
| 106 |
|
| 107 |
+
# Define the transformation to convert PIL images to tensors
|
| 108 |
+
transform = transforms.Compose([
|
| 109 |
+
transforms.ToTensor(),
|
| 110 |
+
])
|
| 111 |
+
|
| 112 |
+
# Apply the transformation to the datasets
|
| 113 |
+
train_dataset.set_transform(lambda example: {'image': transform(example['image']), 'label': example['label']})
|
| 114 |
+
val_dataset.set_transform(lambda example: {'image': transform(example['image']), 'label': example['label']})
|
| 115 |
+
|
| 116 |
+
# Create data loaders
|
| 117 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
|
| 118 |
+
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)
|
| 119 |
|
| 120 |
# Train the model and get the final loss
|
| 121 |
save_loss_path = 'losses.txt'
|