TeacherPuffy commited on
Commit
e460563
·
verified ·
1 Parent(s): cc52ed8

Update train_mlp.py

Browse files
Files changed (1) hide show
  1. train_mlp.py +8 -8
train_mlp.py CHANGED
@@ -38,9 +38,9 @@ def train_model(model, train_dataset, val_dataset, epochs=10, lr=0.001, save_los
38
  running_loss = 0.0
39
  for example in train_dataset:
40
  img = example['image']
41
- img = np.array(img) # Convert PIL image to NumPy array
42
- img = img.transpose((2, 0, 1)) # Transpose to (channels, height, width)
43
- img = torch.from_numpy(img).float().reshape(1, -1).to(device) # Convert to tensor and reshape
44
  label = torch.tensor([example['label']]).to(device)
45
 
46
  optimizer.zero_grad()
@@ -63,9 +63,9 @@ def train_model(model, train_dataset, val_dataset, epochs=10, lr=0.001, save_los
63
  with torch.no_grad():
64
  for example in val_dataset:
65
  img = example['image']
66
- img = np.array(img) # Convert PIL image to NumPy array
67
- img = img.transpose((2, 0, 1)) # Transpose to (channels, height, width)
68
- img = torch.from_numpy(img).float().reshape(1, -1).to(device) # Convert to tensor and reshape
69
  label = torch.tensor([example['label']]).to(device)
70
 
71
  outputs = model(img)
@@ -99,7 +99,7 @@ def main():
99
 
100
  # Split the dataset into train and validation sets
101
  train_dataset = dataset['train']
102
- val_dataset = dataset['valid']
103
 
104
  # Determine the number of classes
105
  num_classes = len(set(train_dataset['label']))
@@ -108,7 +108,7 @@ def main():
108
  image_size = 64 # Assuming the images are square
109
 
110
  # Define the model
111
- input_size = image_size * image_size * 3
112
  hidden_sizes = [args.width] * args.layer_count
113
  output_size = num_classes
114
 
 
38
  running_loss = 0.0
39
  for example in train_dataset:
40
  img = example['image']
41
+ img = np.array(img.convert('L')) # Convert PIL image to grayscale NumPy array
42
+ img = img.reshape(1, -1) # Flatten the image
43
+ img = torch.from_numpy(img).float().to(device) # Convert to tensor
44
  label = torch.tensor([example['label']]).to(device)
45
 
46
  optimizer.zero_grad()
 
63
  with torch.no_grad():
64
  for example in val_dataset:
65
  img = example['image']
66
+ img = np.array(img.convert('L')) # Convert PIL image to grayscale NumPy array
67
+ img = img.reshape(1, -1) # Flatten the image
68
+ img = torch.from_numpy(img).float().to(device) # Convert to tensor
69
  label = torch.tensor([example['label']]).to(device)
70
 
71
  outputs = model(img)
 
99
 
100
  # Split the dataset into train and validation sets
101
  train_dataset = dataset['train']
102
+ val_dataset = dataset['validation'] # Assuming 'validation' is the correct key
103
 
104
  # Determine the number of classes
105
  num_classes = len(set(train_dataset['label']))
 
108
  image_size = 64 # Assuming the images are square
109
 
110
  # Define the model
111
+ input_size = image_size * image_size # Since images are grayscale
112
  hidden_sizes = [args.width] * args.layer_count
113
  output_size = num_classes
114