arcspan / scripts /entity_propagation.py
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
#!/usr/bin/env python3
"""Document-level and cross-document entity propagation post-processor.
Boosts recall by propagating entity labels to all identical string occurrences
within a document (or across all documents with --cross-document).
Usage:
python3 scripts/entity_propagation.py --input preds.jsonl --output out.jsonl
python3 scripts/entity_propagation.py --input preds.jsonl --output out.jsonl --cross-document
"""
import argparse
import json
import re
import sys
from collections import Counter, defaultdict
from pathlib import Path
_MIN_ENTITY_LENGTH = 3 # Don't propagate 1-2 char entities
def parse_spans(spans_dict: dict) -> list[tuple[str, str, int, int]]:
"""Parse spans dict into list of (label, text, start, end)."""
results = []
for key, positions in spans_dict.items():
label, text = key.split(": ", 1)
for start, end in positions:
results.append((label, text, start, end))
return results
def build_spans_dict(spans: list[tuple[str, str, int, int]]) -> dict:
"""Convert list of (label, text, start, end) back to spans dict format."""
d = defaultdict(list)
for label, text, start, end in spans:
d[f"{label}: {text}"].append([start, end])
# Sort positions for consistency
return {k: sorted(v) for k, v in sorted(d.items())}
def find_all_occurrences(text: str, entity_text: str) -> list[tuple[int, int]]:
"""Find all exact (case-sensitive) whole-token occurrences of entity_text in text.
Uses word boundary matching to avoid substring matches like "Win" in "Windows".
"""
# Escape for regex, use word boundaries
pattern = r'(?<!\w)' + re.escape(entity_text) + r'(?!\w)'
return [(m.start(), m.end()) for m in re.finditer(pattern, text)]
def overlaps(start: int, end: int, existing: set[tuple[int, int]]) -> bool:
"""Check if [start, end) overlaps with any existing span."""
for es, ee in existing:
if start < ee and end > es:
return True
return False
def propagate_within_document(
text: str,
spans: list[tuple[str, str, int, int]],
entity_label_map: dict[str, str] | None = None,
) -> tuple[list[tuple[str, str, int, int]], int]:
"""Propagate entities within a single document.
Args:
text: Document text.
spans: Existing spans as (label, text, start, end).
entity_label_map: Optional pre-built entity→label map (for cross-doc mode).
Returns:
(all_spans, num_new) — augmented spans and count of new ones added.
"""
# Build entity→label from this document's spans (or use provided map)
if entity_label_map is None:
local_map: dict[str, str] = {}
for label, ent_text, _, _ in spans:
if len(ent_text) >= _MIN_ENTITY_LENGTH:
local_map[ent_text] = label
else:
local_map = entity_label_map
# Existing span positions
existing_positions = {(s, e) for _, _, s, e in spans}
new_spans = []
# Process longer entities first to avoid substring issues
for ent_text in sorted(local_map, key=len, reverse=True):
label = local_map[ent_text]
occurrences = find_all_occurrences(text, ent_text)
for start, end in occurrences:
if (start, end) not in existing_positions and not overlaps(start, end, existing_positions):
new_spans.append((label, ent_text, start, end))
existing_positions.add((start, end))
return spans + new_spans, len(new_spans)
def build_global_entity_map(records: list[dict], min_count: int = 2) -> dict[str, str]:
"""Build global entity→label map from all documents.
Uses majority vote for entities with conflicting labels.
Only includes entities appearing min_count+ times.
"""
# Count (entity_text, label) pairs
entity_label_counts: dict[str, Counter] = defaultdict(Counter)
entity_total: Counter = Counter()
for rec in records:
seen_in_doc = set()
for label, ent_text, _, _ in parse_spans(rec.get("spans", {})):
if len(ent_text) < _MIN_ENTITY_LENGTH:
continue
key = (ent_text, label)
if key not in seen_in_doc:
entity_label_counts[ent_text][label] += 1
entity_total[ent_text] += 1
seen_in_doc.add(key)
# Filter by min_count and pick majority label
result = {}
for ent_text, label_counts in entity_label_counts.items():
total = sum(label_counts.values())
if total >= min_count:
result[ent_text] = label_counts.most_common(1)[0][0]
return result
def main():
parser = argparse.ArgumentParser(description="Entity propagation post-processor")
parser.add_argument("--input", required=True, help="Input JSONL file")
parser.add_argument("--output", required=True, help="Output JSONL file")
parser.add_argument("--cross-document", action="store_true",
help="Build global entity dict across all documents")
parser.add_argument("--min-length", type=int, default=3,
help="Minimum entity text length to propagate (default: 3)")
parser.add_argument("--min-count", type=int, default=2,
help="Min occurrences for cross-document mode (default: 2)")
args = parser.parse_args()
global _MIN_ENTITY_LENGTH
_MIN_ENTITY_LENGTH = args.min_length
# Load data
records = []
with open(args.input) as f:
for line in f:
line = line.strip()
if line:
records.append(json.loads(line))
print(f"Loaded {len(records)} records from {args.input}")
# Build global map if cross-document mode
global_map = None
if args.cross_document:
global_map = build_global_entity_map(records, min_count=args.min_count)
print(f"Global entity dictionary: {len(global_map)} entities (min_count={args.min_count})")
# Propagate
total_new = 0
class_new: Counter = Counter()
output_records = []
for rec in records:
text = rec["text"]
existing_spans = parse_spans(rec.get("spans", {}))
if args.cross_document:
all_spans, num_new = propagate_within_document(text, existing_spans, global_map)
else:
all_spans, num_new = propagate_within_document(text, existing_spans)
# Count per-class new spans
if num_new > 0:
for label, ent_text, start, end in all_spans[len(existing_spans):]:
class_new[label] += 1
total_new += num_new
out_rec = dict(rec)
out_rec["spans"] = build_spans_dict(all_spans)
output_records.append(out_rec)
# Write output
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w") as f:
for rec in output_records:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
# Stats
mode = "cross-document" if args.cross_document else "within-document"
print(f"\n=== Entity Propagation Stats ({mode}) ===")
print(f"Documents processed: {len(records)}")
print(f"New spans added: {total_new}")
if class_new:
print(f"\nPer-class breakdown:")
for label, count in class_new.most_common():
print(f" {label}: +{count}")
else:
print("No new spans added.")
print(f"\nOutput written to {args.output}")
if __name__ == "__main__":
main()