AniFileBERT / tools /relabel_dataset_from_filenames.py
ModerRAS's picture
Organize parser modules and tools
8c50d16
"""Rebuild AnimeName weak labels from each stored filename."""
from __future__ import annotations
import argparse
import json
from collections import Counter
from datetime import datetime, timezone
from pathlib import Path
from statistics import mean
from typing import Iterable
from tools.dmhy_dataset import weak_label_filename
from anifilebert.label_repairs import repair_jsonl_item
from anifilebert.tokenizer import AnimeTokenizer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Relabel a JSONL dataset from filename strings")
parser.add_argument("--input", required=True, help="Input JSONL containing filename fields")
parser.add_argument("--output", required=True, help="Output relabeled regex-token JSONL")
parser.add_argument("--manifest-output", default=None, help="Relabel manifest JSON")
parser.add_argument("--vocab-output", default=None, help="Optional regex vocab JSON")
parser.add_argument("--base-vocab", default=None, help="Optional regex vocab whose IDs should be preserved")
parser.add_argument("--max-vocab-size", type=int, default=3000)
parser.add_argument("--limit", type=int, default=None)
parser.add_argument("--progress", type=int, default=50000)
parser.add_argument("--example-count", type=int, default=20)
return parser.parse_args()
def iter_jsonl(path: Path) -> Iterable[dict]:
with path.open("r", encoding="utf-8") as handle:
for line_no, line in enumerate(handle, 1):
line = line.strip()
if not line:
continue
try:
yield json.loads(line)
except json.JSONDecodeError as exc:
raise ValueError(f"{path}:{line_no}: invalid JSON") from exc
def length_stats(values: list[int]) -> dict:
if not values:
return {"min": 0, "mean": 0, "p50": 0, "p90": 0, "p95": 0, "p99": 0, "max": 0}
ordered = sorted(values)
def percentile(pct: float) -> int:
index = min(len(ordered) - 1, round((pct / 100) * (len(ordered) - 1)))
return ordered[index]
return {
"min": min(values),
"mean": mean(values),
"p50": percentile(50),
"p90": percentile(90),
"p95": percentile(95),
"p99": percentile(99),
"max": max(values),
}
def main() -> None:
args = parse_args()
input_path = Path(args.input)
output_path = Path(args.output)
manifest_path = Path(args.manifest_output) if args.manifest_output else output_path.with_suffix(".manifest.json")
vocab_path = Path(args.vocab_output) if args.vocab_output else None
output_path.parent.mkdir(parents=True, exist_ok=True)
manifest_path.parent.mkdir(parents=True, exist_ok=True)
if vocab_path:
vocab_path.parent.mkdir(parents=True, exist_ok=True)
tokenizer = AnimeTokenizer()
rows_in = 0
rows_written = 0
rows_failed = 0
rows_repaired_after_relabel = 0
label_counter: Counter[str] = Counter()
failure_counter: Counter[str] = Counter()
token_lists: list[list[str]] = []
lengths: list[int] = []
examples: list[dict] = []
failures: list[dict] = []
with output_path.open("w", encoding="utf-8", newline="\n") as out:
for item in iter_jsonl(input_path):
rows_in += 1
filename = item.get("filename")
if not filename:
rows_failed += 1
failure_counter["missing_filename"] += 1
continue
sample = weak_label_filename(str(filename), tokenizer)
if sample is None:
rows_failed += 1
failure_counter["weak_label_failed"] += 1
if len(failures) < args.example_count:
failures.append({"file_id": item.get("file_id"), "filename": filename})
continue
record = dict(item)
record.pop("tokenizer_variant", None)
record.pop("source_token_count", None)
record.pop("char_token_count", None)
record["tokens"] = sample["tokens"]
record["labels"] = sample["labels"]
repaired, repairs = repair_jsonl_item(record)
if repairs:
rows_repaired_after_relabel += 1
record = repaired
out.write(json.dumps(record, ensure_ascii=False, separators=(",", ":")) + "\n")
rows_written += 1
label_counter.update(record["labels"])
token_lists.append(record["tokens"])
lengths.append(len(record["tokens"]))
if len(examples) < args.example_count:
examples.append(record)
if args.limit is not None and rows_written >= args.limit:
break
if args.progress and rows_written % args.progress == 0:
print(f"relabeled {rows_written:,} rows; failed={rows_failed:,}")
base_vocab = None
if args.base_vocab:
with Path(args.base_vocab).open("r", encoding="utf-8") as handle:
base_vocab = json.load(handle)
tokenizer.build_vocab(token_lists, max_size=args.max_vocab_size, base_vocab=base_vocab)
if vocab_path:
vocab_path.write_text(json.dumps(tokenizer.get_vocab(), ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
manifest = {
"created_at": datetime.now(timezone.utc).isoformat(),
"input": str(input_path),
"output": str(output_path),
"vocab_output": str(vocab_path) if vocab_path else None,
"row_count": rows_written,
"input_rows": rows_in,
"failed_rows": rows_failed,
"repaired_after_relabel_rows": rows_repaired_after_relabel,
"failure_counts": dict(failure_counter),
"label_counts": dict(label_counter),
"token_length": length_stats(lengths),
"vocab_size": tokenizer.vocab_size,
"examples": examples,
"failures": failures,
}
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
print(json.dumps({k: v for k, v in manifest.items() if k not in {"examples", "failures"}}, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()