Instructions to use SrinivasMudiraj/Baaz with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use SrinivasMudiraj/Baaz with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("SrinivasMudiraj/Baaz", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from miscc.config import cfg | |
| from GlobalAttention import func_attention | |
| # ##################Loss for matching text-image################### | |
| def cosine_similarity(x1, x2, dim=1, eps=1e-8): | |
| """Returns cosine similarity between x1 and x2, computed along dim. | |
| """ | |
| w12 = torch.sum(x1 * x2, dim) | |
| w1 = torch.norm(x1, 2, dim) | |
| w2 = torch.norm(x2, 2, dim) | |
| return (w12 / (w1 * w2).clamp(min=eps)).squeeze() | |
| def sent_loss(cnn_code, rnn_code, labels, class_ids, | |
| batch_size, eps=1e-8): | |
| # ### Mask mis-match samples ### | |
| # that come from the same class as the real sample ### | |
| masks = [] | |
| if class_ids is not None: | |
| for i in range(batch_size): | |
| mask = (class_ids == class_ids[i]).astype(np.uint8) | |
| mask[i] = 0 | |
| masks.append(mask.reshape((1, -1))) | |
| masks = np.concatenate(masks, 0) | |
| # masks: batch_size x batch_size | |
| masks = torch.ByteTensor(masks) | |
| if cfg.CUDA: | |
| masks = masks.cuda() | |
| # --> seq_len x batch_size x nef | |
| if cnn_code.dim() == 2: | |
| cnn_code = cnn_code.unsqueeze(0) | |
| rnn_code = rnn_code.unsqueeze(0) | |
| # cnn_code_norm / rnn_code_norm: seq_len x batch_size x 1 | |
| cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True) | |
| rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True) | |
| # scores* / norm*: seq_len x batch_size x batch_size | |
| scores0 = torch.bmm(cnn_code, rnn_code.transpose(1, 2)) | |
| norm0 = torch.bmm(cnn_code_norm, rnn_code_norm.transpose(1, 2)) | |
| scores0 = scores0 / norm0.clamp(min=eps) * cfg.TRAIN.SMOOTH.GAMMA3 | |
| # --> batch_size x batch_size | |
| scores0 = scores0.squeeze() | |
| if class_ids is not None: | |
| scores0.data.masked_fill_(masks, -float('inf')) | |
| scores1 = scores0.transpose(0, 1) | |
| if labels is not None: | |
| loss0 = nn.CrossEntropyLoss()(scores0, labels) | |
| loss1 = nn.CrossEntropyLoss()(scores1, labels) | |
| else: | |
| loss0, loss1 = None, None | |
| return loss0, loss1 | |
| def words_loss(img_features, words_emb, labels, | |
| cap_lens, class_ids, batch_size): | |
| """ | |
| words_emb(query): batch x nef x seq_len | |
| img_features(context): batch x nef x 17 x 17 | |
| """ | |
| masks = [] | |
| att_maps = [] | |
| similarities = [] | |
| cap_lens = cap_lens.data.tolist() | |
| for i in range(batch_size): | |
| if class_ids is not None: | |
| mask = (class_ids == class_ids[i]).astype(np.uint8) | |
| mask[i] = 0 | |
| masks.append(mask.reshape((1, -1))) | |
| # Get the i-th text description | |
| words_num = cap_lens[i] | |
| # -> 1 x nef x words_num | |
| word = words_emb[i, :, :words_num].unsqueeze(0).contiguous() | |
| # -> batch_size x nef x words_num | |
| word = word.repeat(batch_size, 1, 1) | |
| # batch x nef x 17*17 | |
| context = img_features | |
| """ | |
| word(query): batch x nef x words_num | |
| context: batch x nef x 17 x 17 | |
| weiContext: batch x nef x words_num | |
| attn: batch x words_num x 17 x 17 | |
| """ | |
| weiContext, attn = func_attention(word, context, cfg.TRAIN.SMOOTH.GAMMA1) | |
| att_maps.append(attn[i].unsqueeze(0).contiguous()) | |
| # --> batch_size x words_num x nef | |
| word = word.transpose(1, 2).contiguous() | |
| weiContext = weiContext.transpose(1, 2).contiguous() | |
| # --> batch_size*words_num x nef | |
| word = word.view(batch_size * words_num, -1) | |
| weiContext = weiContext.view(batch_size * words_num, -1) | |
| # | |
| # -->batch_size*words_num | |
| row_sim = cosine_similarity(word, weiContext) | |
| # --> batch_size x words_num | |
| row_sim = row_sim.view(batch_size, words_num) | |
| # Eq. (10) | |
| row_sim.mul_(cfg.TRAIN.SMOOTH.GAMMA2).exp_() | |
| row_sim = row_sim.sum(dim=1, keepdim=True) | |
| row_sim = torch.log(row_sim) | |
| # --> 1 x batch_size | |
| # similarities(i, j): the similarity between the i-th image and the j-th text description | |
| similarities.append(row_sim) | |
| # batch_size x batch_size | |
| similarities = torch.cat(similarities, 1) | |
| if class_ids is not None: | |
| masks = np.concatenate(masks, 0) | |
| # masks: batch_size x batch_size | |
| masks = torch.ByteTensor(masks) | |
| if cfg.CUDA: | |
| masks = masks.cuda() | |
| similarities = similarities * cfg.TRAIN.SMOOTH.GAMMA3 | |
| if class_ids is not None: | |
| similarities.data.masked_fill_(masks, -float('inf')) | |
| similarities1 = similarities.transpose(0, 1) | |
| if labels is not None: | |
| loss0 = nn.CrossEntropyLoss()(similarities, labels) | |
| loss1 = nn.CrossEntropyLoss()(similarities1, labels) | |
| else: | |
| loss0, loss1 = None, None | |
| return loss0, loss1, att_maps | |
| # ##################Loss for G and Ds############################## | |
| def discriminator_loss(netD, real_imgs, fake_imgs, conditions, | |
| real_labels, fake_labels): | |
| # Forward | |
| real_features = netD(real_imgs) | |
| fake_features = netD(fake_imgs.detach()) | |
| # loss | |
| # | |
| cond_real_logits = netD.COND_DNET(real_features, conditions) | |
| cond_real_errD = nn.BCELoss()(cond_real_logits, real_labels) | |
| cond_fake_logits = netD.COND_DNET(fake_features, conditions) | |
| cond_fake_errD = nn.BCELoss()(cond_fake_logits, fake_labels) | |
| # | |
| batch_size = real_features.size(0) | |
| cond_wrong_logits = netD.COND_DNET(real_features[:(batch_size - 1)], conditions[1:batch_size]) | |
| cond_wrong_errD = nn.BCELoss()(cond_wrong_logits, fake_labels[1:batch_size]) | |
| if netD.UNCOND_DNET is not None: | |
| real_logits = netD.UNCOND_DNET(real_features) | |
| fake_logits = netD.UNCOND_DNET(fake_features) | |
| real_errD = nn.BCELoss()(real_logits, real_labels) | |
| fake_errD = nn.BCELoss()(fake_logits, fake_labels) | |
| errD = ((real_errD + cond_real_errD) / 2. + | |
| (fake_errD + cond_fake_errD + cond_wrong_errD) / 3.) | |
| else: | |
| errD = cond_real_errD + (cond_fake_errD + cond_wrong_errD) / 2. | |
| log = 'Real_Acc: {:.4f} Fake_Acc: {:.4f} '.format(torch.mean(real_logits).item(), torch.mean(fake_logits).item()) | |
| return errD, log | |
| def generator_loss(netsD, image_encoder, fake_imgs, real_labels, | |
| words_embs, sent_emb, match_labels, | |
| cap_lens, class_ids): | |
| numDs = len(netsD) | |
| batch_size = real_labels.size(0) | |
| logs = '' | |
| # Forward | |
| errG_total = 0 | |
| for i in range(numDs): | |
| features = netsD[i](fake_imgs[i]) | |
| cond_logits = netsD[i].COND_DNET(features, sent_emb) | |
| cond_errG = nn.BCELoss()(cond_logits, real_labels) | |
| if netsD[i].UNCOND_DNET is not None: | |
| logits = netsD[i].UNCOND_DNET(features) | |
| errG = nn.BCELoss()(logits, real_labels) | |
| g_loss = errG + cond_errG | |
| else: | |
| g_loss = cond_errG | |
| errG_total += g_loss | |
| # err_img = errG_total.data[0] | |
| logs += 'g_loss%d: %.2f ' % (i, g_loss.item()) | |
| # Ranking loss | |
| if i == (numDs - 1): | |
| # words_features: batch_size x nef x 17 x 17 | |
| # sent_code: batch_size x nef | |
| region_features, cnn_code = image_encoder(fake_imgs[i]) | |
| w_loss0, w_loss1, _ = words_loss(region_features, words_embs, | |
| match_labels, cap_lens, | |
| class_ids, batch_size) | |
| w_loss = (w_loss0 + w_loss1) * cfg.TRAIN.SMOOTH.LAMBDA | |
| # err_words = err_words + w_loss.data[0] | |
| s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb, | |
| match_labels, class_ids, batch_size) | |
| s_loss = (s_loss0 + s_loss1) * cfg.TRAIN.SMOOTH.LAMBDA | |
| # err_sent = err_sent + s_loss.data[0] | |
| errG_total += w_loss + s_loss | |
| logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.item(), s_loss.item()) | |
| # | |
| # # Ranking loss | |
| # # words_features: batch_size x nef x 17 x 17 | |
| # # sent_code: batch_size x nef | |
| # region_features, cnn_code = image_encoder(fake_imgs[i]) | |
| # w_loss0, w_loss1, _ = words_loss(region_features, words_embs, | |
| # match_labels, cap_lens, | |
| # class_ids, batch_size) | |
| # w_loss = (w_loss0 + w_loss1) * cfg.TRAIN.SMOOTH.LAMBDA | |
| # # err_words = err_words + w_loss.data[0] | |
| # | |
| # s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb, | |
| # match_labels, class_ids, batch_size) | |
| # s_loss = (s_loss0 + s_loss1) * cfg.TRAIN.SMOOTH.LAMBDA | |
| # # err_sent = err_sent + s_loss.data[0] | |
| # | |
| # errG_total += w_loss + s_loss | |
| # logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.item(), s_loss.item()) | |
| return errG_total, logs | |
| ################################################################## | |
| def KL_loss(mu, logvar): | |
| # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) | |
| KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) | |
| KLD = torch.mean(KLD_element).mul_(-0.5) | |
| return KLD | |