File size: 5,726 Bytes
cb0ad2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.activation import ReLU
from libs.utils.metric import cal_segment_pr
from .utils import draw_spans, save_logitmap
def cal_segments(cls_probs, spans, scale=1.0):
segments = list()
for span in spans:
span_cls_probs = cls_probs[int(span[0] * scale): int(span[1] * scale)]
segment = torch.argmax(span_cls_probs).item() + int(span[0] * scale)
segments.append(segment)
segments = [int(item/scale) for item in segments]
return segments
def cal_spans(cls_probs, threshold=0.5):
ids = (cls_probs > threshold).long().tolist()
spans = list()
for idx, id in enumerate(ids):
if id == 1:
if (idx == 0) or (ids[idx-1] != 1):
spans.append([idx, idx+1])
else:
spans[-1][1] = idx + 1
return spans
# draw_spans('row_segment_spans.png', 'row_segment.png', spans, 'row')
def cls_logits_to_segments(segments_logit, masks, type, spans=None, scale=1, threshold=0.5):
if type == 'col':
cls_probs = segments_logit.squeeze(1).sigmoid().mean(dim=1)
lengths = [int(mask[0, :].sum().item()) for mask in masks]
else:
cls_probs = segments_logit.squeeze(1).sigmoid().mean(dim=2)
lengths = [int(mask[:, 0].sum().item()) for mask in masks]
batch_size = cls_probs.shape[0]
segments = list()
for batch_idx in range(batch_size):
length = lengths[batch_idx]
if spans is None:
spans_pi = cal_spans(cls_probs[batch_idx, :length], threshold)
if len(spans_pi) <= 2:
spans_pi = [[0, 1], [length-1, length]]
else:
spans_pi = spans[batch_idx]
segments_pi = cal_segments(cls_probs[batch_idx, :length], spans_pi, scale)
segments.append(segments_pi)
return segments, cls_probs, lengths
def cal_ext_segments(cls_probs, lengths, bg_spans, scale=1, threshold=0.5):
"""
Ѱ�Ҽ�����. ��bg_spans(��line����,����������)��Ѱ��Ԥ��������, �Ҵ���threshold����.
"""
batch_size = cls_probs.shape[0]
ext_segments = list()
for batch_idx in range(batch_size):
length = lengths[batch_idx]
ext_segments_pi = cal_segments(cls_probs[batch_idx, :length], bg_spans[batch_idx], scale)
ext_segments_pi = [segment for segment in ext_segments_pi if cls_probs[batch_idx, segment] > threshold]
ext_segments.append(ext_segments_pi)
return ext_segments
def gen_masks(sizes, scale, device):
batch_size = len(sizes)
max_size = [int(max(item) * scale) for item in zip(*sizes)]
masks = torch.zeros([batch_size, *max_size], dtype=torch.float, device=device)
for batch_idx in range(batch_size):
masks[batch_idx, :sizes[batch_idx][0], :sizes[batch_idx][1]] = 1.
return masks
def gen_targets(sizes, scale, device, fg_spans, bg_spans, type):
batch_size = len(sizes)
max_size = [int(max(item) * scale) for item in zip(*sizes)]
targets = torch.zeros([batch_size, *max_size], dtype=torch.float, device=device)
for batch_idx, fg_spans_pb in enumerate(fg_spans):
if type == 'col':
for fg_spans_pi in fg_spans_pb:
targets[batch_idx, :, int(fg_spans_pi[0] * scale) : int(fg_spans_pi[1] * scale)] = 1.
else:
for fg_spans_pi in fg_spans_pb:
targets[batch_idx, int(fg_spans_pi[0] * scale) : int(fg_spans_pi[1] * scale), :] = 1.
return targets
class SegmentPredictor(nn.Module):
def __init__(self, in_dim, scale=1, threshold=0.5, type=None):
super().__init__()
self.scale = scale
self.in_dim = in_dim
assert type in ['col', 'row']
self.type = type
self.threshold = threshold
self.convs = nn.Sequential(
nn.Conv2d(in_dim, in_dim // 2, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
nn.ReLU(),
nn.Conv2d(in_dim // 2, 1, kernel_size=(1,1), stride=(1,1), padding=(0,0))
)
def forward(self, feats, images_size, fg_spans=None, bg_spans=None):
batch_size = feats.shape[0]
images_size = [image_size[::-1] for image_size in images_size]
segments_logit = self.convs(feats)
masks = gen_masks(images_size, self.scale, feats.device)
# save_logitmap('row_segment.png', segments_logit[0][0])
result_info = dict()
ext_info = dict()
if self.training:
targets = gen_targets(images_size, self.scale, feats.device, fg_spans, bg_spans, self.type)
segments_loss = F.binary_cross_entropy_with_logits(
segments_logit,
targets.unsqueeze(1),
reduction='none'
)
segments_loss = (segments_loss * masks[:, None, :, :]).sum() / targets.sum()
result_info['segments_loss'] = segments_loss
pred_segments, cls_probs, lengths = cls_logits_to_segments(segments_logit, masks, self.type, spans=None, scale=self.scale, threshold=self.threshold)
correct_nums, segment_nums, span_nums = cal_segment_pr(pred_segments, fg_spans, bg_spans)
if segment_nums != 0:
result_info['precision'] = correct_nums/segment_nums
if span_nums != 0:
result_info['recall'] = correct_nums/span_nums
ext_segments = cal_ext_segments(cls_probs, lengths, bg_spans, self.scale, self.threshold)
ext_info['ext_segments'] = ext_segments
pred_segments, *_ = cls_logits_to_segments(segments_logit, masks, self.type, spans=fg_spans, scale=self.scale, threshold=self.threshold)
return pred_segments, result_info, ext_info
|