File size: 5,220 Bytes
cb0ad2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
import math
from torch import nn
from torch.nn import functional as F
from .extractor import RoiPosFeatExtraxtor
class SALayer(nn.Module):
def __init__(self, in_dim, att_dim, head_nums):
super().__init__()
self.in_dim = in_dim
self.att_dim = att_dim
self.head_nums = head_nums
assert self.in_dim % self.head_nums == 0
self.key_layer = nn.Conv1d(self.in_dim, self.att_dim, 1, 1, 0)
self.query_layer = nn.Conv1d(self.in_dim, self.att_dim, 1, 1, 0)
self.value_layer = nn.Conv1d(self.in_dim, self.in_dim, 1, 1, 0)
self.scale = 1 / math.sqrt(self.att_dim)
def forward(self, feats, masks=None):
bs, c, n = feats.shape
keys = self.key_layer(feats).reshape(bs, -1, self.head_nums, n)
querys = self.query_layer(feats).reshape(bs, -1, self.head_nums, n)
values = self.value_layer(feats).reshape(bs, -1, self.head_nums, n)
logits = torch.einsum('bchk,bchq->bhkq', keys, querys) * self.scale
if masks is not None:
logits = logits - (1 - masks[:, None, :, None]) * 1e8
weights = torch.softmax(logits, dim=2)
new_feats = torch.einsum('bchk,bhkq->bchq', values, weights)
new_feats = new_feats.reshape(bs, -1, n)
return new_feats + feats
def gen_cells_bbox(row_segments, col_segments, device):
cells_bbox = list()
for row_segments_pi, col_segments_pi in zip(row_segments, col_segments):
num_rows = len(row_segments_pi) - 1
num_cols = len(col_segments_pi) - 1
cells_bbox_pi = list()
for row_idx in range(num_rows):
for col_idx in range(num_cols):
bbox = [
col_segments_pi[col_idx],
row_segments_pi[row_idx],
col_segments_pi[col_idx + 1],
row_segments_pi[row_idx + 1]
]
cells_bbox_pi.append(bbox)
cells_bbox_pi = torch.tensor(cells_bbox_pi, dtype=torch.float, device=device)
cells_bbox.append(cells_bbox_pi)
return cells_bbox
def align_cells_feat(cells_feat, num_rows, num_cols):
batch_size = len(cells_feat)
dtype = cells_feat[0].dtype
device = cells_feat[0].device
max_row_nums = max(num_rows)
max_col_nums = max(num_cols)
aligned_cells_feat = list()
masks = torch.zeros([batch_size, max_row_nums, max_col_nums], dtype=dtype, device=device)
for batch_idx in range(batch_size):
num_rows_pi = num_rows[batch_idx]
num_cols_pi = num_cols[batch_idx]
cells_feat_pi = cells_feat[batch_idx]
cells_feat_pi = cells_feat_pi.transpose(0, 1).reshape(-1, num_rows_pi, num_cols_pi)
aligned_cells_feat_pi = F.pad(
cells_feat_pi,
(0, max_col_nums - num_cols_pi, 0, max_row_nums - num_rows_pi, 0, 0),
mode='constant',
value=0
)
aligned_cells_feat.append(aligned_cells_feat_pi)
masks[batch_idx, :num_rows_pi, :num_cols_pi] = 1
aligned_cells_feat = torch.stack(aligned_cells_feat, dim=0)
return aligned_cells_feat, masks
class CellsExtractor(nn.Module):
def __init__(self, in_dim, cell_dim, heads, head_nums, pool_size, scale=1):
super().__init__()
self.in_dim = in_dim
self.cell_dim = cell_dim
self.pool_size = pool_size
self.scale = scale
self.box_feat_extractor = RoiPosFeatExtraxtor(
self.scale,
self.pool_size,
self.in_dim,
self.cell_dim
)
self.heads = heads
self.row_sas = nn.ModuleList()
self.col_sas = nn.ModuleList()
for _ in range(self.heads):
self.row_sas.append(SALayer(cell_dim, cell_dim, head_nums))
self.col_sas.append(SALayer(cell_dim, cell_dim, head_nums))
def forward(self, feats, row_segments, col_segments, img_sizes):
device = feats.device
num_rows = [len(row_segments_pi) - 1 for row_segments_pi in row_segments]
num_cols = [len(col_segments_pi) - 1 for col_segments_pi in col_segments]
cells_bbox = gen_cells_bbox(row_segments, col_segments, device)
cells_feat = self.box_feat_extractor(feats, cells_bbox, img_sizes)
aligned_cells_feat, masks = align_cells_feat(cells_feat, num_rows, num_cols)
bs, c, nr, nc = aligned_cells_feat.shape
for idx in range(self.heads):
col_cells_feat = aligned_cells_feat.permute(0, 2, 1, 3).contiguous().reshape(bs * nr, c, nc)
col_masks = masks.reshape(bs * nr, nc)
col_cells_feat = self.col_sas[idx](col_cells_feat, col_masks) # self-attention
aligned_cells_feat = col_cells_feat.reshape(bs, nr, c, nc).permute(0, 2, 1, 3).contiguous()
row_cells_feat = aligned_cells_feat.permute(0, 3, 1, 2).contiguous().reshape(bs * nc, c, nr)
row_masks = masks.transpose(1, 2).reshape(bs * nc, nr)
row_cells_feat = self.row_sas[idx](row_cells_feat, row_masks) # self-attention
aligned_cells_feat = row_cells_feat.reshape(bs, nc, c, nr).permute(0, 2, 3, 1).contiguous()
return aligned_cells_feat, masks
|