import torch from .utils import match_segment_spans, find_unmatch_segment_spans from .teds import TEDS class CellMergeAcc: def __call__(self, preds, labels, labels_mask): preds = preds & labels_mask labels = labels & labels_mask flag = preds == labels flag = flag.reshape(flag.shape[0], flag.shape[1], -1).min(-1)[0] mask = labels.reshape(labels.shape[0], labels.shape[1], -1).max(-1)[0] correct_nums = float(torch.sum(flag & mask).detach().cpu().item()) total_nums = max(float(torch.sum(mask).detach().cpu().item()), 1e-6) return correct_nums, total_nums class AccMetric: def __call__(self, preds, labels, labels_mask): mask = (labels_mask != 0) & (labels != -1) correct_nums = float(torch.sum((preds == labels) & mask).detach().cpu().item()) total_nums = max(float(torch.sum(mask).detach().cpu().item()), 1e-6) return correct_nums, total_nums def cal_cls_acc(cls_preds, cls_labels): mask = (cls_labels != -1) total_nums = float(torch.sum(mask).item()) pred_nums = float(torch.sum((cls_preds == cls_labels) & mask).item()) return pred_nums, total_nums def cal_segment_pr(pred_segments, fg_spans, bg_spans): correct_nums = 0 segment_nums = 0 span_nums = 0 for pred_segments_pi, fg_spans_pi, bg_spans_pi in zip(pred_segments, fg_spans, bg_spans): matched_segments_idx, _ = match_segment_spans(pred_segments_pi, fg_spans_pi) unmatched_segments_idx = find_unmatch_segment_spans(pred_segments_pi, fg_spans_pi + bg_spans_pi) correct_nums += len(matched_segments_idx) segment_nums += len(pred_segments_pi) - len(unmatched_segments_idx) span_nums += len(fg_spans_pi) return correct_nums, segment_nums, span_nums class TEDSMetric: def __init__(self, num_workers=1, structure_only=False): self.evaluator = TEDS(n_jobs=num_workers, structure_only=structure_only) def __call__(self, pred_htmls, label_htmls): assert len(pred_htmls) == len(label_htmls) pred_jsons = {idx: pred_html for idx, pred_html in enumerate(pred_htmls)} label_jsons = {idx: dict(html=label_html) for idx, label_html in enumerate(label_htmls)} scores = self.evaluator.batch_evaluate(pred_jsons, label_jsons) scores = [scores[idx] for idx in range(len(pred_htmls))] return scores