AniFileBERT / tools /convert_annotated_dmhy_dataset.py
ModerRAS's picture
Add DMHY prefix graph annotation workflow
33bb11c
raw
history blame
10.2 kB
"""Convert annotated DMHY graph JSONL into the character-tokenized dataset.
The annotated graph workflow is expected to produce records compatible with
``dmhy_weak.jsonl``: each row has ``filename``, ``tokens``, and ``labels``.
This wrapper validates that contract, then reuses ``tools.convert_to_char_dataset``
for the token-to-character projection and manifest statistics.
"""
from __future__ import annotations
import argparse
import json
from collections import Counter
from datetime import datetime, timezone
from pathlib import Path
from statistics import mean
from typing import Iterable
from tools.convert_to_char_dataset import (
build_vocab,
convert_record,
coverage,
percentile,
)
DEFAULT_INPUT = Path("datasets/AnimeName/dmhy_weak.generated.jsonl")
DEFAULT_OUTPUT = Path("datasets/AnimeName/dmhy_weak.generated_char.jsonl")
DEFAULT_VOCAB_OUTPUT = Path("datasets/AnimeName/vocab.generated.char.json")
DEFAULT_MANIFEST_OUTPUT = Path(
"datasets/AnimeName/dmhy_weak.generated_char.manifest.json"
)
REQUIRED_FIELDS = ("filename", "tokens", "labels")
def is_separator_or_space(char: str) -> bool:
return char.isspace() or not char.isalnum()
def token_has_embedded_separator(token: str) -> bool:
return len(token) > 1 and any(is_separator_or_space(char) for char in token)
def is_bioish_label(label: object) -> bool:
if not isinstance(label, str):
return False
if label == "O":
return True
prefix, sep, entity = label.partition("-")
return sep == "-" and prefix in {"B", "I"} and bool(entity)
def validate_record(
record: object,
path: Path,
line_no: int,
*,
check_punctuation: bool = True,
) -> dict:
if not isinstance(record, dict):
raise ValueError(f"{path}:{line_no}: record must be a JSON object")
missing = [field for field in REQUIRED_FIELDS if field not in record]
if missing:
raise ValueError(
f"{path}:{line_no}: missing required field(s): {', '.join(missing)}"
)
filename = record["filename"]
tokens = record["tokens"]
labels = record["labels"]
if not isinstance(filename, str) or not filename:
raise ValueError(f"{path}:{line_no}: filename must be a non-empty string")
if not isinstance(tokens, list):
raise ValueError(f"{path}:{line_no}: tokens must be a list")
if not isinstance(labels, list):
raise ValueError(f"{path}:{line_no}: labels must be a list")
if len(tokens) != len(labels):
raise ValueError(
f"{path}:{line_no}: token/label length mismatch: "
f"{len(tokens)} tokens, {len(labels)} labels"
)
for index, token in enumerate(tokens):
if not isinstance(token, str):
raise ValueError(f"{path}:{line_no}: tokens[{index}] must be a string")
if check_punctuation and token_has_embedded_separator(token):
raise ValueError(
f"{path}:{line_no}: tokens[{index}] contains punctuation, symbol, or "
f"whitespace that should be a standalone token: {token!r}"
)
for index, label in enumerate(labels):
if not is_bioish_label(label):
raise ValueError(
f"{path}:{line_no}: labels[{index}] is not BIO-ish: {label!r}"
)
return record
def iter_validated_jsonl(path: Path, *, check_punctuation: bool = True) -> Iterable[dict]:
with path.open("r", encoding="utf-8") as handle:
for line_no, line in enumerate(handle, 1):
line = line.strip()
if not line:
continue
try:
record = json.loads(line)
except json.JSONDecodeError as exc:
raise ValueError(f"{path}:{line_no}: invalid JSON") from exc
yield validate_record(
record,
path,
line_no,
check_punctuation=check_punctuation,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Validate annotated DMHY graph JSONL and convert it to the "
"character-tokenized training format."
),
epilog=(
"Equivalent projection logic is provided by "
"tools.convert_to_char_dataset.convert_record."
),
)
parser.add_argument(
"--input",
default=str(DEFAULT_INPUT),
help=f"Input dmhy_weak-compatible JSONL (default: {DEFAULT_INPUT})",
)
parser.add_argument(
"--output",
default=str(DEFAULT_OUTPUT),
help=f"Output character-level JSONL (default: {DEFAULT_OUTPUT})",
)
parser.add_argument(
"--vocab-output",
default=str(DEFAULT_VOCAB_OUTPUT),
help=f"Output character vocab JSON (default: {DEFAULT_VOCAB_OUTPUT})",
)
parser.add_argument(
"--manifest-output",
default=str(DEFAULT_MANIFEST_OUTPUT),
help=(
"Output conversion manifest JSON "
f"(default: {DEFAULT_MANIFEST_OUTPUT})"
),
)
parser.add_argument(
"--max-vocab-size",
type=int,
default=None,
help="Optional vocab cap including special tokens",
)
parser.add_argument("--limit", type=int, default=None, help="Convert only N rows")
parser.add_argument(
"--progress",
type=int,
default=50_000,
help="Print progress every N records",
)
parser.add_argument(
"--validate-only",
action="store_true",
help="Validate input records without writing converted outputs",
)
parser.add_argument(
"--allow-embedded-punctuation",
action="store_true",
help=(
"Skip the generated-workflow check that punctuation and whitespace "
"must be standalone tokens."
),
)
return parser.parse_args()
def main() -> None:
args = parse_args()
input_path = Path(args.input)
output_path = Path(args.output)
vocab_path = Path(args.vocab_output)
manifest_path = Path(args.manifest_output)
if not input_path.exists():
raise FileNotFoundError(f"input JSONL does not exist: {input_path}")
if not args.validate_only:
output_path.parent.mkdir(parents=True, exist_ok=True)
vocab_path.parent.mkdir(parents=True, exist_ok=True)
manifest_path.parent.mkdir(parents=True, exist_ok=True)
char_counter: Counter[str] = Counter()
label_counter: Counter[str] = Counter()
row_count = 0
source_token_count = 0
char_token_count = 0
lengths: list[int] = []
examples: list[dict] = []
output_handle = None
try:
if not args.validate_only:
output_handle = output_path.open("w", encoding="utf-8", newline="\n")
for record in iter_validated_jsonl(
input_path,
check_punctuation=not args.allow_embedded_punctuation,
):
converted = convert_record(record)
if output_handle is not None:
output_handle.write(
json.dumps(converted, ensure_ascii=False, separators=(",", ":"))
+ "\n"
)
row_count += 1
source_token_count += converted["source_token_count"]
char_len = converted["char_token_count"]
char_token_count += char_len
lengths.append(char_len)
char_counter.update(converted["tokens"])
label_counter.update(converted["labels"])
if len(examples) < 5:
examples.append(converted)
if args.limit is not None and row_count >= args.limit:
break
if args.progress and row_count % args.progress == 0:
print(f"converted {row_count:,} rows; unique chars={len(char_counter):,}")
finally:
if output_handle is not None:
output_handle.close()
vocab = build_vocab(char_counter, args.max_vocab_size)
manifest = {
"created_at": datetime.now(timezone.utc).isoformat(),
"input": str(input_path),
"output": None if args.validate_only else str(output_path),
"vocab_output": None if args.validate_only else str(vocab_path),
"manifest_output": None if args.validate_only else str(manifest_path),
"tokenizer_variant": "char",
"source_workflow": "annotated_dmhy_graph",
"validation": {
"required_fields": list(REQUIRED_FIELDS),
"label_contract": "O or B-*/I-* with a non-empty entity name; B/O-only is accepted",
"punctuation_standalone": not args.allow_embedded_punctuation,
},
"projection": {
"B-X": "first char keeps B-X; remaining chars become I-X",
"I-X": "all chars keep I-X",
"O": "all chars keep O",
},
"row_count": row_count,
"source_token_count": source_token_count,
"char_token_count": char_token_count,
"unique_char_count": len(char_counter),
"vocab_size": len(vocab),
"max_vocab_size": args.max_vocab_size,
"vocab_coverage": coverage(char_counter, vocab),
"label_counts": dict(label_counter),
"char_length": {
"min": min(lengths) if lengths else 0,
"mean": mean(lengths) if lengths else 0,
"p50": percentile(lengths, 50),
"p90": percentile(lengths, 90),
"p95": percentile(lengths, 95),
"p99": percentile(lengths, 99),
"max": max(lengths) if lengths else 0,
},
"examples": examples,
}
if not args.validate_only:
vocab_path.write_text(
json.dumps(vocab, ensure_ascii=False, indent=2) + "\n",
encoding="utf-8",
)
manifest_path.write_text(
json.dumps(manifest, ensure_ascii=False, indent=2) + "\n",
encoding="utf-8",
)
print(
json.dumps(
{key: value for key, value in manifest.items() if key != "examples"},
ensure_ascii=False,
indent=2,
)
)
if __name__ == "__main__":
main()