File size: 5,726 Bytes
cb0ad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.activation import ReLU
from libs.utils.metric import cal_segment_pr
from .utils import draw_spans, save_logitmap

def cal_segments(cls_probs, spans, scale=1.0):
    segments = list()
    for span in spans:
        span_cls_probs = cls_probs[int(span[0] * scale): int(span[1] * scale)]
        segment = torch.argmax(span_cls_probs).item() + int(span[0] * scale)
        segments.append(segment)
    segments = [int(item/scale) for item in segments]
    return segments


def cal_spans(cls_probs, threshold=0.5):
    ids = (cls_probs > threshold).long().tolist()
    spans = list()
    for idx, id in enumerate(ids):
        if id == 1:
            if (idx == 0) or (ids[idx-1] != 1):
                spans.append([idx, idx+1])
            else:
                spans[-1][1] = idx + 1
    return spans
# draw_spans('row_segment_spans.png', 'row_segment.png', spans, 'row')

def cls_logits_to_segments(segments_logit, masks, type, spans=None, scale=1, threshold=0.5):
    if type == 'col':
        cls_probs = segments_logit.squeeze(1).sigmoid().mean(dim=1)
        lengths = [int(mask[0, :].sum().item()) for mask in masks]
    else:
        cls_probs = segments_logit.squeeze(1).sigmoid().mean(dim=2)
        lengths = [int(mask[:, 0].sum().item()) for mask in masks]

    batch_size = cls_probs.shape[0]
    segments = list()
    for batch_idx in range(batch_size):
        length = lengths[batch_idx]
        if spans is None:
            spans_pi = cal_spans(cls_probs[batch_idx, :length], threshold)
            if len(spans_pi) <= 2:
                spans_pi = [[0, 1], [length-1, length]]
        else:
            spans_pi = spans[batch_idx]
        segments_pi = cal_segments(cls_probs[batch_idx, :length], spans_pi, scale)
        segments.append(segments_pi)
    return segments, cls_probs, lengths


def cal_ext_segments(cls_probs, lengths, bg_spans, scale=1, threshold=0.5):
    """
    Ѱ�Ҽ�����. ��bg_spans(��line����,����������)��Ѱ��Ԥ��������, �Ҵ���threshold����.
    """
    batch_size = cls_probs.shape[0]
    ext_segments = list()
    for batch_idx in range(batch_size):
        length = lengths[batch_idx]
        ext_segments_pi = cal_segments(cls_probs[batch_idx, :length], bg_spans[batch_idx], scale)
        ext_segments_pi = [segment for segment in ext_segments_pi if cls_probs[batch_idx, segment] > threshold]
        ext_segments.append(ext_segments_pi)
    return ext_segments
    

def gen_masks(sizes, scale, device):
    batch_size = len(sizes)
    max_size = [int(max(item) * scale) for item in zip(*sizes)]
    masks = torch.zeros([batch_size, *max_size], dtype=torch.float, device=device)
    for batch_idx in range(batch_size):
        masks[batch_idx, :sizes[batch_idx][0], :sizes[batch_idx][1]] = 1.
    return masks


def gen_targets(sizes, scale, device, fg_spans, bg_spans, type):
    batch_size = len(sizes)
    max_size = [int(max(item) * scale) for item in zip(*sizes)]
    targets = torch.zeros([batch_size, *max_size], dtype=torch.float, device=device)
    for batch_idx, fg_spans_pb in enumerate(fg_spans):
        if type == 'col':
            for fg_spans_pi in fg_spans_pb:
                targets[batch_idx, :, int(fg_spans_pi[0] * scale) : int(fg_spans_pi[1] * scale)] = 1.
        else:
            for fg_spans_pi in fg_spans_pb:
                targets[batch_idx, int(fg_spans_pi[0] * scale) : int(fg_spans_pi[1] * scale), :] = 1.
    return targets


class SegmentPredictor(nn.Module):
    def __init__(self, in_dim, scale=1, threshold=0.5, type=None):
        super().__init__()
        self.scale = scale
        self.in_dim = in_dim
        assert type in ['col', 'row']
        self.type = type
        self.threshold = threshold
        self.convs = nn.Sequential(
            nn.Conv2d(in_dim, in_dim // 2, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_dim // 2, 1, kernel_size=(1,1), stride=(1,1), padding=(0,0))
        )

    def forward(self, feats, images_size, fg_spans=None, bg_spans=None):
        batch_size = feats.shape[0]
        images_size = [image_size[::-1] for image_size in images_size]
        segments_logit = self.convs(feats)
        masks = gen_masks(images_size, self.scale, feats.device)
        # save_logitmap('row_segment.png', segments_logit[0][0])
        result_info = dict()
        ext_info = dict()

        if self.training:
            targets = gen_targets(images_size, self.scale, feats.device, fg_spans, bg_spans, self.type)
            segments_loss = F.binary_cross_entropy_with_logits(
                segments_logit,
                targets.unsqueeze(1),
                reduction='none'
            )
            segments_loss = (segments_loss * masks[:, None, :, :]).sum() / targets.sum()
            result_info['segments_loss'] = segments_loss
            
            pred_segments, cls_probs, lengths = cls_logits_to_segments(segments_logit, masks, self.type, spans=None, scale=self.scale, threshold=self.threshold)
            correct_nums, segment_nums, span_nums = cal_segment_pr(pred_segments, fg_spans, bg_spans)
            if segment_nums != 0:
                result_info['precision'] = correct_nums/segment_nums
            if span_nums != 0:
                result_info['recall'] = correct_nums/span_nums
            ext_segments = cal_ext_segments(cls_probs, lengths, bg_spans, self.scale, self.threshold)
            ext_info['ext_segments'] = ext_segments

        pred_segments, *_ = cls_logits_to_segments(segments_logit, masks, self.type, spans=fg_spans, scale=self.scale, threshold=self.threshold)
        return pred_segments, result_info, ext_info