IvanBanny commited on
Commit
fa13b6c
·
1 Parent(s): 972a6a4

feat(model, train): improved architecture, overfit prevention, re-trained the model

Browse files
Files changed (8) hide show
  1. README.md +36 -0
  2. model.py +23 -18
  3. performance.json +649 -115
  4. performance_plot.png +0 -0
  5. plots.py +3 -3
  6. predictions.csv +0 -0
  7. train.py +0 -394
  8. train_dist.py +40 -7
README.md CHANGED
@@ -1,2 +1,38 @@
1
  # Places-ResNet
 
2
  My experiment training a ResNet-inspired model for image classification using PyTorch.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Places-ResNet
2
+
3
  My experiment training a ResNet-inspired model for image classification using PyTorch.
4
+
5
+ **Key terms: distributed training, residual layers, convolutional layers, batch normalization, dropout, pooling, SGD, label smoothing, learning rate scheduling, early stopping, data augmentation.**
6
+
7
+ Training time was approximately 10 hours (108 epochs) using **distributed training** across university server GPUs.
8
+
9
+ ## Dataset:
10
+
11
+ **MIT MiniPlaces Dataset:** Contains 100,000 training images, 10,000 validation images, and 10,000 testing images across 100 scene categories. Each image is 128x128 pixels.
12
+
13
+ ## Model:
14
+
15
+ I implemented a 13-layer ResNet-inspired model for image classification. The architecture consists of:
16
+ - Initial **convolutional layer** with 64 filters, followed by **batch normalization, max pooling, and dropout**
17
+ - 3 stages of **residual blocks**, each with 4 convolutional layers
18
+ - Each **residual block** has two 3x3 **convolutional layers** with **batch normalization and dropout**
19
+ - The number of filters increases from 64 in the first stage, to 128, 256, and 512 in the later stages
20
+ - **Global average pooling and dropout** before a final **fully connected layer**
21
+
22
+ The total number of trainable model parameters is 29,678,180.
23
+
24
+ ## Training:
25
+
26
+ The training setup used a **distributed training** approach, with **early stopping** to prevent overfitting. **Data augmentation** techniques were applied to the training and validation sets. An **SGD optimizer** with **label smoothing** was used, along with a **ReduceLROnPlateau learning rate scheduler**.
27
+
28
+ ## Performance:
29
+
30
+ Best model checkpoint results (epoch 108):
31
+
32
+ - Training Loss: 2.3231, Training Accuracy: 53.02%
33
+ - Validation Loss: 2.3426, Validation Accuracy: 54.09%
34
+ - Top-5 Validation Accuracy: 81.34%
35
+
36
+ Achieving a **Top-1 accuracy of 54.09% and Top-5 accuracy of 81.34%**
37
+
38
+ ![Training Performance Plot](performance_plot.png "Training Performance Plot")
model.py CHANGED
@@ -3,14 +3,15 @@ import torch.nn as nn
3
 
4
 
5
  class ResidualBlock(nn.Module):
6
- def __init__(self, in_channels, out_channels):
7
  super(ResidualBlock, self).__init__()
8
  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
9
  self.bn1 = nn.BatchNorm2d(out_channels)
 
10
  self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
11
  self.bn2 = nn.BatchNorm2d(out_channels)
 
12
 
13
- # Skip connection (identity mapping)
14
  self.skip_connection = nn.Sequential()
15
  if in_channels != out_channels:
16
  self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
@@ -18,53 +19,57 @@ class ResidualBlock(nn.Module):
18
  def forward(self, x):
19
  residual = self.skip_connection(x)
20
  out = nn.functional.relu(self.bn1(self.conv1(x)))
 
21
  out = self.bn2(self.conv2(out))
22
- out += residual # Adding the skip connection
 
23
  out = nn.functional.relu(out)
24
  return out
25
 
26
 
27
  class MyModel(nn.Module):
28
- def __init__(self, num_classes=100):
29
  super(MyModel, self).__init__()
 
30
 
31
- # Initial convolutional layer
32
  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
33
  self.bn1 = nn.BatchNorm2d(64)
34
  self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
 
35
 
36
- # Residual blocks
37
- self.block1 = self._resnet_layers(64, 128, num_blocks=3) # 3 residual blocks
38
- self.block2 = self._resnet_layers(128, 256, num_blocks=3) # 3 residual blocks
39
- self.block3 = self._resnet_layers(256, 512, num_blocks=3) # 3 residual blocks
40
 
41
- # Global average pooling
42
  self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
 
 
 
 
43
 
44
- # Combine features
45
  self.features = nn.Sequential(
46
  self.conv1,
47
  self.bn1,
48
  nn.ReLU(),
49
  self.pool1,
 
50
  self.block1,
51
  self.block2,
52
  self.block3,
53
- self.global_avg_pool
 
54
  )
55
 
56
- # Fully connected layer
57
- self.fc = nn.Linear(512, num_classes)
58
-
59
  @staticmethod
60
  def _resnet_layers(in_channels, out_channels, num_blocks):
61
  return nn.Sequential(
62
- ResidualBlock(in_channels, out_channels),
63
- *[ResidualBlock(out_channels, out_channels) for _ in range(num_blocks)]
64
  )
65
 
66
  def forward(self, x):
67
  x = self.features(x)
68
- x = torch.flatten(x, 1) # Flatten the output for the fully connected layer
69
  x = self.fc(x)
70
  return x
 
3
 
4
 
5
  class ResidualBlock(nn.Module):
6
+ def __init__(self, in_channels, out_channels, dropout_rate=0.2):
7
  super(ResidualBlock, self).__init__()
8
  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
9
  self.bn1 = nn.BatchNorm2d(out_channels)
10
+ self.dropout1 = nn.Dropout2d(p=dropout_rate)
11
  self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
12
  self.bn2 = nn.BatchNorm2d(out_channels)
13
+ self.dropout2 = nn.Dropout2d(p=dropout_rate)
14
 
 
15
  self.skip_connection = nn.Sequential()
16
  if in_channels != out_channels:
17
  self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
 
19
  def forward(self, x):
20
  residual = self.skip_connection(x)
21
  out = nn.functional.relu(self.bn1(self.conv1(x)))
22
+ out = self.dropout1(out)
23
  out = self.bn2(self.conv2(out))
24
+ out = self.dropout2(out)
25
+ out += residual
26
  out = nn.functional.relu(out)
27
  return out
28
 
29
 
30
  class MyModel(nn.Module):
31
+ def __init__(self, num_classes=100, dropout_rate=0.2):
32
  super(MyModel, self).__init__()
33
+ self.dropout_rate = dropout_rate
34
 
 
35
  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
36
  self.bn1 = nn.BatchNorm2d(64)
37
  self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
38
+ self.dropout1 = nn.Dropout2d(p=self.dropout_rate)
39
 
40
+ # Increase the number of residual blocks
41
+ self.block1 = self._resnet_layers(64, 128, num_blocks=4)
42
+ self.block2 = self._resnet_layers(128, 256, num_blocks=4)
43
+ self.block3 = self._resnet_layers(256, 512, num_blocks=4)
44
 
 
45
  self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
46
+ self.dropout2 = nn.Dropout(p=self.dropout_rate)
47
+
48
+ # Reduce the size of the fully connected layer
49
+ self.fc = nn.Linear(512, num_classes)
50
 
 
51
  self.features = nn.Sequential(
52
  self.conv1,
53
  self.bn1,
54
  nn.ReLU(),
55
  self.pool1,
56
+ self.dropout1,
57
  self.block1,
58
  self.block2,
59
  self.block3,
60
+ self.global_avg_pool,
61
+ self.dropout2
62
  )
63
 
 
 
 
64
  @staticmethod
65
  def _resnet_layers(in_channels, out_channels, num_blocks):
66
  return nn.Sequential(
67
+ ResidualBlock(in_channels, out_channels, dropout_rate=0.2),
68
+ *[ResidualBlock(out_channels, out_channels, dropout_rate=0.2) for _ in range(num_blocks)]
69
  )
70
 
71
  def forward(self, x):
72
  x = self.features(x)
73
+ x = torch.flatten(x, 1)
74
  x = self.fc(x)
75
  return x
