IvanBanny commited on
Commit
5f436de
·
1 Parent(s): 5634aef

feat(data): implemented data augmentation

Browse files
Files changed (5) hide show
  1. model.py +6 -6
  2. performance.json +64 -88
  3. performance_plot.png +0 -0
  4. train.py +51 -4
  5. train_dist.py +193 -148
model.py CHANGED
@@ -34,9 +34,9 @@ class MyModel(nn.Module):
34
  self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
35
 
36
  # Residual blocks
37
- self.layer1 = self._resnet_layers(64, 128, num_blocks=2) # 2 residual blocks
38
- self.layer2 = self._resnet_layers(128, 256, num_blocks=2) # 2 residual blocks
39
- self.layer3 = self._resnet_layers(256, 512, num_blocks=2) # 2 residual blocks
40
 
41
  # Global average pooling
42
  self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
@@ -47,9 +47,9 @@ class MyModel(nn.Module):
47
  self.bn1,
48
  nn.ReLU(),
49
  self.pool1,
50
- self.layer1,
51
- self.layer2,
52
- self.layer3,
53
  self.global_avg_pool
54
  )
55
 
 
34
  self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
35
 
36
  # Residual blocks
37
+ self.block1 = self._resnet_layers(64, 128, num_blocks=3) # 3 residual blocks
38
+ self.block2 = self._resnet_layers(128, 256, num_blocks=3) # 3 residual blocks
39
+ self.block3 = self._resnet_layers(256, 512, num_blocks=3) # 3 residual blocks
40
 
41
  # Global average pooling
42
  self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
 
47
  self.bn1,
48
  nn.ReLU(),
49
  self.pool1,
50
+ self.block1,
51
+ self.block2,
52
+ self.block3,
53
  self.global_avg_pool
54
  )
55
 
