TeacherPuffy commited on
Commit
116af7a
·
verified ·
1 Parent(s): 6258b13

Update train_mlp.py

Browse files
Files changed (1) hide show
  1. train_mlp.py +7 -4
train_mlp.py CHANGED
@@ -24,6 +24,9 @@ class MLP(nn.Module):
24
 
25
  # Train the model
26
  def train_model(model, train_dataset, val_dataset, epochs=10, lr=0.001, save_loss_path=None):
 
 
 
27
  criterion = nn.CrossEntropyLoss()
28
  optimizer = optim.Adam(model.parameters(), lr=lr)
29
 
@@ -35,8 +38,8 @@ def train_model(model, train_dataset, val_dataset, epochs=10, lr=0.001, save_los
35
  running_loss = 0.0
36
  for example in train_dataset:
37
  img = np.array(example['image'])
38
- img = torch.from_numpy(img).float().view(1, -1)
39
- label = torch.tensor([example['label']])
40
 
41
  optimizer.zero_grad()
42
  outputs = model(img)
@@ -58,8 +61,8 @@ def train_model(model, train_dataset, val_dataset, epochs=10, lr=0.001, save_los
58
  with torch.no_grad():
59
  for example in val_dataset:
60
  img = np.array(example['image'])
61
- img = torch.from_numpy(img).float().view(1, -1)
62
- label = torch.tensor([example['label']])
63
 
64
  outputs = model(img)
65
  loss = criterion(outputs, label)
 
24
 
25
  # Train the model
26
  def train_model(model, train_dataset, val_dataset, epochs=10, lr=0.001, save_loss_path=None):
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ model.to(device)
29
+
30
  criterion = nn.CrossEntropyLoss()
31
  optimizer = optim.Adam(model.parameters(), lr=lr)
32
 
 
38
  running_loss = 0.0
39
  for example in train_dataset:
40
  img = np.array(example['image'])
41
+ img = torch.from_numpy(img).float().view(1, -1).to(device)
42
+ label = torch.tensor([example['label']]).to(device)
43
 
44
  optimizer.zero_grad()
45
  outputs = model(img)
 
61
  with torch.no_grad():
62
  for example in val_dataset:
63
  img = np.array(example['image'])
64
+ img = torch.from_numpy(img).float().view(1, -1).to(device)
65
+ label = torch.tensor([example['label']]).to(device)
66
 
67
  outputs = model(img)
68
  loss = criterion(outputs, label)