felix2703 commited on
Commit
da4f171
·
1 Parent(s): 95382f9

Fix model architectures to match trained checkpoints

Browse files
Files changed (2) hide show
  1. models_attack.py +108 -111
  2. models_shifted.py +98 -81
models_attack.py CHANGED
@@ -8,153 +8,150 @@ import torch.nn.functional as F
8
 
9
 
10
  class StandardCNN(nn.Module):
11
- """Standard CNN with BatchNorm for attack resistance"""
 
 
 
 
12
 
13
  def __init__(self, num_classes=10, dropout_rate=0.5):
14
  super(StandardCNN, self).__init__()
15
 
16
- # Convolutional layers with BatchNorm
17
- self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
18
  self.bn1 = nn.BatchNorm2d(32)
19
- self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
 
 
 
20
  self.bn2 = nn.BatchNorm2d(64)
21
- self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
 
 
 
22
  self.bn3 = nn.BatchNorm2d(128)
 
23
 
24
- # Pooling layer
25
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
26
 
27
  # Fully connected layers
28
- self.fc1 = nn.Linear(128 * 3 * 3, 256)
29
- self.fc2 = nn.Linear(256, 128)
30
- self.fc3 = nn.Linear(128, num_classes)
31
-
32
- # Dropout
33
- self.dropout = nn.Dropout(dropout_rate)
34
 
35
  def forward(self, x, return_logits=False):
36
- # Convolutional layers with BatchNorm, ReLU and pooling
37
- x = self.pool(F.relu(self.bn1(self.conv1(x)))) # 28x28 -> 14x14
38
- x = self.pool(F.relu(self.bn2(self.conv2(x)))) # 14x14 -> 7x7
39
- x = self.pool(F.relu(self.bn3(self.conv3(x)))) # 7x7 -> 3x3
40
-
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # Flatten
42
  x = x.view(x.size(0), -1)
43
-
44
- # Fully connected layers with dropout
45
  x = F.relu(self.fc1(x))
46
- x = self.dropout(x)
47
  x = F.relu(self.fc2(x))
48
- x = self.dropout(x)
 
49
  logits = self.fc3(x)
50
-
51
  if return_logits:
52
  return logits
53
-
54
- # Apply softmax for probability distribution
55
  return F.softmax(logits, dim=1)
56
 
57
 
58
  class LighterCNN(nn.Module):
59
- """Lighter CNN with BatchNorm and Global Average Pooling"""
 
 
 
 
60
 
61
  def __init__(self, num_classes=10, dropout_rate=0.5):
62
  super(LighterCNN, self).__init__()
63
 
64
- # Convolutional layers with BatchNorm
65
- self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
66
- self.bn1 = nn.BatchNorm2d(16)
67
- self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
68
- self.bn2 = nn.BatchNorm2d(32)
69
- self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
70
- self.bn3 = nn.BatchNorm2d(64)
71
-
72
- # Pooling layers
73
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
74
- self.gap = nn.AdaptiveAvgPool2d(1)
75
-
76
- # Single fully connected layer
77
- self.fc = nn.Linear(64, num_classes)
78
-
79
- # Dropout
80
- self.dropout = nn.Dropout(dropout_rate)
81
-
82
  def forward(self, x, return_logits=False):
83
- # Convolutional layers with BatchNorm, ReLU and pooling
84
- x = self.pool(F.relu(self.bn1(self.conv1(x)))) # 28x28 -> 14x14
85
- x = self.pool(F.relu(self.bn2(self.conv2(x)))) # 14x14 -> 7x7
86
- x = self.pool(F.relu(self.bn3(self.conv3(x)))) # 7x7 -> 3x3
87
-
88
- # Global average pooling
89
- x = self.gap(x)
90
-
91
- # Flatten
92
- x = x.view(x.size(0), -1)
93
-
94
- # Apply dropout
95
- x = self.dropout(x)
96
-
97
- # Final classification layer
98
  logits = self.fc(x)
99
-
100
- if return_logits:
101
- return logits
102
-
103
- # Apply softmax for probability distribution
104
- return F.softmax(logits, dim=1)
 
 
 
 
 
 
 
 
105
 
106
 
107
  class DepthwiseCNN(nn.Module):
108
- """Ultra-efficient CNN using Depthwise Separable Convolutions"""
 
 
 
 
109
 
110
  def __init__(self, num_classes=10, dropout_rate=0.5):
111
  super(DepthwiseCNN, self).__init__()
112
 
