File size: 4,654 Bytes
3dac39e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#!/usr/bin/env python3
"""Audit span boundary quality in Arcspan training data."""

import json
import random
import re
import sys
from collections import defaultdict
from pathlib import Path

random.seed(42)

FILES = [
    Path("/home/ubuntu/alkyline/data/processed/enriched_5class_train_cleaned.jsonl"),
    Path("/home/ubuntu/alkyline/data/processed/aptner_5class_train.jsonl"),
    Path("/home/ubuntu/alkyline/data/processed/securebert2_5class_train.jsonl"),
]

SAMPLE_SIZE = 200


def parse_surface(key: str) -> str:
    """Extract surface text from 'Label: surface_text' key."""
    idx = key.find(": ")
    if idx == -1:
        return key
    return key[idx + 2:]


def is_word_boundary(text, pos):
    """Check if position is at a word boundary."""
    if pos <= 0 or pos >= len(text):
        return True
    # Boundary if one side is alnum and other isn't, or at start/end
    left = text[pos - 1]
    right = text[pos]
    # Both alphanumeric = mid-word
    if left.isalnum() and right.isalnum():
        return False
    return True


def audit_file(path: Path):
    lines = path.read_text().strip().split("\n")
    total = len(lines)
    indices = random.sample(range(total), min(SAMPLE_SIZE, total))

    issues = {
        "offset_mismatch": [],
        "leading_trailing_ws": [],
        "mid_word_boundary": [],
        "trailing_punct": [],
        "overlapping": [],
        "empty_or_zero": [],
        "out_of_bounds": [],
    }

    for line_idx in indices:
        line_num = line_idx + 1
        rec = json.loads(lines[line_idx])
        text = rec["text"]
        text_len = len(text)
        all_intervals = []

        for key, offsets in rec.get("spans", {}).items():
            surface = parse_surface(key)
            for start, end in offsets:
                ctx = {"line": line_num, "key": key, "start": start, "end": end}

                # Empty/zero-length
                if start >= end:
                    issues["empty_or_zero"].append(ctx)
                    continue

                # Out of bounds
                if end > text_len or start < 0:
                    issues["out_of_bounds"].append({**ctx, "text_len": text_len})
                    continue

                extracted = text[start:end]
                ctx["extracted"] = extracted

                # Offset mismatch
                if extracted != surface:
                    issues["offset_mismatch"].append({**ctx, "expected": surface})

                # Leading/trailing whitespace
                if extracted != extracted.strip():
                    issues["leading_trailing_ws"].append(ctx)

                # Mid-word boundary
                if not is_word_boundary(text, start) or not is_word_boundary(text, end):
                    issues["mid_word_boundary"].append({
                        **ctx,
                        "context": text[max(0, start-5):end+5]
                    })

                # Trailing punctuation (.,;:!?) that likely shouldn't be in entity
                if extracted and extracted[-1] in ".,;:!?)":
                    issues["trailing_punct"].append(ctx)

                all_intervals.append((start, end, key))

        # Overlapping spans
        all_intervals.sort()
        for i in range(len(all_intervals) - 1):
            s1, e1, k1 = all_intervals[i]
            s2, e2, k2 = all_intervals[i + 1]
            if s2 < e1:
                issues["overlapping"].append({
                    "line": line_num,
                    "span1": (k1, s1, e1),
                    "span2": (k2, s2, e2),
                })

    return total, issues


def main():
    for path in FILES:
        if not path.exists():
            print(f"\n{'='*60}")
            print(f"SKIPPED (not found): {path.name}")
            continue

        print(f"\n{'='*60}")
        print(f"FILE: {path.name}")
        total, issues = audit_file(path)
        print(f"Total examples: {total}, sampled: {min(SAMPLE_SIZE, total)}")
        print(f"{'='*60}")

        any_found = False
        for cat, items in issues.items():
            count = len(items)
            if count == 0:
                continue
            any_found = True
            print(f"\n  [{cat.upper()}] — {count} issues")
            for ex in items[:5]:
                print(f"    Line {ex.get('line', '?')}: {json.dumps(ex, ensure_ascii=False, default=str)}")

        if not any_found:
            print("\n  ✅ No issues found in sample!")

        # Summary table
        print(f"\n  --- Summary ---")
        for cat, items in issues.items():
            print(f"  {cat:25s}: {len(items)}")


if __name__ == "__main__":
    main()