| from collections import defaultdict |
|
|
| from rex.metrics.tagging import tagging_prf1 |
| from rex.utils.io import load_jsonlines |
| from rex.utils.position import find_all_positions |
|
|
|
|
| def main(): |
| middle_filepath = "outputs/InstructBert_TagSpan_DebertaV3Base_ACE05EN_labelmap_Rel_updateTag_bs32/middle/test.final.jsonl" |
| data = load_jsonlines(middle_filepath) |
| for ins in data: |
| gold = ins["gold"] |
| pred = ins["pred"] |
| if gold["spans"] != pred["spans"]: |
| breakpoint() |
|
|
|
|
| def check_ent_string_matching_upper_bound(filepath: str, strategy: str = "first"): |
| def _check_overlap(x, y): |
| if x[0] > y[1] or y[0] > x[1]: |
| return False |
| else: |
| return True |
|
|
| data = load_jsonlines(filepath) |
| golds = [] |
| preds = [] |
| for ins in data: |
| text = ins["text"] |
| gold_ents = ins["ans"]["ent"] |
| gold_ents = list( |
| set([(ent["text"], ent["type"], tuple(ent["span"])) for ent in gold_ents]) |
| ) |
| gold_ents.sort(key=lambda x: len(x[0]), reverse=True) |
| pred_ents = [] |
| matched = set() |
| for gold_ent in gold_ents: |
| ent_string = gold_ent[0] |
| ent_type = gold_ent[1] |
| positions = find_all_positions(text, ent_string) |
| if strategy == "first": |
| for position in positions: |
| if (ent_type, position) not in matched: |
| matched.add((ent_type, position)) |
| pred_ents.append((ent_string, ent_type, tuple(position))) |
| else: |
| flag = False |
| for position in positions: |
| for _, g in matched: |
| if _check_overlap(g, position): |
| flag = True |
| if flag: |
| continue |
|
|
| if (ent_type, position) not in matched: |
| matched.add((ent_type, position)) |
| pred_ents.append((ent_string, ent_type, tuple(position))) |
| break |
|
|
| golds.append(gold_ents) |
| preds.append(pred_ents) |
|
|
| results = tagging_prf1(golds, preds) |
|
|
| print(f"filepath: {filepath}, Strategy: {strategy}") |
| print(f"Results: {results['micro']}") |
|
|
|
|
| def check_rel_tanl_upper_bound(filepath): |
| data = load_jsonlines(filepath) |
| golds = [] |
| preds = [] |
| for ins in data: |
| text = ins["text"] |
| gold_rels = ins["ans"]["rel"] |
| ent_text_to_spans = defaultdict(set) |
| for ent in ins["ans"]["ent"]: |
| ent_text_to_spans[ent["text"]].add(tuple(ent["span"])) |
| gold_rels = list( |
| set( |
| [ |
| ( |
| tuple(rel["head"]["span"]), |
| rel["relation"], |
| tuple(rel["tail"]["span"]), |
| ) |
| for rel in gold_rels |
| ] |
| ) |
| ) |
| pred_rels = [] |
| for pred_rel in ins["ans"]["rel"]: |
| |
| tail_text = pred_rel["tail"]["text"] |
| if ( |
| tail_text in ent_text_to_spans |
| and len(ent_text_to_spans[tail_text]) == 1 |
| ): |
| tail_span = list(ent_text_to_spans[tail_text])[0] |
| pred_rels.append( |
| (tuple(pred_rel["head"]["span"]), pred_rel["relation"], tail_span) |
| ) |
| |
| |
| |
| |
| |
|
|
| golds.append(gold_rels) |
| preds.append(pred_rels) |
|
|
| results = tagging_prf1(golds, preds) |
|
|
| print(f"filepath: {filepath}") |
| print(f"Results: {results['micro']}") |
|
|
|
|
| if __name__ == "__main__": |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| for filepath in [ |
| "/data/tzhu/Mirror/resources/Mirror/uie/rel/ace05-rel/test.jsonl", |
| "/data/tzhu/Mirror/resources/Mirror/uie/rel/conll04/test.jsonl", |
| "/data/tzhu/Mirror/resources/Mirror/uie/rel/nyt/test.jsonl", |
| "/data/tzhu/Mirror/resources/Mirror/uie/rel/scierc/test.jsonl", |
| ]: |
| check_rel_tanl_upper_bound(filepath) |
|
|