performance.json CHANGED
@@ -1,122 +1,98 @@
1
  [
2
  {
3
- "avg_train_loss": 2.0,
4
- "train_accuracy": 0.0,
5
- "avg_val_loss": 4.0,
6
- "val_accuracy": 0.0
7
  },
8
  {
9
- "avg_train_loss": 1.3333333333333333,
10
- "train_accuracy": 0.125,
11
- "avg_val_loss": 2.0,
12
- "val_accuracy": 0.1
13
  },
14
  {
15
- "avg_train_loss": 1.0,
16
- "train_accuracy": 0.2222222222222222,
17
- "avg_val_loss": 1.3333333333333333,
18
- "val_accuracy": 0.18181818181818182
19
  },
20
  {
21
- "avg_train_loss": 0.8,
22
- "train_accuracy": 0.3,
23
- "avg_val_loss": 1.0,
24
- "val_accuracy": 0.25
25
  },
26
  {
27
- "avg_train_loss": 0.6666666666666666,
28
- "train_accuracy": 0.36363636363636365,
29
- "avg_val_loss": 0.8,
30
- "val_accuracy": 0.3076923076923077
31
  },
32
  {
33
- "avg_train_loss": 0.5714285714285714,
34
- "train_accuracy": 0.4166666666666667,
35
- "avg_val_loss": 0.6666666666666666,
36
- "val_accuracy": 0.35714285714285715
37
  },
38
  {
39
- "avg_train_loss": 0.5,
40
- "train_accuracy": 0.46153846153846156,
41
- "avg_val_loss": 0.5714285714285714,
42
- "val_accuracy": 0.4
43
  },
44
  {
45
- "avg_train_loss": 0.4444444444444444,
46
- "train_accuracy": 0.5,
47
- "avg_val_loss": 0.5,
48
- "val_accuracy": 0.4375
49
  },
50
  {
51
- "avg_train_loss": 0.4,
52
- "train_accuracy": 0.5333333333333333,
53
- "avg_val_loss": 0.4444444444444444,
54
- "val_accuracy": 0.47058823529411764
55
  },
56
  {
57
- "avg_train_loss": 0.36363636363636365,
58
- "train_accuracy": 0.5625,
59
- "avg_val_loss": 0.4,
60
- "val_accuracy": 0.5
61
  },
62
  {
63
- "avg_train_loss": 0.3333333333333333,
64
- "train_accuracy": 0.5882352941176471,
65
- "avg_val_loss": 0.36363636363636365,
66
- "val_accuracy": 0.5263157894736842
67
  },
68
  {
69
- "avg_train_loss": 0.3076923076923077,
70
- "train_accuracy": 0.6111111111111112,
71
- "avg_val_loss": 0.3333333333333333,
72
- "val_accuracy": 0.55
73
  },
74
  {
75
- "avg_train_loss": 0.2857142857142857,
76
- "train_accuracy": 0.631578947368421,
77
- "avg_val_loss": 0.3076923076923077,
78
- "val_accuracy": 0.5714285714285714
79
  },
80
  {
81
- "avg_train_loss": 0.26666666666666666,
82
- "train_accuracy": 0.65,
83
- "avg_val_loss": 0.2857142857142857,
84
- "val_accuracy": 0.5909090909090909
85
  },
86
  {
87
- "avg_train_loss": 0.25,
88
- "train_accuracy": 0.6666666666666666,
89
- "avg_val_loss": 0.26666666666666666,
90
- "val_accuracy": 0.6086956521739131
91
  },
92
  {
93
- "avg_train_loss": 0.23529411764705882,
94
- "train_accuracy": 0.6818181818181818,
95
- "avg_val_loss": 0.25,
96
- "val_accuracy": 0.625
97
- },
98
- {
99
- "avg_train_loss": 0.2222222222222222,
100
- "train_accuracy": 0.6956521739130435,
101
- "avg_val_loss": 0.23529411764705882,
102
- "val_accuracy": 0.64
103
- },
104
- {
105
- "avg_train_loss": 0.21052631578947367,
106
- "train_accuracy": 0.7083333333333334,
107
- "avg_val_loss": 0.2222222222222222,
108
- "val_accuracy": 0.6538461538461539
109
- },
110
- {
111
- "avg_train_loss": 0.2,
112
- "train_accuracy": 0.72,
113
- "avg_val_loss": 0.21052631578947367,
114
- "val_accuracy": 0.6666666666666666
115
- },
116
- {
117
- "avg_train_loss": 0.19047619047619047,
118
- "train_accuracy": 0.7307692307692307,
119
- "avg_val_loss": 0.2,
120
- "val_accuracy": 0.6785714285714286
121
  }
122
  ]
 
