Baaz / code /miscc /losses.py
SrinivasMudiraj's picture
Upload 34 files
9c7aa86 verified
import torch
import torch.nn as nn
import numpy as np
from miscc.config import cfg
from GlobalAttention import func_attention
# ##################Loss for matching text-image###################
def cosine_similarity(x1, x2, dim=1, eps=1e-8):
"""Returns cosine similarity between x1 and x2, computed along dim.
"""
w12 = torch.sum(x1 * x2, dim)
w1 = torch.norm(x1, 2, dim)
w2 = torch.norm(x2, 2, dim)
return (w12 / (w1 * w2).clamp(min=eps)).squeeze()
def sent_loss(cnn_code, rnn_code, labels, class_ids,
batch_size, eps=1e-8):
# ### Mask mis-match samples ###
# that come from the same class as the real sample ###
masks = []
if class_ids is not None:
for i in range(batch_size):
mask = (class_ids == class_ids[i]).astype(np.uint8)
mask[i] = 0
masks.append(mask.reshape((1, -1)))
masks = np.concatenate(masks, 0)
# masks: batch_size x batch_size
masks = torch.ByteTensor(masks)
if cfg.CUDA:
masks = masks.cuda()
# --> seq_len x batch_size x nef
if cnn_code.dim() == 2:
cnn_code = cnn_code.unsqueeze(0)
rnn_code = rnn_code.unsqueeze(0)
# cnn_code_norm / rnn_code_norm: seq_len x batch_size x 1
cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True)
rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True)
# scores* / norm*: seq_len x batch_size x batch_size
scores0 = torch.bmm(cnn_code, rnn_code.transpose(1, 2))
norm0 = torch.bmm(cnn_code_norm, rnn_code_norm.transpose(1, 2))
scores0 = scores0 / norm0.clamp(min=eps) * cfg.TRAIN.SMOOTH.GAMMA3
# --> batch_size x batch_size
scores0 = scores0.squeeze()
if class_ids is not None:
scores0.data.masked_fill_(masks, -float('inf'))
scores1 = scores0.transpose(0, 1)
if labels is not None:
loss0 = nn.CrossEntropyLoss()(scores0, labels)
loss1 = nn.CrossEntropyLoss()(scores1, labels)
else:
loss0, loss1 = None, None
return loss0, loss1
def words_loss(img_features, words_emb, labels,
cap_lens, class_ids, batch_size):
"""
words_emb(query): batch x nef x seq_len
img_features(context): batch x nef x 17 x 17
"""
masks = []
att_maps = []
similarities = []
cap_lens = cap_lens.data.tolist()
for i in range(batch_size):
if class_ids is not None:
mask = (class_ids == class_ids[i]).astype(np.uint8)
mask[i] = 0
masks.append(mask.reshape((1, -1)))
# Get the i-th text description
words_num = cap_lens[i]
# -> 1 x nef x words_num
word = words_emb[i, :, :words_num].unsqueeze(0).contiguous()
# -> batch_size x nef x words_num
word = word.repeat(batch_size, 1, 1)
# batch x nef x 17*17
context = img_features
"""
word(query): batch x nef x words_num
context: batch x nef x 17 x 17
weiContext: batch x nef x words_num
attn: batch x words_num x 17 x 17
"""
weiContext, attn = func_attention(word, context, cfg.TRAIN.SMOOTH.GAMMA1)
att_maps.append(attn[i].unsqueeze(0).contiguous())
# --> batch_size x words_num x nef
word = word.transpose(1, 2).contiguous()
weiContext = weiContext.transpose(1, 2).contiguous()
# --> batch_size*words_num x nef
word = word.view(batch_size * words_num, -1)
weiContext = weiContext.view(batch_size * words_num, -1)
#
# -->batch_size*words_num
row_sim = cosine_similarity(word, weiContext)
# --> batch_size x words_num
row_sim = row_sim.view(batch_size, words_num)
# Eq. (10)
row_sim.mul_(cfg.TRAIN.SMOOTH.GAMMA2).exp_()
row_sim = row_sim.sum(dim=1, keepdim=True)
row_sim = torch.log(row_sim)
# --> 1 x batch_size
# similarities(i, j): the similarity between the i-th image and the j-th text description
similarities.append(row_sim)
# batch_size x batch_size
similarities = torch.cat(similarities, 1)
if class_ids is not None:
masks = np.concatenate(masks, 0)
# masks: batch_size x batch_size
masks = torch.ByteTensor(masks)
if cfg.CUDA:
masks = masks.cuda()
similarities = similarities * cfg.TRAIN.SMOOTH.GAMMA3
if class_ids is not None:
similarities.data.masked_fill_(masks, -float('inf'))
similarities1 = similarities.transpose(0, 1)
if labels is not None:
loss0 = nn.CrossEntropyLoss()(similarities, labels)
loss1 = nn.CrossEntropyLoss()(similarities1, labels)
else:
loss0, loss1 = None, None
return loss0, loss1, att_maps
# ##################Loss for G and Ds##############################
def discriminator_loss(netD, real_imgs, fake_imgs, conditions,
real_labels, fake_labels):
# Forward
real_features = netD(real_imgs)
fake_features = netD(fake_imgs.detach())
# loss
#
cond_real_logits = netD.COND_DNET(real_features, conditions)
cond_real_errD = nn.BCELoss()(cond_real_logits, real_labels)
cond_fake_logits = netD.COND_DNET(fake_features, conditions)
cond_fake_errD = nn.BCELoss()(cond_fake_logits, fake_labels)
#
batch_size = real_features.size(0)
cond_wrong_logits = netD.COND_DNET(real_features[:(batch_size - 1)], conditions[1:batch_size])
cond_wrong_errD = nn.BCELoss()(cond_wrong_logits, fake_labels[1:batch_size])
if netD.UNCOND_DNET is not None:
real_logits = netD.UNCOND_DNET(real_features)
fake_logits = netD.UNCOND_DNET(fake_features)
real_errD = nn.BCELoss()(real_logits, real_labels)
fake_errD = nn.BCELoss()(fake_logits, fake_labels)
errD = ((real_errD + cond_real_errD) / 2. +
(fake_errD + cond_fake_errD + cond_wrong_errD) / 3.)
else:
errD = cond_real_errD + (cond_fake_errD + cond_wrong_errD) / 2.
log = 'Real_Acc: {:.4f} Fake_Acc: {:.4f} '.format(torch.mean(real_logits).item(), torch.mean(fake_logits).item())
return errD, log
def generator_loss(netsD, image_encoder, fake_imgs, real_labels,
words_embs, sent_emb, match_labels,
cap_lens, class_ids):
numDs = len(netsD)
batch_size = real_labels.size(0)
logs = ''
# Forward
errG_total = 0
for i in range(numDs):
features = netsD[i](fake_imgs[i])
cond_logits = netsD[i].COND_DNET(features, sent_emb)
cond_errG = nn.BCELoss()(cond_logits, real_labels)
if netsD[i].UNCOND_DNET is not None:
logits = netsD[i].UNCOND_DNET(features)
errG = nn.BCELoss()(logits, real_labels)
g_loss = errG + cond_errG
else:
g_loss = cond_errG
errG_total += g_loss
# err_img = errG_total.data[0]
logs += 'g_loss%d: %.2f ' % (i, g_loss.item())
# Ranking loss
if i == (numDs - 1):
# words_features: batch_size x nef x 17 x 17
# sent_code: batch_size x nef
region_features, cnn_code = image_encoder(fake_imgs[i])
w_loss0, w_loss1, _ = words_loss(region_features, words_embs,
match_labels, cap_lens,
class_ids, batch_size)
w_loss = (w_loss0 + w_loss1) * cfg.TRAIN.SMOOTH.LAMBDA
# err_words = err_words + w_loss.data[0]
s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb,
match_labels, class_ids, batch_size)
s_loss = (s_loss0 + s_loss1) * cfg.TRAIN.SMOOTH.LAMBDA
# err_sent = err_sent + s_loss.data[0]
errG_total += w_loss + s_loss
logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.item(), s_loss.item())
#
# # Ranking loss
# # words_features: batch_size x nef x 17 x 17
# # sent_code: batch_size x nef
# region_features, cnn_code = image_encoder(fake_imgs[i])
# w_loss0, w_loss1, _ = words_loss(region_features, words_embs,
# match_labels, cap_lens,
# class_ids, batch_size)
# w_loss = (w_loss0 + w_loss1) * cfg.TRAIN.SMOOTH.LAMBDA
# # err_words = err_words + w_loss.data[0]
#
# s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb,
# match_labels, class_ids, batch_size)
# s_loss = (s_loss0 + s_loss1) * cfg.TRAIN.SMOOTH.LAMBDA
# # err_sent = err_sent + s_loss.data[0]
#
# errG_total += w_loss + s_loss
# logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.item(), s_loss.item())
return errG_total, logs
##################################################################
def KL_loss(mu, logvar):
# -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
KLD = torch.mean(KLD_element).mul_(-0.5)
return KLD