training_sem / libs /model /segment_predictor.py
kai-2054's picture
Initial commit: add code
cb0ad2d
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.activation import ReLU
from libs.utils.metric import cal_segment_pr
from .utils import draw_spans, save_logitmap
def cal_segments(cls_probs, spans, scale=1.0):
segments = list()
for span in spans:
span_cls_probs = cls_probs[int(span[0] * scale): int(span[1] * scale)]
segment = torch.argmax(span_cls_probs).item() + int(span[0] * scale)
segments.append(segment)
segments = [int(item/scale) for item in segments]
return segments
def cal_spans(cls_probs, threshold=0.5):
ids = (cls_probs > threshold).long().tolist()
spans = list()
for idx, id in enumerate(ids):
if id == 1:
if (idx == 0) or (ids[idx-1] != 1):
spans.append([idx, idx+1])
else:
spans[-1][1] = idx + 1
return spans
# draw_spans('row_segment_spans.png', 'row_segment.png', spans, 'row')
def cls_logits_to_segments(segments_logit, masks, type, spans=None, scale=1, threshold=0.5):
if type == 'col':
cls_probs = segments_logit.squeeze(1).sigmoid().mean(dim=1)
lengths = [int(mask[0, :].sum().item()) for mask in masks]
else:
cls_probs = segments_logit.squeeze(1).sigmoid().mean(dim=2)
lengths = [int(mask[:, 0].sum().item()) for mask in masks]
batch_size = cls_probs.shape[0]
segments = list()
for batch_idx in range(batch_size):
length = lengths[batch_idx]
if spans is None:
spans_pi = cal_spans(cls_probs[batch_idx, :length], threshold)
if len(spans_pi) <= 2:
spans_pi = [[0, 1], [length-1, length]]
else:
spans_pi = spans[batch_idx]
segments_pi = cal_segments(cls_probs[batch_idx, :length], spans_pi, scale)
segments.append(segments_pi)
return segments, cls_probs, lengths
def cal_ext_segments(cls_probs, lengths, bg_spans, scale=1, threshold=0.5):
"""
Ѱ�Ҽ�����. ��bg_spans(��line����,����������)��Ѱ��Ԥ��������, �Ҵ���threshold����.
"""
batch_size = cls_probs.shape[0]
ext_segments = list()
for batch_idx in range(batch_size):
length = lengths[batch_idx]
ext_segments_pi = cal_segments(cls_probs[batch_idx, :length], bg_spans[batch_idx], scale)
ext_segments_pi = [segment for segment in ext_segments_pi if cls_probs[batch_idx, segment] > threshold]
ext_segments.append(ext_segments_pi)
return ext_segments
def gen_masks(sizes, scale, device):
batch_size = len(sizes)
max_size = [int(max(item) * scale) for item in zip(*sizes)]
masks = torch.zeros([batch_size, *max_size], dtype=torch.float, device=device)
for batch_idx in range(batch_size):
masks[batch_idx, :sizes[batch_idx][0], :sizes[batch_idx][1]] = 1.
return masks
def gen_targets(sizes, scale, device, fg_spans, bg_spans, type):
batch_size = len(sizes)
max_size = [int(max(item) * scale) for item in zip(*sizes)]
targets = torch.zeros([batch_size, *max_size], dtype=torch.float, device=device)
for batch_idx, fg_spans_pb in enumerate(fg_spans):
if type == 'col':
for fg_spans_pi in fg_spans_pb:
targets[batch_idx, :, int(fg_spans_pi[0] * scale) : int(fg_spans_pi[1] * scale)] = 1.
else:
for fg_spans_pi in fg_spans_pb:
targets[batch_idx, int(fg_spans_pi[0] * scale) : int(fg_spans_pi[1] * scale), :] = 1.
return targets
class SegmentPredictor(nn.Module):
def __init__(self, in_dim, scale=1, threshold=0.5, type=None):
super().__init__()
self.scale = scale
self.in_dim = in_dim
assert type in ['col', 'row']
self.type = type
self.threshold = threshold
self.convs = nn.Sequential(
nn.Conv2d(in_dim, in_dim // 2, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
nn.ReLU(),
nn.Conv2d(in_dim // 2, 1, kernel_size=(1,1), stride=(1,1), padding=(0,0))
)
def forward(self, feats, images_size, fg_spans=None, bg_spans=None):
batch_size = feats.shape[0]
images_size = [image_size[::-1] for image_size in images_size]
segments_logit = self.convs(feats)
masks = gen_masks(images_size, self.scale, feats.device)
# save_logitmap('row_segment.png', segments_logit[0][0])
result_info = dict()
ext_info = dict()
if self.training:
targets = gen_targets(images_size, self.scale, feats.device, fg_spans, bg_spans, self.type)
segments_loss = F.binary_cross_entropy_with_logits(
segments_logit,
targets.unsqueeze(1),
reduction='none'
)
segments_loss = (segments_loss * masks[:, None, :, :]).sum() / targets.sum()
result_info['segments_loss'] = segments_loss
pred_segments, cls_probs, lengths = cls_logits_to_segments(segments_logit, masks, self.type, spans=None, scale=self.scale, threshold=self.threshold)
correct_nums, segment_nums, span_nums = cal_segment_pr(pred_segments, fg_spans, bg_spans)
if segment_nums != 0:
result_info['precision'] = correct_nums/segment_nums
if span_nums != 0:
result_info['recall'] = correct_nums/span_nums
ext_segments = cal_ext_segments(cls_probs, lengths, bg_spans, self.scale, self.threshold)
ext_info['ext_segments'] = ext_segments
pred_segments, *_ = cls_logits_to_segments(segments_logit, masks, self.type, spans=fg_spans, scale=self.scale, threshold=self.threshold)
return pred_segments, result_info, ext_info