Update train_mlp.py
Browse files- train_mlp.py +8 -8
train_mlp.py
CHANGED
|
@@ -38,9 +38,9 @@ def train_model(model, train_dataset, val_dataset, epochs=10, lr=0.001, save_los
|
|
| 38 |
running_loss = 0.0
|
| 39 |
for example in train_dataset:
|
| 40 |
img = example['image']
|
| 41 |
-
img = np.array(img) # Convert PIL image to NumPy array
|
| 42 |
-
img = img.
|
| 43 |
-
img = torch.from_numpy(img).float().
|
| 44 |
label = torch.tensor([example['label']]).to(device)
|
| 45 |
|
| 46 |
optimizer.zero_grad()
|
|
@@ -63,9 +63,9 @@ def train_model(model, train_dataset, val_dataset, epochs=10, lr=0.001, save_los
|
|
| 63 |
with torch.no_grad():
|
| 64 |
for example in val_dataset:
|
| 65 |
img = example['image']
|
| 66 |
-
img = np.array(img) # Convert PIL image to NumPy array
|
| 67 |
-
img = img.
|
| 68 |
-
img = torch.from_numpy(img).float().
|
| 69 |
label = torch.tensor([example['label']]).to(device)
|
| 70 |
|
| 71 |
outputs = model(img)
|
|
@@ -99,7 +99,7 @@ def main():
|
|
| 99 |
|
| 100 |
# Split the dataset into train and validation sets
|
| 101 |
train_dataset = dataset['train']
|
| 102 |
-
val_dataset = dataset['
|
| 103 |
|
| 104 |
# Determine the number of classes
|
| 105 |
num_classes = len(set(train_dataset['label']))
|
|
@@ -108,7 +108,7 @@ def main():
|
|
| 108 |
image_size = 64 # Assuming the images are square
|
| 109 |
|
| 110 |
# Define the model
|
| 111 |
-
input_size = image_size * image_size
|
| 112 |
hidden_sizes = [args.width] * args.layer_count
|
| 113 |
output_size = num_classes
|
| 114 |
|
|
|
|
| 38 |
running_loss = 0.0
|
| 39 |
for example in train_dataset:
|
| 40 |
img = example['image']
|
| 41 |
+
img = np.array(img.convert('L')) # Convert PIL image to grayscale NumPy array
|
| 42 |
+
img = img.reshape(1, -1) # Flatten the image
|
| 43 |
+
img = torch.from_numpy(img).float().to(device) # Convert to tensor
|
| 44 |
label = torch.tensor([example['label']]).to(device)
|
| 45 |
|
| 46 |
optimizer.zero_grad()
|
|
|
|
| 63 |
with torch.no_grad():
|
| 64 |
for example in val_dataset:
|
| 65 |
img = example['image']
|
| 66 |
+
img = np.array(img.convert('L')) # Convert PIL image to grayscale NumPy array
|
| 67 |
+
img = img.reshape(1, -1) # Flatten the image
|
| 68 |
+
img = torch.from_numpy(img).float().to(device) # Convert to tensor
|
| 69 |
label = torch.tensor([example['label']]).to(device)
|
| 70 |
|
| 71 |
outputs = model(img)
|
|
|
|
| 99 |
|
| 100 |
# Split the dataset into train and validation sets
|
| 101 |
train_dataset = dataset['train']
|
| 102 |
+
val_dataset = dataset['validation'] # Assuming 'validation' is the correct key
|
| 103 |
|
| 104 |
# Determine the number of classes
|
| 105 |
num_classes = len(set(train_dataset['label']))
|
|
|
|
| 108 |
image_size = 64 # Assuming the images are square
|
| 109 |
|
| 110 |
# Define the model
|
| 111 |
+
input_size = image_size * image_size # Since images are grayscale
|
| 112 |
hidden_sizes = [args.width] * args.layer_count
|
| 113 |
output_size = num_classes
|
| 114 |
|