File size: 2,955 Bytes
da63ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/env python3
"""
Train a separate REINFORCE agent for each protocol set (e.g. PCR, ELISA).

Each protocol has its own presets and outcome model. Training one agent per
protocol gives you a policy tailored to that protocol's action/observation
space. Checkpoints are saved under checkpoints/<workflow_id>.pt.

Usage:
  python scripts/train_per_protocol.py --workflows pcr-amplification elisa-readout
  python scripts/train_per_protocol.py --workflows pcr-amplification --train-episodes 1000
"""

from __future__ import annotations

import argparse
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from lab_env.env import LabEnv
from lab_env.spec import get_spec_for_workflow
from agents.rl_agent import ReinforceAgent


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Train one RL agent per protocol set (different presets / specs)"
    )
    parser.add_argument(
        "--workflows",
        nargs="+",
        default=["pcr-amplification", "elisa-readout"],
        help="Workflow IDs to train (each gets its own agent and checkpoint)",
    )
    parser.add_argument("--train-episodes", type=int, default=1500)
    parser.add_argument("--eval-episodes", type=int, default=50)
    parser.add_argument("--lr", type=float, default=3e-3)
    parser.add_argument("--max-trials", type=int, default=4)
    parser.add_argument("--checkpoint-dir", type=str, default="checkpoints")
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)

    for workflow_id in args.workflows:
        spec = get_spec_for_workflow(workflow_id)
        env = LabEnv(spec=spec)
        agent = ReinforceAgent(
            lr=args.lr,
            max_trials=args.max_trials,
            spec=spec,
        )

        print(f"\n{'='*60}")
        print(f"  Training for protocol: {workflow_id} (presets={spec.num_presets}, obs_dim={spec.obs_dim})")
        print("=" * 60)

        for ep in range(1, args.train_episodes + 1):
            result = agent.run_episode(env, seed=args.seed + ep, train=True)
            if ep % 200 == 0 or ep == args.train_episodes:
                print(f"  Episode {ep:5d} | reward: {result['reward']:7.1f} | success: {result['success']}")

        checkpoint_path = Path(args.checkpoint_dir) / f"{workflow_id}.pt"
        agent.save(str(checkpoint_path))
        print(f"  Saved checkpoint: {checkpoint_path}")

        # Quick eval
        successes = 0
        for i in range(args.eval_episodes):
            r = agent.run_episode(env, seed=999_000 + i, train=False)
            successes += r["success"]
        print(f"  Eval success rate: {successes / args.eval_episodes:.0%}")

        env.close()

    print("\nDone. Use each checkpoint with LabEnv(spec=<same_spec>) and ReinforceAgent(spec=spec).load(path).")


if __name__ == "__main__":
    main()