Spaces:
Sleeping
Sleeping
| """CLI wrapper for the fixed-suite harness.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| import yaml | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from evacos_ma.models import DisasterType | |
| from training.checkpoint import load_checkpoint | |
| from training.policy_adapter import StubPolicy, hf_policy_factory | |
| from evaluation.fixed_suite import run_fixed_suite | |
| def _split_csv(raw: str) -> list[str]: | |
| return [item.strip() for item in raw.split(",") if item.strip()] | |
| def _load_snapshot_from_path(path: Path) -> dict: | |
| return json.loads(path.read_text(encoding="utf-8")) | |
| def _load_latest_checkpoint_snapshot(config_path: Path = Path("training/config.yaml")) -> dict: | |
| data = yaml.safe_load(config_path.read_text(encoding="utf-8")) | |
| checkpoint_root = Path(str(data.get("checkpoint", {}).get("root_dir", "outputs/checkpoints"))) | |
| bundle = load_checkpoint(checkpoint_root) | |
| if bundle is None: | |
| raise FileNotFoundError(f"No checkpoint with normalizer snapshot found under {checkpoint_root}") | |
| return bundle.normalizer_snapshot | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--tiers", default="easy") | |
| parser.add_argument("--seeds", default="42,123,456,789,1024") | |
| parser.add_argument("--families", default="fire,flood,gas") | |
| parser.add_argument("--rationale-mode", default="linear_capped") | |
| parser.add_argument("--label", default="trained") | |
| parser.add_argument("--checkpoint") | |
| parser.add_argument( | |
| "--model-name", | |
| default="Qwen/Qwen2.5-3B-Instruct", | |
| help="Base model to pair with --checkpoint. Defaults to the 3B specialist base.", | |
| ) | |
| parser.add_argument("--output-dir", default="outputs/evals") | |
| parser.add_argument("--max-rounds", type=int, default=50) | |
| group = parser.add_mutually_exclusive_group() | |
| group.add_argument("--normalizer-snapshot") | |
| group.add_argument("--use-latest-checkpoint-normalizer", action="store_true") | |
| args = parser.parse_args() | |
| tiers = _split_csv(args.tiers) | |
| seeds = [int(item) for item in _split_csv(args.seeds)] | |
| families = [DisasterType(item) for item in _split_csv(args.families)] | |
| if args.checkpoint: | |
| checkpoint = Path(args.checkpoint) | |
| policy_factory = lambda: hf_policy_factory( | |
| args.model_name, | |
| lora_adapter_path=str(checkpoint), | |
| ) | |
| else: | |
| policy_factory = lambda: StubPolicy(seed=0) | |
| normalizer_snapshot = None | |
| if args.normalizer_snapshot: | |
| normalizer_snapshot = _load_snapshot_from_path(Path(args.normalizer_snapshot)) | |
| elif args.use_latest_checkpoint_normalizer: | |
| normalizer_snapshot = _load_latest_checkpoint_snapshot() | |
| run_fixed_suite( | |
| policy_factory, | |
| tiers=tiers, | |
| seeds=seeds, | |
| disaster_families=families, | |
| max_rounds=args.max_rounds, | |
| rationale_mode=args.rationale_mode, | |
| label=args.label, | |
| output_dir=Path(args.output_dir), | |
| normalizer_snapshot=normalizer_snapshot, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |