# Copyright (c) 2019-present, Zewen Chi # All rights reserved. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from typing import List from .relation import Relation from .table import Table, Chunk DIR_HORIZ = 1 DIR_VERT = 2 DIR_SAME_CELL = 3 def normalize(s:str, rule=0): if rule == 0: s = s.replace("\r", "") s = s.replace("\n", "") s = s.replace(" ", "") s = s.replace("\t", "") return s.upper() else: raise NotImplementedError def eval_relations(gt:List[List], res:List[List], cmp_blank=True): """Evaluate results Args: gt: a list of list of Relation res: a list of list of Relation """ #TODO to know how to calculate the total recall and prec assert len(gt) == len(res) tot_prec = 0 tot_recall = 0 total = 0 # print("evaluating result...") # for _gt, _res in tqdm(zip(gt, res)): # for _gt, _res in tqdm(zip(gt, res), total=len(gt), desc='eval'): idx, t = 0, len(gt) for _gt, _res in zip(gt, res): idx += 1 print('Eval %d/%d (%d%%)' % (idx, t, idx / t * 100), ' ' * 45, end='\r') corr = compare_rel(_gt, _res, cmp_blank) precision = corr / len(_res) if len(_res) != 0 else 0 recall = corr / len(_gt) if len(_gt) != 0 else 0 tot_prec += precision tot_recall += recall total += 1 # print() precision = tot_prec / total recall = tot_recall / total # print("Test on %d instances. Precision: %.2f, Recall: %.2f" % ( # total, precision, recall)) return precision, recall def compare_rel(gt_rel:List[Relation], res_rel:List[Relation], cmp_blank=True): count = 0 #print("compare_rel =======================") #for gt in gt_rel: # print("rel gt:", gt.from_text, gt.to_text, gt.direction) #for gt in res_rel: # print("rel res:", gt.from_text, gt.to_text, gt.direction) #print("\n\n\n\n\n") dup_res_rel = [r for r in res_rel] for gt in gt_rel: to_rm = None for i, res in enumerate(dup_res_rel): if gt.equal(res, cmp_blank): to_rm = i count += 1 break if to_rm is not None: dup_res_rel = dup_res_rel[:i] + dup_res_rel[i + 1:] return count def Table2Relations(t:Table): """Convert a Table object to a List of Relation. """ ret = [] cl = t.coo2cell_id # remove duplicates with pair set used = set() # look right for r in range(t.row_n): for cFrom in range(t.col_n - 1): cTo = cFrom + 1 loop = True while loop and cTo < t.col_n: fid, tid = cl[r][cFrom], cl[r][cTo] if fid != -1 and tid != -1 and fid != tid: if (fid, tid) not in used: ret.append(Relation( from_text=t.cells[fid].text, to_text=t.cells[tid].text, direction=DIR_HORIZ, from_id=fid, to_id=tid, no_blanks=cTo - cFrom - 1 )) used.add((fid, tid)) loop = False else: if fid != -1 and tid != -1 and fid == tid: cFrom = cTo cTo += 1 # look down for c in range(t.col_n): for rFrom in range(t.row_n - 1): rTo = rFrom + 1 loop = True while loop and rTo < t.row_n: fid, tid = cl[rFrom][c], cl[rTo][c] if fid != -1 and tid != -1 and fid != tid: if (fid, tid) not in used: ret.append(Relation( from_text=t.cells[fid].text, to_text=t.cells[tid].text, direction=DIR_VERT, from_id=fid, to_id=tid, no_blanks=rTo - rFrom - 1 )) used.add((fid, tid)) loop = False else: if fid != -1 and tid != -1 and fid == tid: rFrom = rTo rTo += 1 return ret def json2Table(json_obj, tid="", splitted_content=False): """Construct a Table object from json object Args: json_obj: a json object Returns: a Table object """ jo = json_obj["cells"] row_n, col_n = 0, 0 cells = [] for co in jo: content = co["content"] if content is None: continue if splitted_content: content = " ".join(content) else: content = content.strip() if content == "": continue start_row = co["start_row"] end_row = co["end_row"] start_col = co["start_col"] end_col = co["end_col"] row_n = max(row_n, end_row) col_n = max(col_n, end_col) cell = Chunk(content, (start_row, end_row, start_col, end_col)) cells.append(cell) return Table(row_n + 1, col_n + 1, cells, tid) def json2Relations(json_obj, splitted_content): return Table2Relations(json2Table(json_obj, "", splitted_content))