Abdullah-Nazhat commited on
Commit
007d4b5
·
verified ·
1 Parent(s): c9ee00c

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +6 -6
train.py CHANGED
@@ -7,7 +7,7 @@ from torch import nn
7
  from torch.utils.data import DataLoader
8
  from torchvision import datasets
9
  from torchvision.transforms import ToTensor, Normalize, RandomCrop, RandomHorizontalFlip, Compose
10
- from cobra import COBRA
11
  # data transforms
12
 
13
  transform = Compose([
@@ -63,7 +63,7 @@ print(f"using {device} device")
63
 
64
  # model definition
65
 
66
- class COBRAImageClassification(COBRA):
67
  def __init__(
68
  self,
69
  image_size=32,
@@ -93,7 +93,7 @@ class COBRAImageClassification(COBRA):
93
  out = self.classifier(embedding)
94
  return out
95
 
96
- model = COBRAImageClassification().to(device)
97
  print(model)
98
 
99
  # Optimizer
@@ -163,7 +163,7 @@ def test(dataloader, model, loss_fn):
163
 
164
  # apply train and test
165
 
166
- logname = "/PATH/COBRA/Experiments_cifar10/logs_cobra/logs_cifar10.csv"
167
  if not os.path.exists(logname):
168
  with open(logname, 'w') as logfile:
169
  logwriter = csv.writer(logfile, delimiter=',')
@@ -185,8 +185,8 @@ print("Done!")
185
 
186
  # saving trained model
187
 
188
- path = "/PATH/COBRA/Experiments_cifar10/weights_cobra"
189
- model_name = "COBRAImageClassification_cifar10"
190
  torch.save(model.state_dict(), f"{path}/{model_name}.pth")
191
  print(f"Saved Model State to {path}/{model_name}.pth ")
192
 
 
7
  from torch.utils.data import DataLoader
8
  from torchvision import datasets
9
  from torchvision.transforms import ToTensor, Normalize, RandomCrop, RandomHorizontalFlip, Compose
10
+ from mcdpmamba import MCDPMAMBA
11
  # data transforms
12
 
13
  transform = Compose([
 
63
 
64
  # model definition
65
 
66
+ class MCDPMAMBAImageClassification(MCDPMAMBA):
67
  def __init__(
68
  self,
69
  image_size=32,
 
93
  out = self.classifier(embedding)
94
  return out
95
 
96
+ model = MCDPMAMBAImageClassification().to(device)
97
  print(model)
98
 
99
  # Optimizer
 
163
 
164
  # apply train and test
165
 
166
+ logname = "/PATH/MCDP_MAMBA/Experiments_cifar10/logs_mcdpmamba/logs_cifar10.csv"
167
  if not os.path.exists(logname):
168
  with open(logname, 'w') as logfile:
169
  logwriter = csv.writer(logfile, delimiter=',')
 
185
 
186
  # saving trained model
187
 
188
+ path = "/PATH/MCDP_MAMBA/Experiments_cifar10/weights_mcdpmamba"
189
+ model_name = "MCDPMAMBAImageClassification_cifar10"
190
  torch.save(model.state_dict(), f"{path}/{model_name}.pth")
191
  print(f"Saved Model State to {path}/{model_name}.pth ")
192