performance.json CHANGED
@@ -1,176 +1,710 @@
1
  [
2
  {
3
- "avg_train_loss": 4.105753168425597,
4
- "train_accuracy": 0.08726,
5
- "avg_val_loss": 3.8632843403876582,
6
- "val_accuracy": 0.1306000053882599
7
  },
8
  {
9
- "avg_train_loss": 3.7184383619167005,
10
- "train_accuracy": 0.1609,
11
- "avg_val_loss": 3.5157296868819223,
12
- "val_accuracy": 0.20409999787807465
13
  },
14
  {
15
- "avg_train_loss": 3.5134875548770057,
16
- "train_accuracy": 0.20752,
17
- "avg_val_loss": 3.3557024605666537,
18
- "val_accuracy": 0.24459999799728394
19
  },
20
  {
21
- "avg_train_loss": 3.3635203539562957,
22
- "train_accuracy": 0.244,
23
- "avg_val_loss": 3.320155662826345,
24
- "val_accuracy": 0.25769999623298645
25
  },
26
  {
27
- "avg_train_loss": 3.2561175189054836,
28
- "train_accuracy": 0.2721,
29
- "avg_val_loss": 3.2409366655953322,
30
- "val_accuracy": 0.2786000072956085
31
  },
32
  {
33
- "avg_train_loss": 3.165564750466505,
34
- "train_accuracy": 0.2952,
35
- "avg_val_loss": 3.3207412912875793,
36
- "val_accuracy": 0.28110000491142273
37
  },
38
  {
39
- "avg_train_loss": 3.089012709724934,
40
- "train_accuracy": 0.31326,
41
- "avg_val_loss": 3.1548544968230816,
42
- "val_accuracy": 0.3131999969482422
43
  },
44
  {
45
- "avg_train_loss": 3.0239714097488872,
46
- "train_accuracy": 0.33192,
47
- "avg_val_loss": 3.0669574978985366,
48
- "val_accuracy": 0.3246999979019165
49
  },
50
  {
51
- "avg_train_loss": 2.9728026246780628,
52
- "train_accuracy": 0.34622,
53
- "avg_val_loss": 3.1410958978194223,
54
- "val_accuracy": 0.3125999867916107
55
  },
56
  {
57
- "avg_train_loss": 2.926501644236962,
58
- "train_accuracy": 0.35726,
59
- "avg_val_loss": 3.0194991872280457,
60
- "val_accuracy": 0.34369999170303345
61
  },
62
  {
63
- "avg_train_loss": 2.881767414719857,
64
- "train_accuracy": 0.37002,
65
- "avg_val_loss": 3.1654707510260085,
66
- "val_accuracy": 0.3264000117778778
67
  },
68
  {
69
- "avg_train_loss": 2.8386977173178396,
70
- "train_accuracy": 0.37992,
71
- "avg_val_loss": 2.908680589893196,
72
- "val_accuracy": 0.3734000027179718
73
  },
74
  {
75
- "avg_train_loss": 2.7958365852570597,
76
- "train_accuracy": 0.39322,
77
- "avg_val_loss": 2.818336969689478,
78
- "val_accuracy": 0.38659998774528503
79
  },
80
  {
81
- "avg_train_loss": 2.7660993075431763,
82
- "train_accuracy": 0.40192,
83
- "avg_val_loss": 2.941794866248022,
84
- "val_accuracy": 0.3686000108718872
85
  },
86
  {
87
- "avg_train_loss": 2.7263018761754343,
88
- "train_accuracy": 0.41194,
89
- "avg_val_loss": 2.8387841333316852,
90
- "val_accuracy": 0.39089998602867126
91
  },
92
  {
93
- "avg_train_loss": 2.6966828406619294,
94
- "train_accuracy": 0.42182,
95
- "avg_val_loss": 2.975855187524723,
96
- "val_accuracy": 0.36570000648498535
97
  },
98
  {
99
- "avg_train_loss": 2.6595005979928215,
100
- "train_accuracy": 0.43252,
101
- "avg_val_loss": 2.866811245302611,
102
- "val_accuracy": 0.3898000121116638
103
  },
104
  {
105
- "avg_train_loss": 2.626687426396343,
106
- "train_accuracy": 0.4389,
107
- "avg_val_loss": 2.797353237490111,
108
- "val_accuracy": 0.4056999981403351
109
  },
110
  {
111
- "avg_train_loss": 2.607301542521133,
112
- "train_accuracy": 0.44456,
113
- "avg_val_loss": 2.8504348948032043,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  "val_accuracy": 0.4027000069618225
115
  },
116
  {
117
- "avg_train_loss": 2.571250948454718,
118
- "train_accuracy": 0.45616,
119
- "avg_val_loss": 2.87233859677858,
120
- "val_accuracy": 0.39640000462532043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  },
122
  {
123
- "avg_train_loss": 2.5492486542143173,
124
- "train_accuracy": 0.46084,
125
- "avg_val_loss": 2.743163096753857,
126
- "val_accuracy": 0.42489999532699585
127
  },
128
  {
129
- "avg_train_loss": 2.525591837780555,
130
- "train_accuracy": 0.46876,
131
- "avg_val_loss": 2.9510807085640822,
132
- "val_accuracy": 0.3882000148296356
133
  },
134
  {
135
- "avg_train_loss": 2.5095781770813494,
136
- "train_accuracy": 0.47198,
137
- "avg_val_loss": 2.7676040069966374,
138
- "val_accuracy": 0.4284000098705292
139
  },
140
  {
141
- "avg_train_loss": 2.4809405361599937,
142
- "train_accuracy": 0.48236,
143
- "avg_val_loss": 2.7205014772053007,
144
- "val_accuracy": 0.4325999915599823
145
  },
146
  {
147
- "avg_train_loss": 2.4620410210031376,
148
- "train_accuracy": 0.4867,
149
- "avg_val_loss": 2.674741914000692,
150
- "val_accuracy": 0.439300000667572
151
  },
152
  {
153
- "avg_train_loss": 2.431113924059417,
154
- "train_accuracy": 0.494,
155
- "avg_val_loss": 2.6500483645668513,
156
- "val_accuracy": 0.4472000002861023
157
  },
158
  {
159
- "avg_train_loss": 2.4075386673593155,
160
- "train_accuracy": 0.50418,
161
- "avg_val_loss": 2.7652997366989713,
162
- "val_accuracy": 0.4194999933242798
163
  },
164
  {
165
- "avg_train_loss": 2.390994114796524,
166
- "train_accuracy": 0.50546,
167
- "avg_val_loss": 2.750720060324367,
168
- "val_accuracy": 0.42570000886917114
169
  },
170
  {
171
- "avg_train_loss": 2.3609870321610393,
172
- "train_accuracy": 0.5147,
173
- "avg_val_loss": 2.6785139252867878,
174
- "val_accuracy": 0.4438999891281128
175
  }
176
  ]
 
