Spaces:
Sleeping
Sleeping
File size: 2,722 Bytes
4339a77 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | #!/usr/bin/env python3
"""Read the Phase 2 **`routing`** object from a classifier checkpoint's **`eval_report.json`**.
Used by Horizon 1 glue, **rag_faq_smoke**, **embeddings_smoke_test**, **routing_policy** (**`--from-checkpoint`**), **horizon1_route_then_retrieve**, and training/report CLIs so training notes and runtime gates stay aligned."""
from __future__ import annotations
import json
import sys
from pathlib import Path
def load_routing_from_eval_report(model_path: str | Path) -> dict | None:
"""Return the top-level ``routing`` dict if ``model_path`` is a dir with a valid report."""
p = Path(model_path)
if not p.is_dir():
return None
er = p / "eval_report.json"
if not er.is_file():
return None
try:
data = json.loads(er.read_text(encoding="utf-8"))
except json.JSONDecodeError:
return None
r = data.get("routing")
return r if isinstance(r, dict) else None
def format_checkpoint_tip_path(
output_dir: str | Path,
*,
cwd: Path | None = None,
) -> str:
"""Return a repo-relative checkpoint path when ``output_dir`` is under ``cwd``."""
p = Path(output_dir).resolve()
base = (cwd if cwd is not None else Path.cwd()).resolve()
try:
return p.relative_to(base).as_posix()
except ValueError:
return p.as_posix()
def format_routing_policy_from_checkpoint_command(
output_dir: str | Path,
*,
cwd: Path | None = None,
) -> str:
"""Full ``python scripts/routing_policy.py --from-checkpoint …`` line (no shell quoting)."""
tip = format_checkpoint_tip_path(output_dir, cwd=cwd)
return f"python scripts/routing_policy.py --from-checkpoint {tip}"
def print_routing_policy_from_checkpoint_tip(
output_dir: str | Path,
*,
headline: str = "Tip: dump Phase 2 `routing` JSON (no model load):",
cwd: Path | None = None,
) -> None:
"""Print a copy-paste **Tip:** for ``routing_policy`` (shared by train/compare/verify scripts)."""
cmd = format_routing_policy_from_checkpoint_command(output_dir, cwd=cwd)
print(f"{headline}\n {cmd}", flush=True)
def maybe_print_routing_section(model_path: str, *, enabled: bool, prog: str) -> None:
"""If ``enabled``, print ``routing`` JSON or a stderr hint (``prog`` labels the caller)."""
if not enabled:
return
notes = load_routing_from_eval_report(model_path)
if notes is None:
print(
f"{prog}: no eval_report.json with top-level `routing` "
"(Hub id or missing artifact).",
file=sys.stderr,
)
return
print("=== eval_report.json routing (Phase 2 training notes) ===\n")
print(json.dumps(notes, indent=2))
print()
|