LearnWaterFlow commited on
Commit
b128ae9
·
verified ·
1 Parent(s): e02bb38

Update modeling_resnet.py

Browse files
Files changed (1) hide show
  1. modeling_resnet.py +10 -16
modeling_resnet.py CHANGED
@@ -4,9 +4,10 @@ from transformers import PreTrainedModel, PretrainedConfig
4
 
5
  class ResNetConfig(PretrainedConfig):
6
  model_type = "custom_resnet"
7
- def __init__(self, num_classes=10, **kwargs):
8
  super().__init__(**kwargs)
9
  self.num_classes = num_classes
 
10
 
11
  class BasicBlock(nn.Module):
12
  expansion = 1
@@ -34,8 +35,7 @@ class ResNet(PreTrainedModel):
34
  def __init__(self, config):
35
  super().__init__(config)
36
  self.in_channels = 64
37
- # Matches your training script ResNet-34
38
- self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
39
  self.bn1 = nn.BatchNorm2d(64)
40
  self.relu = nn.ReLU(inplace=True)
41
  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@@ -46,18 +46,17 @@ class ResNet(PreTrainedModel):
46
  self.layer4 = self._make_layer(BasicBlock, 512, 3, stride=2)
47
 
48
  self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
49
- self.fc = nn.Linear(512 * BasicBlock.expansion, config.num_classes)
50
 
51
  def _make_layer(self, block, out_channels, blocks, stride=1):
52
  downsample = None
53
- if stride != 1 or self.in_channels != out_channels * block.expansion:
54
  downsample = nn.Sequential(
55
- nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
56
- nn.BatchNorm2d(out_channels * block.expansion),
57
  )
58
- layers = []
59
- layers.append(block(self.in_channels, out_channels, stride, downsample))
60
- self.in_channels = out_channels * block.expansion
61
  for _ in range(1, blocks):
62
  layers.append(block(self.in_channels, out_channels))
63
  return nn.Sequential(*layers)
@@ -72,9 +71,4 @@ class ResNet(PreTrainedModel):
72
  x = torch.flatten(self.avgpool(x), 1)
73
  logits = self.fc(x)
74
 
75
- loss = None
76
- if labels is not None:
77
- loss_fct = nn.CrossEntropyLoss()
78
- loss = loss_fct(logits, labels)
79
-
80
- return {"loss": loss, "logits": logits} if loss is not labels else logits
 
4
 
5
  class ResNetConfig(PretrainedConfig):
6
  model_type = "custom_resnet"
7
+ def __init__(self, num_classes=10, num_channels=3, **kwargs):
8
  super().__init__(**kwargs)
9
  self.num_classes = num_classes
10
+ self.num_channels = num_channels
11
 
12
  class BasicBlock(nn.Module):
13
  expansion = 1
 
35
  def __init__(self, config):
36
  super().__init__(config)
37
  self.in_channels = 64
38
+ self.conv1 = nn.Conv2d(config.num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
 
39
  self.bn1 = nn.BatchNorm2d(64)
40
  self.relu = nn.ReLU(inplace=True)
41
  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
 
46
  self.layer4 = self._make_layer(BasicBlock, 512, 3, stride=2)
47
 
48
  self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
49
+ self.fc = nn.Linear(512, config.num_classes)
50
 
51
  def _make_layer(self, block, out_channels, blocks, stride=1):
52
  downsample = None
53
+ if stride != 1 or self.in_channels != out_channels:
54
  downsample = nn.Sequential(
55
+ nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
56
+ nn.BatchNorm2d(out_channels),
57
  )
58
+ layers = [block(self.in_channels, out_channels, stride, downsample)]
59
+ self.in_channels = out_channels
 
60
  for _ in range(1, blocks):
61
  layers.append(block(self.in_channels, out_channels))
62
  return nn.Sequential(*layers)
 
71
  x = torch.flatten(self.avgpool(x), 1)
72
  logits = self.fc(x)
73
 
74
+ return {"logits": logits}