1
  [
2
  {
3
+ "avg_train_loss": 3.6829103430493553,
4
+ "train_accuracy": 0.1709,
5
+ "avg_val_loss": 3.5155134261793393,
6
+ "val_accuracy": 0.21660000085830688
7
  },
8
  {
9
+ "avg_train_loss": 3.1779507774614175,
10
+ "train_accuracy": 0.28584,
11
+ "avg_val_loss": 3.3872365769307327,
12
+ "val_accuracy": 0.26499998569488525
13
  },
14
  {
15
+ "avg_train_loss": 2.948077251571001,
16
+ "train_accuracy": 0.3488,
17
+ "avg_val_loss": 2.960327925955414,
18
+ "val_accuracy": 0.35409998893737793
19
  },
20
  {
21
+ "avg_train_loss": 2.7825030597142506,
22
+ "train_accuracy": 0.39572,
23
+ "avg_val_loss": 2.9160548896546574,
24
+ "val_accuracy": 0.3675999939441681
25
  },
26
  {
27
+ "avg_train_loss": 2.6581287719619175,
28
+ "train_accuracy": 0.43032,
29
+ "avg_val_loss": 2.8124696768013533,
30
+ "val_accuracy": 0.39629998803138733
31
  },
32
  {
33
+ "avg_train_loss": 2.536289040659455,
34
+ "train_accuracy": 0.46174,
35
+ "avg_val_loss": 2.7144464383459397,
36
+ "val_accuracy": 0.42500001192092896
37
  },
38
  {
39
+ "avg_train_loss": 2.440945129400633,
40
+ "train_accuracy": 0.49412,
41
+ "avg_val_loss": 2.745724817749801,
42
+ "val_accuracy": 0.4189999997615814
43
  },
44
  {
45
+ "avg_train_loss": 2.3424960819483567,
46
+ "train_accuracy": 0.52302,
47
+ "avg_val_loss": 2.744392152045183,
48
+ "val_accuracy": 0.4237000048160553
49
  },
50
  {
51
+ "avg_train_loss": 2.245347209489277,
52
+ "train_accuracy": 0.5516,
53
+ "avg_val_loss": 2.7382394584121217,
54
+ "val_accuracy": 0.43230000138282776
55
  },
56
  {
57
+ "avg_train_loss": 2.155752474042901,
58
+ "train_accuracy": 0.57972,
59
+ "avg_val_loss": 2.7085890071407244,
60
+ "val_accuracy": 0.43689998984336853
61
  },
62
  {
63
+ "avg_train_loss": 2.0571537492218797,
64
+ "train_accuracy": 0.6087,
65
+ "avg_val_loss": 2.7106366005672773,
66
+ "val_accuracy": 0.44179999828338623
67
  },
68
  {
69
+ "avg_train_loss": 1.955078414747979,
70
+ "train_accuracy": 0.64364,
71
+ "avg_val_loss": 2.8602050003732087,
72
+ "val_accuracy": 0.421099990606308
73
  },
74
  {
75
+ "avg_train_loss": 1.8526526395891694,
76
+ "train_accuracy": 0.68184,
77
+ "avg_val_loss": 2.723868011668989,
78
+ "val_accuracy": 0.44020000100135803
79
  },
80
  {
81
+ "avg_train_loss": 1.7414095465830328,
82
+ "train_accuracy": 0.7196,
83
+ "avg_val_loss": 2.8222216952378583,
84
+ "val_accuracy": 0.4287000000476837
85
  },
86
  {
87
+ "avg_train_loss": 1.6265801092942251,
88
+ "train_accuracy": 0.7615,
89
+ "avg_val_loss": 2.75775924002289,
90
+ "val_accuracy": 0.430400013923645
91
  },
92
  {
93
+ "avg_train_loss": 1.5103181164308914,
94
+ "train_accuracy": 0.80724,
95
+ "avg_val_loss": 2.8081995484175954,
96
+ "val_accuracy": 0.43479999899864197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  }
98
  ]
performance_plot.png CHANGED
train.py CHANGED
@@ -9,6 +9,7 @@ from PIL import Image
9
  from torchvision import transforms
10
  from torch.utils.data import DataLoader, Dataset
11
  from model import MyModel
 
12
 
13
 
14
  class MiniPlaces(Dataset):
