"""Evaluate parser checkpoints on fixed real-world filename cases.""" import argparse import json import os from typing import Dict, List, Optional import torch from transformers import BertForTokenClassification from config import Config from inference import parse_filename from tokenizer import load_tokenizer DEFAULT_CASE_FILE = os.path.join("data", "parser_regression_cases.json") def normalize_field_value(field: str, value) -> Optional[str]: if value is None: return None if field in {"episode", "season"}: try: return str(int(value)) except (TypeError, ValueError): return str(value).strip().lower() text = str(value).strip() if field in {"resolution", "source"}: return text.lower().replace("_", "-") return " ".join(text.lower().split()) def load_cases(path: str) -> List[Dict]: with open(path, "r", encoding="utf-8") as f: cases = json.load(f) if not isinstance(cases, list): raise ValueError(f"{path} must contain a JSON list") return cases def evaluate_cases( model_dir: str, case_file: str, tokenizer_variant: Optional[str], max_length: Optional[int], use_rules: bool, constrain_bio: bool, ) -> Dict: cfg = Config() tokenizer = load_tokenizer(model_dir, tokenizer_variant) model = BertForTokenClassification.from_pretrained(model_dir) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() id2label = {int(k): v for k, v in getattr(model.config, "id2label", cfg.id2label).items()} resolved_max_length = max_length or int(getattr(model.config, "max_seq_length", 64)) cases = load_cases(case_file) field_totals: Dict[str, int] = {} field_correct: Dict[str, int] = {} results = [] full_correct = 0 for case in cases: expected = case.get("expected", {}) pred = parse_filename( case["filename"], model, tokenizer, id2label, max_length=resolved_max_length, debug=False, use_rules=use_rules, constrain_bio=constrain_bio, ) errors = {} for field, expected_value in expected.items(): field_totals[field] = field_totals.get(field, 0) + 1 expected_norm = normalize_field_value(field, expected_value) pred_norm = normalize_field_value(field, pred.get(field)) if expected_norm == pred_norm: field_correct[field] = field_correct.get(field, 0) + 1 else: errors[field] = { "expected": expected_value, "pred": pred.get(field), } if not errors: full_correct += 1 results.append( { "id": case.get("id"), "filename": case["filename"], "ok": not errors, "errors": errors, "expected": expected, "pred": {field: pred.get(field) for field in sorted(expected)}, } ) field_accuracy = { field: field_correct.get(field, 0) / total for field, total in sorted(field_totals.items()) } return { "model_dir": model_dir, "case_file": case_file, "tokenizer_variant": getattr(tokenizer, "tokenizer_variant", "regex"), "max_length": resolved_max_length, "use_rules": use_rules, "constrain_bio": constrain_bio, "case_count": len(cases), "full_correct": full_correct, "full_accuracy": full_correct / len(cases) if cases else 0.0, "field_correct": field_correct, "field_total": field_totals, "field_accuracy": field_accuracy, "failures": [result for result in results if not result["ok"]], "results": results, } def evaluate_case_modes( model_dir: str, case_file: str, tokenizer_variant: Optional[str], max_length: Optional[int], ) -> Dict: modes = { "model_only": {"use_rules": False, "constrain_bio": False}, "normalized_only": {"use_rules": False, "constrain_bio": True}, "rule_assisted": {"use_rules": True, "constrain_bio": True}, } results = { name: evaluate_cases( model_dir=model_dir, case_file=case_file, tokenizer_variant=tokenizer_variant, max_length=max_length, use_rules=settings["use_rules"], constrain_bio=settings["constrain_bio"], ) for name, settings in modes.items() } return { "primary_metric": "normalized_only", "modes": results, } def print_metrics(name: str, metrics: Dict) -> None: print( f"{name} full case accuracy: {metrics['full_correct']}/{metrics['case_count']} " f"({metrics['full_accuracy']:.4f})" ) for field, total in metrics["field_total"].items(): correct = metrics["field_correct"].get(field, 0) print(f" {field}: {correct}/{total} ({correct / total:.4f})") if metrics["failures"]: print(f"\n{name} failures:") for failure in metrics["failures"]: print(json.dumps(failure, ensure_ascii=False)) def main() -> None: parser = argparse.ArgumentParser(description="Evaluate parser on fixed filename regression cases") parser.add_argument("--model-dir", required=True) parser.add_argument("--case-file", default=DEFAULT_CASE_FILE) parser.add_argument("--tokenizer", choices=["regex", "char"], default=None) parser.add_argument("--max-length", type=int, default=None) parser.add_argument("--output", default=None, help="Optional JSON output path") parser.add_argument("--mode", choices=["all", "model-only", "normalized-only", "rule-assisted"], default="all") parser.add_argument("--rule-assist", action="store_true", help="Shortcut for --mode rule-assisted") parser.add_argument("--no-rule-assist", action="store_true", help=argparse.SUPPRESS) parser.add_argument("--no-constrained-bio", action="store_true") args = parser.parse_args() if args.rule_assist: args.mode = "rule-assisted" if args.no_rule_assist and args.mode == "rule-assisted": args.mode = "normalized-only" if args.mode == "all" and not args.no_constrained_bio: metrics = evaluate_case_modes( model_dir=args.model_dir, case_file=args.case_file, tokenizer_variant=args.tokenizer, max_length=args.max_length, ) for name in ("model_only", "normalized_only", "rule_assisted"): print_metrics(name, metrics["modes"][name]) print() else: use_rules = args.mode == "rule-assisted" constrain_bio = not args.no_constrained_bio and args.mode != "model-only" metrics = evaluate_cases( model_dir=args.model_dir, case_file=args.case_file, tokenizer_variant=args.tokenizer, max_length=args.max_length, use_rules=use_rules, constrain_bio=constrain_bio, ) print_metrics(args.mode, metrics) if args.output: os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) with open(args.output, "w", encoding="utf-8") as f: json.dump(metrics, f, ensure_ascii=False, indent=2) if __name__ == "__main__": main()