Abdullah-Nazhat commited on
Commit
ae3c6b6
·
verified ·
1 Parent(s): 6ab54d2

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +2 -3
train.py CHANGED
@@ -67,7 +67,6 @@ class TensorMapperImageClassification(TensorMapper):
67
  in_channels=3,
68
  num_classes=10,
69
  d_model = 256,
70
-
71
  num_layers=4,
72
 
73
 
@@ -160,7 +159,7 @@ def test(dataloader, model, loss_fn):
160
 
161
 
162
 
163
- logname = "/home/abdullah/Desktop/Tensor_Mapper/Experiments_cifar10/logs_tensormapper/logs_cifar10.csv"
164
  if not os.path.exists(logname):
165
  with open(logname, 'w') as logfile:
166
  logwriter = csv.writer(logfile, delimiter=',')
@@ -181,7 +180,7 @@ print("Done!")
181
 
182
 
183
 
184
- path = "/home/abdullah/Desktop/Tensor_Mapper/Experiments_cifar10/weights_tensormapper"
185
  model_name = "TensorMapperImageClassification_cifar10"
186
  torch.save(model.state_dict(), f"{path}/{model_name}.pth")
187
  print(f"Saved Model State to {path}/{model_name}.pth ")
 
67
  in_channels=3,
68
  num_classes=10,
69
  d_model = 256,
 
70
  num_layers=4,
71
 
72
 
 
159
 
160
 
161
 
162
+ logname = "/PATH/Tensor_Mapper/Experiments_cifar10/logs_tensormapper/logs_cifar10.csv"
163
  if not os.path.exists(logname):
164
  with open(logname, 'w') as logfile:
165
  logwriter = csv.writer(logfile, delimiter=',')
 
180
 
181
 
182
 
183
+ path = "/PATH/Tensor_Mapper/Experiments_cifar10/weights_tensormapper"
184
  model_name = "TensorMapperImageClassification_cifar10"
185
  torch.save(model.state_dict(), f"{path}/{model_name}.pth")
186
  print(f"Saved Model State to {path}/{model_name}.pth ")