File size: 4,905 Bytes
ad39f2a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | {
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train a Self-Driving Lab Policy with Unsloth\n",
"\n",
"This notebook uses **Unsloth** for fast quantized training on GPU nodes (e.g. H100). It mirrors `train.ipynb` but loads the model via Unsloth's optimized path with 4-bit quantization and LoRA adapters.\n",
"\n",
"**Model**: Uses **Qwen3-4B-Base** by default. Alternatives:\n",
"- `Qwen/Qwen3-4B-Base` (base, no chat template)\n",
"- `unsloth/Qwen2.5-3B-Instruct-bnb-4bit`\n",
"- `unsloth/Qwen2.5-7B-Instruct-bnb-4bit`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install Unsloth and training dependencies (run once per session)\n",
"# Option A: uv (if using uv-managed venv)\n",
"# !uv sync --extra train\n",
"# !uv pip install unsloth unsloth_zoo --no-deps\n",
"\n",
"# Option B: pip\n",
"%pip install -q -U torch transformers datasets trl accelerate bitsandbytes unsloth unsloth_zoo matplotlib huggingface_hub\n",
"\n",
"# Optional extras used by some reward-scoring paths.\n",
"%pip install -q -U sentence-transformers gseapy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Unsloth must be imported before trl, transformers, peft\n",
"import unsloth # noqa: F401\n",
"\n",
"from pathlib import Path\n",
"import torch\n",
"\n",
"from training_unsloth import make_training_args, run_training\n",
"import training_script as base\n",
"\n",
"print(\"CUDA available:\", torch.cuda.is_available())\n",
"if torch.cuda.is_available():\n",
" print(\"GPU:\", torch.cuda.get_device_name(0))\n",
" print(\"bf16 supported:\", torch.cuda.is_bf16_supported())\n",
"\n",
"Path(\"artifacts\").mkdir(exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"args = make_training_args(\n",
" model_id=\"Qwen/Qwen3-4B-Base\",\n",
" output_dir=\"artifacts/grpo-unsloth-qwen3-4b\",\n",
" dataset_episodes=32,\n",
" rollout_steps=10,\n",
" collection_policy=\"heuristic\",\n",
" reward_backend=\"local\",\n",
" domain_randomise=True,\n",
" num_generations=4,\n",
" max_completion_length=160,\n",
" max_prompt_length=1280,\n",
" max_seq_length=2048,\n",
" per_device_train_batch_size=4,\n",
" gradient_accumulation_steps=4,\n",
" learning_rate=5e-6,\n",
" num_train_epochs=1.0,\n",
" logging_steps=1,\n",
" save_steps=25,\n",
" trust_remote_code=True,\n",
" dry_run=False,\n",
" seed=42,\n",
")\n",
"\n",
"args"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"preview_examples = base.build_prompt_examples(\n",
" dataset_episodes=1,\n",
" rollout_steps=args.rollout_steps,\n",
" collection_policy=args.collection_policy,\n",
" scenario_names=[\"cardiac_disease_de\"],\n",
" seed=args.seed,\n",
" domain_randomise=args.domain_randomise,\n",
")\n",
"\n",
"print(preview_examples[0][\"prompt\"][:3500])\n",
"print(\"\\nReference action:\\n\", preview_examples[0][\"reference_action\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional smoke test before a full run.\n",
"dry_run_args = make_training_args(**{**vars(args), \"dry_run\": True})\n",
"dry_run_result = run_training(dry_run_args)\n",
"len(dry_run_result[\"examples\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import Image, display\n",
"\n",
"train_result = run_training(args)\n",
"for name, plot_path in train_result[\"plot_paths\"].items():\n",
" print(name, plot_path)\n",
" display(Image(filename=plot_path))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
|