LearnWaterFlow commited on
Commit
15a2251
·
verified ·
1 Parent(s): e95062a

Update modeling_resnet.py

Browse files
Files changed (1) hide show
  1. modeling_resnet.py +13 -3
modeling_resnet.py CHANGED
@@ -1,13 +1,14 @@
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, PretrainedConfig
 
4
 
5
  class ResNetConfig(PretrainedConfig):
6
  model_type = "custom_resnet"
7
  def __init__(self, num_classes=10, num_channels=3, **kwargs):
8
  super().__init__(**kwargs)
9
  self.num_classes = num_classes
10
- self.num_channels = num_channels
11
 
12
  class BasicBlock(nn.Module):
13
  expansion = 1
@@ -61,7 +62,8 @@ class ResNet(PreTrainedModel):
61
  layers.append(block(self.in_channels, out_channels))
62
  return nn.Sequential(*layers)
63
 
64
- def forward(self, x, labels=None):
 
65
  x = self.relu(self.bn1(self.conv1(x)))
66
  x = self.maxpool(x)
67
  x = self.layer1(x)
@@ -71,4 +73,12 @@ class ResNet(PreTrainedModel):
71
  x = torch.flatten(self.avgpool(x), 1)
72
  logits = self.fc(x)
73
 
74
- return {"logits": logits}
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, PretrainedConfig
4
+ from transformers.modeling_outputs import ImageClassifierOutput
5
 
6
  class ResNetConfig(PretrainedConfig):
7
  model_type = "custom_resnet"
8
  def __init__(self, num_classes=10, num_channels=3, **kwargs):
9
  super().__init__(**kwargs)
10
  self.num_classes = num_classes
11
+ self.num_channels = num_channels
12
 
13
  class BasicBlock(nn.Module):
14
  expansion = 1
 
62
  layers.append(block(self.in_channels, out_channels))
63
  return nn.Sequential(*layers)
64
 
65
+ def forward(self, pixel_values=None, labels=None, **kwargs):
66
+ x = pixel_values
67
  x = self.relu(self.bn1(self.conv1(x)))
68
  x = self.maxpool(x)
69
  x = self.layer1(x)
 
73
  x = torch.flatten(self.avgpool(x), 1)
74
  logits = self.fc(x)
75
 
76
+ loss = None
77
+ if labels is not None:
78
+ loss_fct = nn.CrossEntropyLoss()
79
+ loss = loss_fct(logits, labels)
80
+
81
+ return ImageClassifierOutput(
82
+ loss=loss,
83
+ logits=logits
84
+ )