"""Rebuild AnimeName weak labels from each stored filename.""" from __future__ import annotations import argparse import json from collections import Counter from datetime import datetime, timezone from pathlib import Path from statistics import mean from typing import Iterable from tools.dmhy_dataset import weak_label_filename from anifilebert.label_repairs import repair_jsonl_item from anifilebert.tokenizer import AnimeTokenizer def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Relabel a JSONL dataset from filename strings") parser.add_argument("--input", required=True, help="Input JSONL containing filename fields") parser.add_argument("--output", required=True, help="Output relabeled regex-token JSONL") parser.add_argument("--manifest-output", default=None, help="Relabel manifest JSON") parser.add_argument("--vocab-output", default=None, help="Optional regex vocab JSON") parser.add_argument("--base-vocab", default=None, help="Optional regex vocab whose IDs should be preserved") parser.add_argument("--max-vocab-size", type=int, default=3000) parser.add_argument("--limit", type=int, default=None) parser.add_argument("--progress", type=int, default=50000) parser.add_argument("--example-count", type=int, default=20) return parser.parse_args() def iter_jsonl(path: Path) -> Iterable[dict]: with path.open("r", encoding="utf-8") as handle: for line_no, line in enumerate(handle, 1): line = line.strip() if not line: continue try: yield json.loads(line) except json.JSONDecodeError as exc: raise ValueError(f"{path}:{line_no}: invalid JSON") from exc def length_stats(values: list[int]) -> dict: if not values: return {"min": 0, "mean": 0, "p50": 0, "p90": 0, "p95": 0, "p99": 0, "max": 0} ordered = sorted(values) def percentile(pct: float) -> int: index = min(len(ordered) - 1, round((pct / 100) * (len(ordered) - 1))) return ordered[index] return { "min": min(values), "mean": mean(values), "p50": percentile(50), "p90": percentile(90), "p95": percentile(95), "p99": percentile(99), "max": max(values), } def main() -> None: args = parse_args() input_path = Path(args.input) output_path = Path(args.output) manifest_path = Path(args.manifest_output) if args.manifest_output else output_path.with_suffix(".manifest.json") vocab_path = Path(args.vocab_output) if args.vocab_output else None output_path.parent.mkdir(parents=True, exist_ok=True) manifest_path.parent.mkdir(parents=True, exist_ok=True) if vocab_path: vocab_path.parent.mkdir(parents=True, exist_ok=True) tokenizer = AnimeTokenizer() rows_in = 0 rows_written = 0 rows_failed = 0 rows_repaired_after_relabel = 0 label_counter: Counter[str] = Counter() failure_counter: Counter[str] = Counter() token_lists: list[list[str]] = [] lengths: list[int] = [] examples: list[dict] = [] failures: list[dict] = [] with output_path.open("w", encoding="utf-8", newline="\n") as out: for item in iter_jsonl(input_path): rows_in += 1 filename = item.get("filename") if not filename: rows_failed += 1 failure_counter["missing_filename"] += 1 continue sample = weak_label_filename(str(filename), tokenizer) if sample is None: rows_failed += 1 failure_counter["weak_label_failed"] += 1 if len(failures) < args.example_count: failures.append({"file_id": item.get("file_id"), "filename": filename}) continue record = dict(item) record.pop("tokenizer_variant", None) record.pop("source_token_count", None) record.pop("char_token_count", None) record["tokens"] = sample["tokens"] record["labels"] = sample["labels"] repaired, repairs = repair_jsonl_item(record) if repairs: rows_repaired_after_relabel += 1 record = repaired out.write(json.dumps(record, ensure_ascii=False, separators=(",", ":")) + "\n") rows_written += 1 label_counter.update(record["labels"]) token_lists.append(record["tokens"]) lengths.append(len(record["tokens"])) if len(examples) < args.example_count: examples.append(record) if args.limit is not None and rows_written >= args.limit: break if args.progress and rows_written % args.progress == 0: print(f"relabeled {rows_written:,} rows; failed={rows_failed:,}") base_vocab = None if args.base_vocab: with Path(args.base_vocab).open("r", encoding="utf-8") as handle: base_vocab = json.load(handle) tokenizer.build_vocab(token_lists, max_size=args.max_vocab_size, base_vocab=base_vocab) if vocab_path: vocab_path.write_text(json.dumps(tokenizer.get_vocab(), ensure_ascii=False, indent=2) + "\n", encoding="utf-8") manifest = { "created_at": datetime.now(timezone.utc).isoformat(), "input": str(input_path), "output": str(output_path), "vocab_output": str(vocab_path) if vocab_path else None, "row_count": rows_written, "input_rows": rows_in, "failed_rows": rows_failed, "repaired_after_relabel_rows": rows_repaired_after_relabel, "failure_counts": dict(failure_counter), "label_counts": dict(label_counter), "token_length": length_stats(lengths), "vocab_size": tokenizer.vocab_size, "examples": examples, "failures": failures, } manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") print(json.dumps({k: v for k, v in manifest.items() if k not in {"examples", "failures"}}, ensure_ascii=False, indent=2)) if __name__ == "__main__": main()