113
- # Depthwise Separable Conv 1: 1 -> 8 channels
114
- self.depthwise1 = nn.Conv2d(1, 1, kernel_size=3, padding=1, groups=1)
115
- self.pointwise1 = nn.Conv2d(1, 8, kernel_size=1)
116
- self.bn1 = nn.BatchNorm2d(8)
117
-
118
- # Depthwise Separable Conv 2: 8 -> 16 channels
119
- self.depthwise2 = nn.Conv2d(8, 8, kernel_size=3, padding=1, groups=8)
120
- self.pointwise2 = nn.Conv2d(8, 16, kernel_size=1)
121
- self.bn2 = nn.BatchNorm2d(16)
122
 
123
- # Pooling layers
124
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
 
 
125
  self.gap = nn.AdaptiveAvgPool2d(1)
126
-
127
- # Single fully connected layer
128
- self.fc = nn.Linear(16, num_classes)
129
-
130
- # Dropout
131
- self.dropout = nn.Dropout(dropout_rate)
132
-
133
  def forward(self, x, return_logits=False):
134
- # Depthwise Separable Conv 1
135
- x = self.depthwise1(x)
136
- x = self.pointwise1(x)
137
- x = self.pool(F.relu(self.bn1(x))) # 28x28 -> 14x14
138
-
139
- # Depthwise Separable Conv 2
140
- x = self.depthwise2(x)
141
- x = self.pointwise2(x)
142
- x = self.pool(F.relu(self.bn2(x))) # 14x14 -> 7x7
143
-
144
- # Global average pooling
145
- x = self.gap(x)
146
-
147
- # Flatten
148
- x = x.view(x.size(0), -1)
149
-
150
- # Apply dropout
151
- x = self.dropout(x)
152
-
153
- # Final classification layer
154
- logits = self.fc(x)
155
-
156
- if return_logits:
157
- return logits
158
-
159
- # Apply softmax for probability distribution
160
- return F.softmax(logits, dim=1)
 
8
 
9
 
10
  class StandardCNN(nn.Module):
11
+ """
12
+ Standard CNN Model (Original)
13
+ Architecture: 3 Conv blocks with BatchNorm + 3 FC layers
14
+ Parameters: ~817K
15
+ """
16
 
17
  def __init__(self, num_classes=10, dropout_rate=0.5):
18
  super(StandardCNN, self).__init__()
19
 
20
+ # First convolutional block
21
+ self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
22
  self.bn1 = nn.BatchNorm2d(32)
23
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
24
+
25
+ # Second convolutional block
26
+ self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
27
  self.bn2 = nn.BatchNorm2d(64)
28
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
29
+
30
+ # Third convolutional block
31
+ self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
32
  self.bn3 = nn.BatchNorm2d(128)
33
+ self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
34
 
35
+ # Calculate the flattened size after convolutions
36
+ self.flattened_size = 128 * 3 * 3 # 28x28 -> 14x14 -> 7x7 -> 3x3
37
 
38
  # Fully connected layers
39
+ self.fc1 = nn.Linear(self.flattened_size, 512)
40
+ self.dropout1 = nn.Dropout(dropout_rate)
41
+ self.fc2 = nn.Linear(512, 256)
42
+ self.dropout2 = nn.Dropout(dropout_rate)
43
+ self.fc3 = nn.Linear(256, num_classes)
 
44
 
45
  def forward(self, x, return_logits=False):
46
+ # Conv block 1
47
+ x = self.conv1(x)
48
+ x = self.bn1(x)
49
+ x = F.relu(x)
50
+ x = self.pool1(x)
51
+
52
+ # Conv block 2
53
+ x = self.conv2(x)
54
+ x = self.bn2(x)
55
+ x = F.relu(x)
56
+ x = self.pool2(x)
57
+
58
+ # Conv block 3
59
+ x = self.conv3(x)
60
+ x = self.bn3(x)
61
+ x = F.relu(x)
62
+ x = self.pool3(x)
63
+
64
  # Flatten
65
  x = x.view(x.size(0), -1)
66
+
67
+ # FC layers
68
  x = F.relu(self.fc1(x))
69
+ x = self.dropout1(x)
70
  x = F.relu(self.fc2(x))
71
+ x = self.dropout2(x)
72
+
73
  logits = self.fc3(x)
74
+
75
  if return_logits:
76
  return logits
 
 
77
  return F.softmax(logits, dim=1)
78
 
79
 
80
  class LighterCNN(nn.Module):
