TeacherPuffy commited on
Commit
cc52ed8
·
verified ·
1 Parent(s): 116af7a

Update train_mlp.py

Browse files
Files changed (1) hide show
  1. train_mlp.py +8 -4
train_mlp.py CHANGED
@@ -37,8 +37,10 @@ def train_model(model, train_dataset, val_dataset, epochs=10, lr=0.001, save_los
37
  model.train()
38
  running_loss = 0.0
39
  for example in train_dataset:
40
- img = np.array(example['image'])
41
- img = torch.from_numpy(img).float().view(1, -1).to(device)
 
 
42
  label = torch.tensor([example['label']]).to(device)
43
 
44
  optimizer.zero_grad()
@@ -60,8 +62,10 @@ def train_model(model, train_dataset, val_dataset, epochs=10, lr=0.001, save_los
60
  total = 0
61
  with torch.no_grad():
62
  for example in val_dataset:
63
- img = np.array(example['image'])
64
- img = torch.from_numpy(img).float().view(1, -1).to(device)
 
 
65
  label = torch.tensor([example['label']]).to(device)
66
 
67
  outputs = model(img)
 
37
  model.train()
38
  running_loss = 0.0
39
  for example in train_dataset:
40
+ img = example['image']
41
+ img = np.array(img) # Convert PIL image to NumPy array
42
+ img = img.transpose((2, 0, 1)) # Transpose to (channels, height, width)
43
+ img = torch.from_numpy(img).float().reshape(1, -1).to(device) # Convert to tensor and reshape
44
  label = torch.tensor([example['label']]).to(device)
45
 
46
  optimizer.zero_grad()
 
62
  total = 0
63
  with torch.no_grad():
64
  for example in val_dataset:
65
+ img = example['image']
66
+ img = np.array(img) # Convert PIL image to NumPy array
67
+ img = img.transpose((2, 0, 1)) # Transpose to (channels, height, width)
68
+ img = torch.from_numpy(img).float().reshape(1, -1).to(device) # Convert to tensor and reshape
69
  label = torch.tensor([example['label']]).to(device)
70
 
71
  outputs = model(img)