IvanBanny commited on
Commit
b42a2f2
·
1 Parent(s): 340ccea

feat(train, dist_train, plots): improved training, implemented distributed training, added plotting

Browse files
__pycache__/model.cpython-312.pyc ADDED
Binary file (4.12 kB). View file
 
performance.json ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "avg_train_loss": 2.0,
4
+ "train_accuracy": 0.0,
5
+ "avg_val_loss": 4.0,
6
+ "val_accuracy": 0.0
7
+ },
8
+ {
9
+ "avg_train_loss": 1.3333333333333333,
10
+ "train_accuracy": 0.125,
11
+ "avg_val_loss": 2.0,
12
+ "val_accuracy": 0.1
13
+ },
14
+ {
15
+ "avg_train_loss": 1.0,
16
+ "train_accuracy": 0.2222222222222222,
17
+ "avg_val_loss": 1.3333333333333333,
18
+ "val_accuracy": 0.18181818181818182
19
+ },
20
+ {
21
+ "avg_train_loss": 0.8,
22
+ "train_accuracy": 0.3,
23
+ "avg_val_loss": 1.0,
24
+ "val_accuracy": 0.25
25
+ },
26
+ {
27
+ "avg_train_loss": 0.6666666666666666,
28
+ "train_accuracy": 0.36363636363636365,
29
+ "avg_val_loss": 0.8,
30
+ "val_accuracy": 0.3076923076923077
31
+ },
32
+ {
33
+ "avg_train_loss": 0.5714285714285714,
34
+ "train_accuracy": 0.4166666666666667,
35
+ "avg_val_loss": 0.6666666666666666,
36
+ "val_accuracy": 0.35714285714285715
37
+ },
38
+ {
39
+ "avg_train_loss": 0.5,
40
+ "train_accuracy": 0.46153846153846156,
41
+ "avg_val_loss": 0.5714285714285714,
42
+ "val_accuracy": 0.4
43
+ },
44
+ {
45
+ "avg_train_loss": 0.4444444444444444,
46
+ "train_accuracy": 0.5,
47
+ "avg_val_loss": 0.5,
48
+ "val_accuracy": 0.4375
49
+ },
50
+ {
51
+ "avg_train_loss": 0.4,
52
+ "train_accuracy": 0.5333333333333333,
53
+ "avg_val_loss": 0.4444444444444444,
54
+ "val_accuracy": 0.47058823529411764
55
+ },
56
+ {
57
+ "avg_train_loss": 0.36363636363636365,
58
+ "train_accuracy": 0.5625,
59
+ "avg_val_loss": 0.4,
60
+ "val_accuracy": 0.5
61
+ },
62
+ {
63
+ "avg_train_loss": 0.3333333333333333,
64
+ "train_accuracy": 0.5882352941176471,
65
+ "avg_val_loss": 0.36363636363636365,
66
+ "val_accuracy": 0.5263157894736842
67
+ },
68
+ {
69
+ "avg_train_loss": 0.3076923076923077,
70
+ "train_accuracy": 0.6111111111111112,
71
+ "avg_val_loss": 0.3333333333333333,
72
+ "val_accuracy": 0.55
73
+ },
74
+ {
75
+ "avg_train_loss": 0.2857142857142857,
76
+ "train_accuracy": 0.631578947368421,
77
+ "avg_val_loss": 0.3076923076923077,
78
+ "val_accuracy": 0.5714285714285714
79
+ },
80
+ {
81
+ "avg_train_loss": 0.26666666666666666,
82
+ "train_accuracy": 0.65,
83
+ "avg_val_loss": 0.2857142857142857,
84
+ "val_accuracy": 0.5909090909090909
85
+ },
86
+ {
87
+ "avg_train_loss": 0.25,
88
+ "train_accuracy": 0.6666666666666666,
89
+ "avg_val_loss": 0.26666666666666666,
90
+ "val_accuracy": 0.6086956521739131
91
+ },
92
+ {
93
+ "avg_train_loss": 0.23529411764705882,
94
+ "train_accuracy": 0.6818181818181818,
95
+ "avg_val_loss": 0.25,
96
+ "val_accuracy": 0.625
97
+ },
98
+ {
99
+ "avg_train_loss": 0.2222222222222222,
100
+ "train_accuracy": 0.6956521739130435,
101
+ "avg_val_loss": 0.23529411764705882,
102
+ "val_accuracy": 0.64
103
+ },
104
+ {
105
+ "avg_train_loss": 0.21052631578947367,
106
+ "train_accuracy": 0.7083333333333334,
107
+ "avg_val_loss": 0.2222222222222222,
108
+ "val_accuracy": 0.6538461538461539
109
+ },
110
+ {
111
+ "avg_train_loss": 0.2,
112
+ "train_accuracy": 0.72,
113
+ "avg_val_loss": 0.21052631578947367,
114
+ "val_accuracy": 0.6666666666666666
115
+ },
116
+ {
117
+ "avg_train_loss": 0.19047619047619047,
118
+ "train_accuracy": 0.7307692307692307,
119
+ "avg_val_loss": 0.2,
120
+ "val_accuracy": 0.6785714285714286
121
+ }
122
+ ]
performance_plot.png ADDED
plots.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import matplotlib.pyplot as plt
3
+
4
+ with open("performance.json", "r") as f:
5
+ performance = json.load(f)
6
+
7
+ # Extract values from the performance list
8
+ epochs = range(1, len(performance) + 1)
9
+ train_losses = [epoch["avg_train_loss"] for epoch in performance]
10
+ val_losses = [epoch["avg_val_loss"] for epoch in performance]
11
+ train_accuracies = [epoch["train_accuracy"] for epoch in performance]
12
+ val_accuracies = [epoch["val_accuracy"] for epoch in performance]
13
+
14
+ # Plot Training and Validation Loss
15
+ plt.figure(figsize=(14, 6))
16
+
17
+ # Subplot for Loss
18
+ plt.subplot(1, 2, 1)
19
+ plt.plot(epochs, train_losses, label="Training Loss")
20
+ plt.plot(epochs, val_losses, label="Validation Loss")
21
+ plt.xlabel("Epochs")
22
+ plt.ylabel("Loss")
23
+ plt.title("Training and Validation Loss")
24
+ plt.legend()
25
+ plt.xticks(epochs)
26
+
27
+ # Subplot for Accuracy
28
+ plt.subplot(1, 2, 2)
29
+ plt.plot(epochs, train_accuracies, label="Training Accuracy")
30
+ plt.plot(epochs, val_accuracies, label="Validation Accuracy")
31
+ plt.xlabel("Epochs")
32
+ plt.ylabel("Accuracy")
33
+ plt.title("Training and Validation Accuracy")
34
+ plt.legend()
35
+ plt.xticks(epochs)
36
+
37
+ plt.tight_layout()
38
+
39
+ # Save the plot as an image file
40
+ plt.savefig("performance_plot.png", dpi=300)
41
+
42
+ plt.show()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pillow
4
+ tqdm
5
+ matplotlib
scene_classification.py → train.py RENAMED
@@ -1,5 +1,7 @@
 
