|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
from torch.nn.modules.activation import ReLU |
|
|
from libs.utils.metric import cal_segment_pr |
|
|
from .utils import draw_spans, save_logitmap |
|
|
|
|
|
def cal_segments(cls_probs, spans, scale=1.0): |
|
|
segments = list() |
|
|
for span in spans: |
|
|
span_cls_probs = cls_probs[int(span[0] * scale): int(span[1] * scale)] |
|
|
segment = torch.argmax(span_cls_probs).item() + int(span[0] * scale) |
|
|
segments.append(segment) |
|
|
segments = [int(item/scale) for item in segments] |
|
|
return segments |
|
|
|
|
|
|
|
|
def cal_spans(cls_probs, threshold=0.5): |
|
|
ids = (cls_probs > threshold).long().tolist() |
|
|
spans = list() |
|
|
for idx, id in enumerate(ids): |
|
|
if id == 1: |
|
|
if (idx == 0) or (ids[idx-1] != 1): |
|
|
spans.append([idx, idx+1]) |
|
|
else: |
|
|
spans[-1][1] = idx + 1 |
|
|
return spans |
|
|
|
|
|
|
|
|
def cls_logits_to_segments(segments_logit, masks, type, spans=None, scale=1, threshold=0.5): |
|
|
if type == 'col': |
|
|
cls_probs = segments_logit.squeeze(1).sigmoid().mean(dim=1) |
|
|
lengths = [int(mask[0, :].sum().item()) for mask in masks] |
|
|
else: |
|
|
cls_probs = segments_logit.squeeze(1).sigmoid().mean(dim=2) |
|
|
lengths = [int(mask[:, 0].sum().item()) for mask in masks] |
|
|
|
|
|
batch_size = cls_probs.shape[0] |
|
|
segments = list() |
|
|
for batch_idx in range(batch_size): |
|
|
length = lengths[batch_idx] |
|
|
if spans is None: |
|
|
spans_pi = cal_spans(cls_probs[batch_idx, :length], threshold) |
|
|
if len(spans_pi) <= 2: |
|
|
spans_pi = [[0, 1], [length-1, length]] |
|
|
else: |
|
|
spans_pi = spans[batch_idx] |
|
|
segments_pi = cal_segments(cls_probs[batch_idx, :length], spans_pi, scale) |
|
|
segments.append(segments_pi) |
|
|
return segments, cls_probs, lengths |
|
|
|
|
|
|
|
|
def cal_ext_segments(cls_probs, lengths, bg_spans, scale=1, threshold=0.5): |
|
|
""" |
|
|
Ѱ�Ҽ�����. ��bg_spans(��line����,����������)��Ѱ��Ԥ��������, �Ҵ���threshold����. |
|
|
""" |
|
|
batch_size = cls_probs.shape[0] |
|
|
ext_segments = list() |
|
|
for batch_idx in range(batch_size): |
|
|
length = lengths[batch_idx] |
|
|
ext_segments_pi = cal_segments(cls_probs[batch_idx, :length], bg_spans[batch_idx], scale) |
|
|
ext_segments_pi = [segment for segment in ext_segments_pi if cls_probs[batch_idx, segment] > threshold] |
|
|
ext_segments.append(ext_segments_pi) |
|
|
return ext_segments |
|
|
|
|
|
|
|
|
def gen_masks(sizes, scale, device): |
|
|
batch_size = len(sizes) |
|
|
max_size = [int(max(item) * scale) for item in zip(*sizes)] |
|
|
masks = torch.zeros([batch_size, *max_size], dtype=torch.float, device=device) |
|
|
for batch_idx in range(batch_size): |
|
|
masks[batch_idx, :sizes[batch_idx][0], :sizes[batch_idx][1]] = 1. |
|
|
return masks |
|
|
|
|
|
|
|
|
def gen_targets(sizes, scale, device, fg_spans, bg_spans, type): |
|
|
batch_size = len(sizes) |
|
|
max_size = [int(max(item) * scale) for item in zip(*sizes)] |
|
|
targets = torch.zeros([batch_size, *max_size], dtype=torch.float, device=device) |
|
|
for batch_idx, fg_spans_pb in enumerate(fg_spans): |
|
|
if type == 'col': |
|
|
for fg_spans_pi in fg_spans_pb: |
|
|
targets[batch_idx, :, int(fg_spans_pi[0] * scale) : int(fg_spans_pi[1] * scale)] = 1. |
|
|
else: |
|
|
for fg_spans_pi in fg_spans_pb: |
|
|
targets[batch_idx, int(fg_spans_pi[0] * scale) : int(fg_spans_pi[1] * scale), :] = 1. |
|
|
return targets |
|
|
|
|
|
|
|
|
class SegmentPredictor(nn.Module): |
|
|
def __init__(self, in_dim, scale=1, threshold=0.5, type=None): |
|
|
super().__init__() |
|
|
self.scale = scale |
|
|
self.in_dim = in_dim |
|
|
assert type in ['col', 'row'] |
|
|
self.type = type |
|
|
self.threshold = threshold |
|
|
self.convs = nn.Sequential( |
|
|
nn.Conv2d(in_dim, in_dim // 2, kernel_size=(3,3), stride=(1,1), padding=(1,1)), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(in_dim // 2, 1, kernel_size=(1,1), stride=(1,1), padding=(0,0)) |
|
|
) |
|
|
|
|
|
def forward(self, feats, images_size, fg_spans=None, bg_spans=None): |
|
|
batch_size = feats.shape[0] |
|
|
images_size = [image_size[::-1] for image_size in images_size] |
|
|
segments_logit = self.convs(feats) |
|
|
masks = gen_masks(images_size, self.scale, feats.device) |
|
|
|
|
|
result_info = dict() |
|
|
ext_info = dict() |
|
|
|
|
|
if self.training: |
|
|
targets = gen_targets(images_size, self.scale, feats.device, fg_spans, bg_spans, self.type) |
|
|
segments_loss = F.binary_cross_entropy_with_logits( |
|
|
segments_logit, |
|
|
targets.unsqueeze(1), |
|
|
reduction='none' |
|
|
) |
|
|
segments_loss = (segments_loss * masks[:, None, :, :]).sum() / targets.sum() |
|
|
result_info['segments_loss'] = segments_loss |
|
|
|
|
|
pred_segments, cls_probs, lengths = cls_logits_to_segments(segments_logit, masks, self.type, spans=None, scale=self.scale, threshold=self.threshold) |
|
|
correct_nums, segment_nums, span_nums = cal_segment_pr(pred_segments, fg_spans, bg_spans) |
|
|
if segment_nums != 0: |
|
|
result_info['precision'] = correct_nums/segment_nums |
|
|
if span_nums != 0: |
|
|
result_info['recall'] = correct_nums/span_nums |
|
|
ext_segments = cal_ext_segments(cls_probs, lengths, bg_spans, self.scale, self.threshold) |
|
|
ext_info['ext_segments'] = ext_segments |
|
|
|
|
|
pred_segments, *_ = cls_logits_to_segments(segments_logit, masks, self.type, spans=fg_spans, scale=self.scale, threshold=self.threshold) |
|
|
return pred_segments, result_info, ext_info |
|
|
|