@@ -75,6 +76,47 @@ class MiniPlaces(Dataset):
75
  return image, label
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def evaluate(model, test_loader, criterion, device):
79
  """
80
  Evaluate the CNN classifier on the validation set.
@@ -137,7 +179,7 @@ def train(model, train_loader, val_loader, optimizer, criterion, device,
137
  model = model.to(device)
138
 
139
  # Define early stopping parameters
140
- patience = 3 # Number of epochs to wait for improvement
141
  best_val_accuracy = 0.0 # Best validation accuracy so far
142
  epochs_without_improvement = 0 # Counter for epochs without improvement
143
  best_model_state = None # To store the state of the best model
@@ -277,9 +319,12 @@ def main(args):
277
  transforms.Normalize(image_net_mean, image_net_std),
278
  ])
279
 
280
- data_root = 'data'
 
 
281
 
282
- # Create MiniPlaces dataset object
 
283
  miniplaces_train = MiniPlaces(data_root,
284
  split='train',
285
  transform=data_transform)
@@ -311,6 +356,8 @@ def main(args):
311
  # optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-4, amsgrad=False)
312
  optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, dampening=0, weight_decay=1e-4, nesterov=True)
313
 
 
 
314
  if args.checkpoint:
315
  checkpoint = torch.load(args.checkpoint)
316
  model.load_state_dict(checkpoint['model_state_dict'])
@@ -341,7 +388,7 @@ if __name__ == "__main__":
341
  parser.add_argument('--test', action='store_true')
342
  parser.add_argument('--checkpoint')
343
  parser.add_argument('--gpu', default=0)
344
- parser.add_argument('--epochs', default=10)
345
  parser.add_argument('--batch_size', default=32)
346
  args = parser.parse_args()
347
  main(args)
 
9
  from torchvision import transforms
10
  from torch.utils.data import DataLoader, Dataset
11
  from model import MyModel
12
+ import numpy as np
13
 
14
 
15
  class MiniPlaces(Dataset):
 
76
  return image, label
77
 
78
 
79
+ def create_train_transform():
80
+ """
81
+ Create training data transformation with augmentation
82
+ """
83
+ image_net_mean = torch.Tensor([0.485, 0.456, 0.406])
84
+ image_net_std = torch.Tensor([0.229, 0.224, 0.225])
85
+
86
+ return transforms.Compose([
87
+ transforms.RandomResizedCrop(128, scale=(0.8, 1.0)),
88
+ transforms.RandomHorizontalFlip(p=0.5),
89
+ transforms.ColorJitter(
90
+ brightness=0.4,
91
+ contrast=0.4,
92
+ saturation=0.4,
93
+ hue=0.1
94
+ ),
95
+ transforms.RandomAffine(
96
+ degrees=15, # rotation
97
+ translate=(0.1, 0.1), # horizontal/vertical translation
98
+ scale=(0.9, 1.1), # scale
99
+ ),
100
+ transforms.ToTensor(),
101
+ transforms.Resize((128, 128)),
102
+ transforms.Normalize(image_net_mean, image_net_std)
103
+ ])
104
+
105
+
106
+ def create_val_transform():
107
+ """
108
+ Create validation/test data transformation without augmentation
109
+ """
110
+ image_net_mean = torch.Tensor([0.485, 0.456, 0.406])
111
+ image_net_std = torch.Tensor([0.229, 0.224, 0.225])
112
+
113
+ return transforms.Compose([
114
+ transforms.ToTensor(),
115
+ transforms.Resize((128, 128)),
116
+ transforms.Normalize(image_net_mean, image_net_std)
117
+ ])
118
+
119
+
120
  def evaluate(model, test_loader, criterion, device):
121
  """
122
  Evaluate the CNN classifier on the validation set.
 
179
  model = model.to(device)
180
 
181
  # Define early stopping parameters
182
+ patience = 5 # Number of epochs to wait for improvement
183
  best_val_accuracy = 0.0 # Best validation accuracy so far
184
  epochs_without_improvement = 0 # Counter for epochs without improvement
185
  best_model_state = None # To store the state of the best model
 
319
  transforms.Normalize(image_net_mean, image_net_std),
320
  ])
321
 
322
+ # Separate transforms for training and validation
323
+ train_transform = create_train_transform()
324
+ val_transform = create_val_transform()
325
 
326
+ # Create datasets
327
+ data_root = 'data'
328
  miniplaces_train = MiniPlaces(data_root,
329
  split='train',
330
  transform=data_transform)
 
356
  # optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-4, amsgrad=False)
357
  optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, dampening=0, weight_decay=1e-4, nesterov=True)
358
 
359
+ print("PARAMS NUM:", sum(p.numel() for p in model.parameters() if p.requires_grad))
360
+
361
  if args.checkpoint:
362
  checkpoint = torch.load(args.checkpoint)
363
  model.load_state_dict(checkpoint['model_state_dict'])
 
388
  parser.add_argument('--test', action='store_true')
389
  parser.add_argument('--checkpoint')
390
  parser.add_argument('--gpu', default=0)
391
+ parser.add_argument('--epochs', default=100)
392
  parser.add_argument('--batch_size', default=32)
393
  args = parser.parse_args()
394
  main(args)
train_dist.py CHANGED
@@ -31,13 +31,14 @@ def setup(rank, world_size, port):
31
 
32
  def cleanup():
33
  """
34
- Clean up the distributed training environment by destroying the process group.
35
  """
36
- dist.destroy_process_group()
 
 
37
 
38
 
39
  class MiniPlaces(Dataset):
40
- # Your existing MiniPlaces class implementation remains the same
41
  def __init__(self, root_dir, split, transform=None, label_dict=None):