1
  [
2
  {
3
+ "avg_train_loss": 4.452685312236971,
4
+ "train_accuracy": 0.03584,
5
+ "avg_val_loss": 4.1456576963014244,
6
+ "val_accuracy": 0.07199999690055847
7
  },
8
  {
9
+ "avg_train_loss": 4.22897495028308,
10
+ "train_accuracy": 0.06454,
11
+ "avg_val_loss": 3.984914272646361,
12
+ "val_accuracy": 0.10199999809265137
13
  },
14
  {
15
+ "avg_train_loss": 4.097483895318892,
16
+ "train_accuracy": 0.08686,
17
+ "avg_val_loss": 3.8411196938043908,
18
+ "val_accuracy": 0.13650000095367432
19
  },
20
  {
21
+ "avg_train_loss": 3.993970947192453,
22
+ "train_accuracy": 0.1051,
23
+ "avg_val_loss": 3.7907390353045884,
24
+ "val_accuracy": 0.14059999585151672
25
  },
26
  {
27
+ "avg_train_loss": 3.9059129904603105,
28
+ "train_accuracy": 0.1232,
29
+ "avg_val_loss": 3.656763390649723,
30
+ "val_accuracy": 0.164900004863739
31
  },
32
  {
33
+ "avg_train_loss": 3.8222918010428737,
34
+ "train_accuracy": 0.13996,
35
+ "avg_val_loss": 3.5380990716475473,
36
+ "val_accuracy": 0.19670000672340393
37
  },
38
  {
39
+ "avg_train_loss": 3.756975746520645,
40
+ "train_accuracy": 0.1505,
41
+ "avg_val_loss": 3.4915640142899527,
42
+ "val_accuracy": 0.20239999890327454
43
  },
44
  {
45
+ "avg_train_loss": 3.6912049436203356,
46
+ "train_accuracy": 0.16672,
47
+ "avg_val_loss": 3.4289465795589398,
48
+ "val_accuracy": 0.2176000028848648
49
  },
50
  {
51
+ "avg_train_loss": 3.6327530968829493,
52
+ "train_accuracy": 0.18014,
53
+ "avg_val_loss": 3.3895603614517404,
54
+ "val_accuracy": 0.2402999997138977
55
  },
56
  {
57
+ "avg_train_loss": 3.578799500489784,
58
+ "train_accuracy": 0.19102,
59
+ "avg_val_loss": 3.3035911849782438,
60
+ "val_accuracy": 0.2533000111579895
61
  },
62
  {
63
+ "avg_train_loss": 3.5295029982276587,
64
+ "train_accuracy": 0.20518,
65
+ "avg_val_loss": 3.2572307345233384,
66
+ "val_accuracy": 0.2711000144481659
67
  },
68
  {
69
+ "avg_train_loss": 3.49012098257499,
70
+ "train_accuracy": 0.21266,
71
+ "avg_val_loss": 3.194729382478738,
72
+ "val_accuracy": 0.29030001163482666
73
  },
74
  {
75
+ "avg_train_loss": 3.444543719901453,
76
+ "train_accuracy": 0.22732,
77
+ "avg_val_loss": 3.1562333891663372,
78
+ "val_accuracy": 0.2962000072002411
79
  },
80
  {
81
+ "avg_train_loss": 3.409660898206179,
82
+ "train_accuracy": 0.23328,
83
+ "avg_val_loss": 3.1249258306962027,
84
+ "val_accuracy": 0.303600013256073
85
  },
86
  {
87
+ "avg_train_loss": 3.365924084278019,
88
+ "train_accuracy": 0.2484,
89
+ "avg_val_loss": 3.1003157217291335,
90
+ "val_accuracy": 0.3172000050544739
91
  },
92
  {
93
+ "avg_train_loss": 3.3438522516918914,
94
+ "train_accuracy": 0.25078,
95
+ "avg_val_loss": 3.0775443934187106,
96
+ "val_accuracy": 0.3222000002861023
97
  },
98
  {
99
+ "avg_train_loss": 3.310065208188713,
100
+ "train_accuracy": 0.26132,
101
+ "avg_val_loss": 3.0591898085195806,
102
+ "val_accuracy": 0.3260999917984009
103
  },
104
  {
105
+ "avg_train_loss": 3.2758816282462586,
106
+ "train_accuracy": 0.2669,
107
+ "avg_val_loss": 3.0423757818680777,
108
+ "val_accuracy": 0.3319999873638153
109
  },
110
  {
111
+ "avg_train_loss": 3.2471999869017343,
112
+ "train_accuracy": 0.276,
113
+ "avg_val_loss": 3.0076527655879155,
114
+ "val_accuracy": 0.3409999907016754
115
+ },
116
+ {
117
+ "avg_train_loss": 3.2196996800429987,
118
+ "train_accuracy": 0.28116,
119
+ "avg_val_loss": 2.9659501087816458,
120
+ "val_accuracy": 0.3515999913215637
121
+ },
122
+ {
123
+ "avg_train_loss": 3.195254132875701,
124
+ "train_accuracy": 0.28704,
125
+ "avg_val_loss": 2.9751936514166335,
126
+ "val_accuracy": 0.35580000281333923
127
+ },
128
+ {
129
+ "avg_train_loss": 3.171722950532918,
130
+ "train_accuracy": 0.29404,
131
+ "avg_val_loss": 2.9623639070535006,
132
+ "val_accuracy": 0.3531999886035919
133
+ },
134
+ {
135
+ "avg_train_loss": 3.1580248549771124,
136
+ "train_accuracy": 0.2984,
137
+ "avg_val_loss": 2.909536240976068,
138
+ "val_accuracy": 0.3702000081539154
139
+ },
140
+ {
141
+ "avg_train_loss": 3.125413342814921,
142
+ "train_accuracy": 0.30764,
143
+ "avg_val_loss": 2.8885996371884888,
144
+ "val_accuracy": 0.3765999972820282
145
+ },
146
+ {
147
+ "avg_train_loss": 3.11991933300672,
148
+ "train_accuracy": 0.30932,
149
+ "avg_val_loss": 2.8885996371884888,
150
+ "val_accuracy": 0.3788999915122986
151
+ },
152
+ {
153
+ "avg_train_loss": 3.0930506728799143,
154
+ "train_accuracy": 0.3174,
155
+ "avg_val_loss": 2.852832456178303,
156
+ "val_accuracy": 0.3903000056743622
157
+ },
158
+ {
159
+ "avg_train_loss": 3.0655275760099405,
160
+ "train_accuracy": 0.32236,
161
+ "avg_val_loss": 2.868115388894383,
162
+ "val_accuracy": 0.38449999690055847
163
+ },
164
+ {
165
+ "avg_train_loss": 3.0548000393621146,
166
+ "train_accuracy": 0.32792,
167
+ "avg_val_loss": 2.8296788855444026,
168
+ "val_accuracy": 0.3880999982357025
169
+ },
170
+ {
171
+ "avg_train_loss": 3.032537115809253,
172
+ "train_accuracy": 0.3308,
173
+ "avg_val_loss": 2.7847178012509888,
174
+ "val_accuracy": 0.4097999930381775
175
+ },
176
+ {
177
+ "avg_train_loss": 3.009990167434868,
178
+ "train_accuracy": 0.3352,
179
+ "avg_val_loss": 2.7986656864987145,
180
+ "val_accuracy": 0.4027999937534332
181
+ },
182
+ {
183
+ "avg_train_loss": 2.9918381585489455,
184
+ "train_accuracy": 0.34312,
185
+ "avg_val_loss": 2.7371235135235366,
186
+ "val_accuracy": 0.42089998722076416
187
+ },
188
+ {
189
+ "avg_train_loss": 2.9780573729054094,
190
+ "train_accuracy": 0.34844,
191
+ "avg_val_loss": 2.7818868130068237,
192
+ "val_accuracy": 0.4077000021934509
193
+ },
194
+ {
195
+ "avg_train_loss": 2.965448642020945,
196
+ "train_accuracy": 0.34998,
197
+ "avg_val_loss": 2.7794544847705698,
198
  "val_accuracy": 0.4027000069618225
199
  },
200
  {
201
+ "avg_train_loss": 2.949932183451055,
202
+ "train_accuracy": 0.354,
203
+ "avg_val_loss": 2.7443399550039556,
204
+ "val_accuracy": 0.4169999957084656
205
+ },
206
+ {
207
+ "avg_train_loss": 2.934478478968296,
208
+ "train_accuracy": 0.35772,
209
+ "avg_val_loss": 2.7268808099287973,
210
+ "val_accuracy": 0.4259999990463257
211
+ },
212
+ {
213
+ "avg_train_loss": 2.925810183710454,
214
+ "train_accuracy": 0.36066,
215
+ "avg_val_loss": 2.7194407016416138,
216
+ "val_accuracy": 0.4253999888896942
217
+ },
218
+ {
219
+ "avg_train_loss": 2.9101742749933694,
220
+ "train_accuracy": 0.36768,
221
+ "avg_val_loss": 2.682816324354727,
222
+ "val_accuracy": 0.4357999861240387
223
+ },
224
+ {
225
+ "avg_train_loss": 2.897818704395343,
226
+ "train_accuracy": 0.36902,
227
+ "avg_val_loss": 2.7000809681566458,
228
+ "val_accuracy": 0.4271000027656555
229
+ },
230
+ {
231
+ "avg_train_loss": 2.882525757450582,
232
+ "train_accuracy": 0.3717,
233
+ "avg_val_loss": 2.6881412554390822,
234
+ "val_accuracy": 0.4334000051021576
235
+ },
236
+ {
237
+ "avg_train_loss": 2.868310006683135,
238
+ "train_accuracy": 0.37626,
239
+ "avg_val_loss": 2.696658943272844,
240
+ "val_accuracy": 0.4318999946117401
241
+ },
242
+ {
243
+ "avg_train_loss": 2.8593416683509223,
244
+ "train_accuracy": 0.37938,
245
+ "avg_val_loss": 2.6929442973076543,
246
+ "val_accuracy": 0.4361000061035156
247
+ },
248
+ {
249
+ "avg_train_loss": 2.8498354638019183,
250
+ "train_accuracy": 0.38216,
251
+ "avg_val_loss": 2.6896122799643987,
252
+ "val_accuracy": 0.43959999084472656
253
+ },
254
+ {
255
+ "avg_train_loss": 2.841897433066307,
256
+ "train_accuracy": 0.38478,
257
+ "avg_val_loss": 2.6792281911342957,
258
+ "val_accuracy": 0.4350000023841858
259
+ },
260
+ {
261
+ "avg_train_loss": 2.829337977387411,
262
+ "train_accuracy": 0.38544,
263
+ "avg_val_loss": 2.6498478756675237,
264
+ "val_accuracy": 0.44699999690055847
265
+ },
266
+ {
267
+ "avg_train_loss": 2.817666833967809,
268
+ "train_accuracy": 0.38926,
269
+ "avg_val_loss": 2.6291940423506723,
270
+ "val_accuracy": 0.454800009727478
271
+ },
272
+ {
273
+ "avg_train_loss": 2.798290427993326,
274
+ "train_accuracy": 0.39618,
275
+ "avg_val_loss": 2.6448449243473102,
276
+ "val_accuracy": 0.4528999924659729
277
+ },
278
+ {
279
+ "avg_train_loss": 2.7893726972057995,
280
+ "train_accuracy": 0.39842,
281
+ "avg_val_loss": 2.6335499437549448,
282
+ "val_accuracy": 0.447299987077713
283
+ },
284
+ {
285
+ "avg_train_loss": 2.779167408528535,
286
+ "train_accuracy": 0.401,
287
+ "avg_val_loss": 2.625039595591871,
288
+ "val_accuracy": 0.4578000009059906
289
+ },
290
+ {
291
+ "avg_train_loss": 2.769523390723616,
292
+ "train_accuracy": 0.40362,
293
+ "avg_val_loss": 2.606762849831883,
294
+ "val_accuracy": 0.4546999931335449
295
+ },
296
+ {
297
+ "avg_train_loss": 2.7618973175887866,
298
+ "train_accuracy": 0.40552,
299
+ "avg_val_loss": 2.6070912035205698,
300
+ "val_accuracy": 0.4577000141143799
301
+ },
302
+ {
303
+ "avg_train_loss": 2.744235341811119,
304
+ "train_accuracy": 0.4127,
305
+ "avg_val_loss": 2.577495912962322,
306
+ "val_accuracy": 0.4657000005245209
307
+ },
308
+ {
309
+ "avg_train_loss": 2.7439488306679687,
310
+ "train_accuracy": 0.4089,
311
+ "avg_val_loss": 2.610289561597607,
312
+ "val_accuracy": 0.45730000734329224
313
+ },
314
+ {
315
+ "avg_train_loss": 2.726577682263406,
316
+ "train_accuracy": 0.41524,
317
+ "avg_val_loss": 2.5453265226339994,
318
+ "val_accuracy": 0.4781000018119812
319
+ },
320
+ {
321
+ "avg_train_loss": 2.7213859829451423,
322
+ "train_accuracy": 0.41772,
323
+ "avg_val_loss": 2.5277760179736948,
324
+ "val_accuracy": 0.4763000011444092
325
+ },
326
+ {
327
+ "avg_train_loss": 2.712547524810752,
328
+ "train_accuracy": 0.421,
329
+ "avg_val_loss": 2.5841199657585046,
330
+ "val_accuracy": 0.4708000123500824
331
+ },
332
+ {
333
+ "avg_train_loss": 2.6989601530382394,
334
+ "train_accuracy": 0.4234,
335
+ "avg_val_loss": 2.5469041655335247,
336
+ "val_accuracy": 0.47589999437332153
337
+ },
338
+ {
339
+ "avg_train_loss": 2.6853119339174625,
340
+ "train_accuracy": 0.42606,
341
+ "avg_val_loss": 2.5726125210146362,
342
+ "val_accuracy": 0.4747999906539917
343
+ },
344
+ {
345
+ "avg_train_loss": 2.6801007284837612,
346
+ "train_accuracy": 0.42704,
347
+ "avg_val_loss": 2.5191534501087816,
348
+ "val_accuracy": 0.4855000078678131
349
+ },
350
+ {
351
+ "avg_train_loss": 2.673392378765604,
352
+ "train_accuracy": 0.429,
353
+ "avg_val_loss": 2.543144902096519,
354
+ "val_accuracy": 0.46720001101493835
355
+ },
356
+ {
357
+ "avg_train_loss": 2.6638625219959735,
358
+ "train_accuracy": 0.43398,
359
+ "avg_val_loss": 2.537560185299644,
360
+ "val_accuracy": 0.486299991607666
361
+ },
362
+ {
363
+ "avg_train_loss": 2.651780778184876,
364
+ "train_accuracy": 0.43742,
365
+ "avg_val_loss": 2.536473914037777,
366
+ "val_accuracy": 0.4837999939918518
367
+ },
368
+ {
369
+ "avg_train_loss": 2.6511071432581947,
370
+ "train_accuracy": 0.43532,
371
+ "avg_val_loss": 2.520467637460443,
372
+ "val_accuracy": 0.4832000136375427
373
+ },
374
+ {
375
+ "avg_train_loss": 2.6305635774227056,
376
+ "train_accuracy": 0.44088,
377
+ "avg_val_loss": 2.49658203125,
378
+ "val_accuracy": 0.4912000000476837
379
+ },
380
+ {
381
+ "avg_train_loss": 2.634286204262463,
382
+ "train_accuracy": 0.44186,
383
+ "avg_val_loss": 2.554328242434731,
384
+ "val_accuracy": 0.47690001130104065
385
+ },
386
+ {
387
+ "avg_train_loss": 2.624523153390421,
388
+ "train_accuracy": 0.44464,
389
+ "avg_val_loss": 2.492234821561017,
390
+ "val_accuracy": 0.4921000003814697
391
+ },
392
+ {
393
+ "avg_train_loss": 2.6103764879124243,
394
+ "train_accuracy": 0.44712,
395
+ "avg_val_loss": 2.4733666528629352,
396
+ "val_accuracy": 0.49219998717308044
397
+ },
398
+ {
399
+ "avg_train_loss": 2.597310994592164,
400
+ "train_accuracy": 0.45222,
401
+ "avg_val_loss": 2.4752954410601267,
402
+ "val_accuracy": 0.4984999895095825
403
+ },
404
+ {
405
+ "avg_train_loss": 2.6000741299460914,
406
+ "train_accuracy": 0.45098,
407
+ "avg_val_loss": 2.4956019920638846,
408
+ "val_accuracy": 0.4957999885082245
409
+ },
410
+ {
411
+ "avg_train_loss": 2.585392991295251,
412
+ "train_accuracy": 0.45412,
413
+ "avg_val_loss": 2.4977424718156644,
414
+ "val_accuracy": 0.4909999966621399
415
+ },
416
+ {
417
+ "avg_train_loss": 2.5859423464216538,
418
+ "train_accuracy": 0.45544,
419
+ "avg_val_loss": 2.4699751455572585,
420
+ "val_accuracy": 0.4993000030517578
421
+ },
422
+ {
423
+ "avg_train_loss": 2.5671919788545963,
424
+ "train_accuracy": 0.4598,
425
+ "avg_val_loss": 2.4434667659711233,
426
+ "val_accuracy": 0.5040000081062317
427
+ },
428
+ {
429
+ "avg_train_loss": 2.552455409103647,
430
+ "train_accuracy": 0.4653,
431
+ "avg_val_loss": 2.4661942494066458,
432
+ "val_accuracy": 0.4943000078201294
433
+ },
434
+ {
435
+ "avg_train_loss": 2.5519124113995098,
436
+ "train_accuracy": 0.46254,
437
+ "avg_val_loss": 2.43858723700801,
438
+ "val_accuracy": 0.5048999786376953
439
+ },
440
+ {
441
+ "avg_train_loss": 2.5457511021353096,
442
+ "train_accuracy": 0.4658,
443
+ "avg_val_loss": 2.48284912109375,
444
+ "val_accuracy": 0.5008000135421753
445
+ },
446
+ {
447
+ "avg_train_loss": 2.539912354915648,
448
+ "train_accuracy": 0.46686,
449
+ "avg_val_loss": 2.4563181430478638,
450
+ "val_accuracy": 0.5048999786376953
451
+ },
452
+ {
453
+ "avg_train_loss": 2.5310022376687327,
454
+ "train_accuracy": 0.4681,
455
+ "avg_val_loss": 2.4287455112119263,
456
+ "val_accuracy": 0.5077000260353088
457
+ },
458
+ {
459
+ "avg_train_loss": 2.5193640140011486,
460
+ "train_accuracy": 0.4733,
461
+ "avg_val_loss": 2.4722490914260287,
462
+ "val_accuracy": 0.4941999912261963
463
+ },
464
+ {
465
+ "avg_train_loss": 2.511948669353105,
466
+ "train_accuracy": 0.47668,
467
+ "avg_val_loss": 2.4270836552487145,
468
+ "val_accuracy": 0.510200023651123
469
+ },
470
+ {
471
+ "avg_train_loss": 2.5021252122986346,
472
+ "train_accuracy": 0.47688,
473
+ "avg_val_loss": 2.3986882076987737,
474
+ "val_accuracy": 0.5188000202178955
475
+ },
476
+ {
477
+ "avg_train_loss": 2.4900416806530767,
478
+ "train_accuracy": 0.48144,
479
+ "avg_val_loss": 2.4308195476290546,
480
+ "val_accuracy": 0.5117999911308289
481
+ },
482
+ {
483
+ "avg_train_loss": 2.493807424059914,
484
+ "train_accuracy": 0.48098,
485
+ "avg_val_loss": 2.4242190590387658,
486
+ "val_accuracy": 0.5123999714851379
487
+ },
488
+ {
489
+ "avg_train_loss": 2.4797395354951433,
490
+ "train_accuracy": 0.4871,
491
+ "avg_val_loss": 2.4191836586481408,
492
+ "val_accuracy": 0.5139999985694885
493
+ },
494
+ {
495
+ "avg_train_loss": 2.474879029461795,
496
+ "train_accuracy": 0.48526,
497
+ "avg_val_loss": 2.4089353537257714,
498
+ "val_accuracy": 0.5188000202178955
499
+ },
500
+ {
501
+ "avg_train_loss": 2.4681123287781426,
502
+ "train_accuracy": 0.48606,
503
+ "avg_val_loss": 2.3880093731457674,
504
+ "val_accuracy": 0.5206000208854675
505
+ },
506
+ {
507
+ "avg_train_loss": 2.464694506219586,
508
+ "train_accuracy": 0.48762,
509
+ "avg_val_loss": 2.4599369869956487,
510
+ "val_accuracy": 0.5094000101089478
511
+ },
512
+ {
513
+ "avg_train_loss": 2.4511944366538008,
514
+ "train_accuracy": 0.4928,
515
+ "avg_val_loss": 2.3902290440812894,
516
+ "val_accuracy": 0.5202000141143799
517
+ },
518
+ {
519
+ "avg_train_loss": 2.4458736345896024,
520
+ "train_accuracy": 0.4954,
521
+ "avg_val_loss": 2.3775387534612342,
522
+ "val_accuracy": 0.5228999853134155
523
+ },
524
+ {
525
+ "avg_train_loss": 2.4391189106285114,
526
+ "train_accuracy": 0.4964,
527
+ "avg_val_loss": 2.3863923278035997,
528
+ "val_accuracy": 0.5250999927520752
529
+ },
530
+ {
531
+ "avg_train_loss": 2.431982190102872,
532
+ "train_accuracy": 0.49786,
533
+ "avg_val_loss": 2.399633093725277,
534
+ "val_accuracy": 0.5228000283241272
535
+ },
536
+ {
537
+ "avg_train_loss": 2.4256825097991377,
538
+ "train_accuracy": 0.49946,
539
+ "avg_val_loss": 2.4045251773882517,
540
+ "val_accuracy": 0.515500009059906
541
+ },
542
+ {
543
+ "avg_train_loss": 2.4132598466275597,
544
+ "train_accuracy": 0.50316,
545
+ "avg_val_loss": 2.38846037659464,
546
+ "val_accuracy": 0.5232999920845032
547
+ },
548
+ {
549
+ "avg_train_loss": 2.4151184967411754,
550
+ "train_accuracy": 0.50272,
551
+ "avg_val_loss": 2.3768584818779668,
552
+ "val_accuracy": 0.5232999920845032
553
+ },
554
+ {
555
+ "avg_train_loss": 2.405930356906198,
556
+ "train_accuracy": 0.50574,
557
+ "avg_val_loss": 2.378061415273932,
558
+ "val_accuracy": 0.5267000198364258
559
+ },
560
+ {
561
+ "avg_train_loss": 2.4047722526828346,
562
+ "train_accuracy": 0.50608,
563
+ "avg_val_loss": 2.3851123278654076,
564
+ "val_accuracy": 0.5238000154495239
565
+ },
566
+ {
567
+ "avg_train_loss": 2.3914314154773724,
568
+ "train_accuracy": 0.50882,
569
+ "avg_val_loss": 2.3767078254796283,
570
+ "val_accuracy": 0.5264999866485596
571
+ },
572
+ {
573
+ "avg_train_loss": 2.3860856683357903,
574
+ "train_accuracy": 0.51156,
575
+ "avg_val_loss": 2.376488021657437,
576
+ "val_accuracy": 0.5235000252723694
577
+ },
578
+ {
579
+ "avg_train_loss": 2.386218143546063,
580
+ "train_accuracy": 0.51266,
581
+ "avg_val_loss": 2.3630600941332083,
582
+ "val_accuracy": 0.5235000252723694
583
+ },
584
+ {
585
+ "avg_train_loss": 2.3744330151611583,
586
+ "train_accuracy": 0.51428,
587
+ "avg_val_loss": 2.3920840492731408,
588
+ "val_accuracy": 0.5213000178337097
589
+ },
590
+ {
591
+ "avg_train_loss": 2.3711826228119834,
592
+ "train_accuracy": 0.51532,
593
+ "avg_val_loss": 2.3973187796677213,
594
+ "val_accuracy": 0.5235000252723694
595
+ },
596
+ {
597
+ "avg_train_loss": 2.3610368445706182,
598
+ "train_accuracy": 0.51854,
599
+ "avg_val_loss": 2.3672910279865507,
600
+ "val_accuracy": 0.5357000231742859
601
+ },
602
+ {
603
+ "avg_train_loss": 2.359229836165143,
604
+ "train_accuracy": 0.52028,
605
+ "avg_val_loss": 2.3615512123590783,
606
+ "val_accuracy": 0.5253999829292297
607
+ },
608
+ {
609
+ "avg_train_loss": 2.3491924016372017,
610
+ "train_accuracy": 0.52172,
611
+ "avg_val_loss": 2.3583934156200552,
612
+ "val_accuracy": 0.5288000106811523
613
+ },
614
+ {
615
+ "avg_train_loss": 2.336290584348352,
616
+ "train_accuracy": 0.52592,
617
+ "avg_val_loss": 2.3765073365803007,
618
+ "val_accuracy": 0.5221999883651733
619
+ },
620
+ {
621
+ "avg_train_loss": 2.3395893063081803,
622
+ "train_accuracy": 0.52642,
623
+ "avg_val_loss": 2.377288818359375,
624
+ "val_accuracy": 0.5306000113487244
625
+ },
626
+ {
627
+ "avg_train_loss": 2.3326609063026544,
628
+ "train_accuracy": 0.52884,
629
+ "avg_val_loss": 2.3460967631279668,
630
+ "val_accuracy": 0.5358999967575073
631
+ },
632
+ {
633
+ "avg_train_loss": 2.3343486081608726,
634
+ "train_accuracy": 0.5261,
635
+ "avg_val_loss": 2.3548485478268395,
636
+ "val_accuracy": 0.5310999751091003
637
+ },
638
+ {
639
+ "avg_train_loss": 2.319910573532514,
640
+ "train_accuracy": 0.53014,
641
+ "avg_val_loss": 2.3654568829113924,
642
+ "val_accuracy": 0.5289999842643738
643
+ },
644
+ {
645
+ "avg_train_loss": 2.3231105095590165,
646
+ "train_accuracy": 0.53024,
647
+ "avg_val_loss": 2.3426397782337816,
648
+ "val_accuracy": 0.5408999919891357
649
+ },
650
+ {
651
+ "avg_train_loss": 2.3172146889864638,
652
+ "train_accuracy": 0.52916,
653
+ "avg_val_loss": 2.3707584429390822,
654
+ "val_accuracy": 0.5284000039100647
655
  },
656
  {
657
+ "avg_train_loss": 2.3137913414889284,
658
+ "train_accuracy": 0.5338,
659
+ "avg_val_loss": 2.378921701938291,
660
+ "val_accuracy": 0.5299000144004822
661
  },
662
  {
663
+ "avg_train_loss": 2.292580285188182,
664
+ "train_accuracy": 0.5417,
665
+ "avg_val_loss": 2.3381581366816655,
666
+ "val_accuracy": 0.5406000018119812
667
  },
668
  {
669
+ "avg_train_loss": 2.2915546366625734,
670
+ "train_accuracy": 0.53912,
671
+ "avg_val_loss": 2.36516909056072,
672
+ "val_accuracy": 0.5367000102996826
673
  },
674
  {
675
+ "avg_train_loss": 2.291307615197223,
676
+ "train_accuracy": 0.53922,
677
+ "avg_val_loss": 2.3603519487984572,
678
+ "val_accuracy": 0.535099983215332
679
  },
680
  {
681
+ "avg_train_loss": 2.2937697252958937,
682
+ "train_accuracy": 0.5405,
683
+ "avg_val_loss": 2.3516428500791138,
684
+ "val_accuracy": 0.5401999950408936
685
  },
686
  {
687
+ "avg_train_loss": 2.2855114741703435,
688
+ "train_accuracy": 0.53978,
689
+ "avg_val_loss": 2.3426988818977454,
690
+ "val_accuracy": 0.536300003528595
691
  },
692
  {
693
+ "avg_train_loss": 2.2704710070129552,
694
+ "train_accuracy": 0.54678,
695
+ "avg_val_loss": 2.3620895192592957,
696
+ "val_accuracy": 0.5358999967575073
697
  },
698
  {
699
+ "avg_train_loss": 2.2660938359587393,
700
+ "train_accuracy": 0.5457,
701
+ "avg_val_loss": 2.3297165496439876,
702
+ "val_accuracy": 0.5368000268936157
703
  },
704
  {
705
+ "avg_train_loss": 2.2636455373690865,
706
+ "train_accuracy": 0.54736,
707
+ "avg_val_loss": 2.370776212668117,
708
+ "val_accuracy": 0.5335999727249146
709
  }
710
  ]
