kai-2054's picture
Initial commit: add code
cb0ad2d
# Copyright (c) 2019-present, Zewen Chi
# All rights reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
from typing import Iterable, List, Tuple
def load_chunks(chunk_path):
with open(chunk_path, 'r') as f:
chunks = json.load(f)['chunks']
# NOTE remove the chunk with 0 len
ret = []
for chunk in chunks:
if chunk["pos"][1] < chunk["pos"][0]:
chunk["pos"][0], chunk["pos"][1] = chunk["pos"][1], chunk["pos"][0]
print("Warning load illegal chunk.")
c = Chunk.load_from_dict(chunk)
#if c.x2 == c.x1 or c.y2 == c.y1 or c.text == "":
# continue
ret.append(c)
return ret
class Box(object):
def __init__(self, pos):
"""pos: (x1, x2, y1, y2)"""
self.set_pos(pos)
def set_pos(self, pos):
assert pos[0] <= pos[1]
assert pos[2] <= pos[3]
self.x1 = pos[0]
self.x2 = pos[1]
self.y1 = pos[2]
self.y2 = pos[3]
self.w = self.x2 - self.x1
self.h = self.y2 - self.y1
self.pos = pos
def __lt__(self, other):
return self.pos.__lt__(other.pos)
def __contains__(self, other):
if other.x1 >= self.x1 and other.x2 <= self.x2 and \
other.y1 >= self.y1 and other.y2 <= self.y2:
return True
return False
def __str__(self):
return 'Box(%d, %d, %d, %d)' % self.pos
def __hash__(self):
return self.pos.__hash__()
class Chunk(Box):
def __init__(self, text:str, pos:Tuple, size:float=0.0, cell_id=None):
super(Chunk, self).__init__(pos)
self.text = text
self.size = size
self.cell_id = cell_id
def __str__(self):
return 'Chunk(text="%s", pos=(%d, %d, %d, %d))' % (self.text, *self.pos)
def __repr__(self):
return self.__str__()
def dump_as_json_obj(self):
return {"text":self.text, "pos":self.pos, "cell_id":self.cell_id}
@classmethod
def load_from_dict(cls, d):
assert type(d) == dict
assert type(d["text"]) == str
assert len(d["pos"]) == 4
cell_id = d["cell_id"] if "cell_id" in d else None
return cls(d["text"].strip(), d["pos"], cell_id=cell_id)
class Table(object):
"""
The output of table segmentation.
With the Table object, we can get the set of cells
and their corresponding text.
"""
def __init__(self, row_n, col_n, cells:Iterable[Chunk]=None, tid=""):
# NOTE the Chunk object here represents the coordinate of
# the cell in the table.
# NOTE x in cell object represents the row id
self.tid = tid
self.row_n = row_n
self.col_n = col_n
self.coo2cell_id = [
[ -1 for _ in range(col_n) ] for _ in range(row_n) ]
self.cells:List[Chunk] = []
for cell in cells:
self.add_cell(cell)
def reverse(self, is_col=True):
cells = self.cells
self.cells = []
cell:Chunk = None
for cell in cells:
if is_col:
_c = Chunk(cell.text, (
self.row_n - cell.x2, self.row_n - cell.x1, cell.y1, cell.y2))
else:
_c = Chunk(cell.text, (
cell.x1, cell.x2, self.col_n - cell.y1, self.col_n - cell.y2))
self.add_cell(_c)
def add_cell(self, cell:Chunk):
# TODO Check conflicts of cells
assert cell.y2 < self.col_n
assert cell.x2 < self.row_n
for x in range(cell.x1, cell.x2 + 1, 1):
for y in range(cell.y1, cell.y2 + 1, 1):
self.coo2cell_id[x][y] = len(self.cells)
self.cells.append(cell)
def __getitem__(self, id_tuple):
row_id, col_id = id_tuple
assert row_id < self.row_n and col_id < self.col_n
return self.cells[self.coo2cell_id[row_id][col_id]]