Spaces:
Runtime error
Runtime error
| import random | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import repeat | |
| from ...converter import AttnLabelConverter as ATTN | |
| from .addon_module import * | |
| class Attention(nn.Module): | |
| def __init__(self, | |
| kernel_size, | |
| kernel_dim, | |
| input_size, | |
| hidden_size, | |
| num_classes, | |
| embed_dim=None, | |
| attn_type='coverage', | |
| embed_target=False, | |
| enc_init=False, #init hidden state of decoder with enc output | |
| teacher_forcing=1.0, | |
| droprate=0.1, | |
| method='concat', | |
| seqmodel='ViT', | |
| viz_attn: bool = False, | |
| device='cuda' | |
| ): | |
| super(Attention, self).__init__() | |
| if embed_dim is None: embed_dim = input_size | |
| if embed_target: | |
| self.embedding = nn.Embedding(num_classes, embed_dim, padding_idx=ATTN.START()) | |
| common = { | |
| 'input_size': input_size, | |
| 'hidden_size': hidden_size, | |
| 'num_embeddings': embed_dim if embed_target else num_classes, | |
| 'num_classes': num_classes | |
| } | |
| if attn_type == 'luong': | |
| common['method'] = method | |
| self.attention_cell = LuongAttention(**common) | |
| elif attn_type == 'loc_aware': | |
| self.attention_cell = LocationAwareAttention(kernel_size=kernel_size, kernel_dim=kernel_dim, **common) | |
| elif attn_type == 'coverage': | |
| self.attention_cell = LocationAwareAttention(kernel_size=kernel_size, kernel_dim=kernel_dim, **common) | |
| else: | |
| self.attention_cell = BahdanauAttention(**common) | |
| self.dropout = nn.Dropout(droprate) | |
| self.embed_target = embed_target | |
| self.hidden_size = hidden_size | |
| self.input_size = input_size | |
| self.num_classes = num_classes | |
| self.teacher_forcing = teacher_forcing | |
| self.device = device | |
| self.attn_type = attn_type | |
| self.enc_init = enc_init | |
| self.viz_attn = viz_attn | |
| self.seqmodel = seqmodel | |
| if enc_init: self.init_hidden() | |
| def _embed_text(self, input_char): | |
| return self.embedding(input_char) | |
| def _char_to_onehot(self, input_char, onehot_dim=38): | |
| input_char = input_char.unsqueeze(1) | |
| batch_size = input_char.size(0) | |
| one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(self.device) | |
| one_hot = one_hot.scatter_(1, input_char, 1) | |
| return one_hot | |
| def init_hidden(self): | |
| self.proj_init_h = nn.Linear(self.input_size, self.hidden_size, bias=True) | |
| self.proj_init_c = nn.Linear(self.input_size, self.hidden_size, bias=True) | |
| def forward_beam( | |
| self, | |
| batch_H: torch.Tensor, | |
| batch_max_length=25, | |
| beam_size=4, | |
| ): | |
| batch_size = batch_H.size(0) | |
| assert batch_size == 1 | |
| num_steps = batch_max_length + 1 | |
| batch_H = batch_H.squeeze(dim=0) | |
| batch_H = repeat(batch_H, "s e -> b s e", b = beam_size) | |
| if self.enc_init: | |
| if self.seqmodel == 'BiLSTM': | |
| init_embedding = batch_H.mean(dim=1) | |
| else: | |
| init_embedding = batch_H[:, 0, :] | |
| h_0 = self.proj_init_h(init_embedding) | |
| c_0 = self.proj_init_c(init_embedding) | |
| hidden = (h_0, c_0) | |
| else: | |
| hidden = (torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device), | |
| torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device)) | |
| if self.attn_type == 'coverage': | |
| alpha_cum = torch.zeros(beam_size, batch_H.shape[1], 1, dtype=torch.float32, device=self.device) | |
| self.attention_cell.reset_mem() | |
| k_prev_words = torch.LongTensor([[ATTN.START()]] * beam_size).to(self.device) | |
| seqs = k_prev_words | |
| targets = k_prev_words.squeeze(dim=-1) | |
| top_k_scores = torch.zeros(beam_size, 1).to(self.device) | |
| if self.viz_attn: | |
| seqs_alpha = torch.ones(beam_size, 1, batch_H.shape[1]).to(self.device) | |
| complete_seqs = list() | |
| if self.viz_attn: | |
| complete_seqs_alpha = list() | |
| complete_seqs_scores = list() | |
| for step in range(num_steps): | |
| embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets) | |
| output, hidden, alpha = self.attention_cell(hidden, batch_H, embed_text) | |
| output = self.dropout(output) | |
| vocab_size = output.shape[1] | |
| scores = F.log_softmax(output, dim=-1) | |
| scores = top_k_scores.expand_as(scores) + scores | |
| if step == 0: | |
| top_k_scores, top_k_words = scores[0].topk(beam_size, 0, True, True) | |
| else: | |
| top_k_scores, top_k_words = scores.view(-1).topk(beam_size, 0, True, True) | |
| prev_word_inds = top_k_words // vocab_size | |
| next_word_inds = top_k_words % vocab_size | |
| seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) | |
| if self.viz_attn: | |
| seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].permute(0, 2, 1)], | |
| dim=1) | |
| incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if | |
| next_word != ATTN.END()] | |
| complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds)) | |
| if len(complete_inds) > 0: | |
| complete_seqs.extend(seqs[complete_inds].tolist()) | |
| if self.viz_attn: | |
| complete_seqs_alpha.extend(seqs_alpha[complete_inds]) | |
| complete_seqs_scores.extend(top_k_scores[complete_inds]) | |
| beam_size = beam_size - len(complete_inds) | |
| if beam_size == 0: | |
| break | |
| seqs = seqs[incomplete_inds] | |
| if self.viz_attn: | |
| seqs_alpha = seqs_alpha[incomplete_inds] | |
| hidden = hidden[0][prev_word_inds[incomplete_inds]], \ | |
| hidden[1][prev_word_inds[incomplete_inds]] | |
| batch_H = batch_H[prev_word_inds[incomplete_inds]] | |
| top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) | |
| targets = next_word_inds[incomplete_inds] | |
| if self.attn_type == 'coverage': | |
| alpha_cum = alpha_cum + alpha | |
| alpha_cum = alpha_cum[incomplete_inds] | |
| self.attention_cell.set_mem(alpha_cum) | |
| elif self.attn_type == 'loc_aware': | |
| self.attention_cell.set_mem(alpha) | |
| if len(complete_inds) == 0: | |
| seq = seqs[0][1:].tolist() | |
| seq = torch.LongTensor(seq).unsqueeze(0) | |
| score = top_k_scores[0] | |
| if self.viz_attn: | |
| alphas = seqs_alpha[0][1:, ...] | |
| return seq, score, alphas | |
| else: | |
| return seq, score, None | |
| else: | |
| combine_lst = tuple(zip(complete_seqs, complete_seqs_scores)) | |
| best_ind = combine_lst.index(max(combine_lst, key=lambda x: x[1] / len(x[0]))) #https://youtu.be/XXtpJxZBa2c?t=2407 | |
| seq = complete_seqs[best_ind][1:] #not include [GO] token | |
| seq = torch.LongTensor(seq).unsqueeze(0) | |
| score = max(complete_seqs_scores) | |
| if self.viz_attn: | |
| alphas = complete_seqs_alpha[best_ind][1:, ...] | |
| return seq, score, alphas | |
| else: | |
| return seq, score, None | |
| def forward_greedy(self, batch_H, text, is_train=True, is_test=False, batch_max_length=25): | |
| batch_size = batch_H.size(0) | |
| num_steps = batch_max_length + 1 | |
| if self.enc_init: | |
| if self.seqmodel == 'BiLSTM': | |
| init_embedding = batch_H.mean(dim=1) | |
| encoder_hidden = batch_H | |
| else: | |
| encoder_hidden = batch_H | |
| init_embedding = batch_H[:, 0, :] | |
| h_0 = self.proj_init_h(init_embedding) | |
| c_0 = self.proj_init_c(init_embedding) | |
| hidden = (h_0, c_0) | |
| else: | |
| encoder_hidden = batch_H | |
| hidden = (torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device), | |
| torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device)) | |
| targets = torch.zeros(batch_size, dtype=torch.long, device=self.device) # [GO] token | |
| probs = torch.zeros(batch_size, num_steps, self.num_classes, dtype=torch.float32, device=self.device) | |
| if self.viz_attn: | |
| self.alpha_stores = torch.zeros(batch_size, num_steps, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device) | |
| if self.attn_type == 'coverage': | |
| alpha_cum = torch.zeros(batch_size, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device) | |
| self.attention_cell.reset_mem() | |
| if is_test: | |
| end_flag = torch.zeros(batch_size, dtype=torch.bool, device=self.device) | |
| for i in range(num_steps): | |
| embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets) | |
| output, hidden, alpha = self.attention_cell(hidden, encoder_hidden, embed_text) | |
| output = self.dropout(output) | |
| if self.viz_attn: | |
| self.alpha_stores[:, i] = alpha | |
| if self.attn_type == 'coverage': | |
| alpha_cum = alpha_cum + alpha | |
| self.attention_cell.set_mem(alpha_cum) | |
| elif self.attn_type == 'loc_aware': | |
| self.attention_cell.set_mem(alpha) | |
| probs_step = output | |
| probs[:, i, :] = probs_step | |
| if i == num_steps - 1: | |
| break | |
| if is_train: | |
| if self.teacher_forcing < random.random(): | |
| _, next_input = probs_step.max(1) | |
| targets = next_input | |
| else: | |
| targets = text[:, i+1] | |
| else: | |
| _, next_input = probs_step.max(1) | |
| targets = next_input | |
| if is_test: | |
| end_flag = end_flag | (next_input == ATTN.END()) | |
| if end_flag.all(): | |
| break | |
| _, preds_index = probs.max(2) | |
| return preds_index, probs, None # batch_size x num_steps x num_classes | |
| def forward(self, beam_size, batch_H, text, batch_max_length, is_train=True, is_test=False): | |
| if is_train: | |
| return self.forward_greedy(batch_H, text, is_train, is_test, batch_max_length) | |
| else: | |
| if beam_size > 1: | |
| return self.forward_beam(batch_H, batch_max_length, beam_size) | |
| else: | |
| return self.forward_greedy(batch_H, text, is_train, is_test, batch_max_length) | |