kai-2054's picture
Initial commit: add code
cb0ad2d
import cv2
import copy
import torch
import numpy as np
from torch.nn import functional as F
def proposal_colspan(layout, layout_score, srow, scol):
y, x = torch.where(layout == 1)
if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
return layout, layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean()
else:
lf_row = srow
lf_col = scol
col_count = 0
for col_ in range(lf_col, x.max() + 1):
if layout[lf_row, col_] == 1:
col_count = col_count + 1
else:
break
row_count = 0
for row_ in range(lf_row, y.max() + 1):
if torch.all(layout[row_, lf_col: lf_col + col_count] == 1):
row_count = row_count + 1
else:
break
layout[:, :] = 0
layout[lf_row:lf_row + row_count, lf_col : lf_col + col_count] = 1
return layout, layout_score[lf_row:lf_row + row_count, lf_col : lf_col + col_count].mean()
def proposal_rowspan(layout, layout_score, srow, scol):
y, x = torch.where(layout == 1)
if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
return layout, layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean()
else:
lf_row = srow
lf_col = scol
row_count = 0
for row_ in range(lf_row, y.max() + 1):
if layout[row_, lf_col] == 1:
row_count = row_count + 1
else:
break
col_count = 0
for col_ in range(lf_col, x.max() + 1):
if torch.all(layout[lf_row : lf_row + row_count, col_] == 1):
col_count = col_count + 1
else:
break
layout[:, :] = 0
layout[lf_row:lf_row + row_count, lf_col : lf_col + col_count] = 1
return layout, layout_score[lf_row:lf_row + row_count, lf_col : lf_col + col_count].mean()
def proposal_maxcontain(layout, layout_score, srow, scol):
y, x = torch.where(layout == 1)
if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
return layout, layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean()
else:
lf_row = srow
lf_col = scol
layout[:, :] = 0
layout[lf_row: y.max()+1, lf_col : x.max() + 1] = 1
return layout, layout_score[lf_row: y.max()+1, lf_col : x.max() + 1].mean()
def proposal_maxrowspan(layout, layout_score, srow, scol):
y, x = torch.where(layout == 1)
if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
return layout, layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean()
else:
lf_row = srow
lf_col = scol
row_count = 1
for row_ in range(lf_row + 1, y.max() + 1):
if torch.all(layout[lf_row] == layout[row_]):
row_count = row_count + 1
else:
break
layout[:, :] = 0
layout[lf_row : lf_row + row_count, lf_col : x.max() + 1] = 1
return layout, layout_score[lf_row : lf_row + row_count, lf_col : x.max() + 1].mean()
def proposal_maxcolspan(layout, layout_score, srow, scol):
y, x = torch.where(layout == 1)
if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
return layout, layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean()
else:
lf_row = srow
lf_col = scol
col_count = 1
for col_ in range(lf_col + 1, x.max() + 1):
if torch.all(layout[:, lf_col] == layout[:, col_]):
col_count = col_count + 1
else:
break
layout[:, :] = 0
layout[lf_row : y.max() + 1, lf_col : lf_col + col_count] = 1
return layout, layout_score[lf_row : y.max() + 1, lf_col : lf_col + col_count].mean()
def gen_proposals(layout_score, srow, scol, score_threshold=0.5):
layout = layout_score > score_threshold
layout[srow, scol] = 1
y, x = torch.where(layout == 1)
if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
return layout.unsqueeze(0), layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean().unsqueeze(0).log()
else:
proposal_1, score_1 = proposal_colspan(copy.deepcopy(layout), layout_score, srow, scol)
proposal_2, score_2 = proposal_rowspan(copy.deepcopy(layout), layout_score, srow, scol)
proposal_3, score_3 = proposal_maxcontain(copy.deepcopy(layout), layout_score, srow, scol)
proposal_4, score_4 = proposal_maxrowspan(copy.deepcopy(layout), layout_score, srow, scol)
proposal_5, score_5 = proposal_maxcolspan(copy.deepcopy(layout), layout_score, srow, scol)
proposals = torch.stack([proposal_1, proposal_2, proposal_3, proposal_4, proposal_5], dim=0)
scores = torch.stack([score_1.log(), score_2.log(), score_3.log(), score_4.log(), score_5.log()], dim=0)
return proposals, scores
def extend_segments(row_segments, rows_es, col_segments, cols_es, cells_spans, layouts, divide_labels):
batch_size = len(row_segments)
ext_row_segments = list()
ext_col_segments = list()
ext_cells_spans = list()
ext_layouts = list()
ext_divide_labels = list()
for batch_idx in range(batch_size):
row_segments_pi = row_segments[batch_idx]
col_segments_pi = col_segments[batch_idx]
rows_es_pi = rows_es[batch_idx]
cols_es_pi = cols_es[batch_idx]
cells_spans_pi = cells_spans[batch_idx]
ext_row_segments_pi = row_segments_pi + rows_es_pi
ext_col_segments_pi = col_segments_pi + cols_es_pi
row_segments_idx = sorted(list(range(len(ext_row_segments_pi))), key=lambda idx: ext_row_segments_pi[idx])
col_segments_idx = sorted(list(range(len(ext_col_segments_pi))), key=lambda idx: ext_col_segments_pi[idx])
ext_divide_labels.append(row_segments_idx.index(divide_labels[batch_idx].item()))
ext_row_segments.append([ext_row_segments_pi[idx] for idx in row_segments_idx])
ext_col_segments.append([ext_col_segments_pi[idx] for idx in col_segments_idx])
ext_layouts_pi = np.full((len(ext_row_segments_pi) - 1, len(ext_col_segments_pi) - 1), -1)
ext_cells_spans_pi = list()
for cell_idx, cell_span in enumerate(cells_spans_pi):
l, t, r, b = cell_span
l = col_segments_idx.index(l)
r = col_segments_idx.index(r+1) - 1
t = row_segments_idx.index(t)
b = row_segments_idx.index(b+1) - 1
ext_cells_spans_pi.append([l, t, r, b])
ext_layouts_pi[t:b+1, l:r+1] = cell_idx
ext_cells_spans.append(ext_cells_spans_pi)
ext_layouts.append(ext_layouts_pi)
return ext_row_segments, ext_col_segments, ext_cells_spans, aligned_layouts(ext_layouts, layouts), torch.tensor(ext_divide_labels).to(divide_labels.device)
def aligned_layouts(layouts_list, layouts):
batch_size = len(layouts_list)
dtype = layouts.dtype
device = layouts.device
max_row_nums = max([l.shape[0] for l in layouts_list])
max_col_nums = max([l.shape[1] for l in layouts_list])
aligned_layouts = list()
for batch_idx in range(batch_size):
num_rows_pi = layouts_list[batch_idx].shape[0]
num_cols_pi = layouts_list[batch_idx].shape[1]
layouts_pi = torch.from_numpy(layouts_list[batch_idx]).to(dtype=dtype, device=device)
aligned_layouts_pi = F.pad(
layouts_pi,
(0, max_col_nums-num_cols_pi, 0, max_row_nums-num_rows_pi),
mode='constant',
value=-1
)
aligned_layouts.append(aligned_layouts_pi)
aligned_layouts = torch.stack(aligned_layouts, dim=0)
return aligned_layouts
def parse_layout(spans, num_rows, num_cols):
layout = np.full([num_rows, num_cols], -1, dtype=np.int)
cell_count = 0
for x1, y1, x2, y2, prob in spans:
layout[y1:y2+1, x1:x2+1] = cell_count
cell_count += 1
cells_id = list()
for row_idx in range(num_rows):
for col_idx in range(num_cols):
cell_id = layout[row_idx, col_idx]
if cell_id in cells_id:
layout[row_idx, col_idx] = cells_id.index(cell_id)
else:
layout[row_idx, col_idx] = len(cells_id)
cells_id.append(cell_id)
return layout
def parse_cells(layout, row_segments, col_segments):
cells = list()
num_cells = np.max(layout) + 1
for cell_id in range(num_cells):
cell_positions = np.argwhere(layout == cell_id)
y1 = np.min(cell_positions[:, 0])
y2 = np.max(cell_positions[:, 0])
x1 = np.min(cell_positions[:, 1])
x2 = np.max(cell_positions[:, 1])
assert np.all(layout[y1:y2, x1:x2] == cell_id)
x1 = col_segments[x1]
x2 = col_segments[x2+1]
y1 = row_segments[y1]
y2 = row_segments[y2+1]
cell = dict(
segmentation=[[[x1, y1], [x2, y1], [x2, y2], [x1, y2]]]
)
cells.append(cell)
return cells
def process_layout(score, index):
layout = torch.full_like(index, -1)
layout_mask = torch.full_like(index, -1)
nrow, ncol = score.shape
for cell_id in range(nrow * ncol):
if layout_mask.min() != -1:
break
crow, ccol = torch.where(layout_mask == layout_mask.min())
ccol = ccol[crow == crow.min()].min()
crow = crow.min()
id = index[crow, ccol]
h, w = torch.where(index == id)
if h.shape[0] == 1 or w.shape[0] == 1:
layout_mask[h, w] = 1
layout[h, w] = cell_id
continue
else:
h_min = h.min()
h_max = h.max()
w_min = w.min()
w_max = w.max()
if torch.all(index[h_min:h_max+1, w_min:w_max+1] == id):
layout_mask[h_min:h_max+1, w_min:w_max+1] = 1
layout[h_min:h_max+1, w_min:w_max+1] = cell_id
else:
lf_row = crow
lf_col = ccol
col_mem = -1
for col_ in range(lf_col, w_max + 1):
if index[lf_row, col_] == id:
layout_mask[lf_row, col_] = 1
layout[lf_row, col_] = cell_id
col_mem = col_
else:
break
for row_ in range(lf_row + 1, h_max + 1):
if torch.all(index[row_, lf_col: col_mem + 1] == id):
layout_mask[row_, lf_col: col_mem + 1] = 1
layout[row_, lf_col: col_mem + 1] = cell_id
else:
break
return layout
def process_layout(score, index, use_score=False, is_merge=True, score_threshold=0.5):
if use_score:
if is_merge:
y, x = torch.where(score < score_threshold)
index[y, x] = index.max() + 1
else:
y, x = torch.where(score < score_threshold)
index[y, x] = torch.arange(index.max() + 1, index.max() + 1 + len(y)).to(index.device, index.dtype)
layout = torch.full_like(index, -1)
layout_mask = torch.full_like(index, -1)
nrow, ncol = score.shape
for cell_id in range(max(nrow * ncol, index.max() + 1)):
if layout_mask.min() != -1:
break
crow, ccol = torch.where(layout_mask == layout_mask.min())
ccol = ccol[crow == crow.min()].min()
crow = crow.min()
id = index[crow, ccol]
h, w = torch.where(index == id)
if h.shape[0] == 1 or w.shape[0] == 1:
layout_mask[h, w] = 1
layout[h, w] = cell_id
continue
else:
h_min = h.min()
h_max = h.max()
w_min = w.min()
w_max = w.max()
if torch.all(index[h_min:h_max+1, w_min:w_max+1] == id):
layout_mask[h_min:h_max+1, w_min:w_max+1] = 1
layout[h_min:h_max+1, w_min:w_max+1] = cell_id
else:
lf_row = crow
lf_col = ccol
col_mem = -1
for col_ in range(lf_col, w_max + 1):
if index[lf_row, col_] == id:
layout_mask[lf_row, col_] = 1
layout[lf_row, col_] = cell_id
col_mem = col_
else:
break
for row_ in range(lf_row + 1, h_max + 1):
if torch.all(index[row_, lf_col: col_mem + 1] == id):
layout_mask[row_, lf_col: col_mem + 1] = 1
layout[row_, lf_col: col_mem + 1] = cell_id
else:
break
return layout
def layout2spans(layout):
rows, cols = layout.shape[-2:]
cells_span = list()
for cell_id in range(rows * cols):
cell_positions = np.argwhere(layout == cell_id)
if len(cell_positions) == 0:
continue
y1 = np.min(cell_positions[:, 0])
y2 = np.max(cell_positions[:, 0])
x1 = np.min(cell_positions[:, 1])
x2 = np.max(cell_positions[:, 1])
assert np.all(layout[y1:y2, x1:x2] == cell_id)
cells_span.append([x1, y1, x2, y2])
return [cells_span]
def spatial_att_to_spans(spatial_att_weight_pred):
max_score, max_index = spatial_att_weight_pred.max(dim=0)
layout = process_layout(max_score, max_index, use_score=True, is_merge=False)
layout = process_layout(max_score, layout)
layout = layout.cpu().numpy()
spans = layout2spans(layout)
return spans
def save_logitmap(filename, logit):
cv2.imwrite(filename, (logit.sigmoid()*255).cpu().numpy().astype('uint8'))
def draw_spans(dst, src, spans, type):
image = cv2.imread(src)
H, W, *_ = image.shape
for span in spans:
if type == 'col':
cv2.rectangle(image, (span[0], 0), (span[1], H), (0, 0, 255), thickness=1)
elif type == 'row':
cv2.rectangle(image, (0, span[0]), (W, span[1]), (0, 0, 255), thickness=1)
cv2.imwrite(dst, image)