performance_plot.png CHANGED
plots.py CHANGED
@@ -5,7 +5,7 @@ with open("performance.json", "r") as f:
5
  performance = json.load(f)
6
 
7
  # Extract values from the performance list
8
- epochs = range(1, len(performance) + 1)
9
  train_losses = [epoch["avg_train_loss"] for epoch in performance]
10
  val_losses = [epoch["avg_val_loss"] for epoch in performance]
11
  train_accuracies = [epoch["train_accuracy"] for epoch in performance]
@@ -22,7 +22,7 @@ plt.xlabel("Epochs")
22
  plt.ylabel("Loss")
23
  plt.title("Training and Validation Loss")
24
  plt.legend()
25
- plt.xticks(epochs)
26
 
27
  # Subplot for Accuracy
28
  plt.subplot(1, 2, 2)
@@ -32,7 +32,7 @@ plt.xlabel("Epochs")
32
  plt.ylabel("Accuracy")
33
  plt.title("Training and Validation Accuracy")
34
  plt.legend()
35
- plt.xticks(epochs)
36
 
37
  plt.tight_layout()
38
 
 
5
  performance = json.load(f)
6
 
7
  # Extract values from the performance list
8
+ epochs = list(range(1, len(performance) + 1))
9
  train_losses = [epoch["avg_train_loss"] for epoch in performance]
