File size: 839 Bytes
bc33fb2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | import torch.nn as nn
import torch
import torch.nn.functional as F
from transformers import AutoFeatureExtractor, AutoModel
from transformers.modeling_outputs import ImageClassifierOutput
class ResNetFN(nn.Module):
def __init__(self):
super(ResNetFN, self).__init__()
self.resnet = AutoModel.from_pretrained('microsoft/resnet-50')
self.fc1 = nn.Linear(2048, 512)
self.fc2 = nn.Linear(512, 2)
def forward(self, pixel_values, labels=None):
x1 = self.resnet(pixel_values=pixel_values)
x2 = F.relu(self.fc1(x1.pooler_output.squeeze(-1).squeeze(-1)))
x3 = self.fc2(x2)
loss_func = nn.BCEWithLogitsLoss()
loss = None
if labels != None:
onehot_labels = F.one_hot(labels, num_classes=2)
loss = loss_func(x3, onehot_labels.float())
return ImageClassifierOutput(loss=loss, logits=x3) |