Abdullah-Nazhat commited on
Commit
2e7b6d1
·
verified ·
1 Parent(s): a07b801

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +3 -5
train.py CHANGED
@@ -163,7 +163,7 @@ def test(dataloader, model, loss_fn):
163
 
164
  # apply train and test
165
 
166
- logname = "/home/abdullah/Desktop/Proposals_experiments/COBRA/Experiments_cifar10/logs_cobra/logs_cifar10.csv"
167
  if not os.path.exists(logname):
168
  with open(logname, 'w') as logfile:
169
  logwriter = csv.writer(logfile, delimiter=',')
@@ -175,9 +175,7 @@ epochs = 100
175
  for epoch in range(epochs):
176
  print(f"Epoch {epoch+1}\n-----------------------------------")
177
  train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
178
- # learning rate scheduler
179
- #if scheduler is not None:
180
- # scheduler.step()
181
  test_loss, test_acc = test(test_dataloader, model, loss_fn)
182
  with open(logname, 'a') as logfile:
183
  logwriter = csv.writer(logfile, delimiter=',')
@@ -187,7 +185,7 @@ print("Done!")
187
 
188
  # saving trained model
189
 
190
- path = "/home/abdullah/Desktop/Proposals_experiments/COBRA/Experiments_cifar10/weights_cobra"
191
  model_name = "COBRAImageClassification_cifar10"
192
  torch.save(model.state_dict(), f"{path}/{model_name}.pth")
193
  print(f"Saved Model State to {path}/{model_name}.pth ")
 
163
 
164
  # apply train and test
165
 
166
+ logname = "/PATH/COBRA/Experiments_cifar10/logs_cobra/logs_cifar10.csv"
167
  if not os.path.exists(logname):
168
  with open(logname, 'w') as logfile:
169
  logwriter = csv.writer(logfile, delimiter=',')
 
175
  for epoch in range(epochs):
176
  print(f"Epoch {epoch+1}\n-----------------------------------")
177
  train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
178
+
 
 
179
  test_loss, test_acc = test(test_dataloader, model, loss_fn)
180
  with open(logname, 'a') as logfile:
181
  logwriter = csv.writer(logfile, delimiter=',')
 
185
 
186
  # saving trained model
187
 
188
+ path = "/PATH/COBRA/Experiments_cifar10/weights_cobra"
189
  model_name = "COBRAImageClassification_cifar10"
190
  torch.save(model.state_dict(), f"{path}/{model_name}.pth")
191
  print(f"Saved Model State to {path}/{model_name}.pth ")