File size: 3,173 Bytes
cb0ad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from torch import nn
from torchvision.ops import roi_align


def convert_to_roi_format(lines_box):
    concat_boxes = torch.cat(lines_box, dim=0)
    device, dtype = concat_boxes.device, concat_boxes.dtype
    ids = torch.cat(
        [
            torch.full((lines_box_pi.shape[0], 1), i, dtype=dtype, device=device)
            for i, lines_box_pi in enumerate(lines_box)
        ],
        dim=0
    )
    rois = torch.cat([ids, concat_boxes], dim=1)
    return rois


class RoiFeatExtraxtor(nn.Module):
    def __init__(self, scale, pool_size, input_dim, output_dim):
        super().__init__()
        self.scale = scale
        self.pool_size = pool_size
        self.output_dim = output_dim
        input_dim = input_dim * self.pool_size[0] * self.pool_size[1]
        self.fc = nn.Sequential(
            nn.Linear(input_dim, self.output_dim),
            nn.ReLU(),
            nn.Linear(self.output_dim, self.output_dim)
        )

    def forward(self, feats, lines_box):
        rois = convert_to_roi_format(lines_box)
        lines_feat = roi_align(
            input=feats,
            boxes=rois,
            output_size=self.pool_size,
            spatial_scale=self.scale,
            sampling_ratio=2
        )

        lines_feat = lines_feat.reshape(lines_feat.shape[0], -1)
        lines_feat = self.fc(lines_feat)
        lines_feat = torch.split(lines_feat, [item.shape[0] for item in lines_box])
        return list(lines_feat)


class RoiPosFeatExtraxtor(nn.Module):
    def __init__(self, scale, pool_size, input_dim, output_dim):
        super().__init__()
        self.scale = scale
        self.pool_size = pool_size
        self.output_dim = output_dim
        input_dim = input_dim * self.pool_size[0] * self.pool_size[1]
        self.fc = nn.Sequential(
            nn.Linear(input_dim, self.output_dim),
            nn.ReLU(),
            nn.Linear(self.output_dim, self.output_dim)
        )
        self.bbox_ln = nn.LayerNorm(self.output_dim)
        self.bbox_tranform = nn.Linear(4, self.output_dim)

        self.add_ln = nn.LayerNorm(self.output_dim)

    def forward(self, feats, lines_box, img_sizes):
        rois = convert_to_roi_format(lines_box)
        lines_feat = roi_align(
            input=feats,
            boxes=rois,
            output_size=self.pool_size,
            spatial_scale=self.scale,
            sampling_ratio=2
        )
        lines_feat = lines_feat.reshape(lines_feat.shape[0], -1)
        lines_feat = self.fc(lines_feat)
        lines_feat = list(torch.split(lines_feat, [item.shape[0] for item in lines_box]))
        
        # Add Pos Embedding
        feats_H, feats_W = feats.shape[-2:]
        for idx, (line_box, img_size) in enumerate(zip(lines_box, img_sizes)):
            line_box[:, 0] = line_box[:, 0] * self.scale / feats_W
            line_box[:, 1] = line_box[:, 1] * self.scale / feats_H
            line_box[:, 2] = line_box[:, 2] * self.scale / feats_W
            line_box[:, 3] = line_box[:, 3] * self.scale / feats_H
            lines_feat[idx] = self.add_ln(lines_feat[idx] + self.bbox_ln(self.bbox_tranform(line_box)))
        
        return list(lines_feat)