File size: 3,625 Bytes
cb0ad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (c) 2019-present, Zewen Chi
# All rights reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json

from typing import Iterable, List, Tuple


def load_chunks(chunk_path):
  with open(chunk_path, 'r') as f:
    chunks = json.load(f)['chunks']
  # NOTE remove the chunk with 0 len
  ret = []
  for chunk in chunks:
    if chunk["pos"][1] < chunk["pos"][0]:
        chunk["pos"][0], chunk["pos"][1] = chunk["pos"][1], chunk["pos"][0]
        print("Warning load illegal chunk.")
    c = Chunk.load_from_dict(chunk)
    #if c.x2 == c.x1 or c.y2 == c.y1 or c.text == "": 
    #    continue
    ret.append(c)
  return ret


class Box(object):

  def __init__(self, pos):
    """pos: (x1, x2, y1, y2)"""
    self.set_pos(pos)
  
  def set_pos(self, pos):
    assert pos[0] <= pos[1]
    assert pos[2] <= pos[3]
    self.x1 = pos[0]
    self.x2 = pos[1]
    self.y1 = pos[2]
    self.y2 = pos[3]
    self.w = self.x2 - self.x1
    self.h = self.y2 - self.y1
    self.pos = pos
  
  def __lt__(self, other):
    return self.pos.__lt__(other.pos)
  
  def __contains__(self, other):
    if other.x1 >= self.x1 and other.x2 <= self.x2 and \
       other.y1 >= self.y1 and other.y2 <= self.y2:
       return True
    return False

  def __str__(self):
    return 'Box(%d, %d, %d, %d)' % self.pos
  
  def __hash__(self):
    return self.pos.__hash__()


class Chunk(Box):

  def __init__(self, text:str, pos:Tuple, size:float=0.0, cell_id=None):
    super(Chunk, self).__init__(pos)
    self.text = text
    self.size = size
    self.cell_id = cell_id

  def __str__(self):
    return 'Chunk(text="%s", pos=(%d, %d, %d, %d))' % (self.text, *self.pos)

  def __repr__(self):
    return self.__str__()
  
  def dump_as_json_obj(self):
    return {"text":self.text, "pos":self.pos, "cell_id":self.cell_id}
  
  @classmethod
  def load_from_dict(cls, d):
    assert type(d) == dict
    assert type(d["text"]) == str
    assert len(d["pos"]) == 4
    cell_id = d["cell_id"] if "cell_id" in d else None
    return cls(d["text"].strip(), d["pos"], cell_id=cell_id)
  

class Table(object):
  
  """
  The output of table segmentation.
  With the Table object, we can get the set of cells
  and their corresponding text.
  """
  def __init__(self, row_n, col_n, cells:Iterable[Chunk]=None, tid=""):
    # NOTE the Chunk object here represents the coordinate of
    # the cell in the table.
    # NOTE x in cell object represents the row id
    self.tid = tid
    self.row_n = row_n
    self.col_n = col_n
    self.coo2cell_id = [
      [ -1 for _ in range(col_n) ] for _ in range(row_n) ]
    self.cells:List[Chunk] = []
    for cell in cells:
      self.add_cell(cell)
  
  def reverse(self, is_col=True):
    cells = self.cells
    self.cells = []
    cell:Chunk = None
    for cell in cells:
      if is_col:
        _c = Chunk(cell.text, (
          self.row_n - cell.x2, self.row_n - cell.x1, cell.y1, cell.y2))
      else:
        _c = Chunk(cell.text, (
          cell.x1, cell.x2, self.col_n - cell.y1, self.col_n - cell.y2))
      self.add_cell(_c)

  def add_cell(self, cell:Chunk):
    # TODO Check conflicts of cells
    assert cell.y2 < self.col_n
    assert cell.x2 < self.row_n

    for x in range(cell.x1, cell.x2 + 1, 1):
      for y in range(cell.y1, cell.y2 + 1, 1):
        self.coo2cell_id[x][y] = len(self.cells)
    self.cells.append(cell)
  
  def __getitem__(self, id_tuple):
    row_id, col_id = id_tuple
    assert row_id < self.row_n and col_id < self.col_n
    return self.cells[self.coo2cell_id[row_id][col_id]]