File size: 3,625 Bytes
cb0ad2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Copyright (c) 2019-present, Zewen Chi
# All rights reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
from typing import Iterable, List, Tuple
def load_chunks(chunk_path):
with open(chunk_path, 'r') as f:
chunks = json.load(f)['chunks']
# NOTE remove the chunk with 0 len
ret = []
for chunk in chunks:
if chunk["pos"][1] < chunk["pos"][0]:
chunk["pos"][0], chunk["pos"][1] = chunk["pos"][1], chunk["pos"][0]
print("Warning load illegal chunk.")
c = Chunk.load_from_dict(chunk)
#if c.x2 == c.x1 or c.y2 == c.y1 or c.text == "":
# continue
ret.append(c)
return ret
class Box(object):
def __init__(self, pos):
"""pos: (x1, x2, y1, y2)"""
self.set_pos(pos)
def set_pos(self, pos):
assert pos[0] <= pos[1]
assert pos[2] <= pos[3]
self.x1 = pos[0]
self.x2 = pos[1]
self.y1 = pos[2]
self.y2 = pos[3]
self.w = self.x2 - self.x1
self.h = self.y2 - self.y1
self.pos = pos
def __lt__(self, other):
return self.pos.__lt__(other.pos)
def __contains__(self, other):
if other.x1 >= self.x1 and other.x2 <= self.x2 and \
other.y1 >= self.y1 and other.y2 <= self.y2:
return True
return False
def __str__(self):
return 'Box(%d, %d, %d, %d)' % self.pos
def __hash__(self):
return self.pos.__hash__()
class Chunk(Box):
def __init__(self, text:str, pos:Tuple, size:float=0.0, cell_id=None):
super(Chunk, self).__init__(pos)
self.text = text
self.size = size
self.cell_id = cell_id
def __str__(self):
return 'Chunk(text="%s", pos=(%d, %d, %d, %d))' % (self.text, *self.pos)
def __repr__(self):
return self.__str__()
def dump_as_json_obj(self):
return {"text":self.text, "pos":self.pos, "cell_id":self.cell_id}
@classmethod
def load_from_dict(cls, d):
assert type(d) == dict
assert type(d["text"]) == str
assert len(d["pos"]) == 4
cell_id = d["cell_id"] if "cell_id" in d else None
return cls(d["text"].strip(), d["pos"], cell_id=cell_id)
class Table(object):
"""
The output of table segmentation.
With the Table object, we can get the set of cells
and their corresponding text.
"""
def __init__(self, row_n, col_n, cells:Iterable[Chunk]=None, tid=""):
# NOTE the Chunk object here represents the coordinate of
# the cell in the table.
# NOTE x in cell object represents the row id
self.tid = tid
self.row_n = row_n
self.col_n = col_n
self.coo2cell_id = [
[ -1 for _ in range(col_n) ] for _ in range(row_n) ]
self.cells:List[Chunk] = []
for cell in cells:
self.add_cell(cell)
def reverse(self, is_col=True):
cells = self.cells
self.cells = []
cell:Chunk = None
for cell in cells:
if is_col:
_c = Chunk(cell.text, (
self.row_n - cell.x2, self.row_n - cell.x1, cell.y1, cell.y2))
else:
_c = Chunk(cell.text, (
cell.x1, cell.x2, self.col_n - cell.y1, self.col_n - cell.y2))
self.add_cell(_c)
def add_cell(self, cell:Chunk):
# TODO Check conflicts of cells
assert cell.y2 < self.col_n
assert cell.x2 < self.row_n
for x in range(cell.x1, cell.x2 + 1, 1):
for y in range(cell.y1, cell.y2 + 1, 1):
self.coo2cell_id[x][y] = len(self.cells)
self.cells.append(cell)
def __getitem__(self, id_tuple):
row_id, col_id = id_tuple
assert row_id < self.row_n and col_id < self.col_n
return self.cells[self.coo2cell_id[row_id][col_id]] |