81
+ """
82
+ Lighter CNN Model
83
+ Architecture: 3 Conv blocks with fewer filters + Global Average Pooling
84
+ Parameters: ~94K
85
+ """
86
 
87
  def __init__(self, num_classes=10, dropout_rate=0.5):
88
  super(LighterCNN, self).__init__()
89
 
90
+ self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
91
+ self.bn1 = nn.BatchNorm2d(32)
92
+ self.pool1 = nn.MaxPool2d(2,2)
93
+
94
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
95
+ self.bn2 = nn.BatchNorm2d(64)
96
+ self.pool2 = nn.MaxPool2d(2,2)
97
+
98
+ self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
99
+ self.bn3 = nn.BatchNorm2d(128)
100
+ self.pool3 = nn.MaxPool2d(2,2) # 28->14->7->3
101
+
102
+ self.gap = nn.AdaptiveAvgPool2d(1) # (B,128,1,1)
103
+ self.fc = nn.Linear(128, num_classes)
104
+
 
 
 
105
  def forward(self, x, return_logits=False):
106
+ x = self.pool1(F.relu(self.bn1(self.conv1(x))))
107
+ x = self.pool2(F.relu(self.bn2(self.conv2(x))))
108
+ x = self.pool3(F.relu(self.bn3(self.conv3(x))))
109
+ x = self.gap(x).view(x.size(0), -1) # (B,128)
 
 
 
 
 
 
 
 
 
 
 
110
  logits = self.fc(x)
111
+ return logits if return_logits else F.softmax(logits, dim=1)
112
+
113
+
114
+ class DepthwiseSeparableConv(nn.Module):
115
+ def __init__(self, in_ch, out_ch, stride=1):
116
+ super(DepthwiseSeparableConv, self).__init__()
117
+ self.dw = nn.Conv2d(in_ch, in_ch, 3, stride=stride, padding=1,
118
+ groups=in_ch, bias=False) # depthwise
119
+ self.pw = nn.Conv2d(in_ch, out_ch, 1, bias=False) # pointwise
120
+ self.bn = nn.BatchNorm2d(out_ch)
121
+ def forward(self, x):
122
+ x = self.dw(x)
123
+ x = self.pw(x)
124
+ return F.relu(self.bn(x), inplace=True)
125
 
126
 
127
  class DepthwiseCNN(nn.Module):
128
+ """
129
+ Depthwise Separable CNN
130
+ Ultra-efficient using Depthwise Separable Convolutions
131
+ Parameters: ~1.4K
132
+ """
133
 
134
  def __init__(self, num_classes=10, dropout_rate=0.5):
135
  super(DepthwiseCNN, self).__init__()
136
 
137
+ # Stem: 1 -> 8, reduce size with stride=2 (28->14)
138
+ self.stem = nn.Sequential(
139
+ nn.Conv2d(1, 8, 3, stride=2, padding=1, bias=False),
140
+ nn.BatchNorm2d(8),
141
+ nn.ReLU(inplace=True),
142
+ )
 
 
 
143
 
144
+ # DS blocks
145
+ self.ds1 = DepthwiseSeparableConv(8, 16, stride=1)
146
+ self.ds2 = DepthwiseSeparableConv(16, 32, stride=2) # 14->7
147
+
148
  self.gap = nn.AdaptiveAvgPool2d(1)
149
+ self.fc = nn.Linear(32, num_classes)
150
+
 
 
 
 
 
151
  def forward(self, x, return_logits=False):
