Update train.py
Browse files
train.py
CHANGED
|
@@ -67,7 +67,6 @@ class TensorMapperImageClassification(TensorMapper):
|
|
| 67 |
in_channels=3,
|
| 68 |
num_classes=10,
|
| 69 |
d_model = 256,
|
| 70 |
-
|
| 71 |
num_layers=4,
|
| 72 |
|
| 73 |
|
|
@@ -160,7 +159,7 @@ def test(dataloader, model, loss_fn):
|
|
| 160 |
|
| 161 |
|
| 162 |
|
| 163 |
-
logname = "/
|
| 164 |
if not os.path.exists(logname):
|
| 165 |
with open(logname, 'w') as logfile:
|
| 166 |
logwriter = csv.writer(logfile, delimiter=',')
|
|
@@ -181,7 +180,7 @@ print("Done!")
|
|
| 181 |
|
| 182 |
|
| 183 |
|
| 184 |
-
path = "/
|
| 185 |
model_name = "TensorMapperImageClassification_cifar10"
|
| 186 |
torch.save(model.state_dict(), f"{path}/{model_name}.pth")
|
| 187 |
print(f"Saved Model State to {path}/{model_name}.pth ")
|
|
|
|
| 67 |
in_channels=3,
|
| 68 |
num_classes=10,
|
| 69 |
d_model = 256,
|
|
|
|
| 70 |
num_layers=4,
|
| 71 |
|
| 72 |
|
|
|
|
| 159 |
|
| 160 |
|
| 161 |
|
| 162 |
+
logname = "/PATH/Tensor_Mapper/Experiments_cifar10/logs_tensormapper/logs_cifar10.csv"
|
| 163 |
if not os.path.exists(logname):
|
| 164 |
with open(logname, 'w') as logfile:
|
| 165 |
logwriter = csv.writer(logfile, delimiter=',')
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
|
| 183 |
+
path = "/PATH/Tensor_Mapper/Experiments_cifar10/weights_tensormapper"
|
| 184 |
model_name = "TensorMapperImageClassification_cifar10"
|
| 185 |
torch.save(model.state_dict(), f"{path}/{model_name}.pth")
|
| 186 |
print(f"Saved Model State to {path}/{model_name}.pth ")
|