AniFileBERT / tools /convert_to_char_dataset.py
ModerRAS's picture
Organize parser modules and tools
8c50d16
"""Convert token-level anime filename JSONL datasets to character tokens.
Input records must contain parallel ``tokens`` and ``labels`` arrays. The
converter expands each original token into Unicode code points and projects BIO
labels onto the expanded sequence:
- ``B-X`` keeps ``B-X`` on the first character and uses ``I-X`` afterwards.
- ``I-X`` remains ``I-X`` on every character.
- ``O`` remains ``O`` on every character.
The script streams both input and output so it can process the full DMHY weak
dataset without loading hundreds of MB into memory.
"""
from __future__ import annotations
import argparse
import json
from collections import Counter
from datetime import datetime, timezone
from pathlib import Path
from statistics import mean
from typing import Iterable
SPECIAL_TOKENS = ("[PAD]", "[UNK]", "[CLS]", "[SEP]")
def projected_labels(token: str, label: str) -> tuple[list[str], list[str]]:
"""Return character tokens and projected BIO labels for one source token."""
chars = list(token)
if not chars:
return [], []
if label.startswith("B-"):
entity = label.split("-", 1)[1]
return chars, [label] + [f"I-{entity}"] * (len(chars) - 1)
if label.startswith("I-"):
return chars, [label] * len(chars)
return chars, [label] * len(chars)
def convert_record(record: dict) -> dict:
"""Convert one JSONL record while preserving non-token metadata."""
tokens = record["tokens"]
labels = record["labels"]
if len(tokens) != len(labels):
raise ValueError(
f"token/label length mismatch: {len(tokens)} tokens, {len(labels)} labels"
)
char_tokens: list[str] = []
char_labels: list[str] = []
for token, label in zip(tokens, labels):
pieces, piece_labels = projected_labels(str(token), str(label))
char_tokens.extend(pieces)
char_labels.extend(piece_labels)
converted = dict(record)
converted["tokens"] = char_tokens
converted["labels"] = char_labels
converted["tokenizer_variant"] = "char"
converted["source_token_count"] = len(tokens)
converted["char_token_count"] = len(char_tokens)
return converted
def iter_jsonl(path: Path) -> Iterable[dict]:
with path.open("r", encoding="utf-8") as handle:
for line_no, line in enumerate(handle, 1):
line = line.strip()
if not line:
continue
try:
yield json.loads(line)
except json.JSONDecodeError as exc:
raise ValueError(f"{path}:{line_no}: invalid JSON") from exc
def build_vocab(counter: Counter[str], max_size: int | None = None) -> dict[str, int]:
"""Build a frequency-sorted vocab with fixed special-token IDs."""
vocab = {token: idx for idx, token in enumerate(SPECIAL_TOKENS)}
limit = None if max_size is None else max(max_size - len(vocab), 0)
for token, _count in counter.most_common(limit):
if token not in vocab:
vocab[token] = len(vocab)
return vocab
def coverage(counter: Counter[str], vocab: dict[str, int]) -> float:
total = sum(counter.values())
if total == 0:
return 1.0
covered = sum(count for token, count in counter.items() if token in vocab)
return covered / total
def percentile(values: list[int], pct: float) -> int:
if not values:
return 0
ordered = sorted(values)
index = min(len(ordered) - 1, round((pct / 100) * (len(ordered) - 1)))
return ordered[index]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert JSONL token labels to character labels")
parser.add_argument("--input", required=True, help="Input token-level JSONL")
parser.add_argument("--output", required=True, help="Output character-level JSONL")
parser.add_argument("--vocab-output", required=True, help="Output vocab JSON")
parser.add_argument("--manifest-output", default=None, help="Output manifest JSON")
parser.add_argument("--max-vocab-size", type=int, default=None,
help="Optional vocab cap including special tokens")
parser.add_argument("--limit", type=int, default=None, help="Convert only the first N records")
parser.add_argument("--progress", type=int, default=50_000,
help="Print progress every N records")
return parser.parse_args()
def main() -> None:
args = parse_args()
input_path = Path(args.input)
output_path = Path(args.output)
vocab_path = Path(args.vocab_output)
manifest_path = (
Path(args.manifest_output)
if args.manifest_output
else output_path.with_suffix(".manifest.json")
)
output_path.parent.mkdir(parents=True, exist_ok=True)
vocab_path.parent.mkdir(parents=True, exist_ok=True)
manifest_path.parent.mkdir(parents=True, exist_ok=True)
char_counter: Counter[str] = Counter()
label_counter: Counter[str] = Counter()
row_count = 0
source_token_count = 0
char_token_count = 0
lengths: list[int] = []
examples: list[dict] = []
with output_path.open("w", encoding="utf-8", newline="\n") as out:
for record in iter_jsonl(input_path):
converted = convert_record(record)
out.write(json.dumps(converted, ensure_ascii=False, separators=(",", ":")) + "\n")
row_count += 1
source_token_count += converted["source_token_count"]
char_len = converted["char_token_count"]
char_token_count += char_len
lengths.append(char_len)
char_counter.update(converted["tokens"])
label_counter.update(converted["labels"])
if len(examples) < 5:
examples.append(converted)
if args.limit is not None and row_count >= args.limit:
break
if args.progress and row_count % args.progress == 0:
print(f"converted {row_count:,} rows; unique chars={len(char_counter):,}")
vocab = build_vocab(char_counter, args.max_vocab_size)
vocab_path.write_text(json.dumps(vocab, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
manifest = {
"created_at": datetime.now(timezone.utc).isoformat(),
"input": str(input_path),
"output": str(output_path),
"vocab_output": str(vocab_path),
"tokenizer_variant": "char",
"projection": {
"B-X": "first char keeps B-X; remaining chars become I-X",
"I-X": "all chars keep I-X",
"O": "all chars keep O",
},
"row_count": row_count,
"source_token_count": source_token_count,
"char_token_count": char_token_count,
"unique_char_count": len(char_counter),
"vocab_size": len(vocab),
"max_vocab_size": args.max_vocab_size,
"vocab_coverage": coverage(char_counter, vocab),
"label_counts": dict(label_counter),
"char_length": {
"min": min(lengths) if lengths else 0,
"mean": mean(lengths) if lengths else 0,
"p50": percentile(lengths, 50),
"p90": percentile(lengths, 90),
"p95": percentile(lengths, 95),
"p99": percentile(lengths, 99),
"max": max(lengths) if lengths else 0,
},
"examples": examples,
}
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
print(json.dumps({k: v for k, v in manifest.items() if k != "examples"}, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()