42
  """
43
  Initialize the MiniPlaces dataset with the root directory for the images,
@@ -100,6 +101,47 @@ class MiniPlaces(Dataset):
100
  return image, label
101
 
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def evaluate(model, test_loader, criterion, device):
104
  """
105
  Evaluate the CNN classifier on the validation set.
@@ -158,146 +200,146 @@ def train_worker(rank, world_size, args):
158
  world_size (int): The total number of processes (GPUs).
159
  args (argparse.Namespace): Command-line arguments.
160
  """
161
- setup(rank, world_size, args.port)
162
- device = torch.device(f'cuda:{rank}')
163
-
164
- # Define early stopping parameters
165
- patience = 3 # Number of epochs to wait for improvement
166
- best_val_accuracy = 0.0 # Best validation accuracy so far
167
- epochs_without_improvement = 0 # Counter for epochs without improvement
168
- best_model_state = None # To store the state of the best model
169
-
170
- # Data loading and preprocessing
171
- image_net_mean = torch.Tensor([0.485, 0.456, 0.406])
172
- image_net_std = torch.Tensor([0.229, 0.224, 0.225])
173
- data_transform = transforms.Compose([
174
- transforms.ToTensor(),
175
- transforms.Resize((128, 128)),
176
- transforms.Normalize(image_net_mean, image_net_std),
177
- ])
178
-
179
- # Create datasets
180
- data_root = 'data'
181
- miniplaces_train = MiniPlaces(data_root, split='train', transform=data_transform)
182
- miniplaces_val = MiniPlaces(data_root, split='val', transform=data_transform,
183
- label_dict=miniplaces_train.label_dict)
184
-
185
- # Create distributed samplers
186
- train_sampler = DistributedSampler(miniplaces_train, num_replicas=world_size, rank=rank)
187
- val_sampler = DistributedSampler(miniplaces_val, num_replicas=world_size, rank=rank)
188
-
189
- # Create dataloaders
190
- train_loader = DataLoader(miniplaces_train, batch_size=args.batch_size,
191
- num_workers=2, sampler=train_sampler,
192
- pin_memory=True)
193
- val_loader = DataLoader(miniplaces_val, batch_size=args.batch_size,
194
- num_workers=2, sampler=val_sampler,
195
- pin_memory=True)
196
-
197
- # Create model and move to GPU
198
- model = MyModel(num_classes=len(miniplaces_train.label_dict))
199
- model = model.to(device)
200
- model = DDP(model, device_ids=[rank])
201
-
202
- optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9,
203
- dampening=0, weight_decay=1e-4, nesterov=True)
204
- criterion = torch.nn.CrossEntropyLoss(reduction='mean', label_smoothing=0.1)
205
-
206
- if args.checkpoint:
207
- map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
208
- checkpoint = torch.load(args.checkpoint, map_location=map_location)
209
- model.module.load_state_dict(checkpoint['model_state_dict'])
210
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
211
-
212
- if not args.test:
213
- # Training loop
214
- performance = []
215
- for epoch in range(args.epochs):
216
- model.train()
217
- train_sampler.set_epoch(epoch) # Important for proper shuffling
218
-
219
- running_loss = 0.0
220
- correct_predictions = 0
221
- total_samples = 0
222
-
223
- if rank == 0: # Only show progress bar on rank 0
224
- pbar = tqdm(total=len(train_loader),
225
- desc=f'Epoch {epoch + 1}/{args.epochs}',
226
- position=0, leave=True)
227
-
228
- for inputs, labels in train_loader:
229
- inputs = inputs.to(device)
230
- labels = labels.to(device)
231
-
232
- optimizer.zero_grad()
233
- logits = model(inputs)
234
- loss = criterion(logits, labels)
235
- loss.backward()
236
- optimizer.step()
237
-
238
- running_loss += loss.item()
239
- _, predicted = logits.max(1)
240
- correct_predictions += (predicted == labels).sum().item()
241
- total_samples += labels.size(0)
242
 
243
  if rank == 0:
244
- pbar.update(1)
245
- pbar.set_postfix(loss=loss.item())
246
-
247
- if rank == 0:
248
- pbar.close()
249
-
250
- # Evaluate and log metrics
251
- avg_train_loss = running_loss / len(train_loader)
252
- train_accuracy = correct_predictions / total_samples
253
- avg_val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)
254
-
255
- if rank == 0: # Only save metrics on rank 0
256
- performance.append({
257
- "avg_train_loss": avg_train_loss,
258
- "train_accuracy": train_accuracy,
259
- "avg_val_loss": avg_val_loss,
260
- "val_accuracy": val_accuracy
261
- })
262
- print(
263
- f"Train Loss: {avg_train_loss:.4f}, Accuracy: {train_accuracy:.4f} "
264
- f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}"
265
- )
266
-
267
- # Check for early stopping
268
- if val_accuracy > best_val_accuracy:
269
- best_val_accuracy = val_accuracy
270
- epochs_without_improvement = 0 # Reset counter if there's an improvement
271
-
272
- # Save the model checkpoint for the best model
273
- best_model_state = {
274
- 'model_state_dict': model.module.state_dict(),
275
- 'optimizer_state_dict': optimizer.state_dict(),
276
- 'epoch': epoch,
277
- }
278
- else:
279
- epochs_without_improvement += 1
280
-
281
- # Early stopping condition
282
- if epochs_without_improvement >= patience:
283
- print(f"Early stopping at epoch {epoch + 1}.")
284
- break # Stop training if no improvement for 'patience' epochs
285
-
286
- if rank == 0: # Save performance and the best model checkpoint only on rank 0
287
- with open("performance.json", "w") as f:
288
- json.dump(performance, f, indent=4)
289
- torch.save(best_model_state, 'model.ckpt')
290
-
291
- else: # Testing mode
292
- miniplaces_test = MiniPlaces(data_root, split='test', transform=data_transform)
293
- test_loader = DataLoader(miniplaces_test, batch_size=args.batch_size, num_workers=2, shuffle=False)
294
- checkpoint = torch.load(args.checkpoint, map_location=device)
295
- model.module.load_state_dict(checkpoint['model_state_dict'])
296
- preds = test(model, test_loader, device)
297
- if rank == 0: # Only write predictions on rank 0
298
- write_predictions(preds, 'predictions.csv')
299
-
300
- cleanup()
301
 
302
 
303
  def test(model, test_loader, device):
@@ -345,20 +387,23 @@ def main(args):
345
  Args:
346
  args (argparse.Namespace): Command-line arguments.
347
  """
