| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Image Annotation/Search for COCO with Pytorch |
| | """ |
| | from __future__ import absolute_import, division, unicode_literals |
| |
|
| | import logging |
| | import copy |
| | import numpy as np |
| |
|
| | import torch |
| | from torch import nn |
| | from torch.autograd import Variable |
| | import torch.optim as optim |
| |
|
| |
|
| | class COCOProjNet(nn.Module): |
| | def __init__(self, config): |
| | super(COCOProjNet, self).__init__() |
| | self.imgdim = config['imgdim'] |
| | self.sentdim = config['sentdim'] |
| | self.projdim = config['projdim'] |
| | self.imgproj = nn.Sequential( |
| | nn.Linear(self.imgdim, self.projdim), |
| | ) |
| | self.sentproj = nn.Sequential( |
| | nn.Linear(self.sentdim, self.projdim), |
| | ) |
| |
|
| | def forward(self, img, sent, imgc, sentc): |
| | |
| | |
| | |
| | |
| | img = img.unsqueeze(1).expand_as(imgc).contiguous() |
| | img = img.view(-1, self.imgdim) |
| | imgc = imgc.view(-1, self.imgdim) |
| | sent = sent.unsqueeze(1).expand_as(sentc).contiguous() |
| | sent = sent.view(-1, self.sentdim) |
| | sentc = sentc.view(-1, self.sentdim) |
| |
|
| | imgproj = self.imgproj(img) |
| | imgproj = imgproj / torch.sqrt(torch.pow(imgproj, 2).sum(1, keepdim=True)).expand_as(imgproj) |
| | imgcproj = self.imgproj(imgc) |
| | imgcproj = imgcproj / torch.sqrt(torch.pow(imgcproj, 2).sum(1, keepdim=True)).expand_as(imgcproj) |
| | sentproj = self.sentproj(sent) |
| | sentproj = sentproj / torch.sqrt(torch.pow(sentproj, 2).sum(1, keepdim=True)).expand_as(sentproj) |
| | sentcproj = self.sentproj(sentc) |
| | sentcproj = sentcproj / torch.sqrt(torch.pow(sentcproj, 2).sum(1, keepdim=True)).expand_as(sentcproj) |
| | |
| |
|
| | anchor1 = torch.sum((imgproj*sentproj), 1) |
| | anchor2 = torch.sum((sentproj*imgproj), 1) |
| | img_sentc = torch.sum((imgproj*sentcproj), 1) |
| | sent_imgc = torch.sum((sentproj*imgcproj), 1) |
| |
|
| | |
| | return anchor1, anchor2, img_sentc, sent_imgc |
| |
|
| | def proj_sentence(self, sent): |
| | output = self.sentproj(sent) |
| | output = output / torch.sqrt(torch.pow(output, 2).sum(1, keepdim=True)).expand_as(output) |
| | return output |
| |
|
| | def proj_image(self, img): |
| | output = self.imgproj(img) |
| | output = output / torch.sqrt(torch.pow(output, 2).sum(1, keepdim=True)).expand_as(output) |
| | return output |
| |
|
| |
|
| | class PairwiseRankingLoss(nn.Module): |
| | """ |
| | Pairwise ranking loss |
| | """ |
| | def __init__(self, margin): |
| | super(PairwiseRankingLoss, self).__init__() |
| | self.margin = margin |
| |
|
| | def forward(self, anchor1, anchor2, img_sentc, sent_imgc): |
| |
|
| | cost_sent = torch.clamp(self.margin - anchor1 + img_sentc, |
| | min=0.0).sum() |
| | cost_img = torch.clamp(self.margin - anchor2 + sent_imgc, |
| | min=0.0).sum() |
| | loss = cost_sent + cost_img |
| | return loss |
| |
|
| |
|
| | class ImageSentenceRankingPytorch(object): |
| | |
| | def __init__(self, train, valid, test, config): |
| | |
| | self.seed = config['seed'] |
| | np.random.seed(self.seed) |
| | torch.manual_seed(self.seed) |
| | torch.cuda.manual_seed(self.seed) |
| |
|
| | self.train = train |
| | self.valid = valid |
| | self.test = test |
| |
|
| | self.imgdim = len(train['imgfeat'][0]) |
| | self.sentdim = len(train['sentfeat'][0]) |
| | self.projdim = config['projdim'] |
| | self.margin = config['margin'] |
| |
|
| | self.batch_size = 128 |
| | self.ncontrast = 30 |
| | self.maxepoch = 20 |
| | self.early_stop = True |
| |
|
| | config_model = {'imgdim': self.imgdim,'sentdim': self.sentdim, |
| | 'projdim': self.projdim} |
| | self.model = COCOProjNet(config_model).cuda() |
| |
|
| | self.loss_fn = PairwiseRankingLoss(margin=self.margin).cuda() |
| |
|
| | self.optimizer = optim.Adam(self.model.parameters()) |
| |
|
| | def prepare_data(self, trainTxt, trainImg, devTxt, devImg, |
| | testTxt, testImg): |
| | trainTxt = torch.FloatTensor(trainTxt) |
| | trainImg = torch.FloatTensor(trainImg) |
| | devTxt = torch.FloatTensor(devTxt).cuda() |
| | devImg = torch.FloatTensor(devImg).cuda() |
| | testTxt = torch.FloatTensor(testTxt).cuda() |
| | testImg = torch.FloatTensor(testImg).cuda() |
| |
|
| | return trainTxt, trainImg, devTxt, devImg, testTxt, testImg |
| |
|
| | def run(self): |
| | self.nepoch = 0 |
| | bestdevscore = -1 |
| | early_stop_count = 0 |
| | stop_train = False |
| |
|
| | |
| | logging.info('prepare data') |
| | trainTxt, trainImg, devTxt, devImg, testTxt, testImg = \ |
| | self.prepare_data(self.train['sentfeat'], self.train['imgfeat'], |
| | self.valid['sentfeat'], self.valid['imgfeat'], |
| | self.test['sentfeat'], self.test['imgfeat']) |
| |
|
| | |
| | while not stop_train and self.nepoch <= self.maxepoch: |
| | logging.info('start epoch') |
| | self.trainepoch(trainTxt, trainImg, devTxt, devImg, nepoches=1) |
| | logging.info('Epoch {0} finished'.format(self.nepoch)) |
| |
|
| | results = {'i2t': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0}, |
| | 't2i': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0}, |
| | 'dev': bestdevscore} |
| | score = 0 |
| | for i in range(5): |
| | devTxt_i = devTxt[i*5000:(i+1)*5000] |
| | devImg_i = devImg[i*5000:(i+1)*5000] |
| | |
| | r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(devImg_i, |
| | devTxt_i) |
| | results['i2t']['r1'] += r1_i2t / 5 |
| | results['i2t']['r5'] += r5_i2t / 5 |
| | results['i2t']['r10'] += r10_i2t / 5 |
| | results['i2t']['medr'] += medr_i2t / 5 |
| | logging.info("Image to text: {0}, {1}, {2}, {3}" |
| | .format(r1_i2t, r5_i2t, r10_i2t, medr_i2t)) |
| | |
| | r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(devImg_i, |
| | devTxt_i) |
| | results['t2i']['r1'] += r1_t2i / 5 |
| | results['t2i']['r5'] += r5_t2i / 5 |
| | results['t2i']['r10'] += r10_t2i / 5 |
| | results['t2i']['medr'] += medr_t2i / 5 |
| | logging.info("Text to Image: {0}, {1}, {2}, {3}" |
| | .format(r1_t2i, r5_t2i, r10_t2i, medr_t2i)) |
| | score += (r1_i2t + r5_i2t + r10_i2t + |
| | r1_t2i + r5_t2i + r10_t2i) / 5 |
| |
|
| | logging.info("Dev mean Text to Image: {0}, {1}, {2}, {3}".format( |
| | results['t2i']['r1'], results['t2i']['r5'], |
| | results['t2i']['r10'], results['t2i']['medr'])) |
| | logging.info("Dev mean Image to text: {0}, {1}, {2}, {3}".format( |
| | results['i2t']['r1'], results['i2t']['r5'], |
| | results['i2t']['r10'], results['i2t']['medr'])) |
| |
|
| | |
| | if score > bestdevscore: |
| | bestdevscore = score |
| | bestmodel = copy.deepcopy(self.model) |
| | elif self.early_stop: |
| | if early_stop_count >= 3: |
| | stop_train = True |
| | early_stop_count += 1 |
| | self.model = bestmodel |
| |
|
| | |
| | results = {'i2t': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0}, |
| | 't2i': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0}, |
| | 'dev': bestdevscore} |
| | for i in range(5): |
| | testTxt_i = testTxt[i*5000:(i+1)*5000] |
| | testImg_i = testImg[i*5000:(i+1)*5000] |
| | |
| | r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(testImg_i, testTxt_i) |
| | results['i2t']['r1'] += r1_i2t / 5 |
| | results['i2t']['r5'] += r5_i2t / 5 |
| | results['i2t']['r10'] += r10_i2t / 5 |
| | results['i2t']['medr'] += medr_i2t / 5 |
| | |
| | r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(testImg_i, testTxt_i) |
| | results['t2i']['r1'] += r1_t2i / 5 |
| | results['t2i']['r5'] += r5_t2i / 5 |
| | results['t2i']['r10'] += r10_t2i / 5 |
| | results['t2i']['medr'] += medr_t2i / 5 |
| |
|
| | return bestdevscore, results['i2t']['r1'], results['i2t']['r5'], \ |
| | results['i2t']['r10'], results['i2t']['medr'], \ |
| | results['t2i']['r1'], results['t2i']['r5'], \ |
| | results['t2i']['r10'], results['t2i']['medr'] |
| |
|
| | def trainepoch(self, trainTxt, trainImg, devTxt, devImg, nepoches=1): |
| | self.model.train() |
| | for _ in range(self.nepoch, self.nepoch + nepoches): |
| | permutation = list(np.random.permutation(len(trainTxt))) |
| | all_costs = [] |
| | for i in range(0, len(trainTxt), self.batch_size): |
| | |
| | if i % (self.batch_size*500) == 0 and i > 0: |
| | logging.info('samples : {0}'.format(i)) |
| | r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(devImg, |
| | devTxt) |
| | logging.info("Image to text: {0}, {1}, {2}, {3}".format( |
| | r1_i2t, r5_i2t, r10_i2t, medr_i2t)) |
| | |
| | r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(devImg, |
| | devTxt) |
| | logging.info("Text to Image: {0}, {1}, {2}, {3}".format( |
| | r1_t2i, r5_t2i, r10_t2i, medr_t2i)) |
| | idx = torch.LongTensor(permutation[i:i + self.batch_size]) |
| | imgbatch = Variable(trainImg.index_select(0, idx)).cuda() |
| | sentbatch = Variable(trainTxt.index_select(0, idx)).cuda() |
| |
|
| | idximgc = np.random.choice(permutation[:i] + |
| | permutation[i + self.batch_size:], |
| | self.ncontrast*idx.size(0)) |
| | idxsentc = np.random.choice(permutation[:i] + |
| | permutation[i + self.batch_size:], |
| | self.ncontrast*idx.size(0)) |
| | idximgc = torch.LongTensor(idximgc) |
| | idxsentc = torch.LongTensor(idxsentc) |
| | |
| | imgcbatch = Variable(trainImg.index_select(0, idximgc)).view( |
| | -1, self.ncontrast, self.imgdim).cuda() |
| | sentcbatch = Variable(trainTxt.index_select(0, idxsentc)).view( |
| | -1, self.ncontrast, self.sentdim).cuda() |
| |
|
| | anchor1, anchor2, img_sentc, sent_imgc = self.model( |
| | imgbatch, sentbatch, imgcbatch, sentcbatch) |
| | |
| | loss = self.loss_fn(anchor1, anchor2, img_sentc, sent_imgc) |
| | all_costs.append(loss.data.item()) |
| | |
| | self.optimizer.zero_grad() |
| | loss.backward() |
| | |
| | self.optimizer.step() |
| | self.nepoch += nepoches |
| |
|
| | def t2i(self, images, captions): |
| | """ |
| | Images: (5N, imgdim) matrix of images |
| | Captions: (5N, sentdim) matrix of captions |
| | """ |
| | with torch.no_grad(): |
| | |
| | img_embed, sent_embed = [], [] |
| | for i in range(0, len(images), self.batch_size): |
| | img_embed.append(self.model.proj_image( |
| | Variable(images[i:i + self.batch_size]))) |
| | sent_embed.append(self.model.proj_sentence( |
| | Variable(captions[i:i + self.batch_size]))) |
| | img_embed = torch.cat(img_embed, 0).data |
| | sent_embed = torch.cat(sent_embed, 0).data |
| |
|
| | npts = int(img_embed.size(0) / 5) |
| | idxs = torch.cuda.LongTensor(range(0, len(img_embed), 5)) |
| | ims = img_embed.index_select(0, idxs) |
| |
|
| | ranks = np.zeros(5 * npts) |
| | for index in range(npts): |
| |
|
| | |
| | queries = sent_embed[5*index: 5*index + 5] |
| |
|
| | |
| | scores = torch.mm(queries, ims.transpose(0, 1)).cpu().numpy() |
| | inds = np.zeros(scores.shape) |
| | for i in range(len(inds)): |
| | inds[i] = np.argsort(scores[i])[::-1] |
| | ranks[5 * index + i] = np.where(inds[i] == index)[0][0] |
| |
|
| | |
| | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) |
| | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) |
| | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) |
| | medr = np.floor(np.median(ranks)) + 1 |
| | return (r1, r5, r10, medr) |
| |
|
| | def i2t(self, images, captions): |
| | """ |
| | Images: (5N, imgdim) matrix of images |
| | Captions: (5N, sentdim) matrix of captions |
| | """ |
| | with torch.no_grad(): |
| | |
| | img_embed, sent_embed = [], [] |
| | for i in range(0, len(images), self.batch_size): |
| | img_embed.append(self.model.proj_image( |
| | Variable(images[i:i + self.batch_size]))) |
| | sent_embed.append(self.model.proj_sentence( |
| | Variable(captions[i:i + self.batch_size]))) |
| | img_embed = torch.cat(img_embed, 0).data |
| | sent_embed = torch.cat(sent_embed, 0).data |
| |
|
| | npts = int(img_embed.size(0) / 5) |
| | index_list = [] |
| |
|
| | ranks = np.zeros(npts) |
| | for index in range(npts): |
| |
|
| | |
| | query_img = img_embed[5 * index] |
| |
|
| | |
| | scores = torch.mm(query_img.view(1, -1), |
| | sent_embed.transpose(0, 1)).view(-1) |
| | scores = scores.cpu().numpy() |
| | inds = np.argsort(scores)[::-1] |
| | index_list.append(inds[0]) |
| |
|
| | |
| | rank = 1e20 |
| | for i in range(5*index, 5*index + 5, 1): |
| | tmp = np.where(inds == i)[0][0] |
| | if tmp < rank: |
| | rank = tmp |
| | ranks[index] = rank |
| |
|
| | |
| | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) |
| | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) |
| | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) |
| | medr = np.floor(np.median(ranks)) + 1 |
| | return (r1, r5, r10, medr) |
| |
|