Uploading trained models, logs and training code

#2
by rmgupte - opened
oxford-pets/Test model.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
oxford-pets/animals_cnn_epoch_100.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe9d79a7aed3b460ed3af5d5afc7496e06cf24be48b3393859fd6026f3661c90
3
+ size 25434791
oxford-pets/animals_nn_epoch_100.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1faa6c482d997081ff242c62abbdd2df9d90c5efe4ef934fe7124af9b6a3a2b1
3
+ size 79480573
oxford-pets/cnn.out ADDED
The diff for this file is too large to render. See raw diff
 
oxford-pets/cnn_trainer.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ from torch.utils.data import DataLoader
6
+
7
+ from datasets import load_dataset
8
+
9
+ from PIL import Image
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+
13
+
14
+ import math
15
+ from tqdm import tqdm
16
+
17
+ # creates image and mask transforms. Can discuss hyperparameters later, but we have 256 x 256 images, and normalize to [-1, 1]
18
+ IMG_SIZE = 256
19
+
20
+ image_transform = transforms.Compose([
21
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
24
+ ])
25
+
26
+ mask_transform = transforms.Compose([
27
+ transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=transforms.InterpolationMode.NEAREST),
28
+ transforms.PILToTensor(),
29
+ ])
30
+
31
+ # Helpful methods for visualizing the images and masks
32
+
33
+ def visualize_mask(mask: Image, img: Image = None):
34
+ """
35
+ Visualizes the segmentation mask. If an image is provided, it overlays the mask on the image.
36
+
37
+ :param mask: The segmentation mask to visualize. Expects a pillow image
38
+ :param img: The image to overlay the mask on. Expects a pillow image
39
+ :return:
40
+ """
41
+ mask_np = np.array(mask)
42
+
43
+ class_colors = {
44
+ 1: [255, 0, 0], # red for prediction
45
+ 2: [0, 255, 0], # green for background
46
+ 3: [0, 0, 255], # blue for ambiguous
47
+ }
48
+
49
+ h, w = mask_np.shape
50
+ color_mask = np.zeros((h, w, 3), dtype=np.uint8)
51
+ for class_id, color in class_colors.items():
52
+ color_mask[mask_np == class_id] = color
53
+
54
+ if img is not None:
55
+ img = img.convert('RGBA')
56
+ overlay = Image.fromarray(color_mask).convert('RGBA')
57
+ blended = Image.blend(img, overlay, alpha=0.5)
58
+
59
+ plt.figure(figsize=(6, 6))
60
+ plt.imshow(blended)
61
+ plt.title('Segmentation Overlay')
62
+ plt.axis('off')
63
+ plt.show()
64
+ else:
65
+ plt.figure(figsize=(6, 6))
66
+ plt.imshow(color_mask)
67
+ plt.title('Colorized Segmentation Mask')
68
+ plt.axis('off')
69
+ plt.show()
70
+
71
+ def visualize_mask_tensor(img_tensor: torch.Tensor, mask_tensor: torch.Tensor, alpha: float = 0.5, title: str = "Image Mask"):
72
+ """
73
+ Overlays a mask tensor onto an image tensor and displays the result.
74
+
75
+ :param img_tensor: Image tensor of shape (C, H, W)
76
+ :param mask_tensor: Mask tensor of shape (H, W)
77
+ :param mode: 'class' or 'category'
78
+ :param alpha: Transparency for overlay
79
+ """
80
+ img_np = img_tensor.permute(1, 2, 0).numpy()
81
+ mask_np = mask_tensor.numpy()
82
+
83
+ h, w = mask_np.shape
84
+ color_mask = np.zeros((h, w, 3), dtype=np.uint8)
85
+
86
+ color_mask[mask_np == 1] = [255, 0, 0] # red for mask
87
+ color_mask[mask_np == 2] = [0, 0, 255] # blue for bg
88
+
89
+ plt.figure(figsize=(6, 6))
90
+ plt.imshow(img_np)
91
+ plt.imshow(color_mask, alpha=alpha)
92
+ plt.axis('off')
93
+ plt.title(title)
94
+ plt.show()
95
+
96
+
97
+ class_labels = ['cat', 'dog']
98
+ category_labels = ['Abyssinian', 'american bulldog', 'american pit bull terrier', 'basset hound', 'beagle', 'Bengal', 'Birman', 'Bombay',
99
+ 'boxer', 'British Shorthair', 'chihuahua', 'Egyptian Mau', 'english cocker spaniel', 'english setter', 'german shorthaired', 'great pyrenees', 'havanese', 'japanese chin', 'keeshond', 'leonberger', 'Maine Coon', 'miniature pinscher', 'newfoundland',
100
+ 'Persian', 'pomeranian', 'pug', 'Ragdoll', 'Russian Blue', 'saint bernard', 'samoyed', 'scottish terrier', 'shiba inu', 'Siamese',
101
+ 'Sphynx', 'staffordshire bull terrier', 'wheaten terrier', 'yorkshire terrier']
102
+
103
+ # the index of these lists relate to the label that is associate with an image
104
+
105
+ def transform_class_example(example: dict, include_ambiguous: bool = False) -> dict:
106
+ """
107
+ Transforms the image and mask in from a given example
108
+
109
+ :param example: an example from a hf dataset; a dictionary with keys from the dataset
110
+ :param include_ambiguous: whether to include the ambiguous class in the mask.
111
+ If true, the ambiguous class is included in the mask. If false, the ambiguous class is removed from the mask (background).
112
+
113
+ :return: a dictionary with the transformed image and mask
114
+ """
115
+ class_label = example['class']
116
+ mask = mask_transform(example["msk"]).squeeze(0)
117
+
118
+ mask[mask == 0] = 2
119
+ mask[mask == 1] = 0
120
+ mask[mask == 2] = 1
121
+ if include_ambiguous:
122
+ mask[mask == 3] = 1
123
+ else:
124
+ mask[mask == 3] = 0
125
+
126
+ return {
127
+ "image": image_transform(example["img"]),
128
+ "mask": mask,
129
+ "classification": class_label
130
+ }
131
+
132
+ def transform_category_example(example: dict, include_ambiguous: bool = False) -> dict:
133
+ """
134
+ Transforms the image and mask in from a given example
135
+
136
+ :param example: an example from a hf dataset; a dictionary with keys from the dataset
137
+ :param include_ambiguous: whether to include the ambiguous class in the mask.
138
+ If true, the ambiguous class is included in the mask. If false, the ambiguous class is removed from the mask (background).
139
+ :return: a dictionary with the transformed image and mask
140
+ """
141
+ category_label = example['category']
142
+ mask = mask_transform(example["msk"]).squeeze(0)
143
+
144
+ # switch 0 and 1. Now 0 is background and 1 is the mask
145
+ mask[mask == 0] = 2
146
+ mask[mask == 1] = 0
147
+ mask[mask == 2] = 1
148
+ if include_ambiguous:
149
+ mask[mask == 3] = 1
150
+ else:
151
+ mask[mask == 3] = 0
152
+
153
+
154
+ # here we are switching around the values from the original dataset. Instead of having 0 be foreground, it becomes the vaue
155
+ # of the category number, the background becomes 37 and ambiguous becomes 38
156
+ return {
157
+ "image": image_transform(example["img"]),
158
+ "mask": mask,
159
+ "classification": category_label
160
+ }
161
+
162
+ def generate_dataset(split: str, classification: str = 'class'):
163
+ """
164
+ Generates a dataset for the given split and classification type.
165
+
166
+ :param split: which split to generate. Can be train, test or valid
167
+ :param mask: decides whether to use the mask, or the class label
168
+ :param classification: Can either be class or category. Class is either cat or dog, while category is the breed.
169
+ :return: a dataset of the given split. With transformed images and masks.
170
+ """
171
+ if split not in ['test', 'train', 'valid']:
172
+ raise ValueError('Split must be either test train or valid')
173
+ if classification not in ['class', 'category']:
174
+ raise ValueError('Classification must be either class or category. Class is either cat or dog, while category is the breed.')
175
+
176
+ dataset = load_dataset('cvdl/oxford-pets', split=split)
177
+
178
+ # remove the bbox and non-used classification columns
179
+ if classification == 'class':
180
+ dataset = dataset.map(transform_class_example, remove_columns=['bbox', 'img', 'msk', 'category', 'class'])
181
+ elif classification == 'category':
182
+ dataset = dataset.map(transform_category_example, remove_columns=['bbox', 'img', 'msk', 'class', 'category'])
183
+
184
+ dataset.set_format(type='torch', columns=['image', 'mask', 'classification'])
185
+
186
+ # transform the images and masks
187
+ return dataset
188
+
189
+ train_set = generate_dataset('train', classification='class')
190
+ test_set = generate_dataset('test', classification='class')
191
+
192
+ mask = True
193
+ batch_size = 32
194
+
195
+ train_loader = DataLoader(train_set, batch_size=batch_size)
196
+ test_loader = DataLoader(test_set, batch_size=batch_size)
197
+
198
+ """This CNN contains 8 layers as of now. I start with outputting 16 channels from RGB and keep doubling it in every layer.
199
+ Can increase/decrease layers depending on how well it performs."""
200
+
201
+ class ConvNN(nn.Module):
202
+ def __init__(self, num_classes=4):
203
+ super(ConvNN, self).__init__()
204
+ self.relu = nn.ReLU()
205
+ self.conv1 = nn.Conv2d(3, 16, 11, stride=1, padding='same')
206
+ self.conv2 = nn.Conv2d(16, 32, 11, stride=1, padding='same')
207
+ self.conv3 = nn.Conv2d(32, 64, 3, stride=1, padding='same')
208
+ self.conv4 = nn.Conv2d(64, 128, 3, stride=1, padding='same')
209
+ self.conv5 = nn.Conv2d(128, num_classes, kernel_size=1)
210
+ self.conv5 = nn.Conv2d(128, 256, 3, stride=1, padding='same')
211
+ self.conv6 = nn.Conv2d(256, 512, 3, stride=1, padding='same')
212
+ self.conv7 = nn.Conv2d(512, 1024, 3, stride=1, padding='same')
213
+ self.conv8 = nn.Conv2d(1024, num_classes, kernel_size=1)
214
+ self.pool = nn.MaxPool2d(2, 2)
215
+
216
+
217
+ def forward(self, x):
218
+ original_size = x.shape[2:]
219
+
220
+ x = self.pool(self.relu(self.conv1(x)))
221
+ x = self.pool(self.relu(self.conv2(x)))
222
+ x = self.relu(self.conv3(x))
223
+ x = self.relu(self.conv4(x))
224
+ x = self.relu(self.conv5(x))
225
+ x = self.relu(self.conv6(x))
226
+ x = self.relu(self.conv7(x))
227
+ x = self.conv8(x)
228
+ # x = self.conv2(x)
229
+
230
+ x = F.interpolate(x, size=original_size)
231
+
232
+ return x
233
+
234
+
235
+
236
+
237
+ def train_model(model, train_loader, optimizer, criterion, epoch):
238
+ model.train()
239
+ for batch_idx, batch in tqdm(enumerate(train_loader)):
240
+ optimizer.zero_grad()
241
+ outputs = model(batch['image'])
242
+ loss = criterion(outputs, batch['mask'])
243
+
244
+ loss.backward()
245
+ optimizer.step()
246
+
247
+
248
+ def test_model(model, test_loader, epoch):
249
+ model.eval()
250
+ diff = 0
251
+ total = 0
252
+
253
+ with torch.no_grad():
254
+ for idx, value in tqdm(enumerate(test_loader)):
255
+ output = model(value['image'])
256
+ batch_pred = output.max(1, keepdim=True)[1]
257
+ for p, ans in zip(batch_pred, value['mask']):
258
+ pred = p[0]
259
+ diff += (torch.abs(pred-ans) != 0).sum()
260
+ total += torch.numel(ans)
261
+ # correct += pred.eq(target.view_as(pred)).sum().item()
262
+
263
+ test_acc = 1 - diff/total
264
+ print('[Test set] Epoch: {:d}, Accuracy: {:.2f}%\n'.format(
265
+ epoch+1, 100. * test_acc))
266
+
267
+ return test_acc
268
+
269
+ model = ConvNN(num_classes=2)
270
+ criterion = nn.CrossEntropyLoss()
271
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
272
+
273
+ best_acc = 0
274
+ for epoch in range(0, 100):
275
+ # train model for 1 epoch
276
+ print("now training")
277
+ train_model(model, train_loader, optimizer, criterion, epoch)
278
+ # evaluate the model on test_set after this epoch
279
+ print("now testing")
280
+ acc = test_model(model, test_loader, epoch)
281
+ print(f"epoch {epoch+1}, best_acc {max(best_acc, acc)}")
282
+ best_acc = max(best_acc, acc)
283
+
284
+ torch.save(model.state_dict(), f'animals_cnn_epoch_{epoch+1}.pth')
oxford-pets/nn.out ADDED
The diff for this file is too large to render. See raw diff
 
