| import torch |
| from .utils import match_segment_spans, find_unmatch_segment_spans |
| from .teds import TEDS |
|
|
|
|
| class CellMergeAcc: |
| def __call__(self, preds, labels, labels_mask): |
| preds = preds & labels_mask |
| labels = labels & labels_mask |
| flag = preds == labels |
| flag = flag.reshape(flag.shape[0], flag.shape[1], -1).min(-1)[0] |
| mask = labels.reshape(labels.shape[0], labels.shape[1], -1).max(-1)[0] |
| correct_nums = float(torch.sum(flag & mask).detach().cpu().item()) |
| total_nums = max(float(torch.sum(mask).detach().cpu().item()), 1e-6) |
| return correct_nums, total_nums |
|
|
|
|
| class AccMetric: |
| def __call__(self, preds, labels, labels_mask): |
| mask = (labels_mask != 0) & (labels != -1) |
| correct_nums = float(torch.sum((preds == labels) & mask).detach().cpu().item()) |
| total_nums = max(float(torch.sum(mask).detach().cpu().item()), 1e-6) |
| return correct_nums, total_nums |
|
|
|
|
| def cal_cls_acc(cls_preds, cls_labels): |
| mask = (cls_labels != -1) |
| total_nums = float(torch.sum(mask).item()) |
| pred_nums = float(torch.sum((cls_preds == cls_labels) & mask).item()) |
| return pred_nums, total_nums |
|
|
|
|
| def cal_segment_pr(pred_segments, fg_spans, bg_spans): |
| correct_nums = 0 |
| segment_nums = 0 |
| span_nums = 0 |
| for pred_segments_pi, fg_spans_pi, bg_spans_pi in zip(pred_segments, fg_spans, bg_spans): |
| matched_segments_idx, _ = match_segment_spans(pred_segments_pi, fg_spans_pi) |
| unmatched_segments_idx = find_unmatch_segment_spans(pred_segments_pi, fg_spans_pi + bg_spans_pi) |
|
|
| correct_nums += len(matched_segments_idx) |
| segment_nums += len(pred_segments_pi) - len(unmatched_segments_idx) |
| span_nums += len(fg_spans_pi) |
|
|
| return correct_nums, segment_nums, span_nums |
|
|
|
|
| class TEDSMetric: |
| def __init__(self, num_workers=1, structure_only=False): |
| self.evaluator = TEDS(n_jobs=num_workers, structure_only=structure_only) |
|
|
| def __call__(self, pred_htmls, label_htmls): |
| assert len(pred_htmls) == len(label_htmls) |
| pred_jsons = {idx: pred_html for idx, pred_html in enumerate(pred_htmls)} |
| label_jsons = {idx: dict(html=label_html) for idx, label_html in enumerate(label_htmls)} |
| scores = self.evaluator.batch_evaluate(pred_jsons, label_jsons) |
| scores = [scores[idx] for idx in range(len(pred_htmls))] |
| return scores |
|
|