arcspan / scripts /viterbi_grid_search.py
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
#!/usr/bin/env python3
"""Two-phase grid search for Viterbi transition biases.
Phase 1: Search background_to_start and inside_to_continue over {-2,-1,0,1,2}
while all other biases are 0. (25 combos)
Phase 2: Fix those two at their best values, search the remaining 4 biases
over {-1,0,1}. (81 combos)
Each combo runs `opf eval --metrics-out <tmpfile>` and parses detection.span.f1
from the JSON output.
"""
from __future__ import annotations
import argparse
import itertools
import json
import shutil
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from typing import Sequence
BIAS_KEYS = [
"transition_bias_background_stay",
"transition_bias_background_to_start",
"transition_bias_inside_to_continue",
"transition_bias_inside_to_end",
"transition_bias_end_to_background",
"transition_bias_end_to_start",
]
# Which metric to optimise (from opf eval JSON output)
TARGET_METRIC = "detection.span.f1"
FALLBACK_METRIC = "detection.f1"
def make_calibration(biases: dict[str, float]) -> dict:
"""Build a calibration artifact dict."""
return {
"operating_points": {
"default": {
"biases": {k: biases[k] for k in BIAS_KEYS}
}
}
}
def run_eval(
checkpoint: str,
val_data: str,
calibration_path: str,
device: str,
timeout: int,
) -> float | None:
"""Run one `opf eval` and return the target F1, or None on failure."""
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f:
metrics_out = f.name
cmd = [
sys.executable, "-m", "opf", "eval", val_data,
"--checkpoint", checkpoint,
"--device", device,
"--viterbi-calibration-path", calibration_path,
"--metrics-out", metrics_out,
]
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=timeout,
)
except subprocess.TimeoutExpired:
print(" [TIMEOUT]", file=sys.stderr, flush=True)
Path(metrics_out).unlink(missing_ok=True)
return None
except Exception as exc:
print(f" [ERROR] {exc}", file=sys.stderr, flush=True)
Path(metrics_out).unlink(missing_ok=True)
return None
if result.returncode != 0:
print(f" [FAIL rc={result.returncode}] {result.stderr[:200]}", file=sys.stderr, flush=True)
Path(metrics_out).unlink(missing_ok=True)
return None
try:
payload = json.loads(Path(metrics_out).read_text())
metrics = payload.get("metrics", {})
f1 = metrics.get(TARGET_METRIC)
if f1 is None:
f1 = metrics.get(FALLBACK_METRIC)
if f1 is not None:
return float(f1)
print(f" [WARN] metric not found in output", file=sys.stderr, flush=True)
return None
except (json.JSONDecodeError, KeyError, TypeError) as exc:
print(f" [PARSE ERROR] {exc}", file=sys.stderr, flush=True)
return None
finally:
Path(metrics_out).unlink(missing_ok=True)
def search_phase(
combos: list[dict[str, float]],
checkpoint: str,
val_data: str,
device: str,
timeout: int,
phase_name: str,
) -> tuple[dict[str, float], float]:
"""Run eval for each combo, return (best_biases, best_f1)."""
best_f1 = -1.0
best_biases: dict[str, float] = {k: 0.0 for k in BIAS_KEYS}
total = len(combos)
print(f"\n{'='*60}")
print(f"{phase_name}: {total} combinations")
print(f"{'='*60}")
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as cal_file:
cal_path = cal_file.name
try:
for i, biases in enumerate(combos, 1):
# Write calibration file
Path(cal_path).write_text(json.dumps(make_calibration(biases), indent=2))
t0 = time.time()
f1 = run_eval(checkpoint, val_data, cal_path, device, timeout)
elapsed = time.time() - t0
# Compact display of non-zero biases
nonzero = {k.replace("transition_bias_", ""): v for k, v in biases.items() if v != 0.0}
tag = " ** NEW BEST **" if f1 is not None and f1 > best_f1 else ""
if f1 is not None:
print(f" [{i:3d}/{total}] F1={f1:.4f} {nonzero} ({elapsed:.1f}s){tag}", flush=True)
if f1 > best_f1:
best_f1 = f1
best_biases = dict(biases)
else:
print(f" [{i:3d}/{total}] FAILED {nonzero} ({elapsed:.1f}s)", flush=True)
finally:
Path(cal_path).unlink(missing_ok=True)
print(f"\n Best F1: {best_f1:.4f}")
return best_biases, best_f1
def main(argv: Sequence[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="Viterbi transition bias grid search")
parser.add_argument("--checkpoint", required=True, help="Model checkpoint path")
parser.add_argument("--val-data", required=True, help="Validation JSONL file")
parser.add_argument("--device", default="cuda", help="Device (default: cuda)")
parser.add_argument("--output", required=True, help="Where to save best calibration JSON")
parser.add_argument("--timeout", type=int, default=300, help="Per-eval timeout in seconds")
args = parser.parse_args(argv)
checkpoint = str(Path(args.checkpoint).resolve())
val_data = str(Path(args.val_data).resolve())
# --- Baseline (all zeros) ---
print("Running baseline (all biases = 0) ...")
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
baseline_cal = f.name
json.dump(make_calibration({k: 0.0 for k in BIAS_KEYS}), f, indent=2)
baseline_f1 = run_eval(checkpoint, val_data, baseline_cal, args.device, args.timeout)
Path(baseline_cal).unlink(missing_ok=True)
if baseline_f1 is not None:
print(f"Baseline F1: {baseline_f1:.4f}")
else:
print("WARNING: baseline eval failed, continuing anyway", file=sys.stderr)
baseline_f1 = 0.0
# --- Phase 1: Search background_to_start × inside_to_continue ---
phase1_values = [-2.0, -1.0, 0.0, 1.0, 2.0]
phase1_combos: list[dict[str, float]] = []
for bts, itc in itertools.product(phase1_values, phase1_values):
biases = {k: 0.0 for k in BIAS_KEYS}
biases["transition_bias_background_to_start"] = bts
biases["transition_bias_inside_to_continue"] = itc
phase1_combos.append(biases)
best1, best1_f1 = search_phase(
phase1_combos, checkpoint, val_data, args.device, args.timeout,
"Phase 1 (background_to_start × inside_to_continue)",
)
# --- Phase 2: Fix best two, search remaining 4 ---
fixed_bts = best1["transition_bias_background_to_start"]
fixed_itc = best1["transition_bias_inside_to_continue"]
remaining_keys = [
"transition_bias_background_stay",
"transition_bias_inside_to_end",
"transition_bias_end_to_background",
"transition_bias_end_to_start",
]
phase2_values = [-1.0, 0.0, 1.0]
phase2_combos: list[dict[str, float]] = []
for vals in itertools.product(phase2_values, repeat=4):
biases = {k: 0.0 for k in BIAS_KEYS}
biases["transition_bias_background_to_start"] = fixed_bts
biases["transition_bias_inside_to_continue"] = fixed_itc
for key, val in zip(remaining_keys, vals):
biases[key] = val
phase2_combos.append(biases)
best2, best2_f1 = search_phase(
phase2_combos, checkpoint, val_data, args.device, args.timeout,
"Phase 2 (remaining 4 biases)",
)
# --- Pick overall best ---
if best2_f1 >= best1_f1:
final_biases, final_f1 = best2, best2_f1
else:
final_biases, final_f1 = best1, best1_f1
# --- Save ---
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
calibration = make_calibration(final_biases)
output_path.write_text(json.dumps(calibration, indent=2) + "\n")
print(f"\nSaved best calibration to: {output_path}")
# Copy into checkpoint dir
ckpt_cal = Path(checkpoint) / "viterbi_calibration.json"
try:
ckpt_cal.write_text(json.dumps(calibration, indent=2) + "\n")
print(f"Copied to checkpoint dir: {ckpt_cal}")
except OSError as exc:
print(f"WARNING: could not copy to checkpoint dir: {exc}", file=sys.stderr)
# --- Summary ---
print(f"\n{'='*60}")
print("SUMMARY")
print(f"{'='*60}")
print(f"Baseline F1: {baseline_f1:.4f}")
print(f"Best F1: {final_f1:.4f}")
improvement = final_f1 - baseline_f1
print(f"Improvement: {improvement:+.4f} ({improvement / max(baseline_f1, 1e-9) * 100:+.2f}%)")
print(f"\nBest biases:")
for k in BIAS_KEYS:
print(f" {k}: {final_biases[k]}")
if __name__ == "__main__":
main()