File size: 4,905 Bytes
ad39f2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Train a Self-Driving Lab Policy with Unsloth\n",
        "\n",
        "This notebook uses **Unsloth** for fast quantized training on GPU nodes (e.g. H100). It mirrors `train.ipynb` but loads the model via Unsloth's optimized path with 4-bit quantization and LoRA adapters.\n",
        "\n",
        "**Model**: Uses **Qwen3-4B-Base** by default. Alternatives:\n",
        "- `Qwen/Qwen3-4B-Base` (base, no chat template)\n",
        "- `unsloth/Qwen2.5-3B-Instruct-bnb-4bit`\n",
        "- `unsloth/Qwen2.5-7B-Instruct-bnb-4bit`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Install Unsloth and training dependencies (run once per session)\n",
        "# Option A: uv (if using uv-managed venv)\n",
        "#   !uv sync --extra train\n",
        "#   !uv pip install unsloth unsloth_zoo --no-deps\n",
        "\n",
        "# Option B: pip\n",
        "%pip install -q -U torch transformers datasets trl accelerate bitsandbytes unsloth unsloth_zoo matplotlib huggingface_hub\n",
        "\n",
        "# Optional extras used by some reward-scoring paths.\n",
        "%pip install -q -U sentence-transformers gseapy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Unsloth must be imported before trl, transformers, peft\n",
        "import unsloth  # noqa: F401\n",
        "\n",
        "from pathlib import Path\n",
        "import torch\n",
        "\n",
        "from training_unsloth import make_training_args, run_training\n",
        "import training_script as base\n",
        "\n",
        "print(\"CUDA available:\", torch.cuda.is_available())\n",
        "if torch.cuda.is_available():\n",
        "    print(\"GPU:\", torch.cuda.get_device_name(0))\n",
        "    print(\"bf16 supported:\", torch.cuda.is_bf16_supported())\n",
        "\n",
        "Path(\"artifacts\").mkdir(exist_ok=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "args = make_training_args(\n",
        "    model_id=\"Qwen/Qwen3-4B-Base\",\n",
        "    output_dir=\"artifacts/grpo-unsloth-qwen3-4b\",\n",
        "    dataset_episodes=32,\n",
        "    rollout_steps=10,\n",
        "    collection_policy=\"heuristic\",\n",
        "    reward_backend=\"local\",\n",
        "    domain_randomise=True,\n",
        "    num_generations=4,\n",
        "    max_completion_length=160,\n",
        "    max_prompt_length=1280,\n",
        "    max_seq_length=2048,\n",
        "    per_device_train_batch_size=4,\n",
        "    gradient_accumulation_steps=4,\n",
        "    learning_rate=5e-6,\n",
        "    num_train_epochs=1.0,\n",
        "    logging_steps=1,\n",
        "    save_steps=25,\n",
        "    trust_remote_code=True,\n",
        "    dry_run=False,\n",
        "    seed=42,\n",
        ")\n",
        "\n",
        "args"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "preview_examples = base.build_prompt_examples(\n",
        "    dataset_episodes=1,\n",
        "    rollout_steps=args.rollout_steps,\n",
        "    collection_policy=args.collection_policy,\n",
        "    scenario_names=[\"cardiac_disease_de\"],\n",
        "    seed=args.seed,\n",
        "    domain_randomise=args.domain_randomise,\n",
        ")\n",
        "\n",
        "print(preview_examples[0][\"prompt\"][:3500])\n",
        "print(\"\\nReference action:\\n\", preview_examples[0][\"reference_action\"])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Optional smoke test before a full run.\n",
        "dry_run_args = make_training_args(**{**vars(args), \"dry_run\": True})\n",
        "dry_run_result = run_training(dry_run_args)\n",
        "len(dry_run_result[\"examples\"])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from IPython.display import Image, display\n",
        "\n",
        "train_result = run_training(args)\n",
        "for name, plot_path in train_result[\"plot_paths\"].items():\n",
        "    print(name, plot_path)\n",
        "    display(Image(filename=plot_path))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.10.0"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}