152
+ x = self.stem(x) # B, 8, 14, 14
153
+ x = self.ds1(x) # B,16,14,14
154
+ x = self.ds2(x) # B,32, 7, 7
155
+ x = self.gap(x).flatten(1) # B,32
156
+ logits = self.fc(x) # B,10
157
+ return logits if return_logits else F.softmax(logits, dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models_shifted.py CHANGED
@@ -8,118 +8,135 @@ import torch.nn.functional as F
8
 
9
 
10
  class CNNModel(nn.Module):
11
- """Standard CNN model for MNIST classification"""
 
 
 
 
12
 
13
  def __init__(self, num_classes=10, dropout_rate=0.5):
14
  super(CNNModel, self).__init__()
15
 
16
- # Convolutional layers
17
- self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
18
- self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
19
- self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
20
 
21
- # Pooling layer
22
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
 
 
23
 
24
- # Fully connected layers
25
- self.fc1 = nn.Linear(128 * 3 * 3, 256)
26
- self.fc2 = nn.Linear(256, 128)
27
- self.fc3 = nn.Linear(128, num_classes)
28
 
29
- # Dropout
30
- self.dropout = nn.Dropout(dropout_rate)
 
 
 
 
 
 
31
 
32
  def forward(self, x):
33
- # Convolutional layers with ReLU and pooling
34
- x = self.pool(F.relu(self.conv1(x))) # 28x28 -> 14x14
35
- x = self.pool(F.relu(self.conv2(x))) # 14x14 -> 7x7
36
- x = self.pool(F.relu(self.conv3(x))) # 7x7 -> 3x3
 
 
 
 
37
 
38
- # Flatten
 
 
 
 
39
  x = x.view(x.size(0), -1)
40
 
41
  # Fully connected layers with dropout
42
  x = F.relu(self.fc1(x))
43
- x = self.dropout(x)
44
  x = F.relu(self.fc2(x))
45
- x = self.dropout(x)
46
  x = self.fc3(x)
47
 
48
- # Apply softmax for probability distribution
49
- return F.softmax(x, dim=1)
50
 
51
 
52
  class TinyCNN(nn.Module):
53
- """Lightweight CNN model with fewer parameters"""
 
 
 
54
 
55
  def __init__(self, num_classes=10):
56
  super(TinyCNN, self).__init__()
57
 
58
- # Convolutional layers with fewer filters
59
- self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
60
- self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
61
- self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
62
-
63
- # Pooling layer
64
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
65
-
66
- # Global average pooling instead of large FC layer
67
- self.gap = nn.AdaptiveAvgPool2d(1)
68
-
69
- # Single fully connected layer
70
- self.fc = nn.Linear(64, num_classes)
71
 
72
- def forward(self, x):
73
- # Convolutional layers with ReLU and pooling
74
- x = self.pool(F.relu(self.conv1(x))) # 28x28 -> 14x14
75
- x = self.pool(F.relu(self.conv2(x))) # 14x14 -> 7x7
76
- x = self.pool(F.relu(self.conv3(x))) # 7x7 -> 3x3
 
 
 
 
77
 
78
  # Global average pooling
79
- x = self.gap(x)
80
 
81
- # Flatten
82
- x = x.view(x.size(0), -1)
83
-
84
- # Final classification layer
85
- x = self.fc(x)
86
-
87
- # Apply softmax for probability distribution
88
- return F.softmax(x, dim=1)
 
 
 
89
 
90
 
91
  class MiniCNN(nn.Module):
92
- """Ultra-lightweight CNN model for edge devices"""
93
-
 
 
 
94
  def __init__(self, num_classes=10):
95
  super(MiniCNN, self).__init__()
96
-
97
- # Minimal convolutional layers
98
- self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1)
99
- self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
100
-
101
- # Pooling layer
102
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
103
-
104
- # Global average pooling
105
- self.gap = nn.AdaptiveAvgPool2d(1)
106
-
107
- # Single fully connected layer
108
- self.fc = nn.Linear(16, num_classes)
109
-
 
 
 
110
  def forward(self, x):
111
- # Two convolutional layers with ReLU and pooling
112
- x = self.pool(F.relu(self.conv1(x))) # 28x28 -> 14x14
113
- x = self.pool(F.relu(self.conv2(x))) # 14x14 -> 7x7
114
-
115
- # Global average pooling
116
- x = self.gap(x)
117
-
118
- # Flatten
119
- x = x.view(x.size(0), -1)
120
-
121
- # Final classification layer
122
- x = self.fc(x)
123
-
124
- # Apply softmax for probability distribution
125
- return F.softmax(x, dim=1)
 
8
 
9
 
10
  class CNNModel(nn.Module):
11
+ """
12
+ CNN Model for MNIST digit classification with shifted labels
13
+ Architecture: Conv-BN-ReLU-Pool x3 + FC-Dropout x2 + FC
14
+ Trainable parameters: 817,354
15
+ """
16
 
17
  def __init__(self, num_classes=10, dropout_rate=0.5):
18
  super(CNNModel, self).__init__()
19
 
20
+ # First convolutional block
21
+ self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
22
+ self.bn1 = nn.BatchNorm2d(32)
23
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
24
 
25
+ # Second convolutional block
26
+ self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
27
+ self.bn2 = nn.BatchNorm2d(64)
28
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
29
 
30
+ # Third convolutional block
31
+ self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
32
+ self.bn3 = nn.BatchNorm2d(128)
33
+ self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
34
 
