Spaces:
Runtime error
Runtime error
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.autograd import * | |
| from . import utils | |
| from .CaptionModel import CaptionModel | |
| bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am'] | |
| bad_endings += ['UNK', 'has', 'and', 'more'] | |
| # torch.manual_seed(42) | |
| # if torch.cuda.is_available(): | |
| # torch.cuda.manual_seed(42) | |
| class ShowTellModel(CaptionModel): | |
| def __init__(self, opt): | |
| super(ShowTellModel, self).__init__() | |
| self.vocab_size = opt.vocab_size | |
| self.input_encoding_size = opt.input_encoding_size | |
| self.rnn_type = opt.rnn_type | |
| self.rnn_size = opt.rnn_size | |
| self.num_layers = opt.num_layers | |
| self.drop_prob_lm = opt.drop_prob_lm | |
| self.seq_length = opt.seq_length | |
| self.fc_feat_size = opt.fc_feat_size | |
| self.eos_idx = getattr(opt, 'eos_idx', 0) | |
| self.pad_idx = getattr(opt, 'pad_idx', 0) | |
| self.ss_prob = 0.0 # Schedule sampling probability | |
| self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) | |
| self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm) | |
| self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) | |
| self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) | |
| self.dropout = nn.Dropout(self.drop_prob_lm) | |
| # For remove bad endding | |
| self.vocab = opt.vocab | |
| self.bad_endings_ix = [int(k) for k,v in self.vocab.items() if v in bad_endings] | |
| self.init_weights() | |
| def init_weights(self): | |
| initrange = 0.1 | |
| self.embed.weight.data.uniform_(-initrange, initrange) | |
| self.logit.bias.data.fill_(0) | |
| self.logit.weight.data.uniform_(-initrange, initrange) | |
| def init_hidden(self, bsz): | |
| weight = self.logit.weight | |
| if self.rnn_type == 'lstm': | |
| return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), | |
| weight.new_zeros(self.num_layers, bsz, self.rnn_size)) | |
| else: | |
| return weight.new_zeros(self.num_layers, bsz, self.rnn_size) | |
| def _forward(self, fc_feats, att_feats, seq, att_masks=None): | |
| batch_size = fc_feats.size(0) | |
| if seq.ndim == 3: # B * seq_per_img * seq_len | |
| seq = seq.reshape(-1, seq.shape[2]) | |
| seq_per_img = seq.shape[0] // batch_size | |
| state = self.init_hidden(batch_size*seq_per_img) | |
| outputs = [] | |
| if seq_per_img > 1: | |
| fc_feats = utils.repeat_tensors(seq_per_img, fc_feats) | |
| for i in range(seq.size(1)+1): | |
| if i == 0: | |
| xt = self.img_embed(fc_feats) | |
| else: | |
| if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample | |
| sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1) | |
| sample_mask = sample_prob < self.ss_prob | |
| if sample_mask.sum() == 0: | |
| it = seq[:, i-1].clone() | |
| else: | |
| sample_ind = sample_mask.nonzero().view(-1) | |
| it = seq[:, i-1].data.clone() | |
| #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) | |
| #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) | |
| prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) | |
| it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) | |
| else: | |
| it = seq[:, i-1].clone() | |
| # break if all the sequences end | |
| if i >= 2 and seq[:, i-1].data.sum() == 0: | |
| break | |
| xt = self.embed(it) | |
| output, state = self.core(xt.unsqueeze(0), state) | |
| output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) | |
| outputs.append(output) | |
| return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() | |
| def get_logprobs_state(self, it, state): | |
| # 'it' contains a word index | |
| xt = self.embed(it) | |
| output, state = self.core(xt.unsqueeze(0), state) | |
| logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) | |
| return logprobs, state | |
| def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): | |
| # beam_size = opt.get('beam_size', 10) | |
| # batch_size = fc_feats.size(0) | |
| # assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' | |
| # seq = torch.LongTensor(self.seq_length, batch_size).zero_() | |
| # seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) | |
| # # lets process every image independently for now, for simplicity | |
| beam_size = opt.get('beam_size', 10) | |
| group_size = opt.get('group_size', 1) | |
| sample_n = opt.get('sample_n', 10) | |
| # when sample_n == beam_size then each beam is a sample. | |
| assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' | |
| batch_size = fc_feats.size(0) | |
| assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' | |
| seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) | |
| seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) | |
| self.done_beams = [[] for _ in range(batch_size)] | |
| for k in range(batch_size): | |
| state = self.init_hidden(beam_size) | |
| for t in range(2): | |
| if t == 0: | |
| xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) | |
| elif t == 1: # input <bos> | |
| it = fc_feats.data.new(beam_size).long().zero_() | |
| xt = self.embed(it) | |
| output, state = self.core(xt.unsqueeze(0), state) | |
| logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) | |
| self.done_beams[k] = self.old_beam_search(state, logprobs, opt=opt) | |
| if sample_n == beam_size: | |
| for _n in range(sample_n): | |
| seq[k*sample_n+_n, :] = self.done_beams[k][_n]['seq'] | |
| seqLogprobs[k*sample_n+_n, :] = self.done_beams[k][_n]['logps'] | |
| else: | |
| seq[k, :] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score | |
| seqLogprobs[k, :] = self.done_beams[k][0]['logps'] | |
| # return the samples and their log likelihoods | |
| return seq, seqLogprobs | |
| # seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score | |
| # seqLogprobs[:, k] = self.done_beams[k][0]['logps'] | |
| # # return the samples and their log likelihoods | |
| # return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) | |
| def _new_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): | |
| beam_size = opt.get('beam_size', 10) | |
| group_size = opt.get('group_size', 1) | |
| sample_n = opt.get('sample_n', 10) | |
| # when sample_n == beam_size then each beam is a sample. | |
| assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' | |
| batch_size = fc_feats.size(0) | |
| assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' | |
| seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) | |
| seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) | |
| self.done_beams = [[] for _ in range(batch_size)] | |
| state = self.init_hidden(batch_size) | |
| it = fc_feats.data.new(batch_size).long().zero_() | |
| xt = self.embed(it) | |
| output, state = self.core(xt.unsqueeze(0), state) | |
| logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) | |
| self.done_beams = self.beam_search(state, logprobs, opt=opt) | |
| for k in range(batch_size): | |
| if sample_n == beam_size: | |
| for _n in range(sample_n): | |
| seq_len = self.done_beams[k][_n]['seq'].shape[0] | |
| seq[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['seq'] | |
| seqLogprobs[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['logps'] | |
| else: | |
| seq_len = self.done_beams[k][0]['seq'].shape[0] | |
| seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score | |
| seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps'] | |
| # return the samples and their log likelihoods | |
| return seq, seqLogprobs | |
| def _old_sample(self, fc_feats, att_feats, att_masks=None, opt={}): | |
| sample_method = opt.get('sample_method', 'greedy') | |
| beam_size = opt.get('beam_size', 1) | |
| temperature = opt.get('temperature', 1.0) | |
| if beam_size > 1 and sample_method in ['greedy', 'beam_search']: | |
| return self._sample_beam(fc_feats, att_feats, opt) | |
| batch_size = fc_feats.size(0) | |
| state = self.init_hidden(batch_size) | |
| seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long) | |
| seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) | |
| for t in range(self.seq_length + 2): | |
| if t == 0: | |
| xt = self.img_embed(fc_feats) | |
| else: | |
| if t == 1: # input <bos> | |
| it = fc_feats.data.new(batch_size).long().zero_() | |
| xt = self.embed(it) | |
| output, state = self.core(xt.unsqueeze(0), state) | |
| logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) | |
| # sample the next word | |
| if t == self.seq_length + 1: # skip if we achieve maximum length | |
| break | |
| if sample_method == 'greedy': | |
| sampleLogprobs, it = torch.max(logprobs.data, 1) | |
| it = it.view(-1).long() | |
| else: | |
| if temperature == 1.0: | |
| prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) | |
| else: | |
| # scale logprobs by temperature | |
| prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() | |
| it = torch.multinomial(prob_prev, 1).to(logprobs.device) | |
| sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions | |
| it = it.view(-1).long() # and flatten indices for downstream processing | |
| if t >= 1: | |
| # stop when all finished | |
| if t == 1: | |
| unfinished = it > 0 | |
| else: | |
| unfinished = unfinished & (it > 0) | |
| it = it * unfinished.type_as(it) | |
| seq[:,t-1] = it #seq[t] the input of t+2 time step | |
| seqLogprobs[:,t-1] = sampleLogprobs.view(-1) | |
| if unfinished.sum() == 0: | |
| break | |
| return seq, seqLogprobs | |
| # remove bad endings and UNK | |
| def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): | |
| sample_method = opt.get('sample_method', 'greedy') | |
| beam_size = opt.get('beam_size', 1) | |
| temperature = opt.get('temperature', 1.0) | |
| sample_n = int(opt.get('sample_n', 1)) | |
| sample_n = 1 | |
| group_size = opt.get('group_size', 1) | |
| output_logsoftmax = opt.get('output_logsoftmax', 1) | |
| decoding_constraint = opt.get('decoding_constraint', 0) | |
| block_trigrams = opt.get('block_trigrams', 0) | |
| remove_bad_endings = opt.get('remove_bad_endings', 1) | |
| suppress_UNK = opt.get('suppress_UNK', 1) | |
| if beam_size > 1 and sample_method in ['greedy', 'beam_search']: | |
| return self._sample_beam(fc_feats, att_feats, opt=opt) | |
| batch_size = fc_feats.size(0) | |
| state = self.init_hidden(batch_size) | |
| trigrams = [] # will be a list of batch_size dictionaries | |
| # seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long) | |
| # seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) | |
| seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) | |
| seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) | |
| for t in range(self.seq_length + 1): | |
| if t == 0: | |
| xt = self.img_embed(fc_feats) | |
| else: | |
| if t == 1: # input <bos> | |
| it = fc_feats.data.new(batch_size).long().zero_() | |
| xt = self.embed(it) | |
| output, state = self.core(xt.unsqueeze(0), state) | |
| logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) | |
| if decoding_constraint and t > 0: | |
| tmp = logprobs.new_zeros(logprobs.size()) | |
| tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) | |
| logprobs = logprobs + tmp | |
| # print('seq', seq) | |
| # print('self.seq_length',self.seq_length) | |
| # print('seq shape', seq.shape) | |
| if remove_bad_endings and t > 0: | |
| logprobs[torch.from_numpy(np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf') | |
| # suppress UNK tokens in the decoding | |
| if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK': | |
| logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000 | |
| # if remove_bad_endings and t > 0: | |
| # tmp = logprobs.new_zeros(logprobs.size()) | |
| # prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix) | |
| # # Make it impossible to generate bad_endings | |
| # tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf') | |
| # # tmp[torch.from_numpy(prev_bad.bool()), 0] = float('-inf') | |
| # logprobs = logprobs + tmp | |
| # Mess with trigrams | |
| # Copy from https://github.com/lukemelas/image-paragraph-captioning | |
| if block_trigrams and t >= 3: | |
| # Store trigram generated at last step | |
| prev_two_batch = seq[:,t-3:t-1] | |
| for i in range(batch_size): # = seq.size(0) | |
| prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) | |
| current = seq[i][t-1] | |
| if t == 3: # initialize | |
| trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} | |
| elif t > 3: | |
| if prev_two in trigrams[i]: # add to list | |
| trigrams[i][prev_two].append(current) | |
| else: # create list | |
| trigrams[i][prev_two] = [current] | |
| # Block used trigrams at next step | |
| prev_two_batch = seq[:,t-2:t] | |
| mask = torch.zeros(logprobs.size(), requires_grad=False).to(logprobs.device) # batch_size x vocab_size | |
| for i in range(batch_size): | |
| prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) | |
| if prev_two in trigrams[i]: | |
| for j in trigrams[i][prev_two]: | |
| mask[i,j] += 1 | |
| # Apply mask to log probs | |
| #logprobs = logprobs - (mask * 1e9) | |
| alpha = 2.0 # = 4 | |
| logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) | |
| # sample the next word | |
| if t == self.seq_length+1: # skip if we achieve maximum length | |
| break | |
| it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature) | |
| # stop when all finished | |
| if t == 0: | |
| unfinished = it != self.eos_idx | |
| else: | |
| it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0 | |
| logprobs = logprobs * unfinished.unsqueeze(1).to(logprobs) | |
| unfinished = unfinished & (it != self.eos_idx) | |
| # print('-------logprobs shape:',logprobs.shape) | |
| # print('-------it shape:',it.shape) | |
| seq[:,t-1] = it | |
| seqLogprobs[:,t-1] = logprobs | |
| # quit loop if all sequences have finished | |
| if unfinished.sum() == 0: | |
| break | |
| # print('-------seqLogprobs shape:',seqLogprobs.shape) | |
| # print('-------seq shape:',seq.shape) | |
| return seq, seqLogprobs |