Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch, torchvision | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| from utils.utils_word_embedding import initialize_wordembedding_matrix | |
| class Backbone(nn.Module): | |
| def __init__(self, backbone='resnet18'): | |
| super(Backbone, self).__init__() | |
| if backbone == 'resnet18': | |
| resnet = torchvision.models.resnet.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1) | |
| elif backbone == 'resnet50': | |
| resnet = torchvision.models.resnet.resnet50(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1) | |
| elif backbone == 'resnet101': | |
| resnet = torchvision.models.resnet.resnet101(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1) | |
| self.block0 = nn.Sequential( | |
| resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, | |
| ) | |
| self.block1 = resnet.layer1 | |
| self.block2 = resnet.layer2 | |
| self.block3 = resnet.layer3 | |
| self.block4 = resnet.layer4 | |
| def forward(self, x, returned=[4]): | |
| blocks = [self.block0(x)] | |
| blocks.append(self.block1(blocks[-1])) | |
| blocks.append(self.block2(blocks[-1])) | |
| blocks.append(self.block3(blocks[-1])) | |
| blocks.append(self.block4(blocks[-1])) | |
| out = [blocks[i] for i in returned] | |
| return out | |
| class CosineClassifier(nn.Module): | |
| def __init__(self, temp=0.05): | |
| super(CosineClassifier, self).__init__() | |
| self.temp = temp | |
| def forward(self, img, concept, scale=True): | |
| """ | |
| img: (bs, emb_dim) | |
| concept: (n_class, emb_dim) | |
| """ | |
| img_norm = F.normalize(img, dim=-1) | |
| concept_norm = F.normalize(concept, dim=-1) | |
| pred = torch.matmul(img_norm, concept_norm.transpose(0, 1)) | |
| if scale: | |
| pred = pred / self.temp | |
| return pred | |
| class Embedder(nn.Module): | |
| """ | |
| Text and Visual Embedding Model. | |
| """ | |
| def __init__(self, | |
| type_name, | |
| feat_dim = 512, | |
| mid_dim = 1024, | |
| out_dim = 324, | |
| drop_rate = 0.35, | |
| cosine_cls_temp = 0.05, | |
| wordembs = 'glove', | |
| extractor_name = 'resnet18'): | |
| super(Embedder, self).__init__() | |
| mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | |
| self.type_name = type_name | |
| self.feat_dim = feat_dim | |
| self.mid_dim = mid_dim | |
| self.out_dim = out_dim | |
| self.drop_rate = drop_rate | |
| self.cosine_cls_temp = cosine_cls_temp | |
| self.wordembs = wordembs | |
| self.extractor_name = extractor_name | |
| self.transform = transforms.Normalize(mean, std) | |
| self._setup_word_embedding() | |
| self._setup_image_embedding() | |
| def _setup_image_embedding(self): | |
| # image embedding | |
| self.feat_extractor = Backbone(self.extractor_name) | |
| img_emb_modules = [ | |
| nn.Conv2d(self.feat_dim, self.mid_dim, kernel_size=1, bias=False), | |
| nn.BatchNorm2d(self.mid_dim), | |
| nn.ReLU() | |
| ] | |
| if self.drop_rate > 0: | |
| img_emb_modules += [nn.Dropout2d(self.drop_rate)] | |
| self.img_embedder = nn.Sequential(*img_emb_modules) | |
| self.img_avg_pool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.img_final = nn.Linear(self.mid_dim, self.out_dim) | |
| self.classifier = CosineClassifier(temp=self.cosine_cls_temp) | |
| def _setup_word_embedding(self): | |
| self.type2idx = {self.type_name[i]: i for i in range(len(self.type_name))} | |
| self.num_type = len(self.type_name) | |
| train_type = [self.type2idx[type_i] for type_i in self.type_name] | |
| self.train_type = torch.LongTensor(train_type).to("cuda" if torch.cuda.is_available() else "cpu") | |
| wordemb, self.word_dim = \ | |
| initialize_wordembedding_matrix(self.wordembs, self.type_name) | |
| self.embedder = nn.Embedding(self.num_type, self.word_dim) | |
| self.embedder.weight.data.copy_(wordemb) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(self.word_dim, self.out_dim), | |
| nn.ReLU(True) | |
| ) | |
| def train_forward(self, batch): | |
| scene, img = batch[0], self.transform(batch[1]) | |
| bs = img.shape[0] | |
| # word embedding | |
| scene_emb = self.embedder(self.train_type) | |
| scene_weight = self.mlp(scene_emb) | |
| #image embedding | |
| img = self.feat_extractor(img)[0] | |
| img = self.img_embedder(img) | |
| img = self.img_avg_pool(img).squeeze(3).squeeze(2) | |
| img = self.img_final(img) | |
| pred = self.classifier(img, scene_weight) | |
| label_loss = F.cross_entropy(pred, scene) | |
| pred = torch.max(pred, dim=1)[1] | |
| type_pred = self.train_type[pred] | |
| correct_type = (type_pred == scene) | |
| out = { | |
| 'loss_total': label_loss, | |
| 'acc_type': torch.div(correct_type.sum(),float(bs)), | |
| } | |
| return out | |
| def image_encoder_forward(self, batch): | |
| img = self.transform(batch) | |
| # word embedding | |
| scene_emb = self.embedder(self.train_type) | |
| scene_weight = self.mlp(scene_emb) | |
| #image embedding | |
| img = self.feat_extractor(img)[0] | |
| bs, _, h, w = img.shape | |
| img = self.img_embedder(img) | |
| img = self.img_avg_pool(img).squeeze(3).squeeze(2) | |
| img = self.img_final(img) | |
| pred = self.classifier(img, scene_weight) | |
| pred = torch.max(pred, dim=1)[1] | |
| out_embedding = torch.zeros((bs,self.out_dim)).to("cuda" if torch.cuda.is_available() else "cpu") | |
| for i in range(bs): | |
| out_embedding[i,:] = scene_weight[pred[i],:] | |
| num_type = self.train_type[pred] | |
| text_type = [self.type_name[num_type[i]] for i in range(bs)] | |
| return out_embedding, num_type, text_type | |
| def text_encoder_forward(self, text): | |
| bs = len(text) | |
| # word embedding | |
| scene_emb = self.embedder(self.train_type) | |
| scene_weight = self.mlp(scene_emb) | |
| num_type = torch.zeros((bs)).to("cuda" if torch.cuda.is_available() else "cpu") | |
| for i in range(bs): | |
| num_type[i] = self.type2idx[text[i]] | |
| out_embedding = torch.zeros((bs,self.out_dim)).to("cuda" if torch.cuda.is_available() else "cpu") | |
| for i in range(bs): | |
| out_embedding[i,:] = scene_weight[int(num_type[i]),:] | |
| text_type = text | |
| return out_embedding, num_type, text_type | |
| def text_idx_encoder_forward(self, idx): | |
| bs = idx.shape[0] | |
| # word embedding | |
| scene_emb = self.embedder(self.train_type) | |
| scene_weight = self.mlp(scene_emb) | |
| num_type = idx | |
| out_embedding = torch.zeros((bs,self.out_dim)).to("cuda" if torch.cuda.is_available() else "cpu") | |
| for i in range(bs): | |
| out_embedding[i,:] = scene_weight[int(num_type[i]),:] | |
| return out_embedding | |
| def contrast_loss_forward(self, batch): | |
| img = self.transform(batch) | |
| #image embedding | |
| img = self.feat_extractor(img)[0] | |
| img = self.img_embedder(img) | |
| img = self.img_avg_pool(img).squeeze(3).squeeze(2) | |
| img = self.img_final(img) | |
| return img | |
| def forward(self, x, type = 'image_encoder'): | |
| if type == 'train': | |
| out = self.train_forward(x) | |
| elif type == 'image_encoder': | |
| with torch.no_grad(): | |
| out = self.image_encoder_forward(x) | |
| elif type == 'text_encoder': | |
| out = self.text_encoder_forward(x) | |
| elif type == 'text_idx_encoder': | |
| out = self.text_idx_encoder_forward(x) | |
| elif type == 'visual_embed': | |
| x = F.interpolate(x,size=(224,224),mode='bilinear') | |
| out = self.contrast_loss_forward(x) | |
| return out |