35
+ self.flattened_size = 128 * 3 * 3
36
+
37
+ # Fully connected layers with dropout
38
+ self.fc1 = nn.Linear(self.flattened_size, 512)
39
+ self.dropout1 = nn.Dropout(dropout_rate)
40
+ self.fc2 = nn.Linear(512, 256)
41
+ self.dropout2 = nn.Dropout(dropout_rate)
42
+ self.fc3 = nn.Linear(256, num_classes)
43
 
44
  def forward(self, x):
45
+ """Forward pass through the network"""
46
+ # First conv block: (1, 28, 28) -> (32, 14, 14)
47
+ x = F.relu(self.bn1(self.conv1(x)))
48
+ x = self.pool1(x)
49
+
50
+ # Second conv block: (32, 14, 14) -> (64, 7, 7)
51
+ x = F.relu(self.bn2(self.conv2(x)))
52
+ x = self.pool2(x)
53
 
54
+ # Third conv block: (64, 7, 7) -> (128, 3, 3)
55
+ x = F.relu(self.bn3(self.conv3(x)))
56
+ x = self.pool3(x)
57
+
58
+ # Flatten for FC layers
59
  x = x.view(x.size(0), -1)
60
 
61
  # Fully connected layers with dropout
62
  x = F.relu(self.fc1(x))
63
+ x = self.dropout1(x)
64
  x = F.relu(self.fc2(x))
65
+ x = self.dropout2(x)
66
  x = self.fc3(x)
67
 
68
+ return x
 
69
 
70
 
71
  class TinyCNN(nn.Module):
72
+ """
73
+ Tiny CNN for MNIST using Global Avg Pooling
74
+ Trainable parameters: 94,410
75
+ """
76
 
77
  def __init__(self, num_classes=10):
78
  super(TinyCNN, self).__init__()
79
 
80
+ # First conv block
81
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
82
+ self.bn1 = nn.BatchNorm2d(32)
83
+ self.pool1 = nn.MaxPool2d(2, 2)
 
 
 
 
 
 
 
 
 
84
 
85
+ # Second conv block
86
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
87
+ self.bn2 = nn.BatchNorm2d(64)
88
+ self.pool2 = nn.MaxPool2d(2, 2)
89
+
90
+ # Third conv block
91
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
92
+ self.bn3 = nn.BatchNorm2d(128)
93
+ self.pool3 = nn.MaxPool2d(2, 2)
94
 
95
  # Global average pooling
96
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
97
 
98
+ # Final FC (input = 128 channels after GAP)
99
+ self.fc = nn.Linear(128, num_classes)
100
+
101
+ def forward(self, x):
102
+ x = self.pool1(F.relu(self.bn1(self.conv1(x))))
103
+ x = self.pool2(F.relu(self.bn2(self.conv2(x))))
104
+ x = self.pool3(F.relu(self.bn3(self.conv3(x))))
105
+ x = self.avgpool(x) # (batch, 128, 1, 1)
106
+ x = x.view(x.size(0), -1) # (batch, 128)
107
+ x = self.fc(x) # (batch, num_classes)
108
+ return x
109
 
110
 
111
  class MiniCNN(nn.Module):
112
+ """
113
+ Mini CNN for MNIST using only 2 convolution layers + Global Avg Pooling
114
+ Trainable parameters: ~19K
115
+ """
116
+
117
  def __init__(self, num_classes=10):
118
  super(MiniCNN, self).__init__()
119
+
120
+ # First CNV
121
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
122
+ self.bn1 = nn.BatchNorm2d(32)
123
+ self.pool1 = nn.MaxPool2d(2, 2)
124
+
125
+ # Second CNV
126
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
127
+ self.bn2 = nn.BatchNorm2d(64)
128
+ self.pool2 = nn.MaxPool2d(2, 2)
129
+
130
+ # Global Average Pooling
131
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
132
+
133
+ # Fully connected classifier
134
+ self.fc = nn.Linear(64, num_classes)
135
+
136
  def forward(self, x):
137
+ x = self.pool1(F.relu(self.bn1(self.conv1(x)))) # (batch, 32, 14, 14)
138
+ x = self.pool2(F.relu(self.bn2(self.conv2(x)))) # (batch, 64, 7, 7)
139
+ x = self.avgpool(x) # (batch, 64, 1, 1)
140
+ x = x.view(x.size(0), -1) # (batch, 64)
141
+ x = self.fc(x) # (batch, num_classes)
142
+ return x