import math from numpy.core.fromnumeric import argmax import torch from torch import nn from torch._C import device, dtype, layout from torch.nn import functional as F from torch.nn.functional import cross_entropy, embedding from torch.nn.modules import loss from torch.nn.modules.activation import Tanh from libs.utils.metric import CellMergeAcc, AccMetric from .utils import gen_proposals class ImageAttention(nn.Module): def __init__(self, key_dim, query_dim, cover_kernel): super().__init__() self.query_transform = nn.Linear(query_dim, key_dim) self.weight_transform = nn.Conv2d(1, key_dim, cover_kernel, 1, padding=cover_kernel // 2) self.cum_weight_transform = nn.Conv2d(1, key_dim, cover_kernel, 1, padding=cover_kernel // 2) self.logit_transform = nn.Conv2d(key_dim, 1, 1, 1, 0) def forward(self, key, key_mask, query, spatial_att_weight, cum_spatial_att_weight, value, state, layouts=None, layouts_cum=None, spatial_att_weight_scores=None): query = self.query_transform(query) weight_query = self.weight_transform(spatial_att_weight) cum_weight_query = self.cum_weight_transform(cum_spatial_att_weight) fusion = key + query[:, :, None, None] + weight_query + cum_weight_query # cal new_spatial_att_logit new_spatial_att_logit = self.logit_transform(torch.tanh(fusion)) # cal new_spatial_att_weight new_spatial_att_weight = new_spatial_att_logit - (1 - key_mask) * 1e8 bs, _, h, w = new_spatial_att_weight.shape new_spatial_att_weight = new_spatial_att_weight.reshape(bs, h * w) new_spatial_att_weight = torch.softmax(new_spatial_att_weight, dim=1).reshape(bs, 1, h, w) # cal new_cum_spatial_att_weight if self.training: outputs = list() for (value_pi, layout) in zip(value, layouts): h, w = torch.where(layout == 1.) if len(h) == 0 or len(w) == 0: outputs.append(torch.zeros_like(query[0])) else: outputs.append(value_pi[:, h, w].mean(-1)) outputs = torch.stack(outputs, dim=0) new_cum_spatial_att_weight = torch.clamp(layouts.unsqueeze(1).float() + cum_spatial_att_weight, max=1.) return state, outputs, new_spatial_att_logit, new_spatial_att_weight, new_cum_spatial_att_weight, None, None else: state_list = list() outputs_list = list() scores_list = list() proposals_list = list() new_spatial_att_weight_list = list() new_cum_spatial_att_weight_list = list() layouts_pred = new_spatial_att_logit.squeeze(1).sigmoid() for idx, (value_pi, state_pi, layout) in enumerate(zip(value, state, layouts_pred)): if cum_spatial_att_weight[idx].min() == 1: state_list.append(state_pi) outputs_list.append(torch.zeros_like(query[0])) proposals_list.append(torch.cat((layouts_cum[idx], torch.zeros_like(layout.unsqueeze(0))), dim=0)) scores_list.append(spatial_att_weight_scores[idx]) new_spatial_att_weight_list.append(new_spatial_att_weight[idx]) new_cum_spatial_att_weight_list.append(cum_spatial_att_weight[idx]) else: srow, scol = torch.where(cum_spatial_att_weight[idx].squeeze(0) == cum_spatial_att_weight[idx].squeeze(0).min()) scol = scol[srow == srow.min()].min() srow = srow.min() proposals, scores = gen_proposals(layout, srow, scol, score_threshold=0.5) scores = scores + spatial_att_weight_scores[idx] for s in scores: scores_list.append(s) for p in proposals: proposals_list.append(torch.cat((layouts_cum[idx], p.unsqueeze(0)), dim=0)) h, w = torch.where(p == 1.) outputs_list.append(value_pi[:, h, w].mean(-1)) state_list.append(state_pi) new_spatial_att_weight_list.append(new_spatial_att_weight[idx]) new_cum_spatial_att_weight_list.append(torch.clamp(cum_spatial_att_weight[idx] + p.unsqueeze(0), max=1.)) state_list = torch.stack(state_list, dim=0) proposals_list = torch.stack(proposals_list, dim=0) scores_list = torch.stack(scores_list, dim=0) outputs_list = torch.stack(outputs_list, dim=0) new_spatial_att_weight_list = torch.stack(new_spatial_att_weight_list, dim=0) new_cum_spatial_att_weight_list = torch.stack(new_cum_spatial_att_weight_list, dim=0) sorted_scores, sorted_idxes = torch.sort(scores_list, dim=0, descending=True) sorted_scores = sorted_scores[:6] sorted_idxes = sorted_idxes[:6] proposals = proposals_list[sorted_idxes] new_spatial_att_weight = new_spatial_att_weight_list[sorted_idxes] new_cum_spatial_att_weight = new_cum_spatial_att_weight_list[sorted_idxes] outputs = outputs_list[sorted_idxes] state = state_list[sorted_idxes] return state, outputs, new_spatial_att_logit, new_spatial_att_weight, new_cum_spatial_att_weight, proposals, sorted_scores class Decoder(nn.Module): def __init__(self, vocab, embed_dim, feat_dim, lm_state_dim, proj_dim, cover_kernel, att_threshold, spatial_att_logit_loss_wight): super().__init__() self.vocab = vocab self.embed_dim = embed_dim self.feat_dim = feat_dim self.lm_state_dim = lm_state_dim self.proj_dim = proj_dim self.cover_kernel = cover_kernel self.att_threshold = att_threshold self.spatial_att_logit_loss_wight = spatial_att_logit_loss_wight self.feat_projection = nn.Conv2d(self.feat_dim, self.proj_dim, 1, 1, 0) self.state_init_projection = nn.Conv2d(self.feat_dim, self.lm_state_dim, 1, 1, 0) self.lm_rnn1 = nn.GRUCell(input_size=self.feat_dim, hidden_size=self.lm_state_dim) self.lm_rnn2 = nn.GRUCell(input_size=self.feat_dim, hidden_size=self.lm_state_dim) self.image_attention = ImageAttention(self.proj_dim, self.feat_dim + self.lm_state_dim, cover_kernel) self.struct_cls = nn.Sequential( nn.Linear(self.feat_dim + self.lm_state_dim, self.lm_state_dim), nn.Tanh(), nn.Linear(self.lm_state_dim, len(self.vocab)) ) def init_state(self, feats, feats_mask): bs, _, h, w = feats.shape project_feats = self.feat_projection(feats) * feats_mask init_state = torch.sum(self.state_init_projection(feats), dim=(2, 3))/torch.sum(feats_mask, dim=(2, 3)) init_context = torch.sum(feats, dim=(2, 3)) / torch.sum(feats_mask, dim=(2, 3)) init_spatial_att_weight = torch.zeros([bs, 1, h, w], dtype=torch.float, device=feats.device) init_cum_spatial_att_weight = torch.zeros([bs, 1, h, w], dtype=torch.float, device=feats.device) return project_feats, init_state, init_context, init_spatial_att_weight, init_cum_spatial_att_weight def step(self, feats, project_feats, feats_mask, state, context, spatial_att_weight, cum_spatial_att_weight, layouts=None, layouts_cum=None, spatial_att_weight_scores=None): new_state = self.lm_rnn1(context, state) new_state, new_context, new_spatial_att_logit, \ new_spatial_att_weight, new_cum_spatial_att_weight, \ layouts_cum, spatial_att_weight_scores = self.image_attention( project_feats, feats_mask, torch.cat([context, new_state], dim=1), spatial_att_weight, cum_spatial_att_weight, feats, new_state, layouts, layouts_cum, spatial_att_weight_scores ) new_state = self.lm_rnn2(new_context, new_state) cls_feat = torch.cat([new_context, new_state], dim=1) cls_logits_pt = self.struct_cls(cls_feat) return cls_logits_pt, new_state, new_context, new_spatial_att_logit, new_spatial_att_weight, new_cum_spatial_att_weight, layouts_cum, spatial_att_weight_scores def forward(self, feats, feats_mask, cls_labels=None, labels_mask=None, layouts=None): if self.training: return self.forward_backward(feats, feats_mask, cls_labels, labels_mask, layouts) else: return self.inference(feats, feats_mask) def inference(self, feats, feats_mask): bs, _, h, w = feats.shape device = feats.device assert bs == 1, print('bs should be 1') layouts_cum = torch.zeros_like(feats[:, : 1]) spatial_att_weight_scores = torch.zeros(bs).to(device=device, dtype=feats.dtype) project_feats, init_state, init_context, spatial_att_weight, cum_spatial_att_weight = self.init_state(feats, feats_mask) state = init_state context = init_context for _ in range(h*w): cls_logits_pt, state, context, spatial_att_logit, spatial_att_weight, \ cum_spatial_att_weight, layouts_cum, spatial_att_weight_scores \ = self.step( feats, project_feats, feats_mask, state, context, spatial_att_weight, cum_spatial_att_weight, None, layouts_cum, spatial_att_weight_scores) feats = feats[:1].repeat(layouts_cum.shape[0], 1, 1, 1) feats_mask = feats_mask[:1].repeat(layouts_cum.shape[0], 1, 1, 1) project_feats = project_feats[:1].repeat(layouts_cum.shape[0], 1, 1, 1) if cum_spatial_att_weight.min() == 1: break spatial_att_logit_preds = layouts_cum[spatial_att_weight_scores.argmax(), 1:].unsqueeze(0) return spatial_att_logit_preds, {} def forward_backward(self, feats, feats_mask, cls_labels, labels_mask, layouts): device = feats.device valid_cls_length = torch.sum((labels_mask == 1) & (cls_labels != -1), dim=1).detach() valid_spatial_att_logit_length = torch.stack([layout.max() + 1 for layout in layouts]) max_length = valid_cls_length.max() project_feats, init_state, init_context, spatial_att_weight, cum_spatial_att_weight = self.init_state(feats, feats_mask) state = init_state context = init_context loss_cache = dict() cls_loss = list() cls_preds = list() spatial_att_logit_loss = list() spatial_att_logit_preds = list() spatial_att_logit_masks = list() spatial_att_logit_labels = list() for time_t in range(max_length): cls_logits_pt, state, context, spatial_att_logit, spatial_att_weight, cum_spatial_att_weight, *_ \ = self.step( feats, project_feats, feats_mask, state, context, spatial_att_weight, cum_spatial_att_weight, layouts == time_t ) cls_label = cls_labels[:, time_t] label_mask = labels_mask[:, time_t] # cal cls loss cls_loss_pt = F.cross_entropy(cls_logits_pt, cls_label, ignore_index=-1, reduction='none') * label_mask cls_loss.append(cls_loss_pt) # save for acc cls_preds.append(torch.argmax(cls_logits_pt, dim=1).detach()) spatial_att_logit_preds.append(spatial_att_logit.sigmoid() > self.att_threshold) spatial_att_logit_masks.append((layouts != -1).unsqueeze(1)) spatial_att_logit_labels.append((layouts == time_t).unsqueeze(1)) # cal spatial att loss spatial_att_logit_loss_pt = list() for spatial_att_logit_pi, layout in zip(spatial_att_logit, layouts): target = layout == time_t if torch.any(target) == False: spatial_att_logit_loss_pt_pi = torch.tensor(0.0, dtype=torch.float, device=device) else: mask = (layout != -1).float() spatial_att_logit_loss_pt_pi = F.binary_cross_entropy_with_logits( spatial_att_logit_pi, target.float().unsqueeze(0), reduction='none' ) spatial_att_logit_loss_pt_pi = (spatial_att_logit_loss_pt_pi * mask).sum() spatial_att_logit_loss_pt.append(spatial_att_logit_loss_pt_pi) spatial_att_logit_loss_pt = torch.stack(spatial_att_logit_loss_pt, dim=0) spatial_att_logit_loss.append(spatial_att_logit_loss_pt) cls_loss = torch.mean(torch.sum(torch.stack(cls_loss, dim=1), dim=1)/valid_cls_length) spatial_att_logit_loss = self.spatial_att_logit_loss_wight * torch.mean(torch.sum(torch.stack(spatial_att_logit_loss, dim=1), dim=1) / valid_spatial_att_logit_length) loss_cache['cls_loss'] = cls_loss loss_cache['spatial_att_logit_loss'] = spatial_att_logit_loss cls_preds = torch.stack(cls_preds, dim=1) spatial_att_logit_preds = torch.stack(spatial_att_logit_preds, dim=1) spatial_att_logit_masks = torch.stack(spatial_att_logit_masks, dim=1) spatial_att_logit_labels = torch.stack(spatial_att_logit_labels, dim=1) acc_metric = AccMetric() cell_merge_acc = CellMergeAcc() cls_correct, cls_total = acc_metric(cls_preds, cls_labels, labels_mask) cls_none_correct, cls_none_total = acc_metric(cls_preds, cls_labels, (labels_mask == 1) & (cls_labels == self.vocab.none_id)) cls_bold_correct, cls_bold_total = acc_metric(cls_preds, cls_labels, (labels_mask == 1) & (cls_labels == self.vocab.bold_id)) cls_space_correct, cls_space_total = acc_metric(cls_preds, cls_labels, (labels_mask == 1) & (cls_labels == self.vocab.space_id)) cls_blank_correct = cls_none_correct + cls_bold_correct + cls_space_correct cls_blank_total = cls_none_total + cls_bold_total + cls_space_total cells_correct_nums, cells_total_nums = cell_merge_acc(spatial_att_logit_preds, spatial_att_logit_labels, spatial_att_logit_masks) loss_cache['cls_acc'] = cls_correct / cls_total loss_cache['cls_none_acc'] = cls_none_correct / cls_none_total loss_cache['cls_bold_acc'] = cls_bold_correct / cls_bold_total loss_cache['cls_space_acc'] = cls_space_correct / cls_space_total loss_cache['cls_blank_acc'] = cls_blank_correct / cls_blank_total loss_cache['spatial_att_logit_acc'] = cells_correct_nums / cells_total_nums return (spatial_att_logit_preds), loss_cache def build_decoder(cfg): decoder = Decoder( vocab=cfg.vocab, feat_dim=cfg.encode_dim, line_dim=cfg.extractor_dim, embed_dim=cfg.embed_dim, lm_state_dim=cfg.lm_state_dim, proj_dim=cfg.proj_dim, hidden_dim=cfg.hidden_dim, cover_kernel=cfg.cover_kernel, max_length=cfg.max_length ) return decoder