Spaces:
Sleeping
Sleeping
Atheer Aljuraib (k23108174)
commited on
Update Training.py
Browse filesfixed code to work with data loaders
- Training.py +12 -8
Training.py
CHANGED
|
@@ -26,6 +26,8 @@ def train_model(
|
|
| 26 |
num_classes : int = 39,
|
| 27 |
|
| 28 |
):
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# Move model to device
|
| 31 |
model.to(device)
|
|
@@ -54,18 +56,20 @@ def train_model(
|
|
| 54 |
|
| 55 |
|
| 56 |
|
| 57 |
-
|
| 58 |
# training loop
|
|
|
|
|
|
|
| 59 |
for epoch in range(n_epochs):
|
| 60 |
model.train()
|
| 61 |
train_accuracy_fn.reset()
|
| 62 |
|
| 63 |
# iterate over all the dataloader's mini-batches
|
| 64 |
-
for i,
|
| 65 |
|
| 66 |
# move to GPU memory
|
| 67 |
-
inputs =
|
| 68 |
-
labels =
|
| 69 |
|
| 70 |
# flatten if not cnn REVISE LATER
|
| 71 |
if flatten_input:
|
|
@@ -113,9 +117,9 @@ def train_model(
|
|
| 113 |
# The context 'torch.no_grad()' tells pytorch we are not interested in computing
|
| 114 |
# gradients here, so forward pass is more efficient
|
| 115 |
with torch.no_grad():
|
| 116 |
-
for i,
|
| 117 |
-
inputs =
|
| 118 |
-
labels =
|
| 119 |
|
| 120 |
# flatten if not cnn REVISE LATER
|
| 121 |
if flatten_input:
|
|
@@ -130,7 +134,7 @@ def train_model(
|
|
| 130 |
val_accuracies[epoch] = current_accuracy
|
| 131 |
|
| 132 |
|
| 133 |
-
|
| 134 |
if current_accuracy > best_accuracy:
|
| 135 |
best_accuracy = current_accuracy
|
| 136 |
torch.save(model.state_dict(), save_path)
|
|
|
|
| 26 |
num_classes : int = 39,
|
| 27 |
|
| 28 |
):
|
| 29 |
+
|
| 30 |
+
|
| 31 |
|
| 32 |
# Move model to device
|
| 33 |
model.to(device)
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
|
| 59 |
+
#----------------------
|
| 60 |
# training loop
|
| 61 |
+
#----------------------
|
| 62 |
+
|
| 63 |
for epoch in range(n_epochs):
|
| 64 |
model.train()
|
| 65 |
train_accuracy_fn.reset()
|
| 66 |
|
| 67 |
# iterate over all the dataloader's mini-batches
|
| 68 |
+
for i, batch in enumerate(train_loader):
|
| 69 |
|
| 70 |
# move to GPU memory
|
| 71 |
+
inputs = batch["image"].to(device)
|
| 72 |
+
labels = batch["label"].to(device)
|
| 73 |
|
| 74 |
# flatten if not cnn REVISE LATER
|
| 75 |
if flatten_input:
|
|
|
|
| 117 |
# The context 'torch.no_grad()' tells pytorch we are not interested in computing
|
| 118 |
# gradients here, so forward pass is more efficient
|
| 119 |
with torch.no_grad():
|
| 120 |
+
for i, batch in enumerate(val_loader):
|
| 121 |
+
inputs = batch["image"].to(device)
|
| 122 |
+
labels = batch["label"].to(device)
|
| 123 |
|
| 124 |
# flatten if not cnn REVISE LATER
|
| 125 |
if flatten_input:
|
|
|
|
| 134 |
val_accuracies[epoch] = current_accuracy
|
| 135 |
|
| 136 |
|
| 137 |
+
# keep track of best validation accuracy and save best model so far
|
| 138 |
if current_accuracy > best_accuracy:
|
| 139 |
best_accuracy = current_accuracy
|
| 140 |
torch.save(model.state_dict(), save_path)
|