File size: 4,710 Bytes
cb0ad2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
# Copyright (c) 2019-present, Zewen Chi
# All rights reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from .relation import Relation
from .table import Table, Chunk
DIR_HORIZ = 1
DIR_VERT = 2
DIR_SAME_CELL = 3
def normalize(s:str, rule=0):
if rule == 0:
s = s.replace("\r", "")
s = s.replace("\n", "")
s = s.replace(" ", "")
s = s.replace("\t", "")
return s.upper()
else:
raise NotImplementedError
def eval_relations(gt:List[List], res:List[List], cmp_blank=True):
"""Evaluate results
Args:
gt: a list of list of Relation
res: a list of list of Relation
"""
#TODO to know how to calculate the total recall and prec
assert len(gt) == len(res)
tot_prec = 0
tot_recall = 0
total = 0
# print("evaluating result...")
# for _gt, _res in tqdm(zip(gt, res)):
# for _gt, _res in tqdm(zip(gt, res), total=len(gt), desc='eval'):
idx, t = 0, len(gt)
for _gt, _res in zip(gt, res):
idx += 1
print('Eval %d/%d (%d%%)' % (idx, t, idx / t * 100), ' ' * 45, end='\r')
corr = compare_rel(_gt, _res, cmp_blank)
precision = corr / len(_res) if len(_res) != 0 else 0
recall = corr / len(_gt) if len(_gt) != 0 else 0
tot_prec += precision
tot_recall += recall
total += 1
# print()
precision = tot_prec / total
recall = tot_recall / total
# print("Test on %d instances. Precision: %.2f, Recall: %.2f" % (
# total, precision, recall))
return precision, recall
def compare_rel(gt_rel:List[Relation], res_rel:List[Relation], cmp_blank=True):
count = 0
#print("compare_rel =======================")
#for gt in gt_rel:
# print("rel gt:", gt.from_text, gt.to_text, gt.direction)
#for gt in res_rel:
# print("rel res:", gt.from_text, gt.to_text, gt.direction)
#print("\n\n\n\n\n")
dup_res_rel = [r for r in res_rel]
for gt in gt_rel:
to_rm = None
for i, res in enumerate(dup_res_rel):
if gt.equal(res, cmp_blank):
to_rm = i
count += 1
break
if to_rm is not None:
dup_res_rel = dup_res_rel[:i] + dup_res_rel[i + 1:]
return count
def Table2Relations(t:Table):
"""Convert a Table object to a List of Relation.
"""
ret = []
cl = t.coo2cell_id
# remove duplicates with pair set
used = set()
# look right
for r in range(t.row_n):
for cFrom in range(t.col_n - 1):
cTo = cFrom + 1
loop = True
while loop and cTo < t.col_n:
fid, tid = cl[r][cFrom], cl[r][cTo]
if fid != -1 and tid != -1 and fid != tid:
if (fid, tid) not in used:
ret.append(Relation(
from_text=t.cells[fid].text,
to_text=t.cells[tid].text,
direction=DIR_HORIZ,
from_id=fid,
to_id=tid,
no_blanks=cTo - cFrom - 1
))
used.add((fid, tid))
loop = False
else:
if fid != -1 and tid != -1 and fid == tid:
cFrom = cTo
cTo += 1
# look down
for c in range(t.col_n):
for rFrom in range(t.row_n - 1):
rTo = rFrom + 1
loop = True
while loop and rTo < t.row_n:
fid, tid = cl[rFrom][c], cl[rTo][c]
if fid != -1 and tid != -1 and fid != tid:
if (fid, tid) not in used:
ret.append(Relation(
from_text=t.cells[fid].text,
to_text=t.cells[tid].text,
direction=DIR_VERT,
from_id=fid,
to_id=tid,
no_blanks=rTo - rFrom - 1
))
used.add((fid, tid))
loop = False
else:
if fid != -1 and tid != -1 and fid == tid:
rFrom = rTo
rTo += 1
return ret
def json2Table(json_obj, tid="", splitted_content=False):
"""Construct a Table object from json object
Args:
json_obj: a json object
Returns:
a Table object
"""
jo = json_obj["cells"]
row_n, col_n = 0, 0
cells = []
for co in jo:
content = co["content"]
if content is None: continue
if splitted_content:
content = " ".join(content)
else:
content = content.strip()
if content == "": continue
start_row = co["start_row"]
end_row = co["end_row"]
start_col = co["start_col"]
end_col = co["end_col"]
row_n = max(row_n, end_row)
col_n = max(col_n, end_col)
cell = Chunk(content, (start_row, end_row, start_col, end_col))
cells.append(cell)
return Table(row_n + 1, col_n + 1, cells, tid)
def json2Relations(json_obj, splitted_content):
return Table2Relations(json2Table(json_obj, "", splitted_content))
|