10
  val_losses = [epoch["avg_val_loss"] for epoch in performance]
11
  train_accuracies = [epoch["train_accuracy"] for epoch in performance]
 
22
  plt.ylabel("Loss")
23
  plt.title("Training and Validation Loss")
24
  plt.legend()
25
+ plt.xticks([1] + epochs[9::10] + [epochs[-1]])
26
 
27
  # Subplot for Accuracy
28
  plt.subplot(1, 2, 2)
 
32
  plt.ylabel("Accuracy")
33
  plt.title("Training and Validation Accuracy")
34
  plt.legend()
35
+ plt.xticks([1] + epochs[9::10] + [epochs[-1]])
36
 
37
  plt.tight_layout()
38
 
predictions.csv ADDED
The diff for this file is too large to render. See raw diff
 
train.py DELETED
@@ -1,394 +0,0 @@
1
- #!/usr/bin/env python3
2
- import os
3
- import csv
4
- import json
5
- from tqdm import tqdm
6
- import torch
7
- import argparse
8
- from PIL import Image
9
- from torchvision import transforms
10
- from torch.utils.data import DataLoader, Dataset
11
- from model import MyModel
12
- import numpy as np
13
-
14
-
15
- class MiniPlaces(Dataset):
16
- def __init__(self, root_dir, split, transform=None, label_dict=None):
17
- """
18
- Initialize the MiniPlaces dataset with the root directory for the images,
19
- the split (train/val/test), an optional data transformation,
20
- and an optional label dictionary.
21
-
22
- Args:
23
- root_dir (str): Root directory for the MiniPlaces images.
24
- split (str): Split to use ('train', 'val', or 'test').
25
- transform (callable, optional): Optional data transformation to apply to the images.
26
- label_dict (dict, optional): Optional dictionary mapping integer labels to class names.
27
- """
28
- assert split in ['train', 'val', 'test']
29
- self.root_dir = root_dir
30
- self.split = split
31
- self.transform = transform
32
- self.filenames = []
33
- self.labels = []
34
-
35
- self.label_dict = label_dict if label_dict is not None else {}
36
-
37
- with open(os.path.join(self.root_dir, self.split + '.txt')) as r:
38
- lines = r.readlines()
39
- for line in lines:
40
- line = line.split()
41
- self.filenames.append(line[0])
42
- if split == 'test':
43
- label = line[0]
44
- else:
45
- label = int(line[1])
46
- self.labels.append(label)
47
- if split == 'train':
48
- text_label = line[0].split('/')[2]
49
- self.label_dict[label] = text_label
50
-
51
- def __len__(self):
52
- """
53
- Return the number of images in the dataset.
54
-
55
- Returns:
56
- int: Number of images in the dataset.
57
- """
58
- return len(self.labels)
59
-
60
- def __getitem__(self, idx):
61
- """
62
- Return a single image and its corresponding label when given an index.
63
-
64
- Args:
65
- idx (int): Index of the image to retrieve.
66
-
67
- Returns:
68
- tuple: Tuple containing the image and its label.
69
- """
70
- if self.transform is not None:
71
- image = self.transform(
72
- Image.open(os.path.join(self.root_dir, "images", self.filenames[idx])))
73
- else:
74
- image = Image.open(os.path.join(self.root_dir, "images", self.filenames[idx]))
75
- label = self.labels[idx]
76
- return image, label
77
-
78
-
79
- def create_train_transform():
80
- """
81
- Create training data transformation with augmentation
82
- """
83
- image_net_mean = torch.Tensor([0.485, 0.456, 0.406])
84
- image_net_std = torch.Tensor([0.229, 0.224, 0.225])
85
-
86
- return transforms.Compose([
87
- transforms.RandomResizedCrop(128, scale=(0.8, 1.0)),
88
- transforms.RandomHorizontalFlip(p=0.5),
89
- transforms.ColorJitter(
90
- brightness=0.4,
91
- contrast=0.4,
92
- saturation=0.4,
93
- hue=0.1
94
- ),
95
- transforms.RandomAffine(
96
- degrees=15, # rotation
97
- translate=(0.1, 0.1), # horizontal/vertical translation
98
- scale=(0.9, 1.1), # scale
99
- ),
100
- transforms.ToTensor(),
101
- transforms.Resize((128, 128)),
102
- transforms.Normalize(image_net_mean, image_net_std)
103
- ])
104
-
105
-
106
- def create_val_transform():
107
- """
108
- Create validation/test data transformation without augmentation
109
- """
110
- image_net_mean = torch.Tensor([0.485, 0.456, 0.406])
111
- image_net_std = torch.Tensor([0.229, 0.224, 0.225])
112
-
113
- return transforms.Compose([
114
- transforms.ToTensor(),
115
- transforms.Resize((128, 128)),
116
- transforms.Normalize(image_net_mean, image_net_std)
117
- ])
118
-
119
-
120
- def evaluate(model, test_loader, criterion, device):
121
- """
122
- Evaluate the CNN classifier on the validation set.
123
-
124
- Args:
125
- model (CNN): CNN classifier to evaluate.
126
- test_loader (torch.utils.data.DataLoader): Data loader for the test set.
127
- criterion (callable): Loss function to use for evaluation.
128
- device (torch.device): Device to use for evaluation.
129
-
130
- Returns:
131
- float: Average loss on the test set.
132
- float: Accuracy on the test set.
133
- """
134
- model.eval() # Set model to evaluation mode
135
-
136
- with torch.no_grad():
137
- total_loss = 0.0
138
- num_correct = 0
139
- num_samples = 0
140
-
141
- for inputs, labels in test_loader:
142
- # Move inputs and labels to device
143
- inputs = inputs.to(device)
144
- labels = labels.to(device)
145
-
146
- # Compute the logits and loss
147
- logits = model(inputs)
148
- loss = criterion(logits, labels)
149
- total_loss += loss.item()
150
-
151
- # Compute the accuracy
152
- _, predictions = torch.max(logits, dim=1)
153
- num_correct += (predictions == labels).sum().item()
154
- num_samples += len(inputs)
155
-
156
- # Evaluate the model on the validation set
157
- avg_loss = total_loss / len(test_loader)
158
- accuracy = num_correct / num_samples
159
-
160
- return avg_loss, accuracy
161
-
162
-
163
- def train(model, train_loader, val_loader, optimizer, criterion, device,
164
- num_epochs):
165
- """
166
- Train the CNN classifer on the training set and evaluate it on the validation set every epoch.
167
-
168
- Args:
169
- model (CNN): CNN classifier to train.
170
- train_loader (torch.utils.data.DataLoader): Data loader for the training set.
171
- val_loader (torch.utils.data.DataLoader): Data loader for the validation set.
172
- optimizer (torch.optim.Optimizer): Optimizer to use for training.
173
- criterion (callable): Loss function to use for training.
174
- device (torch.device): Device to use for training.
175
- num_epochs (int): Number of epochs to train the model.
176
- """
177
-
178
- # Place the model on device
179
- model = model.to(device)
180
-
181
- # Define early stopping parameters
182
- patience = 5 # Number of epochs to wait for improvement
183
- best_val_accuracy = 0.0 # Best validation accuracy so far
184
- epochs_without_improvement = 0 # Counter for epochs without improvement
185
- best_model_state = None # To store the state of the best model
186
-
187
- # Performance tracking
188
- performance = []
189
-
190
- for epoch in range(num_epochs):
191
- model.train() # Set model to training mode
192
-
193
- running_loss = 0.0 # Track cumulative loss for averaging
194
- correct_predictions = 0
195
- total_samples = 0
196
-
197
- with tqdm(total=len(train_loader),
198
- desc=f'Epoch {epoch + 1}/{num_epochs}',
199
- position=0,
200
- leave=True) as pbar:
201
- for inputs, labels in train_loader:
202
- # Move inputs and labels to device
203
- inputs = inputs.to(device)
204
- labels = labels.to(device)
205
-
206
- # Zero the gradients
207
- optimizer.zero_grad()
208
-
209
- # Compute the logits and loss
210
- logits = model(inputs)
211
- loss = criterion(logits, labels)
212
-
213
- # Backward pass: Compute gradients
214
- loss.backward()
215
-
216
- # Optimize model parameters
217
- optimizer.step()
218
-
219
- # Track running loss
220
- running_loss += loss.item()
221
-
222
- # Track accuracy
223
- _, predicted = logits.max(1)
224
- correct_predictions += (predicted == labels).sum().item()
225
- total_samples += labels.size(0)
226
-
227
- # Update the progress bar
228
- pbar.update(1)
229
- pbar.set_postfix(loss=loss.item())
230
-
231
- # Calculate average loss and accuracy
232
- avg_train_loss = running_loss / len(train_loader)
233
- train_accuracy = correct_predictions / total_samples
234
- avg_val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)
235
-
236
- performance.append({
237
- "avg_train_loss": avg_train_loss,
238
- "train_accuracy": train_accuracy,
239
- "avg_val_loss": avg_val_loss,
240
- "val_accuracy": val_accuracy
241
- })
242
- print(
243
- f"Train Loss: {avg_train_loss:.4f}, Accuracy: {train_accuracy:.4f} "
244
- f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}"
245
- )
246
-
247
- # Check for early stopping
248
- if val_accuracy > best_val_accuracy:
249
- best_val_accuracy = val_accuracy
250
- epochs_without_improvement = 0 # Reset counter if there's an improvement
251
-
252
- # Save the model checkpoint for the best model
253
- best_model_state = {
254
- 'model_state_dict': model.module.state_dict(),
255
- 'optimizer_state_dict': optimizer.state_dict(),
256
- 'epoch': epoch,
257
- }
258
- else:
259
- epochs_without_improvement += 1
260
-
261
- # Early stopping condition
262
- if epochs_without_improvement >= patience:
263
- print(f"Early stopping at epoch {epoch + 1}.")
264
- break # Stop training if no improvement for 'patience' epochs
265
-
266
- # Save the performance list to a JSON file
267
- with open("performance.json", "w") as f:
268
- json.dump(performance, f, indent=4)
269
- torch.save(best_model_state, 'model.ckpt')
270
-
271
-
272
- def test(model, test_loader, device):
273
- """
274
- Get predictions for the test set.
275
-
276
- Args:
277
- model (CNN): classifier to evaluate.
278
- test_loader (torch.utils.data.DataLoader): Data loader for the test set.
279
- device (torch.device): Device to use for evaluation.
280
-
281
- Returns:
282
- float: Average loss on the test set.
283
- float: Accuracy on the test set.
284
- """
285
- model = model.to(device)
286
- model.eval() # Set model to evaluation mode
287
-
288
- with torch.no_grad():
289
- all_preds = []
290
-
291
- for inputs, labels in test_loader:
292
- # Move inputs and labels to device
293
- inputs = inputs.to(device)
294
-
295
- logits = model(inputs)
296
-
297
- _, predictions = torch.max(logits, dim=1)
298
- preds = list(zip(labels, predictions.tolist()))
299
- all_preds.extend(preds)
300
-
301
- return all_preds
302
-
303
-
304
- def write_predictions(preds, filename):
305
- with open(filename, 'w') as f:
306
- writer = csv.writer(f, delimiter=',')
307
- for im, pred in preds:
308
- writer.writerow((im, pred))
309
-
310
-
311
- def main(args):
312
- image_net_mean = torch.Tensor([0.485, 0.456, 0.406])
313
- image_net_std = torch.Tensor([0.229, 0.224, 0.225])
314
-
315
- # Define data transformation
316
- data_transform = transforms.Compose([
317
- transforms.ToTensor(),
318
- transforms.Resize((128, 128)),
319
- transforms.Normalize(image_net_mean, image_net_std),
320
- ])
321
-
322
- # Separate transforms for training and validation
323
- train_transform = create_train_transform()
324
- val_transform = create_val_transform()
325
-
326
- # Create datasets
327
- data_root = 'data'
328
- miniplaces_train = MiniPlaces(data_root,
329
- split='train',
330
- transform=data_transform)
331
- miniplaces_val = MiniPlaces(data_root,
332
- split='val',
333
- transform=data_transform,
334
- label_dict=miniplaces_train.label_dict)
335
-
336
- # Create the dataloaders
337
-
338
- # Define the batch size and number of workers
339
- batch_size = int(args.batch_size)
340
- num_workers = 2
341
-
342
- # Create DataLoader for training and validation sets
343
- train_loader = DataLoader(miniplaces_train,
344
- batch_size=batch_size,
345
- num_workers=num_workers,
346
- shuffle=True)
347
- val_loader = DataLoader(miniplaces_val,
348
- batch_size=batch_size,
349
- num_workers=num_workers,
350
- shuffle=False)
351
-
352
- device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else 'cpu') # TODO: check cuda
353
-
354
- model = MyModel(num_classes=len(miniplaces_train.label_dict))
355
-
356
- # optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-4, amsgrad=False)
357
- optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, dampening=0, weight_decay=1e-4, nesterov=True)
358
-
359
- print("PARAMS NUM:", sum(p.numel() for p in model.parameters() if p.requires_grad))
360
-
361
- if args.checkpoint:
362
- checkpoint = torch.load(args.checkpoint)
363
- model.load_state_dict(checkpoint['model_state_dict'])
364
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
365
-
366
- criterion = torch.nn.CrossEntropyLoss(reduction='mean', label_smoothing=0.1)
367
-
368
- if not args.test:
369
- train(model, train_loader, val_loader, optimizer, criterion,
370
- device, num_epochs=int(args.epochs))
371
-
372
- else:
373
- miniplaces_test = MiniPlaces(data_root,
374
- split='test',
375
- transform=data_transform)
376
- test_loader = DataLoader(miniplaces_test,
377
- batch_size=batch_size,
378
- num_workers=num_workers,
379
- shuffle=False)
380
- checkpoint = torch.load(args.checkpoint, weights_only=True)
381
- model.load_state_dict(checkpoint['model_state_dict'])
382
- preds = test(model, test_loader, device)
383
- write_predictions(preds, 'predictions.csv')
384
-
385
-
386
- if __name__ == "__main__":
387
- parser = argparse.ArgumentParser()
388
- parser.add_argument('--test', action='store_true')
389
- parser.add_argument('--checkpoint')
390
- parser.add_argument('--gpu', default=0)
391
- parser.add_argument('--epochs', default=100)
392
- parser.add_argument('--batch_size', default=32)
393
- args = parser.parse_args()
394
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_dist.py CHANGED
@@ -2,11 +2,13 @@
2
  import os
