LearnWaterFlow commited on
Commit
e02bb38
·
verified ·
1 Parent(s): 2fb1fad

Update modeling_resnet.py

Browse files
Files changed (1) hide show
  1. modeling_resnet.py +20 -27
modeling_resnet.py CHANGED
@@ -4,26 +4,9 @@ from transformers import PreTrainedModel, PretrainedConfig
4
 
5
  class ResNetConfig(PretrainedConfig):
6
  model_type = "custom_resnet"
7
-
8
- def __init__(
9
- self,
10
- num_classes=10,
11
- image_size=64,
12
- input_channels=3,
13
- layers=(3, 4, 6, 3),
14
- hidden_sizes=(64, 128, 256, 512),
15
- activation_function="relu",
16
- label_smoothing=0.0,
17
- **kwargs
18
- ):
19
  super().__init__(**kwargs)
20
  self.num_classes = num_classes
21
- self.image_size = image_size
22
- self.input_channels = input_channels
23
- self.layers = layers
24
- self.hidden_sizes = hidden_sizes
25
- self.activation_function = activation_function
26
- self.label_smoothing = label_smoothing
27
 
28
  class BasicBlock(nn.Module):
29
  expansion = 1
@@ -51,6 +34,7 @@ class ResNet(PreTrainedModel):
51
  def __init__(self, config):
52
  super().__init__(config)
53
  self.in_channels = 64
 
54
  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
55
  self.bn1 = nn.BatchNorm2d(64)
56
  self.relu = nn.ReLU(inplace=True)
@@ -62,26 +46,35 @@ class ResNet(PreTrainedModel):
62
  self.layer4 = self._make_layer(BasicBlock, 512, 3, stride=2)
63
 
64
  self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
65
- self.fc = nn.Linear(512, config.num_classes)
66
 
67
  def _make_layer(self, block, out_channels, blocks, stride=1):
68
  downsample = None
69
- if stride != 1 or self.in_channels != out_channels:
70
  downsample = nn.Sequential(
71
- nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
72
- nn.BatchNorm2d(out_channels),
73
  )
74
- layers = [block(self.in_channels, out_channels, stride, downsample)]
75
- self.in_channels = out_channels
 
76
  for _ in range(1, blocks):
77
  layers.append(block(self.in_channels, out_channels))
78
  return nn.Sequential(*layers)
79
 
80
- def forward(self, x):
81
- x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
 
82
  x = self.layer1(x)
83
  x = self.layer2(x)
84
  x = self.layer3(x)
85
  x = self.layer4(x)
86
  x = torch.flatten(self.avgpool(x), 1)
87
- return self.fc(x)
 
 
 
 
 
 
 
 
4
 
5
  class ResNetConfig(PretrainedConfig):
6
  model_type = "custom_resnet"
7
+ def __init__(self, num_classes=10, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
8
  super().__init__(**kwargs)
9
  self.num_classes = num_classes
 
 
 
 
 
 
10
 
11
  class BasicBlock(nn.Module):
12
  expansion = 1
 
34
  def __init__(self, config):
35
  super().__init__(config)
36
  self.in_channels = 64
37
+ # Matches your training script ResNet-34
38
  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
39
  self.bn1 = nn.BatchNorm2d(64)
40
  self.relu = nn.ReLU(inplace=True)
 
46
  self.layer4 = self._make_layer(BasicBlock, 512, 3, stride=2)
47
 
48
  self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
49
+ self.fc = nn.Linear(512 * BasicBlock.expansion, config.num_classes)
50
 
51
  def _make_layer(self, block, out_channels, blocks, stride=1):
52
  downsample = None
53
+ if stride != 1 or self.in_channels != out_channels * block.expansion:
54
  downsample = nn.Sequential(
55
+ nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
56
+ nn.BatchNorm2d(out_channels * block.expansion),
57
  )
58
+ layers = []
59
+ layers.append(block(self.in_channels, out_channels, stride, downsample))
60
+ self.in_channels = out_channels * block.expansion
61
  for _ in range(1, blocks):
62
  layers.append(block(self.in_channels, out_channels))
63
  return nn.Sequential(*layers)
64
 
65
+ def forward(self, x, labels=None):
66
+ x = self.relu(self.bn1(self.conv1(x)))
67
+ x = self.maxpool(x)
68
  x = self.layer1(x)
69
  x = self.layer2(x)
70
  x = self.layer3(x)
71
  x = self.layer4(x)
72
  x = torch.flatten(self.avgpool(x), 1)
73
+ logits = self.fc(x)
74
+
75
+ loss = None
76
+ if labels is not None:
77
+ loss_fct = nn.CrossEntropyLoss()
78
+ loss = loss_fct(logits, labels)
79
+
80
+ return {"loss": loss, "logits": logits} if loss is not labels else logits