File size: 7,463 Bytes
3dac39e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
#!/usr/bin/env python3
"""Document-level and cross-document entity propagation post-processor.

Boosts recall by propagating entity labels to all identical string occurrences
within a document (or across all documents with --cross-document).

Usage:
    python3 scripts/entity_propagation.py --input preds.jsonl --output out.jsonl
    python3 scripts/entity_propagation.py --input preds.jsonl --output out.jsonl --cross-document
"""

import argparse
import json
import re
import sys
from collections import Counter, defaultdict
from pathlib import Path

_MIN_ENTITY_LENGTH = 3  # Don't propagate 1-2 char entities


def parse_spans(spans_dict: dict) -> list[tuple[str, str, int, int]]:
    """Parse spans dict into list of (label, text, start, end)."""
    results = []
    for key, positions in spans_dict.items():
        label, text = key.split(": ", 1)
        for start, end in positions:
            results.append((label, text, start, end))
    return results


def build_spans_dict(spans: list[tuple[str, str, int, int]]) -> dict:
    """Convert list of (label, text, start, end) back to spans dict format."""
    d = defaultdict(list)
    for label, text, start, end in spans:
        d[f"{label}: {text}"].append([start, end])
    # Sort positions for consistency
    return {k: sorted(v) for k, v in sorted(d.items())}


def find_all_occurrences(text: str, entity_text: str) -> list[tuple[int, int]]:
    """Find all exact (case-sensitive) whole-token occurrences of entity_text in text.

    Uses word boundary matching to avoid substring matches like "Win" in "Windows".
    """
    # Escape for regex, use word boundaries
    pattern = r'(?<!\w)' + re.escape(entity_text) + r'(?!\w)'
    return [(m.start(), m.end()) for m in re.finditer(pattern, text)]


def overlaps(start: int, end: int, existing: set[tuple[int, int]]) -> bool:
    """Check if [start, end) overlaps with any existing span."""
    for es, ee in existing:
        if start < ee and end > es:
            return True
    return False


def propagate_within_document(
    text: str,
    spans: list[tuple[str, str, int, int]],
    entity_label_map: dict[str, str] | None = None,
) -> tuple[list[tuple[str, str, int, int]], int]:
    """Propagate entities within a single document.

    Args:
        text: Document text.
        spans: Existing spans as (label, text, start, end).
        entity_label_map: Optional pre-built entity→label map (for cross-doc mode).

    Returns:
        (all_spans, num_new) — augmented spans and count of new ones added.
    """
    # Build entity→label from this document's spans (or use provided map)
    if entity_label_map is None:
        local_map: dict[str, str] = {}
        for label, ent_text, _, _ in spans:
            if len(ent_text) >= _MIN_ENTITY_LENGTH:
                local_map[ent_text] = label
    else:
        local_map = entity_label_map

    # Existing span positions
    existing_positions = {(s, e) for _, _, s, e in spans}

    new_spans = []
    # Process longer entities first to avoid substring issues
    for ent_text in sorted(local_map, key=len, reverse=True):
        label = local_map[ent_text]
        occurrences = find_all_occurrences(text, ent_text)
        for start, end in occurrences:
            if (start, end) not in existing_positions and not overlaps(start, end, existing_positions):
                new_spans.append((label, ent_text, start, end))
                existing_positions.add((start, end))

    return spans + new_spans, len(new_spans)


def build_global_entity_map(records: list[dict], min_count: int = 2) -> dict[str, str]:
    """Build global entity→label map from all documents.

    Uses majority vote for entities with conflicting labels.
    Only includes entities appearing min_count+ times.
    """
    # Count (entity_text, label) pairs
    entity_label_counts: dict[str, Counter] = defaultdict(Counter)
    entity_total: Counter = Counter()

    for rec in records:
        seen_in_doc = set()
        for label, ent_text, _, _ in parse_spans(rec.get("spans", {})):
            if len(ent_text) < _MIN_ENTITY_LENGTH:
                continue
            key = (ent_text, label)
            if key not in seen_in_doc:
                entity_label_counts[ent_text][label] += 1
                entity_total[ent_text] += 1
                seen_in_doc.add(key)

    # Filter by min_count and pick majority label
    result = {}
    for ent_text, label_counts in entity_label_counts.items():
        total = sum(label_counts.values())
        if total >= min_count:
            result[ent_text] = label_counts.most_common(1)[0][0]
    return result


def main():
    parser = argparse.ArgumentParser(description="Entity propagation post-processor")
    parser.add_argument("--input", required=True, help="Input JSONL file")
    parser.add_argument("--output", required=True, help="Output JSONL file")
    parser.add_argument("--cross-document", action="store_true",
                        help="Build global entity dict across all documents")
    parser.add_argument("--min-length", type=int, default=3,
                        help="Minimum entity text length to propagate (default: 3)")
    parser.add_argument("--min-count", type=int, default=2,
                        help="Min occurrences for cross-document mode (default: 2)")
    args = parser.parse_args()

    global _MIN_ENTITY_LENGTH
    _MIN_ENTITY_LENGTH = args.min_length

    # Load data
    records = []
    with open(args.input) as f:
        for line in f:
            line = line.strip()
            if line:
                records.append(json.loads(line))

    print(f"Loaded {len(records)} records from {args.input}")

    # Build global map if cross-document mode
    global_map = None
    if args.cross_document:
        global_map = build_global_entity_map(records, min_count=args.min_count)
        print(f"Global entity dictionary: {len(global_map)} entities (min_count={args.min_count})")

    # Propagate
    total_new = 0
    class_new: Counter = Counter()
    output_records = []

    for rec in records:
        text = rec["text"]
        existing_spans = parse_spans(rec.get("spans", {}))

        if args.cross_document:
            all_spans, num_new = propagate_within_document(text, existing_spans, global_map)
        else:
            all_spans, num_new = propagate_within_document(text, existing_spans)

        # Count per-class new spans
        if num_new > 0:
            for label, ent_text, start, end in all_spans[len(existing_spans):]:
                class_new[label] += 1

        total_new += num_new

        out_rec = dict(rec)
        out_rec["spans"] = build_spans_dict(all_spans)
        output_records.append(out_rec)

    # Write output
    Path(args.output).parent.mkdir(parents=True, exist_ok=True)
    with open(args.output, "w") as f:
        for rec in output_records:
            f.write(json.dumps(rec, ensure_ascii=False) + "\n")

    # Stats
    mode = "cross-document" if args.cross_document else "within-document"
    print(f"\n=== Entity Propagation Stats ({mode}) ===")
    print(f"Documents processed: {len(records)}")
    print(f"New spans added: {total_new}")
    if class_new:
        print(f"\nPer-class breakdown:")
        for label, count in class_new.most_common():
            print(f"  {label}: +{count}")
    else:
        print("No new spans added.")
    print(f"\nOutput written to {args.output}")


if __name__ == "__main__":
    main()