File size: 5,592 Bytes
db03c40 ad39f2a db03c40 ad39f2a db03c40 ad39f2a db03c40 ad39f2a db03c40 ad39f2a db03c40 ad39f2a db03c40 ad39f2a db03c40 ad39f2a db03c40 ad39f2a db03c40 ad39f2a db03c40 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | {
"cells": [
{
"cell_type": "markdown",
"id": "cbde861c",
"metadata": {},
"source": [
"# Train A Self-Driving Lab Policy on H100\n",
"\n",
"This notebook trains a GRPO policy for the **same bio-experiment planning task** as `run_agent.py`: choosing structured actions (collect_sample, run_qc, cluster, de_analysis, etc.) step-by-step in the OpenEnv bio-experiment environment.\n",
"\n",
"**Flow:** Build prompts from `BioExperimentEnvironment` rollouts (same env `run_agent.py` uses) → OpenEnv reward scores actions locally → GRPO trains the model. Uses `build_openenv_reward`, `prepare_prompt_examples`, and `build_grpo_trainer` from `training_script.py`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da2e770c",
"metadata": {},
"outputs": [],
"source": [
"%pip install -q -U torch transformers datasets trl accelerate matplotlib huggingface_hub\n",
"\n",
"# Optional extras used by some reward-scoring paths.\n",
"%pip install -q -U sentence-transformers gseapy"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4444591",
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import importlib\n",
"\n",
"import torch\n",
"import training_script as training_script_module\n",
"\n",
"training_script_module = importlib.reload(training_script_module)\n",
"make_training_args = training_script_module.make_training_args\n",
"prepare_prompt_examples = training_script_module.prepare_prompt_examples\n",
"run_training = training_script_module.run_training\n",
"\n",
"print(\"CUDA available:\", torch.cuda.is_available())\n",
"if torch.cuda.is_available():\n",
" print(\"GPU:\", torch.cuda.get_device_name(0))\n",
" print(\"bf16 supported:\", torch.cuda.is_bf16_supported())\n",
"\n",
"Path(\"artifacts\").mkdir(exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c9c472b3",
"metadata": {},
"outputs": [],
"source": [
"args = make_training_args(\n",
" model_id=\"Qwen/Qwen3.5-9B\",\n",
" output_dir=\"artifacts/grpo-h100\",\n",
" dataset_episodes=64, # more data per run\n",
" rollout_steps=12, # slightly longer trajectories\n",
" collection_policy=\"heuristic\",\n",
" reward_backend=\"local\",\n",
" domain_randomise=True,\n",
"\n",
" num_generations=8, # H100 can handle a larger GRPO group\n",
" max_completion_length=192, # small bump if completions are being cut off\n",
" max_prompt_length=1024, # trim a bit unless you truly need 1280\n",
"\n",
" per_device_train_batch_size=8, # first thing to try on H100\n",
" gradient_accumulation_steps=2, # same effective batch as before, fewer sync steps\n",
" learning_rate=1e-5, # slightly more aggressive for LoRA/QLoRA-style RL tuning\n",
" num_train_epochs=1.0,\n",
"\n",
" logging_steps=1,\n",
" save_steps=25,\n",
" trust_remote_code=True,\n",
" dry_run=False,\n",
" seed=42,\n",
")\n",
"\n",
"args"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d4c3d9c4",
"metadata": {},
"outputs": [],
"source": [
"# Same prompt format run_agent.py sees: SYSTEM_PROMPT + observation\n",
"preview_data = prepare_prompt_examples(\n",
" make_training_args(\n",
" dataset_episodes=1,\n",
" rollout_steps=args.rollout_steps,\n",
" collection_policy=args.collection_policy,\n",
" scenario_name=[\"cardiac_disease_de\"],\n",
" seed=args.seed,\n",
" domain_randomise=args.domain_randomise,\n",
" )\n",
")\n",
"preview_examples = preview_data[\"examples\"]\n",
"\n",
"print(preview_examples[0][\"prompt\"][:3500])\n",
"print(\"\\nReference action:\\n\", preview_examples[0][\"reference_action\"])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "647663dd",
"metadata": {},
"outputs": [],
"source": [
"# Optional smoke test before a full run.\n",
"dry_run_args = make_training_args(**{**vars(args), \"dry_run\": True})\n",
"dry_run_result = run_training(dry_run_args)\n",
"len(dry_run_result[\"examples\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f29f456",
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import Image, display\n",
"\n",
"train_result = run_training(args)\n",
"for name, plot_path in train_result[\"plot_paths\"].items():\n",
" print(name, plot_path)\n",
" display(Image(filename=plot_path))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|