3
  import csv
4
  import json
 
5
  from tqdm import tqdm
6
  import torch
7
  import torch.distributed as dist
8
  import torch.multiprocessing as mp
9
  from torch.nn.parallel import DistributedDataParallel as DDP
 
10
  from torch.utils.data.distributed import DistributedSampler
11
  import argparse
12
  from PIL import Image
@@ -36,6 +38,7 @@ def cleanup():
36
  if dist.is_initialized():
37
  dist.barrier() # Synchronize all processes before destroying process group
38
  dist.destroy_process_group()
 
39
 
40
 
41
  class MiniPlaces(Dataset):
@@ -161,6 +164,7 @@ def evaluate(model, test_loader, criterion, device):
161
  with torch.no_grad():
162
  total_loss = 0.0
163
  num_correct = 0
 
164
  num_samples = 0
165
 
166
  for inputs, labels in test_loader:
@@ -173,22 +177,29 @@ def evaluate(model, test_loader, criterion, device):
173
 
174
  _, predictions = torch.max(logits, dim=1)
175
  num_correct += (predictions == labels).sum().item()
 
 
 
 
176
  num_samples += len(inputs)
177
 
178
  # Gather metrics from all processes
179
  world_size = dist.get_world_size()
180
  total_loss = torch.tensor(total_loss).to(device)
