#!/usr/bin/env python3 """Two-phase grid search for Viterbi transition biases. Phase 1: Search background_to_start and inside_to_continue over {-2,-1,0,1,2} while all other biases are 0. (25 combos) Phase 2: Fix those two at their best values, search the remaining 4 biases over {-1,0,1}. (81 combos) Each combo runs `opf eval --metrics-out ` and parses detection.span.f1 from the JSON output. """ from __future__ import annotations import argparse import itertools import json import shutil import subprocess import sys import tempfile import time from pathlib import Path from typing import Sequence BIAS_KEYS = [ "transition_bias_background_stay", "transition_bias_background_to_start", "transition_bias_inside_to_continue", "transition_bias_inside_to_end", "transition_bias_end_to_background", "transition_bias_end_to_start", ] # Which metric to optimise (from opf eval JSON output) TARGET_METRIC = "detection.span.f1" FALLBACK_METRIC = "detection.f1" def make_calibration(biases: dict[str, float]) -> dict: """Build a calibration artifact dict.""" return { "operating_points": { "default": { "biases": {k: biases[k] for k in BIAS_KEYS} } } } def run_eval( checkpoint: str, val_data: str, calibration_path: str, device: str, timeout: int, ) -> float | None: """Run one `opf eval` and return the target F1, or None on failure.""" with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f: metrics_out = f.name cmd = [ sys.executable, "-m", "opf", "eval", val_data, "--checkpoint", checkpoint, "--device", device, "--viterbi-calibration-path", calibration_path, "--metrics-out", metrics_out, ] try: result = subprocess.run( cmd, capture_output=True, text=True, timeout=timeout, ) except subprocess.TimeoutExpired: print(" [TIMEOUT]", file=sys.stderr, flush=True) Path(metrics_out).unlink(missing_ok=True) return None except Exception as exc: print(f" [ERROR] {exc}", file=sys.stderr, flush=True) Path(metrics_out).unlink(missing_ok=True) return None if result.returncode != 0: print(f" [FAIL rc={result.returncode}] {result.stderr[:200]}", file=sys.stderr, flush=True) Path(metrics_out).unlink(missing_ok=True) return None try: payload = json.loads(Path(metrics_out).read_text()) metrics = payload.get("metrics", {}) f1 = metrics.get(TARGET_METRIC) if f1 is None: f1 = metrics.get(FALLBACK_METRIC) if f1 is not None: return float(f1) print(f" [WARN] metric not found in output", file=sys.stderr, flush=True) return None except (json.JSONDecodeError, KeyError, TypeError) as exc: print(f" [PARSE ERROR] {exc}", file=sys.stderr, flush=True) return None finally: Path(metrics_out).unlink(missing_ok=True) def search_phase( combos: list[dict[str, float]], checkpoint: str, val_data: str, device: str, timeout: int, phase_name: str, ) -> tuple[dict[str, float], float]: """Run eval for each combo, return (best_biases, best_f1).""" best_f1 = -1.0 best_biases: dict[str, float] = {k: 0.0 for k in BIAS_KEYS} total = len(combos) print(f"\n{'='*60}") print(f"{phase_name}: {total} combinations") print(f"{'='*60}") with tempfile.NamedTemporaryFile( mode="w", suffix=".json", delete=False ) as cal_file: cal_path = cal_file.name try: for i, biases in enumerate(combos, 1): # Write calibration file Path(cal_path).write_text(json.dumps(make_calibration(biases), indent=2)) t0 = time.time() f1 = run_eval(checkpoint, val_data, cal_path, device, timeout) elapsed = time.time() - t0 # Compact display of non-zero biases nonzero = {k.replace("transition_bias_", ""): v for k, v in biases.items() if v != 0.0} tag = " ** NEW BEST **" if f1 is not None and f1 > best_f1 else "" if f1 is not None: print(f" [{i:3d}/{total}] F1={f1:.4f} {nonzero} ({elapsed:.1f}s){tag}", flush=True) if f1 > best_f1: best_f1 = f1 best_biases = dict(biases) else: print(f" [{i:3d}/{total}] FAILED {nonzero} ({elapsed:.1f}s)", flush=True) finally: Path(cal_path).unlink(missing_ok=True) print(f"\n Best F1: {best_f1:.4f}") return best_biases, best_f1 def main(argv: Sequence[str] | None = None) -> None: parser = argparse.ArgumentParser(description="Viterbi transition bias grid search") parser.add_argument("--checkpoint", required=True, help="Model checkpoint path") parser.add_argument("--val-data", required=True, help="Validation JSONL file") parser.add_argument("--device", default="cuda", help="Device (default: cuda)") parser.add_argument("--output", required=True, help="Where to save best calibration JSON") parser.add_argument("--timeout", type=int, default=300, help="Per-eval timeout in seconds") args = parser.parse_args(argv) checkpoint = str(Path(args.checkpoint).resolve()) val_data = str(Path(args.val_data).resolve()) # --- Baseline (all zeros) --- print("Running baseline (all biases = 0) ...") with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: baseline_cal = f.name json.dump(make_calibration({k: 0.0 for k in BIAS_KEYS}), f, indent=2) baseline_f1 = run_eval(checkpoint, val_data, baseline_cal, args.device, args.timeout) Path(baseline_cal).unlink(missing_ok=True) if baseline_f1 is not None: print(f"Baseline F1: {baseline_f1:.4f}") else: print("WARNING: baseline eval failed, continuing anyway", file=sys.stderr) baseline_f1 = 0.0 # --- Phase 1: Search background_to_start × inside_to_continue --- phase1_values = [-2.0, -1.0, 0.0, 1.0, 2.0] phase1_combos: list[dict[str, float]] = [] for bts, itc in itertools.product(phase1_values, phase1_values): biases = {k: 0.0 for k in BIAS_KEYS} biases["transition_bias_background_to_start"] = bts biases["transition_bias_inside_to_continue"] = itc phase1_combos.append(biases) best1, best1_f1 = search_phase( phase1_combos, checkpoint, val_data, args.device, args.timeout, "Phase 1 (background_to_start × inside_to_continue)", ) # --- Phase 2: Fix best two, search remaining 4 --- fixed_bts = best1["transition_bias_background_to_start"] fixed_itc = best1["transition_bias_inside_to_continue"] remaining_keys = [ "transition_bias_background_stay", "transition_bias_inside_to_end", "transition_bias_end_to_background", "transition_bias_end_to_start", ] phase2_values = [-1.0, 0.0, 1.0] phase2_combos: list[dict[str, float]] = [] for vals in itertools.product(phase2_values, repeat=4): biases = {k: 0.0 for k in BIAS_KEYS} biases["transition_bias_background_to_start"] = fixed_bts biases["transition_bias_inside_to_continue"] = fixed_itc for key, val in zip(remaining_keys, vals): biases[key] = val phase2_combos.append(biases) best2, best2_f1 = search_phase( phase2_combos, checkpoint, val_data, args.device, args.timeout, "Phase 2 (remaining 4 biases)", ) # --- Pick overall best --- if best2_f1 >= best1_f1: final_biases, final_f1 = best2, best2_f1 else: final_biases, final_f1 = best1, best1_f1 # --- Save --- output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) calibration = make_calibration(final_biases) output_path.write_text(json.dumps(calibration, indent=2) + "\n") print(f"\nSaved best calibration to: {output_path}") # Copy into checkpoint dir ckpt_cal = Path(checkpoint) / "viterbi_calibration.json" try: ckpt_cal.write_text(json.dumps(calibration, indent=2) + "\n") print(f"Copied to checkpoint dir: {ckpt_cal}") except OSError as exc: print(f"WARNING: could not copy to checkpoint dir: {exc}", file=sys.stderr) # --- Summary --- print(f"\n{'='*60}") print("SUMMARY") print(f"{'='*60}") print(f"Baseline F1: {baseline_f1:.4f}") print(f"Best F1: {final_f1:.4f}") improvement = final_f1 - baseline_f1 print(f"Improvement: {improvement:+.4f} ({improvement / max(baseline_f1, 1e-9) * 100:+.2f}%)") print(f"\nBest biases:") for k in BIAS_KEYS: print(f" {k}: {final_biases[k]}") if __name__ == "__main__": main()