1
  import os
2
  import csv
 
3
  from tqdm import tqdm
4
  import torch
5
  import argparse
@@ -134,6 +136,15 @@ def train(model, train_loader, val_loader, optimizer, criterion, device,
134
  # Place the model on device
135
  model = model.to(device)
136
 
 
 
 
 
 
 
 
 
 
137
  for epoch in range(num_epochs):
138
  model.train() # Set model to training mode
139
 
@@ -176,15 +187,45 @@ def train(model, train_loader, val_loader, optimizer, criterion, device,
176
  pbar.set_postfix(loss=loss.item())
177
 
178
  # Calculate average loss and accuracy
179
- avg_loss = running_loss / len(train_loader)
180
- accuracy = correct_predictions / total_samples
181
-
182
- avg_val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)
 
 
 
 
 
 
183
  print(
184
- f"Train Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f} "
185
  f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}"
186
  )
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  def test(model, test_loader, device):
190
  """
@@ -250,7 +291,7 @@ def main(args):
250
  # Create the dataloaders
251
 
252
  # Define the batch size and number of workers
253
- batch_size = 64
254
  num_workers = 2
255
 
256
  # Create DataLoader for training and validation sets
@@ -263,7 +304,7 @@ def main(args):
263
  num_workers=num_workers,
264
  shuffle=False)
265
 
266
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # TODO: check cuda
267
 
268
  model = MyModel(num_classes=len(miniplaces_train.label_dict))
269
 
@@ -279,10 +320,7 @@ def main(args):
279
 
280
  if not args.test:
281
  train(model, train_loader, val_loader, optimizer, criterion,
282
- device, num_epochs=25)
283
-
284
- torch.save({'model_state_dict': model.state_dict(),
285
- 'optimizer_state_dict': optimizer.state_dict()}, 'model.ckpt')
286
 
287
  else:
288
  miniplaces_test = MiniPlaces(data_root,
@@ -301,6 +339,9 @@ def main(args):
301
  if __name__ == "__main__":
302
  parser = argparse.ArgumentParser()
303
  parser.add_argument('--test', action='store_true')
304
- parser.add_argument('--checkpoint', default='model.ckpt')
 
 
 
305
  args = parser.parse_args()
306
  main(args)
 
1
+ #!/usr/bin/env python3
2
  import os
3
  import csv
4
+ import json
5
  from tqdm import tqdm
6
  import torch
7
  import argparse
 
136
  # Place the model on device
137
  model = model.to(device)
138
 
139
+ # Define early stopping parameters
140
+ patience = 3 # Number of epochs to wait for improvement
141
+ best_val_accuracy = 0.0 # Best validation accuracy so far
142
+ epochs_without_improvement = 0 # Counter for epochs without improvement
143
+ best_model_state = None # To store the state of the best model
144
+
145
+ # Performance tracking
146
+ performance = []
147
+
148
  for epoch in range(num_epochs):
149
  model.train() # Set model to training mode
150
 
 
187
  pbar.set_postfix(loss=loss.item())
188
 
189
  # Calculate average loss and accuracy
190
+ avg_train_loss = running_loss / len(train_loader)
191
+ train_accuracy = correct_predictions / total_samples
192
+ avg_val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)
193
+
194
+ performance.append({
195
+ "avg_train_loss": avg_train_loss,
196
+ "train_accuracy": train_accuracy,
197
+ "avg_val_loss": avg_val_loss,
198
+ "val_accuracy": val_accuracy
199
+ })
200
  print(
201
+ f"Train Loss: {avg_train_loss:.4f}, Accuracy: {train_accuracy:.4f} "
202
  f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}"
203
  )
204
 
205
+ # Check for early stopping
206
+ if val_accuracy > best_val_accuracy:
207
+ best_val_accuracy = val_accuracy
208
+ epochs_without_improvement = 0 # Reset counter if there's an improvement
209
+
210
+ # Save the model checkpoint for the best model
211
+ best_model_state = {
212
+ 'model_state_dict': model.module.state_dict(),
213
+ 'optimizer_state_dict': optimizer.state_dict(),
214
+ 'epoch': epoch,
215
+ }
216
+ else:
217
+ epochs_without_improvement += 1
218
+
219
+ # Early stopping condition
220
+ if epochs_without_improvement >= patience:
221
+ print(f"Early stopping at epoch {epoch + 1}.")
222
+ break # Stop training if no improvement for 'patience' epochs
223
+
224
+ # Save the performance list to a JSON file
225
+ with open("performance.json", "w") as f:
226
+ json.dump(performance, f, indent=4)
227
+ torch.save(best_model_state, 'model.ckpt')
228
+
229
 
230
  def test(model, test_loader, device):
231
  """
 