181
  num_correct = torch.tensor(num_correct).to(device)
 
182
  num_samples = torch.tensor(num_samples).to(device)
183
 
184
  dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
185
  dist.all_reduce(num_correct, op=dist.ReduceOp.SUM)
 
186
  dist.all_reduce(num_samples, op=dist.ReduceOp.SUM)
187
 
188
  avg_loss = (total_loss / world_size).item() / len(test_loader)
189
  accuracy = (num_correct / num_samples).item()
 
190
 
191
- return avg_loss, accuracy
192
 
193
 
194
  def train_worker(rank, world_size, args):
@@ -201,15 +212,18 @@ def train_worker(rank, world_size, args):
201
  args (argparse.Namespace): Command-line arguments.
202
  """
203
  try:
 
204
  setup(rank, world_size, args.port)
205
  device = torch.device(f'cuda:{rank}')
206
 
207
  # Define early stopping parameters
208
- patience = 3 # Number of epochs to wait for improvement
209
  best_val_accuracy = 0.0 # Best validation accuracy so far
210
  epochs_without_improvement = 0 # Counter for epochs without improvement
211
  best_model_state = None # To store the state of the best model
212
 
 
 
213
  # Separate transforms for training and validation
214
  train_transform = create_train_transform()
215
  val_transform = create_val_transform()
@@ -233,7 +247,7 @@ def train_worker(rank, world_size, args):
233
  pin_memory=True)
234
 
235
  # Create model and move to GPU
236
- model = MyModel(num_classes=len(miniplaces_train.label_dict))
237
  model = model.to(device)
238
  model = DDP(model, device_ids=[rank])
239
 
@@ -247,6 +261,9 @@ def train_worker(rank, world_size, args):
247
  model.module.load_state_dict(checkpoint['model_state_dict'])
248
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
249
 
 
 
 
250
  if not args.test:
251
  # Training loop
252
  performance = []
@@ -288,7 +305,14 @@ def train_worker(rank, world_size, args):
288
  # Evaluate and log metrics
289
  avg_train_loss = running_loss / len(train_loader)
290
  train_accuracy = correct_predictions / total_samples
291
- avg_val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)
 
 
 
 
 
 
 
292
 
293
  if rank == 0: # Only save metrics on rank 0
294
  performance.append({
@@ -327,16 +351,25 @@ def train_worker(rank, world_size, args):
327
  torch.save(best_model_state, 'model.ckpt')
328
 
329
  else: # Testing mode
330
- miniplaces_test = MiniPlaces(data_root, split='test', transform=data_transform)
 
 
 
 
 
 
331
  test_loader = DataLoader(miniplaces_test, batch_size=args.batch_size, num_workers=2, shuffle=False)
332
  checkpoint = torch.load(args.checkpoint, map_location=device)
333
  model.module.load_state_dict(checkpoint['model_state_dict'])
 
334
  preds = test(model, test_loader, device)
335
  if rank == 0: # Only write predictions on rank 0
336
  write_predictions(preds, 'predictions.csv')
 
 
337
  finally:
338
  cleanup()
339
- # Add explicit synchronization before exiting
340
  torch.cuda.synchronize()
341
  if dist.is_initialized():
342
  dist.barrier()
@@ -403,7 +436,7 @@ if __name__ == "__main__":
403
  parser.add_argument('--test', action='store_true')
404
  parser.add_argument('--checkpoint')
405
  parser.add_argument('--epochs', type=int, default=100)
406
- parser.add_argument('--batch_size', type=int, default=32)
407
  parser.add_argument('--port', type=int, default=4224)
408
  args = parser.parse_args()
409
  main(args)
 
2
  import os
3
  import csv
4
  import json
5
+ import warnings
6
  from tqdm import tqdm
7
  import torch
8
  import torch.distributed as dist
9
  import torch.multiprocessing as mp
10
  from torch.nn.parallel import DistributedDataParallel as DDP
11
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
12
  from torch.utils.data.distributed import DistributedSampler
13
  import argparse
14
  from PIL import Image
 
38
  if dist.is_initialized():
39
  dist.barrier() # Synchronize all processes before destroying process group
40
  dist.destroy_process_group()
41
+ torch.cuda.synchronize()
42
 
43
 
44
  class MiniPlaces(Dataset):
 
164
  with torch.no_grad():
165
  total_loss = 0.0
166
  num_correct = 0
167
+ num_correct_top5 = 0
168
  num_samples = 0
169
 
170
  for inputs, labels in test_loader:
 
177
 
178
  _, predictions = torch.max(logits, dim=1)
179
  num_correct += (predictions == labels).sum().item()
180
+
181
+ _, top5_predictions = torch.topk(logits, 5, dim=1)
182
+ num_correct_top5 += (top5_predictions == labels.unsqueeze(1)).any(dim=1).sum().item()
183
+
184
  num_samples += len(inputs)
185
 
186
  # Gather metrics from all processes
187
  world_size = dist.get_world_size()
188
  total_loss = torch.tensor(total_loss).to(device)
189
  num_correct = torch.tensor(num_correct).to(device)
190
+ num_correct_top5 = torch.tensor(num_correct_top5).to(device)
191
  num_samples = torch.tensor(num_samples).to(device)
192
 
193
  dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
194
  dist.all_reduce(num_correct, op=dist.ReduceOp.SUM)
195
+ dist.all_reduce(num_correct_top5, op=dist.ReduceOp.SUM)
196
  dist.all_reduce(num_samples, op=dist.ReduceOp.SUM)
197
 
198
  avg_loss = (total_loss / world_size).item() / len(test_loader)
199
  accuracy = (num_correct / num_samples).item()
200
+ top5_accuracy = (num_correct_top5 / num_samples).item()
201
 
202
+ return avg_loss, accuracy, top5_accuracy
203
 
204
 
205
  def train_worker(rank, world_size, args):
 
212
  args (argparse.Namespace): Command-line arguments.
213
  """
