kai-2054's picture
Initial commit: add code
cb0ad2d
import torch
from .utils import match_segment_spans, find_unmatch_segment_spans
from .teds import TEDS
class CellMergeAcc:
def __call__(self, preds, labels, labels_mask):
preds = preds & labels_mask
labels = labels & labels_mask
flag = preds == labels
flag = flag.reshape(flag.shape[0], flag.shape[1], -1).min(-1)[0]
mask = labels.reshape(labels.shape[0], labels.shape[1], -1).max(-1)[0]
correct_nums = float(torch.sum(flag & mask).detach().cpu().item())
total_nums = max(float(torch.sum(mask).detach().cpu().item()), 1e-6)
return correct_nums, total_nums
class AccMetric:
def __call__(self, preds, labels, labels_mask):
mask = (labels_mask != 0) & (labels != -1)
correct_nums = float(torch.sum((preds == labels) & mask).detach().cpu().item())
total_nums = max(float(torch.sum(mask).detach().cpu().item()), 1e-6)
return correct_nums, total_nums
def cal_cls_acc(cls_preds, cls_labels):
mask = (cls_labels != -1)
total_nums = float(torch.sum(mask).item())
pred_nums = float(torch.sum((cls_preds == cls_labels) & mask).item())
return pred_nums, total_nums
def cal_segment_pr(pred_segments, fg_spans, bg_spans):
correct_nums = 0
segment_nums = 0
span_nums = 0
for pred_segments_pi, fg_spans_pi, bg_spans_pi in zip(pred_segments, fg_spans, bg_spans):
matched_segments_idx, _ = match_segment_spans(pred_segments_pi, fg_spans_pi)
unmatched_segments_idx = find_unmatch_segment_spans(pred_segments_pi, fg_spans_pi + bg_spans_pi)
correct_nums += len(matched_segments_idx)
segment_nums += len(pred_segments_pi) - len(unmatched_segments_idx)
span_nums += len(fg_spans_pi)
return correct_nums, segment_nums, span_nums
class TEDSMetric:
def __init__(self, num_workers=1, structure_only=False):
self.evaluator = TEDS(n_jobs=num_workers, structure_only=structure_only)
def __call__(self, pred_htmls, label_htmls):
assert len(pred_htmls) == len(label_htmls)
pred_jsons = {idx: pred_html for idx, pred_html in enumerate(pred_htmls)}
label_jsons = {idx: dict(html=label_html) for idx, label_html in enumerate(label_htmls)}
scores = self.evaluator.batch_evaluate(pred_jsons, label_jsons)
scores = [scores[idx] for idx in range(len(pred_htmls))]
return scores