| |
| """Document-level and cross-document entity propagation post-processor. |
| |
| Boosts recall by propagating entity labels to all identical string occurrences |
| within a document (or across all documents with --cross-document). |
| |
| Usage: |
| python3 scripts/entity_propagation.py --input preds.jsonl --output out.jsonl |
| python3 scripts/entity_propagation.py --input preds.jsonl --output out.jsonl --cross-document |
| """ |
|
|
| import argparse |
| import json |
| import re |
| import sys |
| from collections import Counter, defaultdict |
| from pathlib import Path |
|
|
| _MIN_ENTITY_LENGTH = 3 |
|
|
|
|
| def parse_spans(spans_dict: dict) -> list[tuple[str, str, int, int]]: |
| """Parse spans dict into list of (label, text, start, end).""" |
| results = [] |
| for key, positions in spans_dict.items(): |
| label, text = key.split(": ", 1) |
| for start, end in positions: |
| results.append((label, text, start, end)) |
| return results |
|
|
|
|
| def build_spans_dict(spans: list[tuple[str, str, int, int]]) -> dict: |
| """Convert list of (label, text, start, end) back to spans dict format.""" |
| d = defaultdict(list) |
| for label, text, start, end in spans: |
| d[f"{label}: {text}"].append([start, end]) |
| |
| return {k: sorted(v) for k, v in sorted(d.items())} |
|
|
|
|
| def find_all_occurrences(text: str, entity_text: str) -> list[tuple[int, int]]: |
| """Find all exact (case-sensitive) whole-token occurrences of entity_text in text. |
| |
| Uses word boundary matching to avoid substring matches like "Win" in "Windows". |
| """ |
| |
| pattern = r'(?<!\w)' + re.escape(entity_text) + r'(?!\w)' |
| return [(m.start(), m.end()) for m in re.finditer(pattern, text)] |
|
|
|
|
| def overlaps(start: int, end: int, existing: set[tuple[int, int]]) -> bool: |
| """Check if [start, end) overlaps with any existing span.""" |
| for es, ee in existing: |
| if start < ee and end > es: |
| return True |
| return False |
|
|
|
|
| def propagate_within_document( |
| text: str, |
| spans: list[tuple[str, str, int, int]], |
| entity_label_map: dict[str, str] | None = None, |
| ) -> tuple[list[tuple[str, str, int, int]], int]: |
| """Propagate entities within a single document. |
| |
| Args: |
| text: Document text. |
| spans: Existing spans as (label, text, start, end). |
| entity_label_map: Optional pre-built entity→label map (for cross-doc mode). |
| |
| Returns: |
| (all_spans, num_new) — augmented spans and count of new ones added. |
| """ |
| |
| if entity_label_map is None: |
| local_map: dict[str, str] = {} |
| for label, ent_text, _, _ in spans: |
| if len(ent_text) >= _MIN_ENTITY_LENGTH: |
| local_map[ent_text] = label |
| else: |
| local_map = entity_label_map |
|
|
| |
| existing_positions = {(s, e) for _, _, s, e in spans} |
|
|
| new_spans = [] |
| |
| for ent_text in sorted(local_map, key=len, reverse=True): |
| label = local_map[ent_text] |
| occurrences = find_all_occurrences(text, ent_text) |
| for start, end in occurrences: |
| if (start, end) not in existing_positions and not overlaps(start, end, existing_positions): |
| new_spans.append((label, ent_text, start, end)) |
| existing_positions.add((start, end)) |
|
|
| return spans + new_spans, len(new_spans) |
|
|
|
|
| def build_global_entity_map(records: list[dict], min_count: int = 2) -> dict[str, str]: |
| """Build global entity→label map from all documents. |
| |
| Uses majority vote for entities with conflicting labels. |
| Only includes entities appearing min_count+ times. |
| """ |
| |
| entity_label_counts: dict[str, Counter] = defaultdict(Counter) |
| entity_total: Counter = Counter() |
|
|
| for rec in records: |
| seen_in_doc = set() |
| for label, ent_text, _, _ in parse_spans(rec.get("spans", {})): |
| if len(ent_text) < _MIN_ENTITY_LENGTH: |
| continue |
| key = (ent_text, label) |
| if key not in seen_in_doc: |
| entity_label_counts[ent_text][label] += 1 |
| entity_total[ent_text] += 1 |
| seen_in_doc.add(key) |
|
|
| |
| result = {} |
| for ent_text, label_counts in entity_label_counts.items(): |
| total = sum(label_counts.values()) |
| if total >= min_count: |
| result[ent_text] = label_counts.most_common(1)[0][0] |
| return result |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Entity propagation post-processor") |
| parser.add_argument("--input", required=True, help="Input JSONL file") |
| parser.add_argument("--output", required=True, help="Output JSONL file") |
| parser.add_argument("--cross-document", action="store_true", |
| help="Build global entity dict across all documents") |
| parser.add_argument("--min-length", type=int, default=3, |
| help="Minimum entity text length to propagate (default: 3)") |
| parser.add_argument("--min-count", type=int, default=2, |
| help="Min occurrences for cross-document mode (default: 2)") |
| args = parser.parse_args() |
|
|
| global _MIN_ENTITY_LENGTH |
| _MIN_ENTITY_LENGTH = args.min_length |
|
|
| |
| records = [] |
| with open(args.input) as f: |
| for line in f: |
| line = line.strip() |
| if line: |
| records.append(json.loads(line)) |
|
|
| print(f"Loaded {len(records)} records from {args.input}") |
|
|
| |
| global_map = None |
| if args.cross_document: |
| global_map = build_global_entity_map(records, min_count=args.min_count) |
| print(f"Global entity dictionary: {len(global_map)} entities (min_count={args.min_count})") |
|
|
| |
| total_new = 0 |
| class_new: Counter = Counter() |
| output_records = [] |
|
|
| for rec in records: |
| text = rec["text"] |
| existing_spans = parse_spans(rec.get("spans", {})) |
|
|
| if args.cross_document: |
| all_spans, num_new = propagate_within_document(text, existing_spans, global_map) |
| else: |
| all_spans, num_new = propagate_within_document(text, existing_spans) |
|
|
| |
| if num_new > 0: |
| for label, ent_text, start, end in all_spans[len(existing_spans):]: |
| class_new[label] += 1 |
|
|
| total_new += num_new |
|
|
| out_rec = dict(rec) |
| out_rec["spans"] = build_spans_dict(all_spans) |
| output_records.append(out_rec) |
|
|
| |
| Path(args.output).parent.mkdir(parents=True, exist_ok=True) |
| with open(args.output, "w") as f: |
| for rec in output_records: |
| f.write(json.dumps(rec, ensure_ascii=False) + "\n") |
|
|
| |
| mode = "cross-document" if args.cross_document else "within-document" |
| print(f"\n=== Entity Propagation Stats ({mode}) ===") |
| print(f"Documents processed: {len(records)}") |
| print(f"New spans added: {total_new}") |
| if class_new: |
| print(f"\nPer-class breakdown:") |
| for label, count in class_new.most_common(): |
| print(f" {label}: +{count}") |
| else: |
| print("No new spans added.") |
| print(f"\nOutput written to {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|