#!/usr/bin/env python3 """Document-level and cross-document entity propagation post-processor. Boosts recall by propagating entity labels to all identical string occurrences within a document (or across all documents with --cross-document). Usage: python3 scripts/entity_propagation.py --input preds.jsonl --output out.jsonl python3 scripts/entity_propagation.py --input preds.jsonl --output out.jsonl --cross-document """ import argparse import json import re import sys from collections import Counter, defaultdict from pathlib import Path _MIN_ENTITY_LENGTH = 3 # Don't propagate 1-2 char entities def parse_spans(spans_dict: dict) -> list[tuple[str, str, int, int]]: """Parse spans dict into list of (label, text, start, end).""" results = [] for key, positions in spans_dict.items(): label, text = key.split(": ", 1) for start, end in positions: results.append((label, text, start, end)) return results def build_spans_dict(spans: list[tuple[str, str, int, int]]) -> dict: """Convert list of (label, text, start, end) back to spans dict format.""" d = defaultdict(list) for label, text, start, end in spans: d[f"{label}: {text}"].append([start, end]) # Sort positions for consistency return {k: sorted(v) for k, v in sorted(d.items())} def find_all_occurrences(text: str, entity_text: str) -> list[tuple[int, int]]: """Find all exact (case-sensitive) whole-token occurrences of entity_text in text. Uses word boundary matching to avoid substring matches like "Win" in "Windows". """ # Escape for regex, use word boundaries pattern = r'(? bool: """Check if [start, end) overlaps with any existing span.""" for es, ee in existing: if start < ee and end > es: return True return False def propagate_within_document( text: str, spans: list[tuple[str, str, int, int]], entity_label_map: dict[str, str] | None = None, ) -> tuple[list[tuple[str, str, int, int]], int]: """Propagate entities within a single document. Args: text: Document text. spans: Existing spans as (label, text, start, end). entity_label_map: Optional pre-built entity→label map (for cross-doc mode). Returns: (all_spans, num_new) — augmented spans and count of new ones added. """ # Build entity→label from this document's spans (or use provided map) if entity_label_map is None: local_map: dict[str, str] = {} for label, ent_text, _, _ in spans: if len(ent_text) >= _MIN_ENTITY_LENGTH: local_map[ent_text] = label else: local_map = entity_label_map # Existing span positions existing_positions = {(s, e) for _, _, s, e in spans} new_spans = [] # Process longer entities first to avoid substring issues for ent_text in sorted(local_map, key=len, reverse=True): label = local_map[ent_text] occurrences = find_all_occurrences(text, ent_text) for start, end in occurrences: if (start, end) not in existing_positions and not overlaps(start, end, existing_positions): new_spans.append((label, ent_text, start, end)) existing_positions.add((start, end)) return spans + new_spans, len(new_spans) def build_global_entity_map(records: list[dict], min_count: int = 2) -> dict[str, str]: """Build global entity→label map from all documents. Uses majority vote for entities with conflicting labels. Only includes entities appearing min_count+ times. """ # Count (entity_text, label) pairs entity_label_counts: dict[str, Counter] = defaultdict(Counter) entity_total: Counter = Counter() for rec in records: seen_in_doc = set() for label, ent_text, _, _ in parse_spans(rec.get("spans", {})): if len(ent_text) < _MIN_ENTITY_LENGTH: continue key = (ent_text, label) if key not in seen_in_doc: entity_label_counts[ent_text][label] += 1 entity_total[ent_text] += 1 seen_in_doc.add(key) # Filter by min_count and pick majority label result = {} for ent_text, label_counts in entity_label_counts.items(): total = sum(label_counts.values()) if total >= min_count: result[ent_text] = label_counts.most_common(1)[0][0] return result def main(): parser = argparse.ArgumentParser(description="Entity propagation post-processor") parser.add_argument("--input", required=True, help="Input JSONL file") parser.add_argument("--output", required=True, help="Output JSONL file") parser.add_argument("--cross-document", action="store_true", help="Build global entity dict across all documents") parser.add_argument("--min-length", type=int, default=3, help="Minimum entity text length to propagate (default: 3)") parser.add_argument("--min-count", type=int, default=2, help="Min occurrences for cross-document mode (default: 2)") args = parser.parse_args() global _MIN_ENTITY_LENGTH _MIN_ENTITY_LENGTH = args.min_length # Load data records = [] with open(args.input) as f: for line in f: line = line.strip() if line: records.append(json.loads(line)) print(f"Loaded {len(records)} records from {args.input}") # Build global map if cross-document mode global_map = None if args.cross_document: global_map = build_global_entity_map(records, min_count=args.min_count) print(f"Global entity dictionary: {len(global_map)} entities (min_count={args.min_count})") # Propagate total_new = 0 class_new: Counter = Counter() output_records = [] for rec in records: text = rec["text"] existing_spans = parse_spans(rec.get("spans", {})) if args.cross_document: all_spans, num_new = propagate_within_document(text, existing_spans, global_map) else: all_spans, num_new = propagate_within_document(text, existing_spans) # Count per-class new spans if num_new > 0: for label, ent_text, start, end in all_spans[len(existing_spans):]: class_new[label] += 1 total_new += num_new out_rec = dict(rec) out_rec["spans"] = build_spans_dict(all_spans) output_records.append(out_rec) # Write output Path(args.output).parent.mkdir(parents=True, exist_ok=True) with open(args.output, "w") as f: for rec in output_records: f.write(json.dumps(rec, ensure_ascii=False) + "\n") # Stats mode = "cross-document" if args.cross_document else "within-document" print(f"\n=== Entity Propagation Stats ({mode}) ===") print(f"Documents processed: {len(records)}") print(f"New spans added: {total_new}") if class_new: print(f"\nPer-class breakdown:") for label, count in class_new.most_common(): print(f" {label}: +{count}") else: print("No new spans added.") print(f"\nOutput written to {args.output}") if __name__ == "__main__": main()