import torch.nn as nn from torchvision import models class EncoderResnet18(nn.Module): def __init__(self, num_classes=50, embed_size=512): super().__init__() model = models.resnet18( weights=models.ResNet18_Weights.DEFAULT ) modules = list(model.children())[:-1] self.backbone = nn.Sequential(*modules) for param in self.backbone.parameters(): param.requires_grad = False self.classifier = nn.Linear( model.fc.in_features, num_classes ) cap_modules = list(model.children())[:-2] self.cap_backbone = nn.Sequential(*cap_modules) for param in self.cap_backbone.parameters(): param.requires_grad = False self.projector = nn.Linear( model.fc.in_features, embed_size ) def forward( self, images, return_features=False ): features = self.backbone(images) features = features.view( features.size(0), -1 ) logits = self.classifier(features) cap_features = self.cap_backbone(images) cap_features = cap_features.flatten(2) cap_features = cap_features.permute(0, 2, 1) cap_features = self.projector(cap_features) # classification if not return_features: return logits # captioning return cap_features