AniFileBERT / tools /diagnose_pipeline.py
ModerRAS's picture
Organize parser modules and tools
8c50d16
"""Diagnostics for the anime filename NER pipeline.
The checks focus on structured filename parsing failure modes:
- train/inference tokenizer mismatch
- BIO legality and boundary drift
- tokenizer split and vocabulary coverage
- label/entity distribution
- optional model confusion on a sampled validation split
"""
from __future__ import annotations
import argparse
import json
import math
import os
import random
import re
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple
import numpy as np
import torch
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from transformers import BertForTokenClassification
from anifilebert.config import Config
from anifilebert.dataset import labels_for_tokenizer
from anifilebert.inference import constrained_bio_decode, postprocess
from anifilebert.tokenizer import AnimeTokenizer, create_tokenizer, load_tokenizer
def iter_jsonl(path: Path, limit: Optional[int] = None) -> Iterable[dict]:
with path.open("r", encoding="utf-8") as handle:
for line_no, line in enumerate(handle, 1):
if limit is not None and line_no > limit:
break
line = line.strip()
if not line:
continue
try:
yield json.loads(line)
except json.JSONDecodeError as exc:
raise ValueError(f"{path}:{line_no}: invalid JSON") from exc
def detect_dataset_variant(samples: List[dict], vocab_file: Optional[str]) -> str:
variants = {sample.get("tokenizer_variant") for sample in samples if sample.get("tokenizer_variant")}
if len(variants) == 1:
return next(iter(variants))
if len(variants) > 1:
return "mixed"
if vocab_file and ".char" in os.path.basename(vocab_file).lower():
return "char"
char_like = 0
with_filename = 0
for sample in samples:
filename = sample.get("filename")
if filename is None:
continue
with_filename += 1
if sample.get("tokens") == list(filename):
char_like += 1
if with_filename and char_like / with_filename >= 0.95:
return "char"
return "regex"
def entity_type(label: str) -> Optional[str]:
if "-" not in label:
return None
return label.split("-", 1)[1]
def bio_violations(tokens: List[str], labels: List[str]) -> List[dict]:
violations: List[dict] = []
previous_label = "O"
current_entity: Optional[str] = None
for idx, label in enumerate(labels):
token = tokens[idx] if idx < len(tokens) else None
if label == "O":
current_entity = None
elif label.startswith("B-"):
current_entity = entity_type(label)
elif label.startswith("I-"):
label_entity = entity_type(label)
previous_entity = entity_type(previous_label)
if idx == 0 or previous_label == "O" or previous_entity != label_entity:
violations.append(
{
"type": "ORPHAN_I",
"index": idx,
"prev_label": previous_label,
"label": label,
"token": token,
}
)
current_entity = label_entity
else:
violations.append(
{
"type": "UNKNOWN_LABEL",
"index": idx,
"prev_label": previous_label,
"label": label,
"token": token,
}
)
current_entity = None
previous_label = label
return violations
def bio_boundary_warnings(tokens: List[str], labels: List[str]) -> List[dict]:
"""Collect legal-but-suspicious boundary patterns separately from BIO errors."""
warnings: List[dict] = []
for idx, label in enumerate(labels[1:], 1):
previous_label = labels[idx - 1]
if label == "O" and previous_label.startswith("B-"):
warnings.append(
{
"type": "SINGLE_TOKEN_ENTITY",
"index": idx,
"prev_label": previous_label,
"label": label,
"token": tokens[idx] if idx < len(tokens) else None,
}
)
return warnings
def spans_from_labels(tokens: List[str], labels: List[str]) -> List[dict]:
spans: List[dict] = []
start: Optional[int] = None
current_type: Optional[str] = None
current_tokens: List[str] = []
for idx, (token, label) in enumerate(zip(tokens, labels)):
if label.startswith("B-"):
if current_type is not None and start is not None:
spans.append(
{
"type": current_type,
"start": start,
"end": idx,
"text": "".join(current_tokens),
}
)
current_type = entity_type(label)
start = idx
current_tokens = [token]
elif label.startswith("I-") and current_type == entity_type(label):
current_tokens.append(token)
elif label.startswith("I-"):
if current_type is not None and start is not None:
spans.append(
{
"type": current_type,
"start": start,
"end": idx,
"text": "".join(current_tokens),
}
)
current_type = entity_type(label)
start = idx
current_tokens = [token]
else:
if current_type is not None and start is not None:
spans.append(
{
"type": current_type,
"start": start,
"end": idx,
"text": "".join(current_tokens),
}
)
current_type = None
start = None
current_tokens = []
if current_type is not None and start is not None:
spans.append(
{
"type": current_type,
"start": start,
"end": len(labels),
"text": "".join(current_tokens),
}
)
return spans
def count_entities(samples: List[dict]) -> Counter:
counts: Counter = Counter()
for sample in samples:
for span in spans_from_labels(sample["tokens"], sample["labels"]):
counts[span["type"]] += 1
return counts
def percentile(values: List[int], pct: float) -> int:
if not values:
return 0
ordered = sorted(values)
idx = min(len(ordered) - 1, round((pct / 100) * (len(ordered) - 1)))
return ordered[idx]
def token_mismatch(sample: dict, tokenizer: AnimeTokenizer) -> Optional[dict]:
filename = sample.get("filename")
if filename is None:
return None
inferred = tokenizer.tokenize(filename)
dataset_tokens = sample.get("tokens", [])
if inferred == dataset_tokens:
return None
prefix = 0
for left, right in zip(inferred, dataset_tokens):
if left != right:
break
prefix += 1
return {
"file_id": sample.get("file_id"),
"filename": filename,
"common_prefix": prefix,
"dataset_tokens": dataset_tokens[:40],
"tokenizer_tokens": inferred[:40],
"dataset_len": len(dataset_tokens),
"tokenizer_len": len(inferred),
}
def format_counter(counter: Counter, total: Optional[int] = None, limit: Optional[int] = None) -> str:
if total is None:
total = sum(counter.values())
rows = []
items = counter.most_common(limit)
for key, count in items:
pct = count / total * 100 if total else 0.0
rows.append(f"- `{key}`: {count:,} ({pct:.2f}%)")
return "\n".join(rows) if rows else "- none"
def token_id_stats(samples: List[dict], tokenizer: AnimeTokenizer) -> dict:
total = 0
unk = 0
unk_counter: Counter = Counter()
for sample in samples:
tokens, _labels = labels_for_tokenizer(sample, tokenizer)
ids = tokenizer.convert_tokens_to_ids(tokens)
for token, token_id in zip(tokens, ids):
total += 1
if token_id == tokenizer.unk_token_id:
unk += 1
unk_counter[token] += 1
return {
"total": total,
"unk": unk,
"unk_rate": unk / total if total else 0.0,
"top_unk": unk_counter.most_common(25),
}
def prepare_inputs(
sample: dict,
tokenizer: AnimeTokenizer,
label2id: Dict[str, int],
max_length: int,
) -> Tuple[List[int], List[int], List[int], List[str]]:
tokens, labels = labels_for_tokenizer(sample, tokenizer)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_ids = [tokenizer.cls_token_id] + input_ids + [tokenizer.sep_token_id]
label_ids = [-100] + [label2id.get(label, 0) for label in labels] + [-100]
attention_mask = [1] * len(input_ids)
if len(input_ids) > max_length:
input_ids = [input_ids[0]] + input_ids[1:max_length - 1] + [input_ids[-1]]
label_ids = [label_ids[0]] + label_ids[1:max_length - 1] + [label_ids[-1]]
attention_mask = [1] * len(input_ids)
pad_len = max_length - len(input_ids)
if pad_len > 0:
input_ids += [tokenizer.pad_token_id] * pad_len
label_ids += [-100] * pad_len
attention_mask += [0] * pad_len
return input_ids, attention_mask, label_ids, tokens
def normalize_field_value(field: str, value) -> Optional[str]:
if value is None:
return None
if field in {"episode", "season"}:
try:
return str(int(value))
except (TypeError, ValueError):
return str(value).strip().lower()
text = str(value).strip()
if field in {"resolution", "source"}:
return text.lower().replace("_", "-")
return re.sub(r"\s+", " ", text).strip().lower()
def update_parse_metrics(counter: Counter, gold: dict, pred: dict) -> None:
fields = ["group", "title", "season", "episode", "resolution", "source", "special"]
all_match = True
for field in fields:
gold_value = normalize_field_value(field, gold.get(field))
pred_value = normalize_field_value(field, pred.get(field))
if gold_value == pred_value:
counter[f"{field}_correct"] += 1
else:
all_match = False
counter[(field, gold_value, pred_value)] += 1
counter[f"{field}_total"] += 1
if all_match:
counter["full_match_correct"] += 1
counter["full_match_total"] += 1
def collect_field_failures(gold: dict, pred: dict) -> Dict[str, Dict[str, Optional[str]]]:
return {
field: {
"gold": normalize_field_value(field, gold.get(field)),
"pred": normalize_field_value(field, pred.get(field)),
}
for field in ["group", "title", "season", "episode", "resolution", "source", "special"]
if normalize_field_value(field, gold.get(field)) != normalize_field_value(field, pred.get(field))
}
def evaluate_model(
samples: List[dict],
model_dir: Path,
tokenizer: AnimeTokenizer,
max_length: int,
limit: int,
seed: int,
) -> dict:
cfg = Config()
model = BertForTokenClassification.from_pretrained(str(model_dir))
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
rng = random.Random(seed)
eval_samples = list(samples)
rng.shuffle(eval_samples)
eval_samples = eval_samples[:limit]
id2label = {int(k): v for k, v in getattr(model.config, "id2label", cfg.id2label).items()}
label2id = {v: int(k) for k, v in id2label.items()}
if not label2id:
label2id = cfg.label2id
id2label = cfg.id2label
true_sequences: List[List[str]] = []
pred_sequences: List[List[str]] = []
confusion: Counter = Counter()
entity_confusion: Counter = Counter()
boundary_errors: Counter = Counter()
parse_metrics: Counter = Counter()
field_failures: List[dict] = []
with torch.no_grad():
for sample in eval_samples:
input_ids, attention_mask, label_ids, sample_tokens = prepare_inputs(
sample,
tokenizer,
label2id,
max_length,
)
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
mask_tensor = torch.tensor([attention_mask], dtype=torch.long, device=device)
logits = model(input_ids=input_tensor, attention_mask=mask_tensor).logits
active_count = sum(1 for label_id in label_ids if label_id != -100)
pred_ids = constrained_bio_decode(logits[0, 1:1 + active_count, :], id2label)
true_labels: List[str] = []
pred_labels: List[str] = []
pred_idx = 0
for label_id in label_ids:
if label_id == -100:
continue
pred_id = pred_ids[pred_idx]
pred_idx += 1
true_label = id2label.get(label_id, "O")
pred_label = id2label.get(pred_id, "O")
true_labels.append(true_label)
pred_labels.append(pred_label)
confusion[(true_label, pred_label)] += 1
entity_confusion[(entity_type(true_label) or "O", entity_type(pred_label) or "O")] += 1
if true_label != pred_label:
if true_label.startswith("B-") or pred_label.startswith("B-"):
boundary_errors["B-boundary"] += 1
elif entity_type(true_label) != entity_type(pred_label):
boundary_errors["entity-type"] += 1
else:
boundary_errors["BIO-prefix"] += 1
true_sequences.append(true_labels)
pred_sequences.append(pred_labels)
active_tokens = sample_tokens[:len(true_labels)]
gold_parse = postprocess(
active_tokens,
true_labels,
tokenizer=tokenizer,
)
pred_parse = postprocess(
active_tokens,
pred_labels,
tokenizer=tokenizer,
)
update_parse_metrics(parse_metrics, gold_parse, pred_parse)
failures = collect_field_failures(gold_parse, pred_parse)
if failures and len(field_failures) < 30:
field_failures.append(
{
"filename": sample.get("filename"),
"errors": failures,
"gold": gold_parse,
"pred": pred_parse,
}
)
errors = confusion.copy()
for label in set(label for pair in confusion for label in pair):
errors.pop((label, label), None)
return {
"sample_count": len(eval_samples),
"precision": precision_score(true_sequences, pred_sequences),
"recall": recall_score(true_sequences, pred_sequences),
"f1": f1_score(true_sequences, pred_sequences),
"classification_report": classification_report(true_sequences, pred_sequences, digits=4),
"top_token_confusions": errors.most_common(30),
"top_entity_confusions": Counter(
{k: v for k, v in entity_confusion.items() if k[0] != k[1]}
).most_common(30),
"boundary_errors": boundary_errors,
"parse_metrics": parse_metrics,
"field_failures": field_failures,
}
def tokenizer_split_examples(samples: List[dict], tokenizers: Dict[str, AnimeTokenizer], limit: int = 8) -> List[dict]:
examples: List[dict] = []
for sample in samples:
filename = sample.get("filename")
if not filename:
continue
row = {
"file_id": sample.get("file_id"),
"filename": filename,
"dataset_tokens": sample.get("tokens", [])[:80],
}
for name, tokenizer in tokenizers.items():
row[f"{name}_tokens"] = tokenizer.tokenize(filename)[:80]
examples.append(row)
if len(examples) >= limit:
break
return examples
def write_report(path: Path, title: str, sections: List[Tuple[str, str]]) -> None:
parts = [f"# {title}", ""]
for heading, body in sections:
parts.append(f"## {heading}")
parts.append("")
parts.append(body.strip() if body.strip() else "_No data._")
parts.append("")
path.write_text("\n".join(parts), encoding="utf-8")
def markdown_json(value) -> str:
return "```json\n" + json.dumps(value, ensure_ascii=False, indent=2) + "\n```"
def markdown_table(headers: List[str], rows: List[List[str]], limit: Optional[int] = None) -> str:
if limit is not None:
rows = rows[:limit]
table = ["| " + " | ".join(headers) + " |", "| " + " | ".join("---" for _ in headers) + " |"]
for row in rows:
table.append("| " + " | ".join(str(cell).replace("\n", " ") for cell in row) + " |")
return "\n".join(table)
def main() -> None:
parser = argparse.ArgumentParser(description="Diagnose anime filename NER data and model pipeline")
parser.add_argument("--data-file", required=True, help="JSONL dataset with tokens and labels")
parser.add_argument("--vocab-file", default=None, help="Tokenizer vocab JSON")
parser.add_argument("--tokenizer", choices=["regex", "char"], default=None,
help="Tokenizer variant to diagnose. Defaults to dataset metadata")
parser.add_argument("--model-dir", default=None, help="Optional model directory for confusion analysis")
parser.add_argument("--max-length", type=int, default=None, help="Max sequence length for model eval/truncation stats")
parser.add_argument("--sample-limit", type=int, default=20000, help="Rows to inspect for data diagnostics")
parser.add_argument("--eval-limit", type=int, default=512, help="Rows to evaluate when --model-dir is provided")
parser.add_argument("--output", default="diagnostics_report.md", help="Markdown report path")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
data_path = Path(args.data_file)
samples = list(iter_jsonl(data_path, args.sample_limit))
if not samples:
raise ValueError(f"No samples loaded from {data_path}")
dataset_variant = detect_dataset_variant(samples, args.vocab_file)
tokenizer_variant = args.tokenizer or (dataset_variant if dataset_variant != "mixed" else "regex")
vocab_file = args.vocab_file
if vocab_file is None:
vocab_file = str(data_path.with_name("vocab.char.json" if tokenizer_variant == "char" else "vocab.json"))
tokenizer = create_tokenizer(tokenizer_variant, vocab_file=vocab_file)
if args.model_dir:
model_tokenizer = load_tokenizer(args.model_dir)
else:
model_tokenizer = tokenizer
label_counter: Counter = Counter()
length_values: List[int] = []
aligned_length_values: List[int] = []
violations: List[dict] = []
boundary_warnings: List[dict] = []
mismatch_examples: List[dict] = []
space_label_counter: Counter = Counter()
boundary_drift_counter: Counter = Counter()
truncation_count = 0
max_length = args.max_length
if max_length is None and args.model_dir:
model_config = BertForTokenClassification.from_pretrained(args.model_dir).config
max_length = int(getattr(model_config, "max_seq_length", 64))
max_length = max_length or (128 if tokenizer_variant == "char" else 64)
for row_idx, sample in enumerate(samples, 1):
tokens = sample.get("tokens", [])
labels = sample.get("labels", [])
if len(tokens) != len(labels):
violations.append(
{
"type": "LENGTH_MISMATCH",
"row": row_idx,
"file_id": sample.get("file_id"),
"token_count": len(tokens),
"label_count": len(labels),
"filename": sample.get("filename"),
}
)
continue
label_counter.update(labels)
length_values.append(len(tokens))
aligned_tokens, aligned_labels = labels_for_tokenizer(sample, tokenizer)
aligned_length_values.append(len(aligned_tokens))
if len(aligned_tokens) + 2 > max_length:
truncation_count += 1
for token, label in zip(tokens, labels):
if token.isspace():
space_label_counter[label] += 1
for violation in bio_violations(tokens, labels):
violation.update(
{
"row": row_idx,
"file_id": sample.get("file_id"),
"filename": sample.get("filename"),
"context_tokens": tokens[max(0, violation["index"] - 5):violation["index"] + 6],
"context_labels": labels[max(0, violation["index"] - 5):violation["index"] + 6],
}
)
violations.append(violation)
for warning in bio_boundary_warnings(tokens, labels):
warning.update(
{
"row": row_idx,
"file_id": sample.get("file_id"),
"filename": sample.get("filename"),
"context_tokens": tokens[max(0, warning["index"] - 5):warning["index"] + 6],
"context_labels": labels[max(0, warning["index"] - 5):warning["index"] + 6],
}
)
boundary_warnings.append(warning)
for span in spans_from_labels(tokens, labels):
text = span["text"]
if span["type"] == "TITLE":
if text.startswith("[") or text.endswith("[") or "]" in text[:3]:
boundary_drift_counter["title_contains_bracket_edge"] += 1
if re.search(r"\b(?:WEB[-_ ]?DL|WebRip|\d{3,4}[pP]|HEVC|AVC|AAC)\b", text, re.I):
boundary_drift_counter["title_contains_meta"] += 1
if span["type"] == "GROUP" and ("[" in text or "]" in text):
boundary_drift_counter["group_contains_bracket"] += 1
if len(mismatch_examples) < 10:
mismatch = token_mismatch(sample, tokenizer)
if mismatch:
mismatch_examples.append(mismatch)
entity_counter = count_entities(samples)
id_stats = token_id_stats(samples, tokenizer)
split_examples = tokenizer_split_examples(
samples,
{
"diagnosed": tokenizer,
"regex": create_tokenizer("regex", vocab_file=str(data_path.with_name("vocab.json"))),
"char": create_tokenizer("char", vocab_file=str(data_path.with_name("vocab.char.json"))),
},
)
model_eval = None
if args.model_dir:
model_eval = evaluate_model(
samples=samples,
model_dir=Path(args.model_dir),
tokenizer=model_tokenizer,
max_length=max_length,
limit=args.eval_limit,
seed=args.seed,
)
total_labels = sum(label_counter.values())
o_count = label_counter.get("O", 0)
sections: List[Tuple[str, str]] = []
sections.append(
(
"Executive Summary",
"\n".join(
[
f"- Dataset: `{data_path}`",
f"- Inspected rows: {len(samples):,}",
f"- Dataset tokenizer variant: `{dataset_variant}`",
f"- Diagnosed tokenizer variant: `{tokenizer_variant}`",
f"- Vocab: `{vocab_file}` ({tokenizer.vocab_size:,} tokens)",
f"- Max sequence length checked: {max_length}",
f"- O-label ratio: {o_count / total_labels * 100:.2f}%" if total_labels else "- O-label ratio: n/a",
f"- Truncation risk: {truncation_count:,}/{len(samples):,} rows ({truncation_count / len(samples) * 100:.2f}%)",
f"- UNK rate after selected tokenizer: {id_stats['unk_rate'] * 100:.4f}%",
f"- BIO warnings collected: {len(violations):,}",
"",
"Primary finding: this task is structural filename parsing. Tokenizer/preprocessing identity is more important than lowering token loss.",
]
),
)
)
sections.append(
(
"Label And Entity Statistics",
"\n".join(
[
"### Label distribution",
format_counter(label_counter, total_labels),
"",
"### Entity count",
format_counter(entity_counter),
"",
"### Length distribution",
markdown_json(
{
"raw_tokens": {
"min": min(length_values),
"p50": percentile(length_values, 50),
"p90": percentile(length_values, 90),
"p95": percentile(length_values, 95),
"p99": percentile(length_values, 99),
"max": max(length_values),
},
"aligned_tokens": {
"min": min(aligned_length_values),
"p50": percentile(aligned_length_values, 50),
"p90": percentile(aligned_length_values, 90),
"p95": percentile(aligned_length_values, 95),
"p99": percentile(aligned_length_values, 99),
"max": max(aligned_length_values),
},
}
),
"",
"### Whitespace labels",
format_counter(space_label_counter),
]
),
)
)
violation_counter = Counter(v["type"] for v in violations)
warning_counter = Counter(w["type"] for w in boundary_warnings)
sections.append(
(
"BIO Violations And Boundary Drift",
"\n".join(
[
"### True BIO violation counts",
format_counter(violation_counter),
"",
"### Legal boundary warning counts",
format_counter(warning_counter),
"",
"### Boundary drift heuristics",
format_counter(boundary_drift_counter),
"",
"### Sample violations",
markdown_json(violations[:30]),
"",
"### Sample boundary warnings",
markdown_json(boundary_warnings[:30]),
]
),
)
)
sections.append(
(
"Tokenizer Split And Alignment",
"\n".join(
[
"### Dataset tokens vs selected tokenizer mismatches",
markdown_json(mismatch_examples),
"",
"### Split examples",
markdown_json(split_examples),
"",
"### Vocabulary coverage",
markdown_json(id_stats),
]
),
)
)
if args.model_dir:
model_tokenizer_variant = getattr(model_tokenizer, "tokenizer_variant", "unknown")
sections.append(
(
"Train Inference Tokenizer Comparison",
"\n".join(
[
f"- Model dir: `{args.model_dir}`",
f"- Model tokenizer variant: `{model_tokenizer_variant}`",
f"- Dataset tokenizer variant: `{dataset_variant}`",
f"- Diagnostic tokenizer variant: `{tokenizer_variant}`",
f"- Model tokenizer vocab size: {model_tokenizer.vocab_size:,}",
f"- Diagnostic tokenizer vocab size: {tokenizer.vocab_size:,}",
"",
"If dataset and model tokenizer variants differ, validation loss can be low while real inference sees different token IDs and boundaries.",
]
),
)
)
if model_eval:
token_rows = [
[true, pred, f"{count:,}"]
for (true, pred), count in model_eval["top_token_confusions"]
]
entity_rows = [
[true, pred, f"{count:,}"]
for (true, pred), count in model_eval["top_entity_confusions"]
]
def parse_metric_tables(metrics: Counter) -> Tuple[List[List[str]], str, List[List[str]]]:
field_rows = []
for field in ["group", "title", "season", "episode", "resolution", "source", "special"]:
total = metrics.get(f"{field}_total", 0)
correct = metrics.get(f"{field}_correct", 0)
acc = correct / total if total else 0.0
field_rows.append([field, f"{correct:,}/{total:,}", f"{acc:.4f}"])
full_total = metrics.get("full_match_total", 0)
full_correct = metrics.get("full_match_correct", 0)
full_acc = full_correct / full_total if full_total else 0.0
full_line = f"{full_correct:,}/{full_total:,} ({full_acc:.4f})"
error_rows = [
[field, str(gold), str(pred), f"{count:,}"]
for key, count in Counter(
{key: count for key, count in metrics.items() if isinstance(key, tuple)}
).most_common(30)
if isinstance(key, tuple)
for field, gold, pred in [key]
]
return field_rows, full_line, error_rows
parse_field_rows, parse_full_line, parse_error_rows = parse_metric_tables(model_eval["parse_metrics"])
sections.append(
(
"Model Confusion Analysis",
"\n".join(
[
f"- Evaluated samples: {model_eval['sample_count']:,}",
f"- Entity precision: {model_eval['precision']:.4f}",
f"- Entity recall: {model_eval['recall']:.4f}",
f"- Entity F1: {model_eval['f1']:.4f}",
"",
"### Boundary error classes",
format_counter(model_eval["boundary_errors"]),
"",
"### Top token-label confusions",
markdown_table(["true", "pred", "count"], token_rows) if token_rows else "- none",
"",
"### Top entity-type confusions",
markdown_table(["true", "pred", "count"], entity_rows) if entity_rows else "- none",
"",
"### Field exact-match accuracy (thin runtime)",
markdown_table(["field", "correct/total", "accuracy"], parse_field_rows),
"",
f"Thin-runtime full parse exact match: {parse_full_line}",
"",
"### Top thin-runtime field parse errors",
markdown_table(["field", "gold", "pred", "count"], parse_error_rows) if parse_error_rows else "- none",
"",
"### Hardest sampled parse failures",
markdown_json(model_eval["field_failures"][:10]) if model_eval["field_failures"] else "- none",
"",
"### Seqeval report",
"```text\n" + model_eval["classification_report"] + "\n```",
]
),
)
)
sections.append(
(
"Recommended Pipeline",
"\n".join(
[
"1. Use one tokenizer variant end to end and save it in the checkpoint metadata.",
"2. Prefer char-level or a deterministic hybrid tokenizer for DMHY filenames; avoid generic subword tokenization for labels.",
"3. For char-level runs, use `--tokenizer char --max-seq-length 128` with `vocab.char.json`.",
"4. Add CRF decoding or constrained BIO decoding so illegal I-X transitions and impossible boundary jumps are blocked.",
"5. Keep runtime post-processing thin: BIO aggregation plus string/number normalization.",
"6. Track entity-level F1 and field exact-match on real filenames; do not accept low validation loss alone.",
]
),
)
)
write_report(Path(args.output), "Anime Filename Parser Diagnostics Report", sections)
print(f"Wrote diagnostics report: {args.output}")
if __name__ == "__main__":
main()