291
  # Create the dataloaders
292
 
293
  # Define the batch size and number of workers
294
+ batch_size = int(args.batch_size)
295
  num_workers = 2
296
 
297
  # Create DataLoader for training and validation sets
 
304
  num_workers=num_workers,
305
  shuffle=False)
306
 
307
+ device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else 'cpu') # TODO: check cuda
308
 
309
  model = MyModel(num_classes=len(miniplaces_train.label_dict))
310
 
 
320
 
321
  if not args.test:
322
  train(model, train_loader, val_loader, optimizer, criterion,
323
+ device, num_epochs=int(args.epochs))
 
 
 
324
 
325
  else:
326
  miniplaces_test = MiniPlaces(data_root,
 
339
  if __name__ == "__main__":
340
  parser = argparse.ArgumentParser()
341
  parser.add_argument('--test', action='store_true')
342
+ parser.add_argument('--checkpoint')
343
+ parser.add_argument('--gpu', default=0)
344
+ parser.add_argument('--epochs', default=10)
345
+ parser.add_argument('--batch_size', default=32)
346
  args = parser.parse_args()
347
  main(args)
train_dist.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import csv
4
+ import json
5
+ from tqdm import tqdm
6
+ import torch
7
+ import torch.distributed as dist
8
+ import torch.multiprocessing as mp
9
+ from torch.nn.parallel import DistributedDataParallel as DDP
10
+ from torch.utils.data.distributed import DistributedSampler
11
+ import argparse
12
+ from PIL import Image
13
+ from torchvision import transforms
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from model import MyModel
16
+
17
+
18
+ def setup(rank, world_size, port):
19
+ """
20
+ Initialize the distributed training environment.
21
+
22
+ Args:
23
+ rank (int): The rank of the current process.
24
+ world_size (int): The total number of processes (GPUs).
25
+ port (int): The port number for communication.
26
+ """
27
+ os.environ['MASTER_ADDR'] = 'localhost'
28
+ os.environ['MASTER_PORT'] = str(port)
29
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
30
+
31
+
32
+ def cleanup():
33
+ """
34
+ Clean up the distributed training environment by destroying the process group.
35
+ """
36
+ dist.destroy_process_group()
37
+
38
+
39
+ class MiniPlaces(Dataset):
40
+ # Your existing MiniPlaces class implementation remains the same
41
+ def __init__(self, root_dir, split, transform=None, label_dict=None):
42
+ """
43
+ Initialize the MiniPlaces dataset with the root directory for the images,
44
+ the split (train/val/test), an optional data transformation,
45
+ and an optional label dictionary.
46
+
47
+ Args:
48
+ root_dir (str): Root directory for the MiniPlaces images.
49
+ split (str): Split to use ('train', 'val', or 'test').
50
+ transform (callable, optional): Optional data transformation to apply to the images.
51
+ label_dict (dict, optional): Optional dictionary mapping integer labels to class names.
52
+ """
53
+ assert split in ['train', 'val', 'test']
54
+ self.root_dir = root_dir
55
+ self.split = split
56
+ self.transform = transform
57
+ self.filenames = []
58
+ self.labels = []
59
+ self.label_dict = label_dict if label_dict is not None else {}
60
+
61
+ with open(os.path.join(self.root_dir, self.split + '.txt')) as r:
62
+ lines = r.readlines()
63
+ for line in lines:
64
+ line = line.split()
65
+ self.filenames.append(line[0])
66
+ if split == 'test':
67
+ label = line[0]
68
+ else:
69
+ label = int(line[1])
70
+ self.labels.append(label)
71
+ if split == 'train':
72
+ text_label = line[0].split('/')[2]
73
+ self.label_dict[label] = text_label
74
+
75
+ def __len__(self):
76
+ """
77
+ Return the number of images in the dataset.
78
+
79
+ Returns:
80
+ int: Number of images in the dataset.
81
+ """
82
+ return len(self.labels)
83
+
84
+ def __getitem__(self, idx):
85
+ """
86
+ Return a single image and its corresponding label when given an index.
87
+
88
+ Args:
89
+ idx (int): Index of the image to retrieve.
90
+
91
+ Returns:
92
+ tuple: Tuple containing the image and its label.
93
+ """
94
+ if self.transform is not None:
95
+ image = self.transform(
96
+ Image.open(os.path.join(self.root_dir, "images", self.filenames[idx])))
97
+ else:
98
+ image = Image.open(os.path.join(self.root_dir, "images", self.filenames[idx]))
99
+ label = self.labels[idx]
100
+ return image, label
101
+
102
+
103
+ def evaluate(model, test_loader, criterion, device):
104
+ """
105
+ Evaluate the CNN classifier on the validation set.
106
+
107
+ Args:
108
+ model (CNN): CNN classifier to evaluate.
109
+ test_loader (torch.utils.data.DataLoader): Data loader for the test set.
110
+ criterion (callable): Loss function to use for evaluation.
111
+ device (torch.device): Device to use for evaluation.
112
+
113
+ Returns:
114
+ float: Average loss on the test set.
115
+ float: Accuracy on the test set.
116
+ """
117
+ model.eval()
118
+
119
+ with torch.no_grad():
120
+ total_loss = 0.0
121
+ num_correct = 0
122
+ num_samples = 0
123
+
124
+ for inputs, labels in test_loader:
125
+ inputs = inputs.to(device)
126
+ labels = labels.to(device)
127
+
128
+ logits = model(inputs)
129
+ loss = criterion(logits, labels)
130
+ total_loss += loss.item()
131
+
132
+ _, predictions = torch.max(logits, dim=1)
133
+ num_correct += (predictions == labels).sum().item()
134
+ num_samples += len(inputs)
135
+
136
+ # Gather metrics from all processes
137
+ world_size = dist.get_world_size()
138
+ total_loss = torch.tensor(total_loss).to(device)
139
+ num_correct = torch.tensor(num_correct).to(device)
140
+ num_samples = torch.tensor(num_samples).to(device)
141
+
142
+ dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
143
+ dist.all_reduce(num_correct, op=dist.ReduceOp.SUM)
144
+ dist.all_reduce(num_samples, op=dist.ReduceOp.SUM)
145
+
146
+ avg_loss = (total_loss / world_size).item() / len(test_loader)
147
+ accuracy = (num_correct / num_samples).item()
148
+
149
+ return avg_loss, accuracy
150
+
151
+
152
+ def train_worker(rank, world_size, args):
153
+ """
154
+ Train the model in a distributed setup.
155
+
156
+ Args:
157
+ rank (int): The rank of the current process.
158
+ world_size (int): The total number of processes (GPUs).
159
+ args (argparse.Namespace): Command-line arguments.
160
+ """
161
+ setup(rank, world_size, args.port)
162
+ device = torch.device(f'cuda:{rank}')
163
+
164
+ # Define early stopping parameters
165
+ patience = 3 # Number of epochs to wait for improvement
166
+ best_val_accuracy = 0.0 # Best validation accuracy so far
167
+ epochs_without_improvement = 0 # Counter for epochs without improvement
168
+ best_model_state = None # To store the state of the best model
169
+
170
+ # Data loading and preprocessing
171
+ image_net_mean = torch.Tensor([0.485, 0.456, 0.406])
172
+ image_net_std = torch.Tensor([0.229, 0.224, 0.225])
173
+ data_transform = transforms.Compose([
174
+ transforms.ToTensor(),
175
+ transforms.Resize((128, 128)),
176
+ transforms.Normalize(image_net_mean, image_net_std),
177
+ ])
178
+
179
+ # Create datasets
180
+ data_root = 'data'
181
+ miniplaces_train = MiniPlaces(data_root, split='train', transform=data_transform)
182
+ miniplaces_val = MiniPlaces(data_root, split='val', transform=data_transform,
183
+ label_dict=miniplaces_train.label_dict)
184
+
185
+ # Create distributed samplers
186
+ train_sampler = DistributedSampler(miniplaces_train, num_replicas=world_size, rank=rank)
187
+ val_sampler = DistributedSampler(miniplaces_val, num_replicas=world_size, rank=rank)
188
+
189
+ # Create dataloaders
190
+ train_loader = DataLoader(miniplaces_train, batch_size=args.batch_size,
191
+ num_workers=2, sampler=train_sampler,
192
+ pin_memory=True)
193
+ val_loader = DataLoader(miniplaces_val, batch_size=args.batch_size,
194
+ num_workers=2, sampler=val_sampler,
195
+ pin_memory=True)
196
+
197
+ # Create model and move to GPU
198
+ model = MyModel(num_classes=len(miniplaces_train.label_dict))
199
+ model = model.to(device)
200
+ model = DDP(model, device_ids=[rank])
201
+
202
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9,
203
+ dampening=0, weight_decay=1e-4, nesterov=True)
204
+ criterion = torch.nn.CrossEntropyLoss(reduction='mean', label_smoothing=0.1)
205
+
206
+ if args.checkpoint:
207
+ map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
208
+ checkpoint = torch.load(args.checkpoint, map_location=map_location)
209
+ model.module.load_state_dict(checkpoint['model_state_dict'])
210
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
211
+
212
+ if not args.test:
213
+ # Training loop
214
+ performance = []
215
+ for epoch in range(args.epochs):
216
+ model.train()
217
+ train_sampler.set_epoch(epoch) # Important for proper shuffling
218
+
219
+ running_loss = 0.0
220
+ correct_predictions = 0
221
+ total_samples = 0
222
+
223
+ if rank == 0: # Only show progress bar on rank 0
224
+ pbar = tqdm(total=len(train_loader),
225
+ desc=f'Epoch {epoch + 1}/{args.epochs}',
226
+ position=0, leave=True)
227
+
228
+ for inputs, labels in train_loader:
229
+ inputs = inputs.to(device)
230
+ labels = labels.to(device)
231
+
232
+ optimizer.zero_grad()
233
+ logits = model(inputs)
234
+ loss = criterion(logits, labels)
235
+ loss.backward()
236
+ optimizer.step()
237
+
238
+ running_loss += loss.item()
239
+ _, predicted = logits.max(1)
240
+ correct_predictions += (predicted == labels).sum().item()
241
+ total_samples += labels.size(0)
242
+
243
+ if rank == 0:
244
+ pbar.update(1)
245
+ pbar.set_postfix(loss=loss.item())
246
+
247
+ if rank == 0:
248
+ pbar.close()
249
+
250
+ # Evaluate and log metrics
251
+ avg_train_loss = running_loss / len(train_loader)
252
+ train_accuracy = correct_predictions / total_samples
253
+ avg_val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)
254
+
255
+ if rank == 0: # Only save metrics on rank 0
256
+ performance.append({
257
+ "avg_train_loss": avg_train_loss,
258
+ "train_accuracy": train_accuracy,
259
+ "avg_val_loss": avg_val_loss,
260
+ "val_accuracy": val_accuracy
261
+ })
262
+ print(
263
+ f"Train Loss: {avg_train_loss:.4f}, Accuracy: {train_accuracy:.4f} "
264
+ f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}"
265
+ )
266
+
267
+ # Check for early stopping
268
+ if val_accuracy > best_val_accuracy:
269
+ best_val_accuracy = val_accuracy
270
+ epochs_without_improvement = 0 # Reset counter if there's an improvement
271
+
272
+ # Save the model checkpoint for the best model
273
+ best_model_state = {
274
+ 'model_state_dict': model.module.state_dict(),
275
+ 'optimizer_state_dict': optimizer.state_dict(),
276
+ 'epoch': epoch,
277
+ }
278
+ else:
279
+ epochs_without_improvement += 1
280
+
281
+ # Early stopping condition
282
+ if epochs_without_improvement >= patience:
283
+ print(f"Early stopping at epoch {epoch + 1}.")
284
+ break # Stop training if no improvement for 'patience' epochs
285
+
286
+ if rank == 0: # Save performance and the best model checkpoint only on rank 0
287
+ with open("performance.json", "w") as f:
288
+ json.dump(performance, f, indent=4)
289
+ torch.save(best_model_state, 'model.ckpt')
290
+
291
+ else: # Testing mode
292
+ miniplaces_test = MiniPlaces(data_root, split='test', transform=data_transform)
293
+ test_loader = DataLoader(miniplaces_test, batch_size=args.batch_size, num_workers=2, shuffle=False)
294
+ checkpoint = torch.load(args.checkpoint, map_location=device)
295
+ model.module.load_state_dict(checkpoint['model_state_dict'])
296
+ preds = test(model, test_loader, device)
297
+ if rank == 0: # Only write predictions on rank 0
298
+ write_predictions(preds, 'predictions.csv')
299
+
300
+ cleanup()
301
+
302
+
303
+ def test(model, test_loader, device):
304
+ """
305
+ Test the model on a dataset and return predictions.
306
+
307
+ Args:
308
+ model (torch.nn.Module): The model to test.
309
+ test_loader (DataLoader): The DataLoader for the test dataset.
310
+ device (torch.device): The device to run the test on.
311
+
312
+ Returns:
313
+ list: A list of (label, prediction) tuples for each image.
314
+ """
315
+ model.eval()
316
+ with torch.no_grad():
317
+ all_preds = []
318
+ for inputs, labels in test_loader:
319
+ inputs = inputs.to(device)
320
+ logits = model(inputs)
321
+ _, predictions = torch.max(logits, dim=1)
322
+ preds = list(zip(labels, predictions.tolist()))
323
+ all_preds.extend(preds)
324
+ return all_preds
325
+
326
+
327
+ def write_predictions(preds, filename):
328
+ """
329
+ Write model predictions to a CSV file.
330
+
331
+ Args:
332
+ preds (list): A list of (label, prediction) tuples.
333
+ filename (str): The name of the CSV file to save predictions to.
334
+ """
335
+ with open(filename, 'w') as f:
336
+ writer = csv.writer(f, delimiter=',')
337
+ for im, pred in preds:
338
+ writer.writerow((im, pred))
339
+
340
+
341
+ def main(args):
342
+ """
343
+ Main function to start the training process using multiple GPUs.
344
+
345
+ Args:
346
+ args (argparse.Namespace): Command-line arguments.
347
+ """
348
+ # Get number of available GPUs
349
+ world_size = torch.cuda.device_count()
350
+ mp.spawn(train_worker,
351
+ args=(world_size, args),
352
+ nprocs=world_size,
353
+ join=True)
354
+
355
+
356
+ if __name__ == "__main__":
357
+ parser = argparse.ArgumentParser()
358
+ parser.add_argument('--test', action='store_true')
359
+ parser.add_argument('--checkpoint')
360
+ parser.add_argument('--epochs', type=int, default=10)
361
+ parser.add_argument('--batch_size', type=int, default=64)
362
+ parser.add_argument('--port', type=int, default=4224)
363
+ args = parser.parse_args()
364
+ main(args)