348
- # Get number of available GPUs
349
  world_size = torch.cuda.device_count()
350
- mp.spawn(train_worker,
351
- args=(world_size, args),
352
- nprocs=world_size,
353
- join=True)
 
 
 
 
354
 
355
 
356
  if __name__ == "__main__":
357
  parser = argparse.ArgumentParser()
358
  parser.add_argument('--test', action='store_true')
359
  parser.add_argument('--checkpoint')
360
- parser.add_argument('--epochs', type=int, default=10)
361
- parser.add_argument('--batch_size', type=int, default=64)
362
  parser.add_argument('--port', type=int, default=4224)
363
  args = parser.parse_args()
364
  main(args)
 
31
 
32
  def cleanup():
33
  """
34
+ Clean up distributed training environment
35
  """
36
+ if dist.is_initialized():
37
+ dist.barrier() # Synchronize all processes before destroying process group
38
+ dist.destroy_process_group()
39
 
40
 
41
  class MiniPlaces(Dataset):
 
42
  def __init__(self, root_dir, split, transform=None, label_dict=None):
43
  """
44
  Initialize the MiniPlaces dataset with the root directory for the images,
 
101
  return image, label
102
 
103
 
104
+ def create_train_transform():
105
+ """
106
+ Create training data transformation with augmentation
107
+ """
108
+ image_net_mean = torch.Tensor([0.485, 0.456, 0.406])
109
+ image_net_std = torch.Tensor([0.229, 0.224, 0.225])
110
+
111
+ return transforms.Compose([
112
+ transforms.RandomResizedCrop(128, scale=(0.8, 1.0)),
113
+ transforms.RandomHorizontalFlip(p=0.5),
114
+ transforms.ColorJitter(
115
+ brightness=0.4,
116
+ contrast=0.4,
117
+ saturation=0.4,
118
+ hue=0.1
119
+ ),
120
+ transforms.RandomAffine(
121
+ degrees=15, # rotation
122
+ translate=(0.1, 0.1), # horizontal/vertical translation
123
+ scale=(0.9, 1.1), # scale
124
+ ),
125
+ transforms.ToTensor(),
126
+ transforms.Resize((128, 128)),
127
+ transforms.Normalize(image_net_mean, image_net_std)
128
+ ])
129
+
130
+
131
+ def create_val_transform():
132
+ """
133
+ Create validation/test data transformation without augmentation
134
+ """
135
+ image_net_mean = torch.Tensor([0.485, 0.456, 0.406])
136
+ image_net_std = torch.Tensor([0.229, 0.224, 0.225])
137
+
138
+ return transforms.Compose([
139
+ transforms.ToTensor(),
140
+ transforms.Resize((128, 128)),
141
+ transforms.Normalize(image_net_mean, image_net_std)
142
+ ])
143
+
144
+
145
  def evaluate(model, test_loader, criterion, device):
146
  """