214
  try:
215
+ warnings.filterwarnings("ignore")
216
  setup(rank, world_size, args.port)
217
  device = torch.device(f'cuda:{rank}')
218
 
219
  # Define early stopping parameters
220
+ patience = 10 # Number of epochs to wait for improvement
221
  best_val_accuracy = 0.0 # Best validation accuracy so far
222
  epochs_without_improvement = 0 # Counter for epochs without improvement
223
  best_model_state = None # To store the state of the best model
224
 
225
+ last_lr = 0
226
+
227
  # Separate transforms for training and validation
228
  train_transform = create_train_transform()
229
  val_transform = create_val_transform()
 
247
  pin_memory=True)
248
 
249
  # Create model and move to GPU
250
+ model = MyModel(num_classes=len(miniplaces_train.label_dict), dropout_rate=0.2)
251
  model = model.to(device)
252
  model = DDP(model, device_ids=[rank])
253
 
 
261
  model.module.load_state_dict(checkpoint['model_state_dict'])
262
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
263
 
264
+ # Initialize the ReduceLROnPlateau scheduler
265
+ scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=4)
266
+
267
  if not args.test:
268
  # Training loop
269
  performance = []
 
305
  # Evaluate and log metrics
306
  avg_train_loss = running_loss / len(train_loader)
307
  train_accuracy = correct_predictions / total_samples
308
+ avg_val_loss, val_accuracy, val_top5_accuracy = evaluate(model, val_loader, criterion, device)
309
+
310
+ # Step the scheduler with the validation loss
311
+ scheduler.step(avg_val_loss)
312
+ if scheduler.get_last_lr()[0] != last_lr:
313
+ last_lr = scheduler.get_last_lr()[0]
314
+ if epoch != 0:
315
+ print(f"New learning rate: {scheduler.get_last_lr()[0]}")
316
 
317
  if rank == 0: # Only save metrics on rank 0
318
  performance.append({
 
351
  torch.save(best_model_state, 'model.ckpt')
352
 
353
  else: # Testing mode
354
+ avg_val_loss, val_accuracy, val_top5_accuracy = evaluate(model, val_loader, criterion, device)
355
+ if rank == 0:
356
+ print(f"\nValidation Loss: {avg_val_loss:.4f}\n"
357
+ f"Validation Accuracy: {val_accuracy:.4f}\n"
358
+ f"Validation Top-5 Accuracy: {val_top5_accuracy:.4f}\n")
359
+
360
+ miniplaces_test = MiniPlaces(data_root, split='test', transform=val_transform)
361
  test_loader = DataLoader(miniplaces_test, batch_size=args.batch_size, num_workers=2, shuffle=False)
362
  checkpoint = torch.load(args.checkpoint, map_location=device)
363
  model.module.load_state_dict(checkpoint['model_state_dict'])
364
+
365
  preds = test(model, test_loader, device)
366
  if rank == 0: # Only write predictions on rank 0
367
  write_predictions(preds, 'predictions.csv')
368
+ print("Predictions saved to predictions.csv\n")
369
+
370
  finally:
371
  cleanup()
372
+ # Explicit synchronization before exiting
373
  torch.cuda.synchronize()
374
  if dist.is_initialized():
375
  dist.barrier()
 
436
  parser.add_argument('--test', action='store_true')
437
  parser.add_argument('--checkpoint')
438
  parser.add_argument('--epochs', type=int, default=100)
439
+ parser.add_argument('--batch_size', type=int, default=64)
440
  parser.add_argument('--port', type=int, default=4224)
441
  args = parser.parse_args()
442
  main(args)