| |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| from torch.autograd import Variable |
| from uniperceiver.config import configurable |
| from uniperceiver.functional import expand_tensor |
| from .decode_strategy import DecodeStrategy |
| from .build import DECODE_STRATEGY_REGISTRY |
| from uniperceiver.utils import comm |
| import math |
| from torch.cuda.amp import autocast |
|
|
| @DECODE_STRATEGY_REGISTRY.register() |
| class CaptionBeamSearcherV3(DecodeStrategy): |
|
|
| def data_half(self, data): |
| if self.fp16: |
| for k, v in data.items(): |
| if isinstance(v, torch.Tensor) and v.dtype == torch.float32: |
| data[k] = v.half() |
| |
| return data |
| else: |
| return data |
|
|
|
|
|
|
|
|
| def _select(self, batch_size, beam_size, t, candidate_logprob): |
| selected_logprob, selected_idx = torch.sort(candidate_logprob.view(batch_size, -1), -1, descending=True) |
| selected_logprob, selected_idx = selected_logprob[:, :beam_size], selected_idx[:, :beam_size] |
| return selected_idx, selected_logprob |
|
|
| def _expand_state(self, states, selected_beam, batch_size, beam_size, cur_beam_size): |
| for i in range(len(states)): |
| shape = list(states[i].shape) |
| beam = selected_beam |
| for _ in shape[1:]: |
| beam = beam.unsqueeze(-1) |
| states[i] = torch.gather(states[i].view(*([batch_size, cur_beam_size] + shape[1:])), 1, |
| beam.expand(*([batch_size, beam_size] + shape[1:]))) |
| states[i] = states[i].view(*([-1, ] + shape[1:])) |
|
|
|
|
| def _forward(self, batched_inputs, model): |
| |
| |
| |
| |
| |
|
|
| |
|
|
| inputs = batched_inputs |
| inputs = self.data_half(inputs) |
|
|
|
|
| out_size = batched_inputs.get('OUT_SIZE', 1) |
|
|
| task_info = inputs['task_info'] |
| bs = task_info['batch_size'] |
| if isinstance(bs, torch.Tensor): |
| bs = bs.item() |
|
|
| image_input = inputs['input_sample_list'] |
| vocab_input = inputs['shared_target_sets'][self.vocab_name] |
|
|
|
|
| |
| moe_embedding = None |
| for image_data in image_input: |
| if 'moe_embedding' in image_data: |
| moe_embedding = image_data['moe_embedding'] |
| image_encode = model._forward_data(image_input, task_info=task_info, return_all=True)[0]['data'] |
|
|
|
|
| |
| if getattr(self, 'pre_computed_word_embeds', None) is None: |
| vocab_encode = model._forward_data(vocab_input, task_info=task_info, return_all=False)[0] |
| self.pre_computed_word_embeds = vocab_encode |
| else: |
| vocab_encode = self.pre_computed_word_embeds |
|
|
| |
|
|
| comm._CAPTION_GEN_MODE = True |
| task_info.update({"prefix_spe_before_fuse": False}) |
|
|
| beam_size = self.beam_size |
| log_probs = [] |
| selected_words = None |
| seq_logprob = torch.zeros((bs, 1, 1)).cuda() |
| seq_mask = torch.ones((bs, beam_size, 1)).cuda() |
| wt = Variable(torch.zeros(bs, dtype=torch.long).cuda().unsqueeze(1)) + self.spe_token_id |
| u_tokens_type = wt.new_zeros(wt.shape) |
|
|
| history_states = image_encode[:-1] |
| len_prefix = history_states[0].shape[1] |
| total_history_states = [ history_states[0].new_zeros(beam_size * bs, image_encode[0].shape[1] + self.max_seq_len, image_encode[0].shape[2]) for _ in history_states] |
| for i, ths in enumerate(total_history_states): |
| hs = history_states[i] |
| ths[:hs.shape[0], :hs.shape[1], :] = hs |
|
|
| outputs = [] |
| common_info = { |
| "modality": "text", |
| 'data_type': 'input', |
| 'moe_embedding': moe_embedding, |
|
|
| } |
| for t in range(self.max_seq_len): |
| cur_beam_size = 1 if t == 0 else beam_size |
|
|
| history_states = [ ths[ :(cur_beam_size*bs), :(len_prefix+t), :] for ths in total_history_states] |
|
|
| step_data = { "data": wt, |
| "time_step": t, |
| "sample_info": |
| { |
| "data_cum_length": [1, len_prefix, len_prefix+t+1] |
| } |
| } |
| step_data.update(common_info) |
|
|
| step_encode = model._forward_data([step_data], task_info=task_info, history_states=history_states, return_all=False) |
|
|
| step_predictor_input = { |
| "input_sample_list": step_encode, |
| "target_sample_list": [], |
| "shared_target_sets": {self.vocab_name: [vocab_encode]}, |
| "target_set_list": [self.vocab_name], |
| "target_idx_list": [], |
| "task_info": task_info |
| } |
| logit = model.loss_prepare(**step_predictor_input)['output'] |
|
|
| with autocast(enabled=not self.cfg.SOLVER.FORCE_SOFTMAX_FP16): |
| word_logprob = F.log_softmax(logit, dim=-1) |
| word_logprob = word_logprob.view(bs, cur_beam_size, -1) |
| candidate_logprob = seq_logprob + word_logprob |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if t > 0: |
| mask = (selected_words.view(bs, cur_beam_size) != self.eos_token_id).float().unsqueeze(-1) |
| seq_mask = seq_mask * mask |
| word_logprob = word_logprob * seq_mask.expand_as(word_logprob) |
| old_seq_logprob = seq_logprob.expand_as(candidate_logprob).contiguous() |
| old_seq_logprob[:, :, :self.eos_token_id] = -999 |
| old_seq_logprob[:, :, self.eos_token_id + 1:] = -999 |
| candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * (1 - seq_mask) |
|
|
| selected_idx, selected_logprob = self._select(bs, beam_size, t, candidate_logprob) |
| selected_beam = torch.div(selected_idx, candidate_logprob.shape[-1], rounding_mode='floor') |
| selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1] |
|
|
| self._expand_state(history_states, selected_beam, bs, beam_size, cur_beam_size) |
|
|
| seq_logprob = selected_logprob.unsqueeze(-1) |
| seq_mask = torch.gather(seq_mask, 1, selected_beam.unsqueeze(-1)) |
| outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs) |
| outputs.append(selected_words.unsqueeze(-1)) |
|
|
| this_word_logprob = torch.gather(word_logprob, 1, |
| selected_beam.unsqueeze(-1).expand(bs, beam_size, word_logprob.shape[-1])) |
| this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) |
| log_probs = list( |
| torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(bs, beam_size, 1)) for o in log_probs) |
| log_probs.append(this_word_logprob) |
| selected_words = selected_words.view(-1, 1) |
| |
|
|
| if t == 0: |
| u_tokens_type = expand_tensor(u_tokens_type, beam_size) |
| wt = expand_tensor(wt, beam_size) |
|
|
| step_selected_data = {"data": selected_words, "time_step": t, "sample_info": {"data_cum_length": [1, len_prefix, len_prefix+t+1]}} |
| step_selected_data.update(common_info) |
|
|
| step_selected_encode = model._forward_data([step_selected_data], task_info=task_info, history_states=history_states, return_all=True) |
|
|
| for i, ths in enumerate(total_history_states): |
| hs = history_states[i] |
| ths[:hs.shape[0], :hs.shape[1], :] = hs |
| ths[:hs.shape[0], hs.shape[1], :] = step_selected_encode[0]['data'][i].squeeze(1) |
|
|
| outputs = torch.cat(outputs, -1) |
|
|
|
|
| if self.len_penalty > 0: |
| step = outputs.ne(self.eos_token_id).sum(-1, keepdim=True) + 1 |
| seq_logprob /= step ** self.len_penalty |
| seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True) |
|
|
| outputs = torch.gather(outputs, 1, sort_idxs.expand(bs, beam_size, self.max_seq_len)) |
| log_probs = torch.cat(log_probs, -1) |
| log_probs = torch.gather(log_probs, 1, sort_idxs.expand(bs, beam_size, self.max_seq_len)) |
|
|
| outputs = outputs.contiguous()[:, :out_size] |
| log_probs = log_probs.contiguous()[:, :out_size] |
| if out_size == 1: |
| outputs = outputs.squeeze(1) |
| log_probs = log_probs.squeeze(1) |
|
|
| comm._CAPTION_GEN_MODE = False |
|
|
| ids = torch.stack([torch.tensor(v['id']) for v in inputs['input_sample_list'][0]['sample_info']]) |
|
|
| return { |
| "IDS": ids, |
| "G_SENTS_IDS": outputs, |
| "G_LOGP": log_probs |
| } |
|
|