147
  Evaluate the CNN classifier on the validation set.
 
200
  world_size (int): The total number of processes (GPUs).
201
  args (argparse.Namespace): Command-line arguments.
202
  """
203
+ try:
204
+ setup(rank, world_size, args.port)
205
+ device = torch.device(f'cuda:{rank}')
206
+
207
+ # Define early stopping parameters
208
+ patience = 3 # Number of epochs to wait for improvement
209
+ best_val_accuracy = 0.0 # Best validation accuracy so far
210
+ epochs_without_improvement = 0 # Counter for epochs without improvement
211
+ best_model_state = None # To store the state of the best model
212
+
213
+ # Separate transforms for training and validation
214
+ train_transform = create_train_transform()
215
+ val_transform = create_val_transform()
216
+
217
+ # Create datasets
218
+ data_root = 'data'
219
+ miniplaces_train = MiniPlaces(data_root, split='train', transform=train_transform)
220
+ miniplaces_val = MiniPlaces(data_root, split='val', transform=val_transform,
221
+ label_dict=miniplaces_train.label_dict)
222
+
223
+ # Create distributed samplers
224
+ train_sampler = DistributedSampler(miniplaces_train, num_replicas=world_size, rank=rank)
225
+ val_sampler = DistributedSampler(miniplaces_val, num_replicas=world_size, rank=rank)
226
+
227
+ # Create dataloaders
228
+ train_loader = DataLoader(miniplaces_train, batch_size=args.batch_size,
229
+ num_workers=2, sampler=train_sampler,
230
+ pin_memory=True)
231
+ val_loader = DataLoader(miniplaces_val, batch_size=args.batch_size,
232
+ num_workers=2, sampler=val_sampler,
233
+ pin_memory=True)
234
+
235
+ # Create model and move to GPU
236
+ model = MyModel(num_classes=len(miniplaces_train.label_dict))
237
+ model = model.to(device)
238
+ model = DDP(model, device_ids=[rank])
239
+
240
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9,
241
+ dampening=0, weight_decay=1e-4, nesterov=True)
242
+ criterion = torch.nn.CrossEntropyLoss(reduction='mean', label_smoothing=0.1)
243
+
244
+ if args.checkpoint:
245
+ map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
246
+ checkpoint = torch.load(args.checkpoint, map_location=map_location)
247
+ model.module.load_state_dict(checkpoint['model_state_dict'])
248
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
249
+
250
+ if not args.test:
251
+ # Training loop
252
+ performance = []
253
+ for epoch in range(args.epochs):
254
+ model.train()
255
+ train_sampler.set_epoch(epoch) # Important for proper shuffling
256
+
257
+ running_loss = 0.0
258
+ correct_predictions = 0
259
+ total_samples = 0
260
+
261
+ if rank == 0: # Only show progress bar on rank 0
262
+ pbar = tqdm(total=len(train_loader),
263
+ desc=f'Epoch {epoch + 1}/{args.epochs}',
264
+ position=0, leave=True)
265
+
266
+ for inputs, labels in train_loader:
267
+ inputs = inputs.to(device)
268
+ labels = labels.to(device)
269
+
270
+ optimizer.zero_grad()
271
+ logits = model(inputs)
272
+ loss = criterion(logits, labels)
273
+ loss.backward()
274
+ optimizer.step()
275
+
276
+ running_loss += loss.item()
277
+ _, predicted = logits.max(1)
278
+ correct_predictions += (predicted == labels).sum().item()
279
+ total_samples += labels.size(0)
280
+
281
+ if rank == 0:
282
+ pbar.update(1)
283
+ pbar.set_postfix(loss=loss.item())
284
 
285
  if rank == 0:
286
+ pbar.close()
287
+
288
+ # Evaluate and log metrics
289
+ avg_train_loss = running_loss / len(train_loader)
290
+ train_accuracy = correct_predictions / total_samples
291
+ avg_val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)
292
+
293
+ if rank == 0: # Only save metrics on rank 0
294
+ performance.append({
295
+ "avg_train_loss": avg_train_loss,
296
+ "train_accuracy": train_accuracy,
297
+ "avg_val_loss": avg_val_loss,
298
+ "val_accuracy": val_accuracy
299
+ })
300
+ print(
301
+ f"Train Loss: {avg_train_loss:.4f}, Accuracy: {train_accuracy:.4f} "
302
+ f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}"
303
+ )
304
+
305
+ # Check for early stopping
306
+ if val_accuracy > best_val_accuracy:
307
+ best_val_accuracy = val_accuracy
308
+ epochs_without_improvement = 0 # Reset counter if there's an improvement
309
+
310
+ # Save the model checkpoint for the best model
311
+ best_model_state = {
312
+ 'model_state_dict': model.module.state_dict(),
313
+ 'optimizer_state_dict': optimizer.state_dict(),
314
+ 'epoch': epoch,
315
+ }
316
+ else:
317
+ epochs_without_improvement += 1
318
+
319
+ # Early stopping condition
320
+ if epochs_without_improvement >= patience:
321
+ print(f"Early stopping at epoch {epoch + 1}.")
322
+ break # Stop training if no improvement for 'patience' epochs
323
+
324
+ if rank == 0: # Save performance and the best model checkpoint only on rank 0
325
+ with open("performance.json", "w") as f:
326
+ json.dump(performance, f, indent=4)
327
+ torch.save(best_model_state, 'model.ckpt')
328
+
329
+ else: # Testing mode
330
+ miniplaces_test = MiniPlaces(data_root, split='test', transform=data_transform)
331
+ test_loader = DataLoader(miniplaces_test, batch_size=args.batch_size, num_workers=2, shuffle=False)
332
+ checkpoint = torch.load(args.checkpoint, map_location=device)
333
+ model.module.load_state_dict(checkpoint['model_state_dict'])
334
+ preds = test(model, test_loader, device)
335
+ if rank == 0: # Only write predictions on rank 0
336
+ write_predictions(preds, 'predictions.csv')
337
+ finally:
338
+ cleanup()
339
+ # Add explicit synchronization before exiting
340
+ torch.cuda.synchronize()
341
+ if dist.is_initialized():
342
+ dist.barrier()
343
 
344
 
345
  def test(model, test_loader, device):
 
387
  Args:
388
  args (argparse.Namespace): Command-line arguments.
389
  """
 
390
  world_size = torch.cuda.device_count()
391
+ try:
392
+ mp.spawn(train_worker,
393
+ args=(world_size, args),
394
+ nprocs=world_size,
395
+ join=True)
396
+ finally:
397
+ # Force cleanup of any remaining CUDA resources
398
+ torch.cuda.empty_cache()
399
 
400
 
401
  if __name__ == "__main__":
402
  parser = argparse.ArgumentParser()
403
  parser.add_argument('--test', action='store_true')
404
  parser.add_argument('--checkpoint')
405
+ parser.add_argument('--epochs', type=int, default=100)
406
+ parser.add_argument('--batch_size', type=int, default=32)
407
  parser.add_argument('--port', type=int, default=4224)
408
  args = parser.parse_args()
409
  main(args)