triagesieve_env / scripts /validate_episode_bank.py
Angshuman28's picture
Upload folder using huggingface_hub
a7effbb verified
"""Episode bank validator.
Usage:
python scripts/validate_episode_bank.py --input data/seeded_episodes.jsonl [--verbose]
Three-pass validation:
1. Parse — every line is valid JSON with required top-level keys.
2. Determinism — re-render a sample via EpisodeEngine and compare.
3. Solvability — scripted expert scores >= 0.90 on each episode.
Exit code 0 on PASS, 1 on FAIL.
"""
from __future__ import annotations
import argparse
import dataclasses
import json
import logging
import sys
from enum import Enum
from pathlib import Path
from typing import Any
from triagesieve_env.baseline.scripted_expert import ScriptedExpert
from triagesieve_env.models import TaskDifficulty
from triagesieve_env.server.episode_engine import EpisodeEngine
from triagesieve_env.server.triagesieve_env_environment import TriageSieveEnvironment
logger = logging.getLogger(__name__)
_REQUIRED_KEYS = {"episode_id", "seed", "task_difficulty", "tickets", "action_budget", "base_time"}
_SOLVABILITY_THRESHOLDS: dict[str, float] = {
"easy": 0.90,
"medium": 0.75,
"hard": 0.20,
}
_DETERMINISM_SAMPLE = 10 # spot-check first N unless --verbose
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
"""Parse CLI arguments."""
parser = argparse.ArgumentParser(description="Validate the seeded episode bank.")
parser.add_argument("--input", type=str, required=True, help="Path to seeded_episodes.jsonl.")
parser.add_argument(
"--verbose", action="store_true", help="Check ALL episodes for determinism (not just 10)."
)
return parser.parse_args(argv)
# ---------------------------------------------------------------------------
# Pass 1 — parse
# ---------------------------------------------------------------------------
def _pass_parse(lines: list[str]) -> tuple[list[dict[str, Any]], list[str]]:
"""Parse every line as JSON and check required keys.
Returns:
Tuple of (parsed episodes, error messages).
"""
episodes: list[dict[str, Any]] = []
errors: list[str] = []
if not lines:
errors.append("Episode bank is empty (0 lines).")
return episodes, errors
for i, line in enumerate(lines, 1):
line = line.strip()
if not line:
continue
try:
ep = json.loads(line)
except json.JSONDecodeError as exc:
errors.append(f"Line {i}: invalid JSON — {exc}")
continue
missing = _REQUIRED_KEYS - set(ep.keys())
if missing:
errors.append(f"Line {i} ({ep.get('episode_id', '?')}): missing keys {sorted(missing)}")
continue
if not ep.get("tickets"):
errors.append(f"Line {i} ({ep['episode_id']}): zero tickets.")
continue
episodes.append(ep)
return episodes, errors
# ---------------------------------------------------------------------------
# Pass 2 — determinism
# ---------------------------------------------------------------------------
def _enum_serializer(obj: Any) -> Any:
"""Convert enum values for JSON comparison."""
if isinstance(obj, Enum):
return obj.value
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
def _pass_determinism(episodes: list[dict[str, Any]], max_check: int | None = None) -> list[str]:
"""Re-render episodes from (seed, difficulty) and compare to bank.
Args:
episodes: Parsed episode dicts from the bank.
max_check: Maximum episodes to spot-check. None = all.
Returns:
List of error messages (empty = pass).
"""
engine = EpisodeEngine()
errors: list[str] = []
to_check = episodes[:max_check] if max_check is not None else episodes
for ep in to_check:
seed = ep["seed"]
difficulty = TaskDifficulty(ep["task_difficulty"])
rendered = engine.render_episode(seed=seed, difficulty=difficulty)
# Serialize rendered episode the same way generate_episodes does
raw = dataclasses.asdict(rendered)
rendered_dict = json.loads(json.dumps(raw, default=_enum_serializer))
if rendered_dict != ep:
# Provide a specific diagnostic for common divergences
if rendered_dict["episode_id"] != ep["episode_id"]:
detail = f"episode_id {rendered_dict['episode_id']!r} vs {ep['episode_id']!r}"
elif len(rendered_dict["tickets"]) != len(ep["tickets"]):
detail = f"ticket count {len(rendered_dict['tickets'])} vs {len(ep['tickets'])}"
else:
detail = "content mismatch (same id and ticket count)"
errors.append(f"Determinism FAIL: {ep['episode_id']} (seed={seed}): {detail}")
return errors
# ---------------------------------------------------------------------------
# Pass 3 — solvability
# ---------------------------------------------------------------------------
def _pass_solvability(episodes: list[dict[str, Any]]) -> tuple[list[float], list[str]]:
"""Run scripted expert on each episode and assert score >= threshold.
Returns:
Tuple of (score list, error messages).
"""
env = TriageSieveEnvironment()
expert = ScriptedExpert(env)
scores: list[float] = []
errors: list[str] = []
for ep in episodes:
seed = ep["seed"]
diff_str = ep["task_difficulty"]
difficulty = TaskDifficulty(diff_str)
threshold = _SOLVABILITY_THRESHOLDS.get(diff_str, 0.90)
trace = expert.run_episode(seed=seed, difficulty=difficulty)
score = trace["final_score"]
scores.append(score)
if score < threshold:
errors.append(
f"Solvability FAIL: {ep['episode_id']} (seed={seed}, "
f"difficulty={diff_str}) scored {score:.4f} "
f"< {threshold}"
)
return scores, errors
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main(argv: list[str] | None = None) -> int:
"""Entry point. Returns 0 on PASS, 1 on FAIL."""
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
args = _parse_args(argv)
input_path = Path(args.input)
# Read file
if not input_path.exists():
logger.error("File not found: %s", input_path)
return 1
lines = input_path.read_text(encoding="utf-8").splitlines()
all_ok = True
# Pass 1: Parse
logger.info("Pass 1/3: Parse (%d lines)...", len(lines))
episodes, parse_errors = _pass_parse(lines)
if parse_errors:
for err in parse_errors:
logger.error(" %s", err)
logger.error("Pass 1 FAIL: %d parse errors.", len(parse_errors))
return 1
logger.info(" Pass 1 OK: %d episodes parsed.", len(episodes))
# Pass 2: Determinism
max_check = None if args.verbose else min(_DETERMINISM_SAMPLE, len(episodes))
logger.info("Pass 2/3: Determinism (checking %s episodes)...", max_check or "all")
det_errors = _pass_determinism(episodes, max_check=max_check)
if det_errors:
for err in det_errors:
logger.error(" %s", err)
logger.error("Pass 2 FAIL: %d determinism errors.", len(det_errors))
all_ok = False
else:
logger.info(" Pass 2 OK.")
# Pass 3: Solvability
logger.info("Pass 3/3: Solvability (%d episodes)...", len(episodes))
scores, solve_errors = _pass_solvability(episodes)
if solve_errors:
for err in solve_errors:
logger.error(" %s", err)
logger.error("Pass 3 FAIL: %d episodes below threshold.", len(solve_errors))
all_ok = False
else:
logger.info(" Pass 3 OK.")
# Summary
if scores:
min_s, avg_s, max_s = min(scores), sum(scores) / len(scores), max(scores)
thresholds_str = ", ".join(f"{k}={v}" for k, v in _SOLVABILITY_THRESHOLDS.items())
logger.info(
"Scores: min=%.4f avg=%.4f max=%.4f (thresholds: %s)",
min_s,
avg_s,
max_s,
thresholds_str,
)
if all_ok:
logger.info("RESULT: PASS (%d episodes validated)", len(episodes))
return 0
else:
logger.error("RESULT: FAIL")
return 1
if __name__ == "__main__":
sys.exit(main())