oxford-pets/nn_trainer.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ from torch.utils.data import DataLoader
6
+
7
+ from datasets import load_dataset
8
+
9
+ from PIL import Image
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+
13
+
14
+ import math
15
+ from tqdm import tqdm
16
+
17
+ # creates image and mask transforms. Can discuss hyperparameters later, but we have 256 x 256 images, and normalize to [-1, 1]
18
+ IMG_SIZE = 256
19
+
20
+ image_transform = transforms.Compose([
21
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
24
+ ])
25
+
26
+ mask_transform = transforms.Compose([
27
+ transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=transforms.InterpolationMode.NEAREST),
28
+ transforms.PILToTensor(),
29
+ ])
30
+
31
+ # Helpful methods for visualizing the images and masks
32
+
33
+ def visualize_mask(mask: Image, img: Image = None):
34
+ """
35
+ Visualizes the segmentation mask. If an image is provided, it overlays the mask on the image.
36
+
37
+ :param mask: The segmentation mask to visualize. Expects a pillow image
38
+ :param img: The image to overlay the mask on. Expects a pillow image
39
+ :return:
40
+ """
41
+ mask_np = np.array(mask)
42
+
43
+ class_colors = {
44
+ 1: [255, 0, 0], # red for prediction
45
+ 2: [0, 255, 0], # green for background
46
+ 3: [0, 0, 255], # blue for ambiguous
47
+ }
48
+
49
+ h, w = mask_np.shape
50
+ color_mask = np.zeros((h, w, 3), dtype=np.uint8)
51
+ for class_id, color in class_colors.items():
52
+ color_mask[mask_np == class_id] = color
53
+
54
+ if img is not None:
55
+ img = img.convert('RGBA')
56
+ overlay = Image.fromarray(color_mask).convert('RGBA')
57
+ blended = Image.blend(img, overlay, alpha=0.5)
58
+
59
+ plt.figure(figsize=(6, 6))
60
+ plt.imshow(blended)
61
+ plt.title('Segmentation Overlay')
62
+ plt.axis('off')
63
+ plt.show()
64
+ else:
65
+ plt.figure(figsize=(6, 6))
66
+ plt.imshow(color_mask)
67
+ plt.title('Colorized Segmentation Mask')
68
+ plt.axis('off')
69
+ plt.show()
70
+
71
+ def visualize_mask_tensor(img_tensor: torch.Tensor, mask_tensor: torch.Tensor, alpha: float = 0.5, title: str = "Image Mask"):
72
+ """
73
+ Overlays a mask tensor onto an image tensor and displays the result.
74
+
75
+ :param img_tensor: Image tensor of shape (C, H, W)
76
+ :param mask_tensor: Mask tensor of shape (H, W)
77
+ :param mode: 'class' or 'category'
78
+ :param alpha: Transparency for overlay
79
+ """
80
+ img_np = img_tensor.permute(1, 2, 0).numpy()
81
+ mask_np = mask_tensor.numpy()
82
+
83
+ h, w = mask_np.shape
84
+ color_mask = np.zeros((h, w, 3), dtype=np.uint8)
85
+
86
+ color_mask[mask_np == 1] = [255, 0, 0] # red for mask
87
+ color_mask[mask_np == 2] = [0, 0, 255] # blue for bg
88
+
89
+ plt.figure(figsize=(6, 6))
90
+ plt.imshow(img_np)
91
+ plt.imshow(color_mask, alpha=alpha)
92
+ plt.axis('off')
93
+ plt.title(title)
94
+ plt.show()
95
+
96
+
97
+ class_labels = ['cat', 'dog']
98
+ category_labels = ['Abyssinian', 'american bulldog', 'american pit bull terrier', 'basset hound', 'beagle', 'Bengal', 'Birman', 'Bombay',
99
+ 'boxer', 'British Shorthair', 'chihuahua', 'Egyptian Mau', 'english cocker spaniel', 'english setter', 'german shorthaired', 'great pyrenees', 'havanese', 'japanese chin', 'keeshond', 'leonberger', 'Maine Coon', 'miniature pinscher', 'newfoundland',
100
+ 'Persian', 'pomeranian', 'pug', 'Ragdoll', 'Russian Blue', 'saint bernard', 'samoyed', 'scottish terrier', 'shiba inu', 'Siamese',
101
+ 'Sphynx', 'staffordshire bull terrier', 'wheaten terrier', 'yorkshire terrier']
102
+
103
+ # the index of these lists relate to the label that is associate with an image
104
+
105
+ def transform_class_example(example: dict, include_ambiguous: bool = False) -> dict:
106
+ """
107
+ Transforms the image and mask in from a given example
108
+
109
+ :param example: an example from a hf dataset; a dictionary with keys from the dataset
110
+ :param include_ambiguous: whether to include the ambiguous class in the mask.
111
+ If true, the ambiguous class is included in the mask. If false, the ambiguous class is removed from the mask (background).
112
+
113
+ :return: a dictionary with the transformed image and mask
114
+ """
115
+ class_label = example['class']
116
+ mask = mask_transform(example["msk"]).squeeze(0)
117
+
118
+ mask[mask == 0] = 2
119
+ mask[mask == 1] = 0
120
+ mask[mask == 2] = 1
121
+ if include_ambiguous:
122
+ mask[mask == 3] = 1
123
+ else:
124
+ mask[mask == 3] = 0
125
+
126
+ return {
127
+ "image": image_transform(example["img"]),
128
+ "mask": mask,
129
+ "classification": class_label
130
+ }
131
+
132
+ def transform_category_example(example: dict, include_ambiguous: bool = False) -> dict:
133
+ """
134
+ Transforms the image and mask in from a given example
135
+
136
+ :param example: an example from a hf dataset; a dictionary with keys from the dataset
137
+ :param include_ambiguous: whether to include the ambiguous class in the mask.
138
+ If true, the ambiguous class is included in the mask. If false, the ambiguous class is removed from the mask (background).
139
+ :return: a dictionary with the transformed image and mask
140
+ """
141
+ category_label = example['category']
142
+ mask = mask_transform(example["msk"]).squeeze(0)
143
+
144
+ # switch 0 and 1. Now 0 is background and 1 is the mask
145
+ mask[mask == 0] = 2
146
+ mask[mask == 1] = 0
147
+ mask[mask == 2] = 1
148
+ if include_ambiguous:
149
+ mask[mask == 3] = 1
150
+ else:
151
+ mask[mask == 3] = 0
152
+
153
+
154
+ # here we are switching around the values from the original dataset. Instead of having 0 be foreground, it becomes the vaue
155
+ # of the category number, the background becomes 37 and ambiguous becomes 38
156
+ return {
157
+ "image": image_transform(example["img"]),
158
+ "mask": mask,
159
+ "classification": category_label
160
+ }
161
+
162
+ def generate_dataset(split: str, classification: str = 'class'):
163
+ """
164
+ Generates a dataset for the given split and classification type.
165
+
166
+ :param split: which split to generate. Can be train, test or valid
167
+ :param mask: decides whether to use the mask, or the class label
168
+ :param classification: Can either be class or category. Class is either cat or dog, while category is the breed.
169
+ :return: a dataset of the given split. With transformed images and masks.
170
+ """
171
+ if split not in ['test', 'train', 'valid']:
172
+ raise ValueError('Split must be either test train or valid')
173
+ if classification not in ['class', 'category']:
174
+ raise ValueError('Classification must be either class or category. Class is either cat or dog, while category is the breed.')
175
+
176
+ dataset = load_dataset('cvdl/oxford-pets', split=split)
177
+
178
+ # remove the bbox and non-used classification columns
179
+ if classification == 'class':
180
+ dataset = dataset.map(transform_class_example, remove_columns=['bbox', 'img', 'msk', 'category', 'class'])
181
+ elif classification == 'category':
182
+ dataset = dataset.map(transform_category_example, remove_columns=['bbox', 'img', 'msk', 'class', 'category'])
183
+
184
+ dataset.set_format(type='torch', columns=['image', 'mask', 'classification'])
185
+
186
+ # transform the images and masks
187
+ return dataset
188
+
189
+ train_set = generate_dataset('train', classification='class')
190
+ test_set = generate_dataset('test', classification='class')
191
+
192
+ mask = True
193
+ batch_size = 32
194
+
195
+ train_loader = DataLoader(train_set, batch_size=batch_size)
196
+ test_loader = DataLoader(test_set, batch_size=batch_size)
197
+
198
+ class LeNet(nn.Module):
199
+ def __init__(self, num_classes=2):
200
+ super(LeNet, self).__init__()
201
+
202
+ self.conv1 = nn.Conv2d(3, 6, 5, 1)
203
+ self.conv2 = nn.Conv2d(6, 12, 5, 1)
204
+ self.conv3 = nn.Conv2d(12, 24, 5, 1)
205
+
206
+ self.lin1 = nn.Linear(24*28*28, 1024)
207
+ self.lin2 = nn.Linear(1024, 512)
208
+ self.lin3 = nn.Linear(512, 128)
209
+ self.lin4 = nn.Linear(128, num_classes)
210
+
211
+ def forward(self, x):
212
+ x = nn.functional.max_pool2d(nn.functional.relu(self.conv1(x)), (2,2))
213
+ x = nn.functional.max_pool2d(nn.functional.relu(self.conv2(x)), (2,2))
214
+ x = nn.functional.max_pool2d(nn.functional.relu(self.conv3(x)), (2,2))
215
+ x = torch.flatten(x, 1)
216
+ x = nn.functional.relu(self.lin1(x))
217
+ x = nn.functional.relu(self.lin2(x))
218
+ x = nn.functional.relu(self.lin3(x))
219
+ out = self.lin4(x)
220
+ return out
221
+
222
+ def train_model(model, train_loader, optimizer, criterion, epoch):
223
+ model.train()
224
+ train_loss = 0.0
225
+ for idx, batch in tqdm(enumerate(train_loader)):
226
+ optimizer.zero_grad()
227
+ # print(batch['image'].shape, batch['classification'].shape)
228
+ output = model(batch['image'])
229
+ loss = criterion(output, batch['classification'])
230
+ loss.backward()
231
+ optimizer.step()
232
+ train_loss += loss.item()
233
+
234
+ train_loss /= len(train_loader)
235
+ print('[Training set] Epoch: {:d}, Average loss: {:.4f}'.format(epoch+1, train_loss))
236
+
237
+ return train_loss
238
+
239
+ def test_model(model, test_loader, epoch):
240
+ model.eval()
241
+ correct = 0
242
+
243
+ with torch.no_grad():
244
+ for idx, value in tqdm(enumerate(test_loader)):
245
+ output = model(value['image'])
246
+ batch_pred = output.max(1, keepdim=True)[1]
247
+ correct += batch_pred.eq(value['classification'].view_as(batch_pred)).sum().item()
248
+
249
+ # print(correct, len(test_loader.dataset))
250
+ test_acc = correct / len(test_loader.dataset)
251
+ print('[Test set] Epoch: {:d}, Accuracy: {:.2f}%\n'.format(
252
+ epoch+1, 100. * test_acc))
253
+
254
+ return test_acc
255
+
256
+ model = LeNet(num_classes=2)
257
+ criterion = nn.CrossEntropyLoss()
258
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
259
+
260
+ best_acc = 0
261
+ for epoch in range(0, 100):
262
+ print("now training")
263
+ train_model(model, train_loader, optimizer, criterion, epoch)
264
+ print("now testing")
265
+ acc = test_model(model, test_loader, epoch)
266
+ print(f"epoch {epoch+1}, best_acc {max(best_acc, acc)}")
267
+ best_acc = max(best_acc, acc)
268
+
269
+ torch.save(model.state_dict(), f'animals_nn_epoch_{epoch+1}.pth')