Atheer Aljuraib (k23108174) commited on
Commit
f61bf9a
·
unverified ·
1 Parent(s): 0174d4f

Update Training.py

Browse files

fixed code to work with data loaders

Files changed (1) hide show
  1. Training.py +12 -8
Training.py CHANGED
@@ -26,6 +26,8 @@ def train_model(
26
  num_classes : int = 39,
27
 
28
  ):
 
 
29
 
30
  # Move model to device
31
  model.to(device)
@@ -54,18 +56,20 @@ def train_model(
54
 
55
 
56
 
57
-
58
  # training loop
 
 
59
  for epoch in range(n_epochs):
60
  model.train()
61
  train_accuracy_fn.reset()
62
 
63
  # iterate over all the dataloader's mini-batches
64
- for i, (inputs, labels) in enumerate(train_loader):
65
 
66
  # move to GPU memory
67
- inputs = inputs.to(device)
68
- labels = labels.to(device)
69
 
70
  # flatten if not cnn REVISE LATER
71
  if flatten_input:
@@ -113,9 +117,9 @@ def train_model(
113
  # The context 'torch.no_grad()' tells pytorch we are not interested in computing
114
  # gradients here, so forward pass is more efficient
115
  with torch.no_grad():
116
- for i, (inputs, labels) in enumerate(val_loader):
117
- inputs = inputs.to(device)
118
- labels = labels.to(device)
119
 
120
  # flatten if not cnn REVISE LATER
121
  if flatten_input:
@@ -130,7 +134,7 @@ def train_model(
130
  val_accuracies[epoch] = current_accuracy
131
 
132
 
133
- # keep track of best validation accuracy and save best model so far
134
  if current_accuracy > best_accuracy:
135
  best_accuracy = current_accuracy
136
  torch.save(model.state_dict(), save_path)
 
26
  num_classes : int = 39,
27
 
28
  ):
29
+
30
+
31
 
32
  # Move model to device
33
  model.to(device)
 
56
 
57
 
58
 
59
+ #----------------------
60
  # training loop
61
+ #----------------------
62
+
63
  for epoch in range(n_epochs):
64
  model.train()
65
  train_accuracy_fn.reset()
66
 
67
  # iterate over all the dataloader's mini-batches
68
+ for i, batch in enumerate(train_loader):
69
 
70
  # move to GPU memory
71
+ inputs = batch["image"].to(device)
72
+ labels = batch["label"].to(device)
73
 
74
  # flatten if not cnn REVISE LATER
75
  if flatten_input:
 
117
  # The context 'torch.no_grad()' tells pytorch we are not interested in computing
118
  # gradients here, so forward pass is more efficient
119
  with torch.no_grad():
120
+ for i, batch in enumerate(val_loader):
121
+ inputs = batch["image"].to(device)
122
+ labels = batch["label"].to(device)
123
 
124
  # flatten if not cnn REVISE LATER
125
  if flatten_input:
 
134
  val_accuracies[epoch] = current_accuracy
135
 
136
 
137
+ # keep track of best validation accuracy and save best model so far
138
  if current_accuracy > best_accuracy:
139
  best_accuracy = current_accuracy
140
  torch.save(model.state_dict(), save_path)