AniFileBERT / evaluate_parser_cases.py
ModerRAS's picture
Remove structural parser rule assists
116c87c
raw
history blame
6.76 kB
"""Evaluate parser checkpoints on fixed real-world filename cases."""
import argparse
import json
import os
from typing import Dict, List, Optional
import torch
from transformers import BertForTokenClassification
from config import Config
from inference import parse_filename
from tokenizer import load_tokenizer
DEFAULT_CASE_FILE = os.path.join("data", "parser_regression_cases.json")
def normalize_field_value(field: str, value) -> Optional[str]:
if value is None:
return None
if field in {"episode", "season"}:
try:
return str(int(value))
except (TypeError, ValueError):
return str(value).strip().lower()
text = str(value).strip()
if field in {"resolution", "source"}:
return text.lower().replace("_", "-")
return " ".join(text.lower().split())
def load_cases(path: str) -> List[Dict]:
with open(path, "r", encoding="utf-8") as f:
cases = json.load(f)
if not isinstance(cases, list):
raise ValueError(f"{path} must contain a JSON list")
return cases
def evaluate_cases(
model_dir: str,
case_file: str,
tokenizer_variant: Optional[str],
max_length: Optional[int],
constrain_bio: bool,
) -> Dict:
cfg = Config()
tokenizer = load_tokenizer(model_dir, tokenizer_variant)
model = BertForTokenClassification.from_pretrained(model_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
id2label = {int(k): v for k, v in getattr(model.config, "id2label", cfg.id2label).items()}
resolved_max_length = max_length or int(getattr(model.config, "max_seq_length", 64))
cases = load_cases(case_file)
field_totals: Dict[str, int] = {}
field_correct: Dict[str, int] = {}
results = []
full_correct = 0
for case in cases:
expected = case.get("expected", {})
pred = parse_filename(
case["filename"],
model,
tokenizer,
id2label,
max_length=resolved_max_length,
debug=False,
constrain_bio=constrain_bio,
)
errors = {}
for field, expected_value in expected.items():
field_totals[field] = field_totals.get(field, 0) + 1
expected_norm = normalize_field_value(field, expected_value)
pred_norm = normalize_field_value(field, pred.get(field))
if expected_norm == pred_norm:
field_correct[field] = field_correct.get(field, 0) + 1
else:
errors[field] = {
"expected": expected_value,
"pred": pred.get(field),
}
if not errors:
full_correct += 1
results.append(
{
"id": case.get("id"),
"filename": case["filename"],
"ok": not errors,
"errors": errors,
"expected": expected,
"pred": {field: pred.get(field) for field in sorted(expected)},
}
)
field_accuracy = {
field: field_correct.get(field, 0) / total
for field, total in sorted(field_totals.items())
}
return {
"model_dir": model_dir,
"case_file": case_file,
"tokenizer_variant": getattr(tokenizer, "tokenizer_variant", "regex"),
"max_length": resolved_max_length,
"constrain_bio": constrain_bio,
"case_count": len(cases),
"full_correct": full_correct,
"full_accuracy": full_correct / len(cases) if cases else 0.0,
"field_correct": field_correct,
"field_total": field_totals,
"field_accuracy": field_accuracy,
"failures": [result for result in results if not result["ok"]],
"results": results,
}
def evaluate_case_modes(
model_dir: str,
case_file: str,
tokenizer_variant: Optional[str],
max_length: Optional[int],
) -> Dict:
modes = {
"model_only": {"constrain_bio": False},
"normalized_only": {"constrain_bio": True},
}
results = {
name: evaluate_cases(
model_dir=model_dir,
case_file=case_file,
tokenizer_variant=tokenizer_variant,
max_length=max_length,
constrain_bio=settings["constrain_bio"],
)
for name, settings in modes.items()
}
return {
"primary_metric": "normalized_only",
"modes": results,
}
def print_metrics(name: str, metrics: Dict) -> None:
print(
f"{name} full case accuracy: {metrics['full_correct']}/{metrics['case_count']} "
f"({metrics['full_accuracy']:.4f})"
)
for field, total in metrics["field_total"].items():
correct = metrics["field_correct"].get(field, 0)
print(f" {field}: {correct}/{total} ({correct / total:.4f})")
if metrics["failures"]:
print(f"\n{name} failures:")
for failure in metrics["failures"]:
print(json.dumps(failure, ensure_ascii=False))
def main() -> None:
parser = argparse.ArgumentParser(description="Evaluate parser on fixed filename regression cases")
parser.add_argument("--model-dir", required=True)
parser.add_argument("--case-file", default=DEFAULT_CASE_FILE)
parser.add_argument("--tokenizer", choices=["regex", "char"], default=None)
parser.add_argument("--max-length", type=int, default=None)
parser.add_argument("--output", default=None, help="Optional JSON output path")
parser.add_argument("--mode", choices=["all", "model-only", "normalized-only"], default="all")
parser.add_argument("--no-constrained-bio", action="store_true")
args = parser.parse_args()
if args.mode == "all" and not args.no_constrained_bio:
metrics = evaluate_case_modes(
model_dir=args.model_dir,
case_file=args.case_file,
tokenizer_variant=args.tokenizer,
max_length=args.max_length,
)
for name in ("model_only", "normalized_only"):
print_metrics(name, metrics["modes"][name])
print()
else:
constrain_bio = not args.no_constrained_bio and args.mode != "model-only"
metrics = evaluate_cases(
model_dir=args.model_dir,
case_file=args.case_file,
tokenizer_variant=args.tokenizer,
max_length=args.max_length,
constrain_bio=constrain_bio,
)
print_metrics(args.mode, metrics)
if args.output:
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
with open(args.output, "w", encoding="utf-8") as f:
json.dump(metrics, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
main()