#!/usr/bin/env python3 """ Train a separate REINFORCE agent for each protocol set (e.g. PCR, ELISA). Each protocol has its own presets and outcome model. Training one agent per protocol gives you a policy tailored to that protocol's action/observation space. Checkpoints are saved under checkpoints/.pt. Usage: python scripts/train_per_protocol.py --workflows pcr-amplification elisa-readout python scripts/train_per_protocol.py --workflows pcr-amplification --train-episodes 1000 """ from __future__ import annotations import argparse import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from lab_env.env import LabEnv from lab_env.spec import get_spec_for_workflow from agents.rl_agent import ReinforceAgent def main() -> None: parser = argparse.ArgumentParser( description="Train one RL agent per protocol set (different presets / specs)" ) parser.add_argument( "--workflows", nargs="+", default=["pcr-amplification", "elisa-readout"], help="Workflow IDs to train (each gets its own agent and checkpoint)", ) parser.add_argument("--train-episodes", type=int, default=1500) parser.add_argument("--eval-episodes", type=int, default=50) parser.add_argument("--lr", type=float, default=3e-3) parser.add_argument("--max-trials", type=int, default=4) parser.add_argument("--checkpoint-dir", type=str, default="checkpoints") parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True) for workflow_id in args.workflows: spec = get_spec_for_workflow(workflow_id) env = LabEnv(spec=spec) agent = ReinforceAgent( lr=args.lr, max_trials=args.max_trials, spec=spec, ) print(f"\n{'='*60}") print(f" Training for protocol: {workflow_id} (presets={spec.num_presets}, obs_dim={spec.obs_dim})") print("=" * 60) for ep in range(1, args.train_episodes + 1): result = agent.run_episode(env, seed=args.seed + ep, train=True) if ep % 200 == 0 or ep == args.train_episodes: print(f" Episode {ep:5d} | reward: {result['reward']:7.1f} | success: {result['success']}") checkpoint_path = Path(args.checkpoint_dir) / f"{workflow_id}.pt" agent.save(str(checkpoint_path)) print(f" Saved checkpoint: {checkpoint_path}") # Quick eval successes = 0 for i in range(args.eval_episodes): r = agent.run_episode(env, seed=999_000 + i, train=False) successes += r["success"] print(f" Eval success rate: {successes / args.eval_episodes:.0%}") env.close() print("\nDone. Use each checkpoint with LabEnv(spec=) and ReinforceAgent(spec=spec).load(path).") if __name__ == "__main__": main()