File size: 2,955 Bytes
da63ca8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | #!/usr/bin/env python3
"""
Train a separate REINFORCE agent for each protocol set (e.g. PCR, ELISA).
Each protocol has its own presets and outcome model. Training one agent per
protocol gives you a policy tailored to that protocol's action/observation
space. Checkpoints are saved under checkpoints/<workflow_id>.pt.
Usage:
python scripts/train_per_protocol.py --workflows pcr-amplification elisa-readout
python scripts/train_per_protocol.py --workflows pcr-amplification --train-episodes 1000
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from lab_env.env import LabEnv
from lab_env.spec import get_spec_for_workflow
from agents.rl_agent import ReinforceAgent
def main() -> None:
parser = argparse.ArgumentParser(
description="Train one RL agent per protocol set (different presets / specs)"
)
parser.add_argument(
"--workflows",
nargs="+",
default=["pcr-amplification", "elisa-readout"],
help="Workflow IDs to train (each gets its own agent and checkpoint)",
)
parser.add_argument("--train-episodes", type=int, default=1500)
parser.add_argument("--eval-episodes", type=int, default=50)
parser.add_argument("--lr", type=float, default=3e-3)
parser.add_argument("--max-trials", type=int, default=4)
parser.add_argument("--checkpoint-dir", type=str, default="checkpoints")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
for workflow_id in args.workflows:
spec = get_spec_for_workflow(workflow_id)
env = LabEnv(spec=spec)
agent = ReinforceAgent(
lr=args.lr,
max_trials=args.max_trials,
spec=spec,
)
print(f"\n{'='*60}")
print(f" Training for protocol: {workflow_id} (presets={spec.num_presets}, obs_dim={spec.obs_dim})")
print("=" * 60)
for ep in range(1, args.train_episodes + 1):
result = agent.run_episode(env, seed=args.seed + ep, train=True)
if ep % 200 == 0 or ep == args.train_episodes:
print(f" Episode {ep:5d} | reward: {result['reward']:7.1f} | success: {result['success']}")
checkpoint_path = Path(args.checkpoint_dir) / f"{workflow_id}.pt"
agent.save(str(checkpoint_path))
print(f" Saved checkpoint: {checkpoint_path}")
# Quick eval
successes = 0
for i in range(args.eval_episodes):
r = agent.run_episode(env, seed=999_000 + i, train=False)
successes += r["success"]
print(f" Eval success rate: {successes / args.eval_episodes:.0%}")
env.close()
print("\nDone. Use each checkpoint with LabEnv(spec=<same_spec>) and ReinforceAgent(spec=spec).load(path).")
if __name__ == "__main__":
main()
|