import torch import math from torch import nn from torch.nn import functional as F from .extractor import RoiPosFeatExtraxtor class SALayer(nn.Module): def __init__(self, in_dim, att_dim, head_nums): super().__init__() self.in_dim = in_dim self.att_dim = att_dim self.head_nums = head_nums assert self.in_dim % self.head_nums == 0 self.key_layer = nn.Conv1d(self.in_dim, self.att_dim, 1, 1, 0) self.query_layer = nn.Conv1d(self.in_dim, self.att_dim, 1, 1, 0) self.value_layer = nn.Conv1d(self.in_dim, self.in_dim, 1, 1, 0) self.scale = 1 / math.sqrt(self.att_dim) def forward(self, feats, masks=None): bs, c, n = feats.shape keys = self.key_layer(feats).reshape(bs, -1, self.head_nums, n) querys = self.query_layer(feats).reshape(bs, -1, self.head_nums, n) values = self.value_layer(feats).reshape(bs, -1, self.head_nums, n) logits = torch.einsum('bchk,bchq->bhkq', keys, querys) * self.scale if masks is not None: logits = logits - (1 - masks[:, None, :, None]) * 1e8 weights = torch.softmax(logits, dim=2) new_feats = torch.einsum('bchk,bhkq->bchq', values, weights) new_feats = new_feats.reshape(bs, -1, n) return new_feats + feats def gen_cells_bbox(row_segments, col_segments, device): cells_bbox = list() for row_segments_pi, col_segments_pi in zip(row_segments, col_segments): num_rows = len(row_segments_pi) - 1 num_cols = len(col_segments_pi) - 1 cells_bbox_pi = list() for row_idx in range(num_rows): for col_idx in range(num_cols): bbox = [ col_segments_pi[col_idx], row_segments_pi[row_idx], col_segments_pi[col_idx + 1], row_segments_pi[row_idx + 1] ] cells_bbox_pi.append(bbox) cells_bbox_pi = torch.tensor(cells_bbox_pi, dtype=torch.float, device=device) cells_bbox.append(cells_bbox_pi) return cells_bbox def align_cells_feat(cells_feat, num_rows, num_cols): batch_size = len(cells_feat) dtype = cells_feat[0].dtype device = cells_feat[0].device max_row_nums = max(num_rows) max_col_nums = max(num_cols) aligned_cells_feat = list() masks = torch.zeros([batch_size, max_row_nums, max_col_nums], dtype=dtype, device=device) for batch_idx in range(batch_size): num_rows_pi = num_rows[batch_idx] num_cols_pi = num_cols[batch_idx] cells_feat_pi = cells_feat[batch_idx] cells_feat_pi = cells_feat_pi.transpose(0, 1).reshape(-1, num_rows_pi, num_cols_pi) aligned_cells_feat_pi = F.pad( cells_feat_pi, (0, max_col_nums - num_cols_pi, 0, max_row_nums - num_rows_pi, 0, 0), mode='constant', value=0 ) aligned_cells_feat.append(aligned_cells_feat_pi) masks[batch_idx, :num_rows_pi, :num_cols_pi] = 1 aligned_cells_feat = torch.stack(aligned_cells_feat, dim=0) return aligned_cells_feat, masks class CellsExtractor(nn.Module): def __init__(self, in_dim, cell_dim, heads, head_nums, pool_size, scale=1): super().__init__() self.in_dim = in_dim self.cell_dim = cell_dim self.pool_size = pool_size self.scale = scale self.box_feat_extractor = RoiPosFeatExtraxtor( self.scale, self.pool_size, self.in_dim, self.cell_dim ) self.heads = heads self.row_sas = nn.ModuleList() self.col_sas = nn.ModuleList() for _ in range(self.heads): self.row_sas.append(SALayer(cell_dim, cell_dim, head_nums)) self.col_sas.append(SALayer(cell_dim, cell_dim, head_nums)) def forward(self, feats, row_segments, col_segments, img_sizes): device = feats.device num_rows = [len(row_segments_pi) - 1 for row_segments_pi in row_segments] num_cols = [len(col_segments_pi) - 1 for col_segments_pi in col_segments] cells_bbox = gen_cells_bbox(row_segments, col_segments, device) cells_feat = self.box_feat_extractor(feats, cells_bbox, img_sizes) aligned_cells_feat, masks = align_cells_feat(cells_feat, num_rows, num_cols) bs, c, nr, nc = aligned_cells_feat.shape for idx in range(self.heads): col_cells_feat = aligned_cells_feat.permute(0, 2, 1, 3).contiguous().reshape(bs * nr, c, nc) col_masks = masks.reshape(bs * nr, nc) col_cells_feat = self.col_sas[idx](col_cells_feat, col_masks) # self-attention aligned_cells_feat = col_cells_feat.reshape(bs, nr, c, nc).permute(0, 2, 1, 3).contiguous() row_cells_feat = aligned_cells_feat.permute(0, 3, 1, 2).contiguous().reshape(bs * nc, c, nr) row_masks = masks.transpose(1, 2).reshape(bs * nc, nr) row_cells_feat = self.row_sas[idx](row_cells_feat, row_masks) # self-attention aligned_cells_feat = row_cells_feat.reshape(bs, nc, c, nr).permute(0, 2, 3, 1).contiguous() return aligned_cells_feat, masks