File size: 7,593 Bytes
0779202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c50d16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""Convert token-level anime filename JSONL datasets to character tokens.

Input records must contain parallel ``tokens`` and ``labels`` arrays. The
converter expands each original token into Unicode code points and projects BIO
labels onto the expanded sequence:

- ``B-X`` keeps ``B-X`` on the first character and uses ``I-X`` afterwards.
- ``I-X`` remains ``I-X`` on every character.
- ``O`` remains ``O`` on every character.

The script streams both input and output so it can process the full DMHY weak
dataset without loading hundreds of MB into memory.
"""

from __future__ import annotations

import argparse
import json
from collections import Counter
from datetime import datetime, timezone
from pathlib import Path
from statistics import mean
from typing import Iterable


SPECIAL_TOKENS = ("[PAD]", "[UNK]", "[CLS]", "[SEP]")


def projected_labels(token: str, label: str) -> tuple[list[str], list[str]]:
    """Return character tokens and projected BIO labels for one source token."""
    chars = list(token)
    if not chars:
        return [], []

    if label.startswith("B-"):
        entity = label.split("-", 1)[1]
        return chars, [label] + [f"I-{entity}"] * (len(chars) - 1)
    if label.startswith("I-"):
        return chars, [label] * len(chars)
    return chars, [label] * len(chars)


def convert_record(record: dict) -> dict:
    """Convert one JSONL record while preserving non-token metadata."""
    tokens = record["tokens"]
    labels = record["labels"]
    if len(tokens) != len(labels):
        raise ValueError(
            f"token/label length mismatch: {len(tokens)} tokens, {len(labels)} labels"
        )

    char_tokens: list[str] = []
    char_labels: list[str] = []
    for token, label in zip(tokens, labels):
        pieces, piece_labels = projected_labels(str(token), str(label))
        char_tokens.extend(pieces)
        char_labels.extend(piece_labels)

    converted = dict(record)
    converted["tokens"] = char_tokens
    converted["labels"] = char_labels
    converted["tokenizer_variant"] = "char"
    converted["source_token_count"] = len(tokens)
    converted["char_token_count"] = len(char_tokens)
    return converted


def iter_jsonl(path: Path) -> Iterable[dict]:
    with path.open("r", encoding="utf-8") as handle:
        for line_no, line in enumerate(handle, 1):
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError as exc:
                raise ValueError(f"{path}:{line_no}: invalid JSON") from exc


def build_vocab(counter: Counter[str], max_size: int | None = None) -> dict[str, int]:
    """Build a frequency-sorted vocab with fixed special-token IDs."""
    vocab = {token: idx for idx, token in enumerate(SPECIAL_TOKENS)}
    limit = None if max_size is None else max(max_size - len(vocab), 0)
    for token, _count in counter.most_common(limit):
        if token not in vocab:
            vocab[token] = len(vocab)
    return vocab


def coverage(counter: Counter[str], vocab: dict[str, int]) -> float:
    total = sum(counter.values())
    if total == 0:
        return 1.0
    covered = sum(count for token, count in counter.items() if token in vocab)
    return covered / total


def percentile(values: list[int], pct: float) -> int:
    if not values:
        return 0
    ordered = sorted(values)
    index = min(len(ordered) - 1, round((pct / 100) * (len(ordered) - 1)))
    return ordered[index]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Convert JSONL token labels to character labels")
    parser.add_argument("--input", required=True, help="Input token-level JSONL")
    parser.add_argument("--output", required=True, help="Output character-level JSONL")
    parser.add_argument("--vocab-output", required=True, help="Output vocab JSON")
    parser.add_argument("--manifest-output", default=None, help="Output manifest JSON")
    parser.add_argument("--max-vocab-size", type=int, default=None,
                        help="Optional vocab cap including special tokens")
    parser.add_argument("--limit", type=int, default=None, help="Convert only the first N records")
    parser.add_argument("--progress", type=int, default=50_000,
                        help="Print progress every N records")
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    input_path = Path(args.input)
    output_path = Path(args.output)
    vocab_path = Path(args.vocab_output)
    manifest_path = (
        Path(args.manifest_output)
        if args.manifest_output
        else output_path.with_suffix(".manifest.json")
    )

    output_path.parent.mkdir(parents=True, exist_ok=True)
    vocab_path.parent.mkdir(parents=True, exist_ok=True)
    manifest_path.parent.mkdir(parents=True, exist_ok=True)

    char_counter: Counter[str] = Counter()
    label_counter: Counter[str] = Counter()
    row_count = 0
    source_token_count = 0
    char_token_count = 0
    lengths: list[int] = []
    examples: list[dict] = []

    with output_path.open("w", encoding="utf-8", newline="\n") as out:
        for record in iter_jsonl(input_path):
            converted = convert_record(record)
            out.write(json.dumps(converted, ensure_ascii=False, separators=(",", ":")) + "\n")

            row_count += 1
            source_token_count += converted["source_token_count"]
            char_len = converted["char_token_count"]
            char_token_count += char_len
            lengths.append(char_len)
            char_counter.update(converted["tokens"])
            label_counter.update(converted["labels"])
            if len(examples) < 5:
                examples.append(converted)

            if args.limit is not None and row_count >= args.limit:
                break
            if args.progress and row_count % args.progress == 0:
                print(f"converted {row_count:,} rows; unique chars={len(char_counter):,}")

    vocab = build_vocab(char_counter, args.max_vocab_size)
    vocab_path.write_text(json.dumps(vocab, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")

    manifest = {
        "created_at": datetime.now(timezone.utc).isoformat(),
        "input": str(input_path),
        "output": str(output_path),
        "vocab_output": str(vocab_path),
        "tokenizer_variant": "char",
        "projection": {
            "B-X": "first char keeps B-X; remaining chars become I-X",
            "I-X": "all chars keep I-X",
            "O": "all chars keep O",
        },
        "row_count": row_count,
        "source_token_count": source_token_count,
        "char_token_count": char_token_count,
        "unique_char_count": len(char_counter),
        "vocab_size": len(vocab),
        "max_vocab_size": args.max_vocab_size,
        "vocab_coverage": coverage(char_counter, vocab),
        "label_counts": dict(label_counter),
        "char_length": {
            "min": min(lengths) if lengths else 0,
            "mean": mean(lengths) if lengths else 0,
            "p50": percentile(lengths, 50),
            "p90": percentile(lengths, 90),
            "p95": percentile(lengths, 95),
            "p99": percentile(lengths, 99),
            "max": max(lengths) if lengths else 0,
        },
        "examples": examples,
    }
    manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
    print(json.dumps({k: v for k, v in manifest.items() if k != "examples"}, ensure_ascii=False, indent=2))


if __name__ == "__main__":
    main()