File size: 5,592 Bytes
db03c40
 
 
 
 
 
 
 
 
ad39f2a
 
 
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad39f2a
db03c40
 
ad39f2a
db03c40
ad39f2a
 
 
 
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad39f2a
db03c40
ad39f2a
 
db03c40
 
 
ad39f2a
 
 
 
 
 
 
 
db03c40
ad39f2a
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad39f2a
 
 
 
 
 
 
 
 
 
db03c40
ad39f2a
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "cbde861c",
      "metadata": {},
      "source": [
        "# Train A Self-Driving Lab Policy on H100\n",
        "\n",
        "This notebook trains a GRPO policy for the **same bio-experiment planning task** as `run_agent.py`: choosing structured actions (collect_sample, run_qc, cluster, de_analysis, etc.) step-by-step in the OpenEnv bio-experiment environment.\n",
        "\n",
        "**Flow:** Build prompts from `BioExperimentEnvironment` rollouts (same env `run_agent.py` uses) → OpenEnv reward scores actions locally → GRPO trains the model. Uses `build_openenv_reward`, `prepare_prompt_examples`, and `build_grpo_trainer` from `training_script.py`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "da2e770c",
      "metadata": {},
      "outputs": [],
      "source": [
        "%pip install -q -U torch transformers datasets trl accelerate matplotlib huggingface_hub\n",
        "\n",
        "# Optional extras used by some reward-scoring paths.\n",
        "%pip install -q -U sentence-transformers gseapy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f4444591",
      "metadata": {},
      "outputs": [],
      "source": [
        "from pathlib import Path\n",
        "import importlib\n",
        "\n",
        "import torch\n",
        "import training_script as training_script_module\n",
        "\n",
        "training_script_module = importlib.reload(training_script_module)\n",
        "make_training_args = training_script_module.make_training_args\n",
        "prepare_prompt_examples = training_script_module.prepare_prompt_examples\n",
        "run_training = training_script_module.run_training\n",
        "\n",
        "print(\"CUDA available:\", torch.cuda.is_available())\n",
        "if torch.cuda.is_available():\n",
        "    print(\"GPU:\", torch.cuda.get_device_name(0))\n",
        "    print(\"bf16 supported:\", torch.cuda.is_bf16_supported())\n",
        "\n",
        "Path(\"artifacts\").mkdir(exist_ok=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c9c472b3",
      "metadata": {},
      "outputs": [],
      "source": [
        "args = make_training_args(\n",
        "    model_id=\"Qwen/Qwen3.5-9B\",\n",
        "    output_dir=\"artifacts/grpo-h100\",\n",
        "    dataset_episodes=64,                 # more data per run\n",
        "    rollout_steps=12,                    # slightly longer trajectories\n",
        "    collection_policy=\"heuristic\",\n",
        "    reward_backend=\"local\",\n",
        "    domain_randomise=True,\n",
        "\n",
        "    num_generations=8,                   # H100 can handle a larger GRPO group\n",
        "    max_completion_length=192,           # small bump if completions are being cut off\n",
        "    max_prompt_length=1024,              # trim a bit unless you truly need 1280\n",
        "\n",
        "    per_device_train_batch_size=8,       # first thing to try on H100\n",
        "    gradient_accumulation_steps=2,       # same effective batch as before, fewer sync steps\n",
        "    learning_rate=1e-5,                  # slightly more aggressive for LoRA/QLoRA-style RL tuning\n",
        "    num_train_epochs=1.0,\n",
        "\n",
        "    logging_steps=1,\n",
        "    save_steps=25,\n",
        "    trust_remote_code=True,\n",
        "    dry_run=False,\n",
        "    seed=42,\n",
        ")\n",
        "\n",
        "args"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d4c3d9c4",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Same prompt format run_agent.py sees: SYSTEM_PROMPT + observation\n",
        "preview_data = prepare_prompt_examples(\n",
        "    make_training_args(\n",
        "        dataset_episodes=1,\n",
        "        rollout_steps=args.rollout_steps,\n",
        "        collection_policy=args.collection_policy,\n",
        "        scenario_name=[\"cardiac_disease_de\"],\n",
        "        seed=args.seed,\n",
        "        domain_randomise=args.domain_randomise,\n",
        "    )\n",
        ")\n",
        "preview_examples = preview_data[\"examples\"]\n",
        "\n",
        "print(preview_examples[0][\"prompt\"][:3500])\n",
        "print(\"\\nReference action:\\n\", preview_examples[0][\"reference_action\"])\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "647663dd",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Optional smoke test before a full run.\n",
        "dry_run_args = make_training_args(**{**vars(args), \"dry_run\": True})\n",
        "dry_run_result = run_training(dry_run_args)\n",
        "len(dry_run_result[\"examples\"])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5f29f456",
      "metadata": {},
      "outputs": [],
      "source": [
        "from IPython.display import Image, display\n",
        "\n",
        "train_result = run_training(args)\n",
        "for name, plot_path in train_result[\"plot_paths\"].items():\n",
        "    print(name, plot_path)\n",
        "    display(Image(filename=plot_path))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}