| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| class VisualExtractor(nn.Module): | |
| def __init__(self, args): | |
| super(VisualExtractor, self).__init__() | |
| self.cov1x1 = nn.Conv2d(in_channels=2048, out_channels=args.nhidden, kernel_size=(1, 1)) | |
| self.visual_extractor = args.visual_extractor | |
| self.pretrained = args.visual_extractor_pretrained | |
| model = getattr(models, self.visual_extractor)(pretrained=self.pretrained) | |
| modules = list(model.children())[:-2] | |
| self.model = nn.Sequential(*modules) | |
| self.avg_fnt = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0) | |
| if self.pretrained is True: print('first init the imagenet pretrained!') | |
| def forward(self, images): | |
| patch_feats = self.model(images) | |
| att_feat_it = self.cov1x1(patch_feats) | |
| avg_feat_it = self.avg_fnt(att_feat_it).squeeze().reshape(-1, att_feat_it.size(1)) | |
| avg_feats = self.avg_fnt(patch_feats).squeeze().reshape(-1, patch_feats.size(1)) | |
| batch_size, feat_size, _, _ = patch_feats.shape | |
| patch_feats = patch_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1) | |
| return patch_feats, avg_feats, att_feat_it, avg_feat_it | |