File size: 7,463 Bytes
3dac39e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | #!/usr/bin/env python3
"""Document-level and cross-document entity propagation post-processor.
Boosts recall by propagating entity labels to all identical string occurrences
within a document (or across all documents with --cross-document).
Usage:
python3 scripts/entity_propagation.py --input preds.jsonl --output out.jsonl
python3 scripts/entity_propagation.py --input preds.jsonl --output out.jsonl --cross-document
"""
import argparse
import json
import re
import sys
from collections import Counter, defaultdict
from pathlib import Path
_MIN_ENTITY_LENGTH = 3 # Don't propagate 1-2 char entities
def parse_spans(spans_dict: dict) -> list[tuple[str, str, int, int]]:
"""Parse spans dict into list of (label, text, start, end)."""
results = []
for key, positions in spans_dict.items():
label, text = key.split(": ", 1)
for start, end in positions:
results.append((label, text, start, end))
return results
def build_spans_dict(spans: list[tuple[str, str, int, int]]) -> dict:
"""Convert list of (label, text, start, end) back to spans dict format."""
d = defaultdict(list)
for label, text, start, end in spans:
d[f"{label}: {text}"].append([start, end])
# Sort positions for consistency
return {k: sorted(v) for k, v in sorted(d.items())}
def find_all_occurrences(text: str, entity_text: str) -> list[tuple[int, int]]:
"""Find all exact (case-sensitive) whole-token occurrences of entity_text in text.
Uses word boundary matching to avoid substring matches like "Win" in "Windows".
"""
# Escape for regex, use word boundaries
pattern = r'(?<!\w)' + re.escape(entity_text) + r'(?!\w)'
return [(m.start(), m.end()) for m in re.finditer(pattern, text)]
def overlaps(start: int, end: int, existing: set[tuple[int, int]]) -> bool:
"""Check if [start, end) overlaps with any existing span."""
for es, ee in existing:
if start < ee and end > es:
return True
return False
def propagate_within_document(
text: str,
spans: list[tuple[str, str, int, int]],
entity_label_map: dict[str, str] | None = None,
) -> tuple[list[tuple[str, str, int, int]], int]:
"""Propagate entities within a single document.
Args:
text: Document text.
spans: Existing spans as (label, text, start, end).
entity_label_map: Optional pre-built entity→label map (for cross-doc mode).
Returns:
(all_spans, num_new) — augmented spans and count of new ones added.
"""
# Build entity→label from this document's spans (or use provided map)
if entity_label_map is None:
local_map: dict[str, str] = {}
for label, ent_text, _, _ in spans:
if len(ent_text) >= _MIN_ENTITY_LENGTH:
local_map[ent_text] = label
else:
local_map = entity_label_map
# Existing span positions
existing_positions = {(s, e) for _, _, s, e in spans}
new_spans = []
# Process longer entities first to avoid substring issues
for ent_text in sorted(local_map, key=len, reverse=True):
label = local_map[ent_text]
occurrences = find_all_occurrences(text, ent_text)
for start, end in occurrences:
if (start, end) not in existing_positions and not overlaps(start, end, existing_positions):
new_spans.append((label, ent_text, start, end))
existing_positions.add((start, end))
return spans + new_spans, len(new_spans)
def build_global_entity_map(records: list[dict], min_count: int = 2) -> dict[str, str]:
"""Build global entity→label map from all documents.
Uses majority vote for entities with conflicting labels.
Only includes entities appearing min_count+ times.
"""
# Count (entity_text, label) pairs
entity_label_counts: dict[str, Counter] = defaultdict(Counter)
entity_total: Counter = Counter()
for rec in records:
seen_in_doc = set()
for label, ent_text, _, _ in parse_spans(rec.get("spans", {})):
if len(ent_text) < _MIN_ENTITY_LENGTH:
continue
key = (ent_text, label)
if key not in seen_in_doc:
entity_label_counts[ent_text][label] += 1
entity_total[ent_text] += 1
seen_in_doc.add(key)
# Filter by min_count and pick majority label
result = {}
for ent_text, label_counts in entity_label_counts.items():
total = sum(label_counts.values())
if total >= min_count:
result[ent_text] = label_counts.most_common(1)[0][0]
return result
def main():
parser = argparse.ArgumentParser(description="Entity propagation post-processor")
parser.add_argument("--input", required=True, help="Input JSONL file")
parser.add_argument("--output", required=True, help="Output JSONL file")
parser.add_argument("--cross-document", action="store_true",
help="Build global entity dict across all documents")
parser.add_argument("--min-length", type=int, default=3,
help="Minimum entity text length to propagate (default: 3)")
parser.add_argument("--min-count", type=int, default=2,
help="Min occurrences for cross-document mode (default: 2)")
args = parser.parse_args()
global _MIN_ENTITY_LENGTH
_MIN_ENTITY_LENGTH = args.min_length
# Load data
records = []
with open(args.input) as f:
for line in f:
line = line.strip()
if line:
records.append(json.loads(line))
print(f"Loaded {len(records)} records from {args.input}")
# Build global map if cross-document mode
global_map = None
if args.cross_document:
global_map = build_global_entity_map(records, min_count=args.min_count)
print(f"Global entity dictionary: {len(global_map)} entities (min_count={args.min_count})")
# Propagate
total_new = 0
class_new: Counter = Counter()
output_records = []
for rec in records:
text = rec["text"]
existing_spans = parse_spans(rec.get("spans", {}))
if args.cross_document:
all_spans, num_new = propagate_within_document(text, existing_spans, global_map)
else:
all_spans, num_new = propagate_within_document(text, existing_spans)
# Count per-class new spans
if num_new > 0:
for label, ent_text, start, end in all_spans[len(existing_spans):]:
class_new[label] += 1
total_new += num_new
out_rec = dict(rec)
out_rec["spans"] = build_spans_dict(all_spans)
output_records.append(out_rec)
# Write output
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w") as f:
for rec in output_records:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
# Stats
mode = "cross-document" if args.cross_document else "within-document"
print(f"\n=== Entity Propagation Stats ({mode}) ===")
print(f"Documents processed: {len(records)}")
print(f"New spans added: {total_new}")
if class_new:
print(f"\nPer-class breakdown:")
for label, count in class_new.most_common():
print(f" {label}: +{count}")
else:
print("No new spans added.")
print(f"\nOutput written to {args.output}")
if __name__ == "__main__":
main()
|