File size: 839 Bytes
bc33fb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch.nn as nn
import torch
import torch.nn.functional as F
from transformers import AutoFeatureExtractor, AutoModel
from transformers.modeling_outputs import ImageClassifierOutput

class ResNetFN(nn.Module):
  def __init__(self):
    super(ResNetFN, self).__init__()
    self.resnet = AutoModel.from_pretrained('microsoft/resnet-50')
    self.fc1 = nn.Linear(2048, 512)
    self.fc2 = nn.Linear(512, 2)
    
  def forward(self, pixel_values, labels=None):
    x1 = self.resnet(pixel_values=pixel_values)
    x2 = F.relu(self.fc1(x1.pooler_output.squeeze(-1).squeeze(-1)))
    x3 = self.fc2(x2)
    loss_func = nn.BCEWithLogitsLoss()
    loss = None
    if labels != None:
      onehot_labels = F.one_hot(labels, num_classes=2)
      loss = loss_func(x3, onehot_labels.float())
    return ImageClassifierOutput(loss=loss, logits=x3)