TeacherPuffy commited on
Commit
d928b03
·
verified ·
1 Parent(s): e7059c0

Update train_mlp.py

Browse files
Files changed (1) hide show
  1. train_mlp.py +22 -15
train_mlp.py CHANGED
@@ -3,7 +3,8 @@ import os
3
  import torch
4
  import torch.nn as nn
5
  import torch.optim as optim
6
- from datasets import load_from_disk
 
7
 
8
  # Define the MLP model
9
  class MLP(nn.Module):
@@ -20,12 +21,6 @@ class MLP(nn.Module):
20
  def forward(self, x):
21
  return self.model(x)
22
 
23
- # Custom collate function
24
- def custom_collate(batch):
25
- images = torch.stack([item['image'] for item in batch])
26
- labels = torch.tensor([item['label'] for item in batch])
27
- return {'image': images, 'label': labels}
28
-
29
  # Train the model
30
  def train_model(model, train_loader, val_loader, epochs=10, lr=0.001, save_loss_path=None):
31
  criterion = nn.CrossEntropyLoss()
@@ -84,20 +79,23 @@ def train_model(model, train_loader, val_loader, epochs=10, lr=0.001, save_loss_
84
 
85
  # Main function
86
  def main():
87
- parser = argparse.ArgumentParser(description='Train an MLP on a Hugging Face dataset with JPEG images and class labels.')
88
  parser.add_argument('--layer_count', type=int, default=2, help='Number of hidden layers (default: 2)')
89
  parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
90
  args = parser.parse_args()
91
 
92
- # Load the preprocessed datasets
93
- train_dataset = load_from_disk('preprocessed_train_dataset')
94
- val_dataset = load_from_disk('preprocessed_val_dataset')
 
 
 
95
 
96
  # Determine the number of classes
97
  num_classes = len(set(train_dataset['label']))
98
 
99
  # Determine the fixed resolution of the images
100
- image_size = train_dataset[0]['image'].size(1) # Assuming the images are square
101
 
102
  # Define the model
103
  input_size = image_size * image_size * 3
@@ -106,9 +104,18 @@ def main():
106
 
107
  model = MLP(input_size, hidden_sizes, output_size)
108
 
109
- # Create data loaders with custom collate function
110
- train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate)
111
- val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate)
 
 
 
 
 
 
 
 
 
112
 
113
  # Train the model and get the final loss
114
  save_loss_path = 'losses.txt'
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.optim as optim
6
+ from datasets import load_dataset
7
+ from torchvision import transforms
8
 
9
  # Define the MLP model
10
  class MLP(nn.Module):
 
21
  def forward(self, x):
22
  return self.model(x)
23
 
 
 
 
 
 
 
24
  # Train the model
25
  def train_model(model, train_loader, val_loader, epochs=10, lr=0.001, save_loss_path=None):
26
  criterion = nn.CrossEntropyLoss()
 
79
 
80
  # Main function
81
  def main():
82
+ parser = argparse.ArgumentParser(description='Train an MLP on the zh-plus/tiny-imagenet dataset.')
83
  parser.add_argument('--layer_count', type=int, default=2, help='Number of hidden layers (default: 2)')
84
  parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
85
  args = parser.parse_args()
86
 
87
+ # Load the zh-plus/tiny-imagenet dataset
88
+ dataset = load_dataset('zh-plus/tiny-imagenet')
89
+
90
+ # Split the dataset into train and validation sets
91
+ train_dataset = dataset['train']
92
+ val_dataset = dataset['valid']
93
 
94
  # Determine the number of classes
95
  num_classes = len(set(train_dataset['label']))
96
 
97
  # Determine the fixed resolution of the images
98
+ image_size = 64 # Assuming the images are square
99
 
100
  # Define the model
101
  input_size = image_size * image_size * 3
 
104
 
105
  model = MLP(input_size, hidden_sizes, output_size)
106
 
107
+ # Define the transformation to convert PIL images to tensors
108
+ transform = transforms.Compose([
109
+ transforms.ToTensor(),
110
+ ])
111
+
112
+ # Apply the transformation to the datasets
113
+ train_dataset.set_transform(lambda example: {'image': transform(example['image']), 'label': example['label']})
114
+ val_dataset.set_transform(lambda example: {'image': transform(example['image']), 'label': example['label']})
115
+
116
+ # Create data loaders
117
+ train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
118
+ val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)
119
 
120
  # Train the model and get the final loss
121
  save_loss_path = 'losses.txt'