Update train.py
Browse files
train.py
CHANGED
|
@@ -160,7 +160,7 @@ def test(dataloader, model, loss_fn):
|
|
| 160 |
|
| 161 |
|
| 162 |
|
| 163 |
-
logname = "/
|
| 164 |
if not os.path.exists(logname):
|
| 165 |
with open(logname, 'w') as logfile:
|
| 166 |
logwriter = csv.writer(logfile, delimiter=',')
|
|
@@ -181,7 +181,7 @@ print("Done!")
|
|
| 181 |
|
| 182 |
|
| 183 |
|
| 184 |
-
path = "/
|
| 185 |
model_name = "InteractorImageClassification_cifar10"
|
| 186 |
torch.save(model.state_dict(), f"{path}/{model_name}.pth")
|
| 187 |
print(f"Saved Model State to {path}/{model_name}.pth ")
|
|
|
|
| 160 |
|
| 161 |
|
| 162 |
|
| 163 |
+
logname = "/PATH/Interactor/Experiments_cifar10/logs_interactor/logs_cifar10.csv"
|
| 164 |
if not os.path.exists(logname):
|
| 165 |
with open(logname, 'w') as logfile:
|
| 166 |
logwriter = csv.writer(logfile, delimiter=',')
|
|
|
|
| 181 |
|
| 182 |
|
| 183 |
|
| 184 |
+
path = "/PATH/Interactor/Experiments_cifar10/weights_interactor"
|
| 185 |
model_name = "InteractorImageClassification_cifar10"
|
| 186 |
torch.save(model.state_dict(), f"{path}/{model_name}.pth")
|
| 187 |
print(f"Saved Model State to {path}/{model_name}.pth ")
|