Update train.py
Browse files
train.py
CHANGED
|
@@ -5,7 +5,7 @@ from torch import nn
|
|
| 5 |
from torch.utils.data import DataLoader
|
| 6 |
from torchvision import datasets
|
| 7 |
from torchvision.transforms import ToTensor, Normalize, RandomCrop, RandomHorizontalFlip, Compose
|
| 8 |
-
|
| 9 |
|
| 10 |
|
| 11 |
|
|
@@ -159,7 +159,7 @@ def test(dataloader, model, loss_fn):
|
|
| 159 |
|
| 160 |
|
| 161 |
|
| 162 |
-
logname = "/
|
| 163 |
if not os.path.exists(logname):
|
| 164 |
with open(logname, 'w') as logfile:
|
| 165 |
logwriter = csv.writer(logfile, delimiter=',')
|
|
@@ -180,7 +180,7 @@ print("Done!")
|
|
| 180 |
|
| 181 |
|
| 182 |
|
| 183 |
-
path = "/
|
| 184 |
model_name = "LiteTensorMapperImageClassification_cifar10"
|
| 185 |
torch.save(model.state_dict(), f"{path}/{model_name}.pth")
|
| 186 |
print(f"Saved Model State to {path}/{model_name}.pth ")
|
|
|
|
| 5 |
from torch.utils.data import DataLoader
|
| 6 |
from torchvision import datasets
|
| 7 |
from torchvision.transforms import ToTensor, Normalize, RandomCrop, RandomHorizontalFlip, Compose
|
| 8 |
+
from litetesnormapper import LiteTensorMapper
|
| 9 |
|
| 10 |
|
| 11 |
|
|
|
|
| 159 |
|
| 160 |
|
| 161 |
|
| 162 |
+
logname = "/PATH/Experiments_cifar10/logs_litetensormapper/logs_cifar10.csv"
|
| 163 |
if not os.path.exists(logname):
|
| 164 |
with open(logname, 'w') as logfile:
|
| 165 |
logwriter = csv.writer(logfile, delimiter=',')
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
|
| 183 |
+
path = "/PATH/Experiments_cifar10/weights_litetensormapper"
|
| 184 |
model_name = "LiteTensorMapperImageClassification_cifar10"
|
| 185 |
torch.save(model.state_dict(), f"{path}/{model_name}.pth")
|
| 186 |
print(f"Saved Model State to {path}/{model_name}.pth ")
|