CircleStar commited on
Commit
2e34d29
·
verified ·
1 Parent(s): c3905ef

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +12 -6
model.py CHANGED
@@ -2,15 +2,13 @@ import torch.nn as nn
2
  from torchvision import models
3
 
4
 
5
- class SimpleCNN(nn.Module):
6
  def __init__(
7
  self,
8
  num_classes: int,
9
- conv1_channels: int = 16,
10
- conv2_channels: int = 32,
11
- kernel_size: int = 3,
12
- dropout: float = 0.2,
13
- fc_dim: int = 128,
14
  ):
15
  super().__init__()
16
 
@@ -18,6 +16,11 @@ class SimpleCNN(nn.Module):
18
  self.backbone = models.resnet18(weights=weights)
19
 
20
  in_features = self.backbone.fc.in_features
 
 
 
 
 
21
  self.backbone.fc = nn.Sequential(
22
  nn.Dropout(dropout),
23
  nn.Linear(in_features, fc_dim),
@@ -26,5 +29,8 @@ class SimpleCNN(nn.Module):
26
  nn.Linear(fc_dim, num_classes),
27
  )
28
 
 
 
 
29
  def forward(self, x):
30
  return self.backbone(x)
 
2
  from torchvision import models
3
 
4
 
5
+ class ResNet18Classifier(nn.Module):
6
  def __init__(
7
  self,
8
  num_classes: int,
9
+ dropout: float = 0.5,
10
+ fc_dim: int = 256,
11
+ freeze_backbone: bool = True,
 
 
12
  ):
13
  super().__init__()
14
 
 
16
  self.backbone = models.resnet18(weights=weights)
17
 
18
  in_features = self.backbone.fc.in_features
19
+
20
+ if freeze_backbone:
21
+ for param in self.backbone.parameters():
22
+ param.requires_grad = False
23
+
24
  self.backbone.fc = nn.Sequential(
25
  nn.Dropout(dropout),
26
  nn.Linear(in_features, fc_dim),
 
29
  nn.Linear(fc_dim, num_classes),
30
  )
31
 
32
+ for param in self.backbone.fc.parameters():
33
+ param.requires_grad = True
34
+
35
  def forward(self, x):
36
  return self.backbone(x)