{ "cells": [ { "cell_type": "markdown", "id": "cbde861c", "metadata": {}, "source": [ "# Train A Self-Driving Lab Policy on H100\n", "\n", "This notebook trains a GRPO policy for the **same bio-experiment planning task** as `run_agent.py`: choosing structured actions (collect_sample, run_qc, cluster, de_analysis, etc.) step-by-step in the OpenEnv bio-experiment environment.\n", "\n", "**Flow:** Build prompts from `BioExperimentEnvironment` rollouts (same env `run_agent.py` uses) → OpenEnv reward scores actions locally → GRPO trains the model. Uses `build_openenv_reward`, `prepare_prompt_examples`, and `build_grpo_trainer` from `training_script.py`." ] }, { "cell_type": "code", "execution_count": null, "id": "da2e770c", "metadata": {}, "outputs": [], "source": [ "%pip install -q -U torch transformers datasets trl accelerate matplotlib huggingface_hub\n", "\n", "# Optional extras used by some reward-scoring paths.\n", "%pip install -q -U sentence-transformers gseapy" ] }, { "cell_type": "code", "execution_count": null, "id": "f4444591", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "import importlib\n", "\n", "import torch\n", "import training_script as training_script_module\n", "\n", "training_script_module = importlib.reload(training_script_module)\n", "make_training_args = training_script_module.make_training_args\n", "prepare_prompt_examples = training_script_module.prepare_prompt_examples\n", "run_training = training_script_module.run_training\n", "\n", "print(\"CUDA available:\", torch.cuda.is_available())\n", "if torch.cuda.is_available():\n", " print(\"GPU:\", torch.cuda.get_device_name(0))\n", " print(\"bf16 supported:\", torch.cuda.is_bf16_supported())\n", "\n", "Path(\"artifacts\").mkdir(exist_ok=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "c9c472b3", "metadata": {}, "outputs": [], "source": [ "args = make_training_args(\n", " model_id=\"Qwen/Qwen3.5-9B\",\n", " output_dir=\"artifacts/grpo-h100\",\n", " dataset_episodes=64, # more data per run\n", " rollout_steps=12, # slightly longer trajectories\n", " collection_policy=\"heuristic\",\n", " reward_backend=\"local\",\n", " domain_randomise=True,\n", "\n", " num_generations=8, # H100 can handle a larger GRPO group\n", " max_completion_length=192, # small bump if completions are being cut off\n", " max_prompt_length=1024, # trim a bit unless you truly need 1280\n", "\n", " per_device_train_batch_size=8, # first thing to try on H100\n", " gradient_accumulation_steps=2, # same effective batch as before, fewer sync steps\n", " learning_rate=1e-5, # slightly more aggressive for LoRA/QLoRA-style RL tuning\n", " num_train_epochs=1.0,\n", "\n", " logging_steps=1,\n", " save_steps=25,\n", " trust_remote_code=True,\n", " dry_run=False,\n", " seed=42,\n", ")\n", "\n", "args" ] }, { "cell_type": "code", "execution_count": null, "id": "d4c3d9c4", "metadata": {}, "outputs": [], "source": [ "# Same prompt format run_agent.py sees: SYSTEM_PROMPT + observation\n", "preview_data = prepare_prompt_examples(\n", " make_training_args(\n", " dataset_episodes=1,\n", " rollout_steps=args.rollout_steps,\n", " collection_policy=args.collection_policy,\n", " scenario_name=[\"cardiac_disease_de\"],\n", " seed=args.seed,\n", " domain_randomise=args.domain_randomise,\n", " )\n", ")\n", "preview_examples = preview_data[\"examples\"]\n", "\n", "print(preview_examples[0][\"prompt\"][:3500])\n", "print(\"\\nReference action:\\n\", preview_examples[0][\"reference_action\"])\n" ] }, { "cell_type": "code", "execution_count": null, "id": "647663dd", "metadata": {}, "outputs": [], "source": [ "# Optional smoke test before a full run.\n", "dry_run_args = make_training_args(**{**vars(args), \"dry_run\": True})\n", "dry_run_result = run_training(dry_run_args)\n", "len(dry_run_result[\"examples\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "5f29f456", "metadata": {}, "outputs": [], "source": [ "from IPython.display import Image, display\n", "\n", "train_result = run_training(args)\n", "for name, plot_path in train_result[\"plot_paths\"].items():\n", " print(name, plot_path)\n", " display(Image(filename=plot_path))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }