kai-2054's picture
Initial commit: add code
cb0ad2d
raw
history blame
17.3 kB
import os
import cv2
import json
import copy
import tqdm
import numpy as np
import fitz
from .extract_table_lines import extract_fg_bg_spans
def get_paths(root_dir, sub_names, names_path, exts, val=None):
# Check the existence of directories
assert os.path.isdir(root_dir)
with open(names_path, "r") as f:
names = f.readlines()
names = [name.strip() for name in names]
# TODO: sub_dirs redundancy
sub_dirs = []
for sub_name in sub_names:
sub_dir = os.path.join(root_dir, sub_name)
assert os.path.isdir(sub_dir), '"%s" is not dir.' % sub_dir
sub_dirs.append(sub_dir)
paths = []
names = names[:val]
for name in tqdm.tqdm(names):
sub_paths = []
for sub_dir, ext in zip(sub_dirs, exts):
sub_path = os.path.join(sub_dir, name + ext)
assert os.path.exists(sub_path), print('%s is not exist' % sub_path)
sub_paths.append(sub_path)
paths.append(sub_paths)
return paths
def get_sub_paths(root_dir, sub_names, exts, val=None):
# Check the existence of directories
assert os.path.isdir(root_dir)
# TODO: sub_dirs redundancy
sub_dirs = []
for sub_name in sub_names:
sub_dir = os.path.join(root_dir, sub_name)
assert os.path.isdir(sub_dir), '"%s" is not dir.' % sub_dir
sub_dirs.append(sub_dir)
paths = []
d = os.listdir(sub_dirs[0])[:val]
for file_name in tqdm.tqdm(d):
sub_paths = [os.path.join(sub_dirs[0], file_name)]
name = os.path.splitext(file_name)[0]
for sub_name, ext in zip(sub_names[1:], exts[1:]):
sub_path = os.path.join(root_dir, sub_name, name + ext)
assert os.path.exists(sub_path)
sub_paths.append(sub_path)
paths.append(sub_paths)
return paths
def cal_wer(label, rec):
dist_mat = np.zeros((len(label) + 1, len(rec) + 1), dtype='int32')
dist_mat[0, :] = range(len(rec) + 1)
dist_mat[:, 0] = range(len(label) + 1)
for i in range(1, len(label) + 1):
for j in range(1, len(rec) + 1):
hit_score = dist_mat[i - 1, j - 1] + (label[i - 1] != rec[j - 1])
ins_score = dist_mat[i, j - 1] + 1
del_score = dist_mat[i - 1, j] + 1
dist_mat[i, j] = min(hit_score, ins_score, del_score)
dist = dist_mat[len(label), len(rec)]
return 1 - dist / len(label)
def visualize(img_path, chunks, structures):
image = cv2.imread(img_path)
for chunk in chunks:
x1, x2, y1, y2 = chunk["pos"]
transcript = chunk["text"]
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255))
cv2.putText(image, ''.join(transcript), (int(x1), int(max(0, y1-1))), cv2.FONT_HERSHEY_COMPLEX, 0.25, (0 , 0, 255), 1)
return image
def visualize_table(img_path, output_dir, table):
img = cv2.imread(img_path)
for cell in table['cells']:
x1, y1, x2, y2 = cell['bbox']
transcript = cell['transcript']
cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255))
cv2.putText(img, ''.join(transcript), (int(x1), int(max(0, y1-1))), cv2.FONT_HERSHEY_COMPLEX, 0.25, (0 , 0, 255), 1)
cv2.imwrite(os.path.join(output_dir, os.path.basename(img_path)), img)
def crop_pdf(path, output_dir, zoom_x = 2.0, zoom_y = 2.0, rotate=0, expand=10, y_fix=.0):
'''
path:[pdf_path, chunk_path]
crop table region in pdf
save pdf_name.png
return list[x1, x2, y1, y2], [str]. note these are corresponding to crop pdf
'''
# load data
with open(path[1], 'r') as f:
chunks = json.load(f)['chunks']
doc = fitz.open(path[0])
pdf_name = os.path.splitext(os.path.basename(path[0]))[0]
assert doc.pageCount == 1, print(pdf_name, ' has more than 1 page!')
# transfer pdf to img
trans = fitz.Matrix(zoom_x, zoom_y).preRotate(rotate)
pm = doc[0].getPixmap(matrix=trans, alpha=False)
pm.writePNG(os.path.join(output_dir, '%s.png' % pdf_name))
# crop table region
pdf_img = cv2.imread(os.path.join(output_dir, '%s.png' % pdf_name))
h, w, *_ = pdf_img.shape
positions = []
transcripts = []
for chunk in chunks:
positions.append([chunk['pos'][0], chunk['pos'][1], chunk['pos'][3], chunk['pos'][2]]) # x1, x2, y2, y1
transcripts.append(chunk["text"])
# the last chunk transcrip is repeated
transcripts[-1] = transcripts[-1][:-1]
positions = np.array(positions)
positions[:, :2] *= zoom_x
positions[:, 2:] = h - positions[:, 2:] * zoom_y
x_min = int(max(0, positions[:, :2].min() - expand))
y_min = int(max(0, positions[:, 2:].min() - expand))
x_max = int(min(w, positions[:, :2].max() + expand))
y_max = int(min(h, positions[:, 2:].max() + expand))
img_crop = pdf_img[y_min:y_max, x_min:x_max]
cv2.imwrite(os.path.join(output_dir, '%s.png' % pdf_name), img_crop)
positions[:, :2] = np.clip(positions[:, :2] - x_min, 0, w)
positions[:, 2] -= y_fix * zoom_y
positions[:, 2:] = np.clip(positions[:, 2:] - y_min, 0, h)
return positions, transcripts
def crop_cells(img_path, output_dir, info, expand=10):
cells = info['cells']
img = cv2.imread(img_path)
h, w, *_ = img.shape
bboxes = [cell['bbox'] for cell in cells if 'bbox' in cell.keys()]
bboxes = np.array(bboxes)
x_min = int(max(bboxes[:, 0].min() - expand, 0))
y_min = int(max(bboxes[:, 1].min() - expand, 0))
x_max = int(min(bboxes[:, 2].max() + expand, w))
y_max = int(min(bboxes[:, 3].max() + expand, h))
cv2.imwrite(os.path.join(output_dir, os.path.splitext(os.path.basename(img_path))[0] + '.png'), img[y_min:y_max, x_min:x_max])
# refine cell bbox
new_cells = []
for cell in cells:
if 'bbox' not in cell.keys():
new_cells.append(cell)
else:
cell['bbox'][0] = max(0, cell['bbox'][0] - x_min)
cell['bbox'][1] = max(0, cell['bbox'][1] - y_min)
cell['bbox'][2] = max(0, cell['bbox'][2] - x_min)
cell['bbox'][3] = max(0, cell['bbox'][3] - y_min)
segmentation = cell['segmentation']
cell['segmentation'] = [[[pt[0] - x_min, pt[1] - y_min] for pt in contour] for contour in segmentation]
new_cells.append(cell)
info['cells'] = new_cells
def visualize_ocr(img_path, output_dir, positions, transcripts):
img = cv2.imread(img_path)
for position, transcript in zip(positions, transcripts):
x1, x2, y1, y2 = position
cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255))
cv2.putText(img, transcript, (int(x1), int(y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,0,0), 1)
cv2.imwrite(os.path.join(output_dir, os.path.splitext(os.path.basename(img_path))[0] + '_ocr.png'), img)
def cal_cell_spans(table):
layout = table['layout']
num_cells = len(table['cells'])
cells_span = list()
for cell_id in range(num_cells):
cell_positions = np.argwhere(layout == cell_id)
y1 = np.min(cell_positions[:, 0])
y2 = np.max(cell_positions[:, 0])
x1 = np.min(cell_positions[:, 1])
x2 = np.max(cell_positions[:, 1])
assert np.all(layout[y1:y2, x1:x2] == cell_id)
cells_span.append([x1, y1, x2, y2])
return cells_span
def visualize_cell(img_path, output_dir, table):
def spans2lines(spans):
lines = []
lines.append(spans[0][0])
for span in spans[1:-1]:
t1, t2 = span
lines.append(int((t1 + t2) / 2))
lines.append(spans[-1][-1])
return lines
img = cv2.imread(img_path)
# draw table lines
rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span = extract_fg_bg_spans(table, img.shape[::-1][-2:])
row_lines = spans2lines(rows_fg_span)
col_lines = spans2lines(cols_fg_span)
for span in cells_span:
x1, y1, x2, y2 = span
cv2.rectangle(img, (int(col_lines[x1]), int(row_lines[y1])), (int(col_lines[x2 + 1]), int(row_lines[y2 + 1])), (0, 0, 255), 2)
# draw ocr results
for cell in table['cells']:
if 'bbox' not in cell.keys():
continue
x1, y1, x2, y2 = cell['bbox']
transcript = cell['transcript']
cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 1)
cv2.putText(img, ''.join(transcript), (int(x1), int(y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,0,0), 1)
cv2.imwrite(os.path.join(output_dir, os.path.splitext(os.path.basename(img_path))[0] + '.png'), img)
def match_cells(path, positions, transcripts, k=16, start=0.333, stop=0.1, stop_percent=0.3, gap=0.25):
'''
path: [pdf_path, chunk_path, structure_path]
positions: [x1, x2, y1, y2],
transcripts: [str]
retrun dict(
'layout':np.array()
'bbox':[x1, y1, x2, y2]
'transcript: str
'head_rows':[]
'body_rows':[]
)
'''
# load data
with open(path[2], 'r') as f:
cells = json.load(f)['cells']
# first sort cells from left to right, from top to down
cells_pos = [] # xl1, yl1, xl2, yl2
contents = []
for cell in cells:
cells_pos.append([cell['start_col'], cell['start_row'], cell['end_col'], cell['end_row']])
contents.append(' '.join(cell['content']))
# sorted cells from left to right, from top to down
sorted_idx = sorted(list(range(len(cells_pos))), key=lambda idx: cells_pos[idx][0] + 1e6 * cells_pos[idx][1])
cells_pos = [cells_pos[idx] for idx in sorted_idx]
contents = [contents[idx] for idx in sorted_idx]
# layout
n_row = np.array(cells_pos)[:, 3].max() + 1
n_col = np.array(cells_pos)[:, 2].max() + 1
layout = np.full((n_row, n_col), -1)
# head_rows & body_rows
head_rows = list(range((np.array(cells_pos)[np.array(cells_pos)[:,1] == 0][:, 3] - np.array(cells_pos)[np.array(cells_pos)[:,1] == 0][:, 1]).max() + 1))
body_rows = list(range((np.array(cells_pos)[np.array(cells_pos)[:,1] == 0][:, 3] - np.array(cells_pos)[np.array(cells_pos)[:,1] == 0][:, 1]).max() + 1, n_row))
lt = [-1, -1]
cells = []
valid_idx = list(range(len(transcripts)))
# init start/end index of ocr results
start_content = ''
for content in contents:
if len(content) > 0:
start_content = content
break
try:
start_index = [cal_wer(start_content, transcript) > start for transcript in transcripts[:k]].index(True)
except:
start_index = 0
end_content = ''
for content in contents[::-1]:
if len(content) > 0:
end_content = content
break
try:
end_index = [cal_wer(end_content, transcript) > start for transcript in transcripts[::-1][:k]].index(True)
except:
end_index = 0
valid_idx = valid_idx[start_index:] if end_index == 0 else valid_idx[start_index: -end_index]
assert len(contents) >= len(valid_idx), print('OCR Results Have Error')
stop_counts = 0
for index, (cell_pos, content) in enumerate(zip(cells_pos, contents)):
# confirm the cell pos is increase
assert cell_pos[0] > lt[0] or cell_pos[1] > lt[1], print('Sorted Cells Have Error')
lt = cell_pos[:2]
xl1, yl1, xl2, yl2 = cell_pos
layout[yl1:yl2+1, xl1:xl2+1] = index
if len(content) == 0:
cells.append(dict(transcript=[]))
else:
is_completed = False
bboxes_list = [positions[valid_idx[0]]]
transcripts_list = [transcripts[valid_idx[0]]]
valid_idx.pop(0)
wer_last = cal_wer(content, ' '.join(transcripts_list))
if wer_last < stop:
bboxes_list = np.array(bboxes_list)
x1 = int(bboxes_list[:, :2].min())
x2 = int(bboxes_list[:, :2].max())
y1 = int(bboxes_list[:, 2:].min())
y2 = int(bboxes_list[:, 2:].max())
cells.append(dict(transcript=list(content), bbox=[x1, y1, x2, y2], segmentation=[[[x1,y1],[x2,y1],[x2,y2],[x1,y2]]]))
stop_counts += 1
continue
for idx in valid_idx[:k]:
if content == ' '.join(transcripts_list):
bboxes_list = np.array(bboxes_list)
x1 = int(bboxes_list[:, :2].min())
x2 = int(bboxes_list[:, :2].max())
y1 = int(bboxes_list[:, 2:].min())
y2 = int(bboxes_list[:, 2:].max())
cells.append(dict(transcript=list(content), bbox=[x1, y1, x2, y2], segmentation=[[[x1,y1],[x2,y1],[x2,y2],[x1,y2]]]))
is_completed = True
break
else:
cur_trans = copy.deepcopy(transcripts_list)
cur_trans.append(transcripts[idx])
wer = cal_wer(content, ' '.join(cur_trans))
# if add new str, and wer is not increase a lot, it should not be added in
if wer < wer_last + gap:
continue
else:
transcripts_list.append(transcripts[idx])
bboxes_list.append(positions[idx])
valid_idx.pop(valid_idx.index(idx))
if wer == 1.0:
break
else:
wer_last = wer
if not is_completed:
bboxes_list = np.array(bboxes_list)
x1 = int(bboxes_list[:, :2].min())
x2 = int(bboxes_list[:, :2].max())
y1 = int(bboxes_list[:, 2:].min())
y2 = int(bboxes_list[:, 2:].max())
cells.append(dict(transcript=list(content), bbox=[x1, y1, x2, y2], segmentation=[[[x1,y1],[x2,y1],[x2,y2],[x1,y2]]]))
assert stop_counts / len(contents) < stop_percent, print('This Table Has Many Error Match with OCR Results')
assert layout.min() == 0, print('This Table Layout is not Completely Resolved')
return dict(
layout=layout,
cells=cells,
head_rows=head_rows,
body_rows=body_rows,
)
def extract_ocr(path, positions, transcripts, k=16, start=0.333):
'''
path: [pdf_path, chunk_path, structure_path]
positions: [x1, x2, y1, y2],
transcripts: [ ]
retrun dict(
'cells':{
'bbox':[x1, y1, x2, y2]
'transcript: []
}
)
'''
# load data
with open(path[2], 'r') as f:
cells = json.load(f)['cells']
# first sort cells from left to right, from top to down
cells_pos = [] # xl1, yl1, xl2, yl2
contents = []
for cell in cells:
cells_pos.append([cell['start_col'], cell['start_row'], cell['end_col'], cell['end_row']])
contents.append(' '.join(cell['content']))
# sorted cells from left to right, from top to down
sorted_idx = sorted(list(range(len(cells_pos))), key=lambda idx: cells_pos[idx][0] + 1e6 * cells_pos[idx][1])
cells_pos = [cells_pos[idx] for idx in sorted_idx]
contents = [contents[idx] for idx in sorted_idx]
# init start/end index, condition is the first/last index must not over split, and wer should be larger than start threshold
valid_idx = list(range(len(transcripts)))
start_content = ''
for content in contents:
if len(content) > 0:
start_content = content
break
try:
start_index = [cal_wer(start_content, transcript) > start for transcript in transcripts[:k]].index(True)
except:
start_index = 0
end_content = ''
for content in contents[::-1]:
if len(content) > 0:
end_content = content
break
try:
end_index = [cal_wer(end_content, transcript) > start for transcript in transcripts[::-1][:k]].index(True)
except:
end_index = 0
valid_idx = valid_idx[start_index:] if end_index == 0 else valid_idx[start_index: -end_index]
cells = []
for idx in valid_idx:
x1, x2, y1, y2 = positions[idx].astype('int').tolist()
cells.append(dict(transcript=list(transcripts[idx]), bbox=[x1, y1, x2, y2], segmentation=[[[x1,y1],[x2,y1],[x2,y2],[x1,y2]]]))
return dict(
cells=cells
)
def refine_table(table, img_path, output_dir, expand=10):
cells = table['cells']
bboxes = [cell['bbox'] for cell in table['cells'] if 'bbox' in cell.keys()]
bboxes = np.array(bboxes)
img = cv2.imread(img_path)
h, w, *_ = img.shape
x1 = int(max(0, bboxes[:, 0].min() - expand))
y1 = int(max(0, bboxes[:, 1].min() - expand))
x2 = int(min(w, bboxes[:, 2].max() + expand))
y2 = int(min(h, bboxes[:, 3].max() + expand))
# refine cells
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2] - x1, 0, 1e6)
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2] - y1, 0, 1e6)
bboxes = bboxes.tolist()
for cell, bbox in zip(cells, bboxes):
cell['bbox'] = bbox
img = img[y1:y2, x1:x2]
cv2.imwrite(os.path.join(output_dir, os.path.basename(img_path)), img)
table['image_path'] = os.path.join(output_dir, os.path.basename(img_path))
return table