AniFileBERT / tools /build_balanced_focus_dataset.py
ModerRAS's picture
Add balanced focus dataset generator
651ad49
raw
history blame
7.35 kB
"""Build a balanced focus set from real parser failures and nearby DMHY rows.
The goal is to repair boundary mistakes without teaching the model that every
special-like token should dominate title/season/episode context. Reported
failures are resolved back to their authoritative char BIO rows from DMHY when
possible, then mixed with repaired rows, broad boundary-pattern rows, random
context, and a small number of deterministic hard cases.
"""
from __future__ import annotations
import argparse
import json
import random
import re
from collections import Counter
from pathlib import Path
from typing import Iterable, Sequence
from anifilebert.label_repairs import repair_jsonl_item
from tools.build_path_focus_dataset import build_cases as build_path_cases
from tools.build_repair_focus_dataset import manual_cases as repair_manual_cases
BOUNDARY_FOCUS_RE = re.compile(
r"(?ix)"
r"(?:"
r"\b(?:NCOP|NCED|OP|ED|PV|CM|TVCM|OVA|OAD|SP|Menu)\s*[_\-.]?\s*(?:\d{0,4}|ep\.?\s*\d{1,4}|ver\.?\s*\d{1,2})\b|"
r"\b(?:Blu[-_ ]?ray\s*&\s*DVD|BD[-_ ]?BOX|Disc\.?\s*\d+|Vol\.?\s*\d+)\b|"
r"\b(?:S\d{1,2}[_\-.]?(?:OP|ED|NCOP|NCED)|NC(?:OP|ED)\d+[_\-.]?\d+)\b|"
r"\b(?:II|III|IV|V|Ⅱ|Ⅲ|Ⅳ|Ⅴ)\s+(?:OVA|OAD|CM|PV|OP|ED|Menu)\b|"
r"(?:弐|貳|贰|二|三|參|叁|参)\s*(?:ノ|の|之)\s*(?:章|期|季|部)|"
r"第\s*(?:\d+|[一二三四五六七八九十兩两貳贰弐弍參叁参肆伍陸陆柒捌玖]+)\s*[季期部章]|"
r"\b(?:Act|Part)\s+(?:II|III|IV|V)\b|"
r"\b(?:h\.?264|x\.?264|h\.?265|x\.?265|AVC[-_ ]?YUV|yuv\d+p?\d*|AAC\([^)]*\))\b"
r")"
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--input", required=True, help="Authoritative char JSONL dataset")
parser.add_argument("--output", required=True, help="Output focus JSONL")
parser.add_argument(
"--failure-report",
action="append",
default=[],
help="Parse/case metrics JSON with failures to resolve back to DMHY rows",
)
parser.add_argument("--context-samples", type=int, default=70000)
parser.add_argument("--max-boundary-rows", type=int, default=90000)
parser.add_argument("--repeat-failure", type=int, default=18)
parser.add_argument("--repeat-repaired", type=int, default=2)
parser.add_argument("--repeat-boundary", type=int, default=1)
parser.add_argument("--repeat-manual", type=int, default=8)
parser.add_argument("--repeat-path", type=int, default=24)
parser.add_argument("--seed", type=int, default=42)
return parser.parse_args()
def iter_jsonl(path: Path) -> Iterable[dict]:
with path.open("r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if line:
yield json.loads(line)
def reservoir_add(rows: list[dict], item: dict, limit: int, rng: random.Random, seen_count: int) -> None:
if limit <= 0:
return
if len(rows) < limit:
rows.append(item)
return
index = rng.randrange(seen_count)
if index < limit:
rows[index] = item
def failure_filenames(report_paths: Sequence[str]) -> set[str]:
filenames: set[str] = set()
for value in report_paths:
path = Path(value)
if not path.exists():
continue
report = json.loads(path.read_text(encoding="utf-8"))
modes = report.get("modes", {})
for mode in modes.values():
if not isinstance(mode, dict):
continue
for failure in mode.get("failures", []):
filename = failure.get("filename")
if filename:
filenames.add(str(filename))
for result in mode.get("results", []):
if result.get("ok", True):
continue
filename = result.get("filename")
if filename:
filenames.add(str(filename))
return filenames
def clone_with_source(item: dict, source: str) -> dict:
cloned = dict(item)
cloned["source"] = source
return cloned
def main() -> None:
args = parse_args()
rng = random.Random(args.seed)
input_path = Path(args.input)
output_path = Path(args.output)
targets = failure_filenames(args.failure_report)
failure_rows: list[dict] = []
repaired_rows: list[dict] = []
boundary_rows: list[dict] = []
context_rows: list[dict] = []
seen_filenames: set[str] = set()
source_counts: Counter[str] = Counter()
total_rows = 0
boundary_seen = 0
context_seen = 0
for item in iter_jsonl(input_path):
total_rows += 1
filename = str(item.get("filename") or "")
if not filename:
continue
if filename in targets and filename not in seen_filenames:
failure_rows.append(clone_with_source(item, "balanced_report_failure"))
seen_filenames.add(filename)
continue
_repaired_item, repairs = repair_jsonl_item(item)
if repairs and filename not in seen_filenames:
repaired_rows.append(clone_with_source(item, "balanced_repaired_context"))
seen_filenames.add(filename)
continue
if BOUNDARY_FOCUS_RE.search(filename) and filename not in seen_filenames:
boundary_seen += 1
reservoir_add(
boundary_rows,
clone_with_source(item, "balanced_boundary_pattern"),
args.max_boundary_rows,
rng,
boundary_seen,
)
seen_filenames.add(filename)
continue
if filename in seen_filenames:
continue
context_seen += 1
reservoir_add(context_rows, clone_with_source(item, "balanced_random_context"), args.context_samples, rng, context_seen)
rows: list[dict] = []
rows.extend(failure_rows * max(1, args.repeat_failure))
rows.extend(repaired_rows * max(1, args.repeat_repaired))
rows.extend(boundary_rows * max(1, args.repeat_boundary))
rows.extend(context_rows)
for item in repair_manual_cases():
rows.extend([clone_with_source(item, "balanced_manual_repair")] * max(1, args.repeat_manual))
for item in build_path_cases("balanced_manual_path"):
rows.extend([item] * max(1, args.repeat_path))
rng.shuffle(rows)
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as handle:
for item in rows:
handle.write(json.dumps(item, ensure_ascii=False, separators=(",", ":")) + "\n")
source_counts[str(item.get("source", "unknown"))] += 1
print(
json.dumps(
{
"input": str(input_path),
"output": str(output_path),
"total_rows": total_rows,
"failure_targets": len(targets),
"matched_failure_rows": len(failure_rows),
"repaired_rows": len(repaired_rows),
"boundary_rows": len(boundary_rows),
"context_rows": len(context_rows),
"written_rows": len(rows),
"source_counts": dict(source_counts),
},
ensure_ascii=False,
indent=2,
)
)
if __name__ == "__main__":
main()