Update train.py
Browse files
train.py
CHANGED
|
@@ -7,7 +7,7 @@ from torch import nn
|
|
| 7 |
from torch.utils.data import DataLoader
|
| 8 |
from torchvision import datasets
|
| 9 |
from torchvision.transforms import ToTensor, Normalize, RandomCrop, RandomHorizontalFlip, Compose
|
| 10 |
-
from
|
| 11 |
# data transforms
|
| 12 |
|
| 13 |
transform = Compose([
|
|
@@ -63,7 +63,7 @@ print(f"using {device} device")
|
|
| 63 |
|
| 64 |
# model definition
|
| 65 |
|
| 66 |
-
class
|
| 67 |
def __init__(
|
| 68 |
self,
|
| 69 |
image_size=32,
|
|
@@ -93,7 +93,7 @@ class COBRAImageClassification(COBRA):
|
|
| 93 |
out = self.classifier(embedding)
|
| 94 |
return out
|
| 95 |
|
| 96 |
-
model =
|
| 97 |
print(model)
|
| 98 |
|
| 99 |
# Optimizer
|
|
@@ -163,7 +163,7 @@ def test(dataloader, model, loss_fn):
|
|
| 163 |
|
| 164 |
# apply train and test
|
| 165 |
|
| 166 |
-
logname = "/PATH/
|
| 167 |
if not os.path.exists(logname):
|
| 168 |
with open(logname, 'w') as logfile:
|
| 169 |
logwriter = csv.writer(logfile, delimiter=',')
|
|
@@ -185,8 +185,8 @@ print("Done!")
|
|
| 185 |
|
| 186 |
# saving trained model
|
| 187 |
|
| 188 |
-
path = "/PATH/
|
| 189 |
-
model_name = "
|
| 190 |
torch.save(model.state_dict(), f"{path}/{model_name}.pth")
|
| 191 |
print(f"Saved Model State to {path}/{model_name}.pth ")
|
| 192 |
|
|
|
|
| 7 |
from torch.utils.data import DataLoader
|
| 8 |
from torchvision import datasets
|
| 9 |
from torchvision.transforms import ToTensor, Normalize, RandomCrop, RandomHorizontalFlip, Compose
|
| 10 |
+
from mcdpmamba import MCDPMAMBA
|
| 11 |
# data transforms
|
| 12 |
|
| 13 |
transform = Compose([
|
|
|
|
| 63 |
|
| 64 |
# model definition
|
| 65 |
|
| 66 |
+
class MCDPMAMBAImageClassification(MCDPMAMBA):
|
| 67 |
def __init__(
|
| 68 |
self,
|
| 69 |
image_size=32,
|
|
|
|
| 93 |
out = self.classifier(embedding)
|
| 94 |
return out
|
| 95 |
|
| 96 |
+
model = MCDPMAMBAImageClassification().to(device)
|
| 97 |
print(model)
|
| 98 |
|
| 99 |
# Optimizer
|
|
|
|
| 163 |
|
| 164 |
# apply train and test
|
| 165 |
|
| 166 |
+
logname = "/PATH/MCDP_MAMBA/Experiments_cifar10/logs_mcdpmamba/logs_cifar10.csv"
|
| 167 |
if not os.path.exists(logname):
|
| 168 |
with open(logname, 'w') as logfile:
|
| 169 |
logwriter = csv.writer(logfile, delimiter=',')
|
|
|
|
| 185 |
|
| 186 |
# saving trained model
|
| 187 |
|
| 188 |
+
path = "/PATH/MCDP_MAMBA/Experiments_cifar10/weights_mcdpmamba"
|
| 189 |
+
model_name = "MCDPMAMBAImageClassification_cifar10"
|
| 190 |
torch.save(model.state_dict(), f"{path}/{model_name}.pth")
|
| 191 |
print(f"Saved Model State to {path}/{model_name}.pth ")
|
| 192 |
|