Abdullah-Nazhat commited on
Commit
02856ff
·
verified ·
1 Parent(s): 613daed

Delete train_mlp_nin.py

Browse files
Files changed (1) hide show
  1. train_mlp_nin.py +0 -194
train_mlp_nin.py DELETED
@@ -1,194 +0,0 @@
1
- #imports
2
-
3
- import os
4
- import csv
5
- import torch
6
- from torch import nn
7
- from torch.utils.data import DataLoader
8
- from torchvision import datasets
9
- from torchvision.transforms import ToTensor, Normalize, RandomCrop, RandomHorizontalFlip, Compose
10
- from contextualizer_mlp_nin import ContextualizerNiN
11
-
12
- # data transforms
13
-
14
- transform = Compose([
15
- RandomCrop(32, padding=4),
16
- RandomHorizontalFlip(),
17
- ToTensor(),
18
- Normalize((0.5, 0.5,0.5),(0.5, 0.5,0.5))
19
-
20
- ])
21
-
22
- training_data = datasets.CIFAR10(
23
- root='data',
24
- train=True,
25
- download=True,
26
- transform=transform
27
- )
28
-
29
- test_data = datasets.CIFAR10(
30
- root='data',
31
- train=False,
32
- download=True,
33
- transform=transform
34
- )
35
- # create dataloaders
36
-
37
- batch_size = 128
38
-
39
- train_dataloader = DataLoader(training_data, batch_size=batch_size,shuffle=True)
40
- test_dataloader = DataLoader(test_data, batch_size=batch_size)
41
-
42
-
43
- for X, y in test_dataloader:
44
- print(f"Shape of X [N,C,H,W]:{X.shape}")
45
- print(f"Shape of y:{y.shape}{y.dtype}")
46
- break
47
-
48
- # size checking for loading images
49
- def check_sizes(image_size, patch_size):
50
- sqrt_num_patches, remainder = divmod(image_size, patch_size)
51
- assert remainder == 0, "`image_size` must be divisibe by `patch_size`"
52
- num_patches = sqrt_num_patches ** 2
53
- return num_patches
54
-
55
-
56
-
57
- # create model
58
- # Get cpu or gpu device for training.
59
- device = "cuda" if torch.cuda.is_available() else "cpu"
60
-
61
- print(f"using {device} device")
62
-
63
- # model definition
64
-
65
- class ContextualizerNiNImageClassification(ContextualizerNiN):
66
- def __init__(
67
- self,
68
- image_size=32,
69
- patch_size=4,
70
- in_channels=3,
71
- num_classes=10,
72
- d_ffn=512,
73
- d_model = 256,
74
- num_tokens = 64,
75
- num_layers=4,
76
- dropout=0.5
77
- ):
78
- num_patches = check_sizes(image_size, patch_size)
79
- super().__init__(d_model,d_ffn,num_layers,dropout, num_tokens)
80
- self.patcher = nn.Conv2d(
81
- in_channels, d_model, kernel_size=patch_size, stride=patch_size
82
- )
83
- self.classifier = nn.Linear(d_model, num_classes)
84
-
85
- def forward(self, x):
86
-
87
- patches = self.patcher(x)
88
- batch_size, num_channels, _, _ = patches.shape
89
- patches = patches.permute(0, 2, 3, 1)
90
- patches = patches.view(batch_size, -1, num_channels)
91
- embedding = self.model(patches)
92
- embedding = embedding.mean(dim=1) # global average pooling
93
- out = self.classifier(embedding)
94
- return out
95
-
96
- model = ContextualizerNiNImageClassification().to(device)
97
- print(model)
98
-
99
- # Optimizer
100
-
101
- loss_fn = nn.CrossEntropyLoss()
102
- optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
103
-
104
-
105
- # Training Loop
106
-
107
- def train(dataloader, model, loss_fn, optimizer):
108
- size = len(dataloader.dataset)
109
- num_batches = len(dataloader)
110
- model.train()
111
- train_loss = 0
112
- correct = 0
113
- for batch, (X,y) in enumerate(dataloader):
114
- X, y = X.to(device), y.to(device)
115
-
116
- #compute prediction error
117
- pred = model(X)
118
- loss = loss_fn(pred,y)
119
-
120
- # backpropagation
121
- optimizer.zero_grad()
122
- loss.backward()
123
- optimizer.step()
124
- train_loss += loss.item()
125
- _, labels = torch.max(pred.data, 1)
126
- correct += labels.eq(y.data).type(torch.float).sum()
127
-
128
-
129
-
130
-
131
- if batch % 100 == 0:
132
- loss, current = loss.item(), batch * len(X)
133
- print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
134
-
135
- train_loss /= num_batches
136
- train_accuracy = 100. * correct.item() / size
137
- print(train_accuracy)
138
- return train_loss,train_accuracy
139
-
140
-
141
-
142
- # Test loop
143
-
144
- def test(dataloader, model, loss_fn):
145
- size = len(dataloader.dataset)
146
- num_batches = len(dataloader)
147
- model.eval()
148
- test_loss = 0
149
- correct = 0
150
- with torch.no_grad():
151
- for X,y in dataloader:
152
- X,y = X.to(device), y.to(device)
153
- pred = model(X)
154
- test_loss += loss_fn(pred, y).item()
155
- correct += (pred.argmax(1) == y).type(torch.float).sum().item()
156
- test_loss /= num_batches
157
- correct /= size
158
- print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
159
- test_accuracy = 100*correct
160
- return test_loss, test_accuracy
161
-
162
-
163
-
164
- # apply train and test
165
-
166
- logname = "/PATH/Contextualizer_mlp_NiN/Experiments_cifar10/logs_contextualizer/logs_cifar10.csv"
167
- if not os.path.exists(logname):
168
- with open(logname, 'w') as logfile:
169
- logwriter = csv.writer(logfile, delimiter=',')
170
- logwriter.writerow(['epoch', 'train loss', 'train acc',
171
- 'test loss', 'test acc'])
172
-
173
-
174
- epochs = 100
175
- for epoch in range(epochs):
176
- print(f"Epoch {epoch+1}\n-----------------------------------")
177
- train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
178
- # learning rate scheduler
179
- #if scheduler is not None:
180
- # scheduler.step()
181
- test_loss, test_acc = test(test_dataloader, model, loss_fn)
182
- with open(logname, 'a') as logfile:
183
- logwriter = csv.writer(logfile, delimiter=',')
184
- logwriter.writerow([epoch+1, train_loss, train_acc,
185
- test_loss, test_acc])
186
- print("Done!")
187
-
188
- # saving trained model
189
-
190
- path = "/PATH/Contextualizer_mlp_NiN/Experiments_cifar10/weights_contextualizer"
191
- model_name = "ContextualizerMLPNiNImageClassification_cifar10"
192
- torch.save(model.state_dict(), f"{path}/{model_name}.pth")
193
- print(f"Saved Model State to {path}/{model_name}.pth ")
194
-