biosim / scripts /train_per_protocol.py
arminfg's picture
SimLab: lab automation RL env, OpenEnv adapter, Training UI, agents
da63ca8
#!/usr/bin/env python3
"""
Train a separate REINFORCE agent for each protocol set (e.g. PCR, ELISA).
Each protocol has its own presets and outcome model. Training one agent per
protocol gives you a policy tailored to that protocol's action/observation
space. Checkpoints are saved under checkpoints/<workflow_id>.pt.
Usage:
python scripts/train_per_protocol.py --workflows pcr-amplification elisa-readout
python scripts/train_per_protocol.py --workflows pcr-amplification --train-episodes 1000
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from lab_env.env import LabEnv
from lab_env.spec import get_spec_for_workflow
from agents.rl_agent import ReinforceAgent
def main() -> None:
parser = argparse.ArgumentParser(
description="Train one RL agent per protocol set (different presets / specs)"
)
parser.add_argument(
"--workflows",
nargs="+",
default=["pcr-amplification", "elisa-readout"],
help="Workflow IDs to train (each gets its own agent and checkpoint)",
)
parser.add_argument("--train-episodes", type=int, default=1500)
parser.add_argument("--eval-episodes", type=int, default=50)
parser.add_argument("--lr", type=float, default=3e-3)
parser.add_argument("--max-trials", type=int, default=4)
parser.add_argument("--checkpoint-dir", type=str, default="checkpoints")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
for workflow_id in args.workflows:
spec = get_spec_for_workflow(workflow_id)
env = LabEnv(spec=spec)
agent = ReinforceAgent(
lr=args.lr,
max_trials=args.max_trials,
spec=spec,
)
print(f"\n{'='*60}")
print(f" Training for protocol: {workflow_id} (presets={spec.num_presets}, obs_dim={spec.obs_dim})")
print("=" * 60)
for ep in range(1, args.train_episodes + 1):
result = agent.run_episode(env, seed=args.seed + ep, train=True)
if ep % 200 == 0 or ep == args.train_episodes:
print(f" Episode {ep:5d} | reward: {result['reward']:7.1f} | success: {result['success']}")
checkpoint_path = Path(args.checkpoint_dir) / f"{workflow_id}.pt"
agent.save(str(checkpoint_path))
print(f" Saved checkpoint: {checkpoint_path}")
# Quick eval
successes = 0
for i in range(args.eval_episodes):
r = agent.run_episode(env, seed=999_000 + i, train=False)
successes += r["success"]
print(f" Eval success rate: {successes / args.eval_episodes:.0%}")
env.close()
print("\nDone. Use each checkpoint with LabEnv(spec=<same_spec>) and ReinforceAgent(spec=spec).load(path).")
if __name__ == "__main__":
main()