{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Train a Self-Driving Lab Policy with Unsloth\n", "\n", "This notebook uses **Unsloth** for fast quantized training on GPU nodes (e.g. H100). It mirrors `train.ipynb` but loads the model via Unsloth's optimized path with 4-bit quantization and LoRA adapters.\n", "\n", "**Model**: Uses **Qwen3-4B-Base** by default. Alternatives:\n", "- `Qwen/Qwen3-4B-Base` (base, no chat template)\n", "- `unsloth/Qwen2.5-3B-Instruct-bnb-4bit`\n", "- `unsloth/Qwen2.5-7B-Instruct-bnb-4bit`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install Unsloth and training dependencies (run once per session)\n", "# Option A: uv (if using uv-managed venv)\n", "# !uv sync --extra train\n", "# !uv pip install unsloth unsloth_zoo --no-deps\n", "\n", "# Option B: pip\n", "%pip install -q -U torch transformers datasets trl accelerate bitsandbytes unsloth unsloth_zoo matplotlib huggingface_hub\n", "\n", "# Optional extras used by some reward-scoring paths.\n", "%pip install -q -U sentence-transformers gseapy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Unsloth must be imported before trl, transformers, peft\n", "import unsloth # noqa: F401\n", "\n", "from pathlib import Path\n", "import torch\n", "\n", "from training_unsloth import make_training_args, run_training\n", "import training_script as base\n", "\n", "print(\"CUDA available:\", torch.cuda.is_available())\n", "if torch.cuda.is_available():\n", " print(\"GPU:\", torch.cuda.get_device_name(0))\n", " print(\"bf16 supported:\", torch.cuda.is_bf16_supported())\n", "\n", "Path(\"artifacts\").mkdir(exist_ok=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "args = make_training_args(\n", " model_id=\"Qwen/Qwen3-4B-Base\",\n", " output_dir=\"artifacts/grpo-unsloth-qwen3-4b\",\n", " dataset_episodes=32,\n", " rollout_steps=10,\n", " collection_policy=\"heuristic\",\n", " reward_backend=\"local\",\n", " domain_randomise=True,\n", " num_generations=4,\n", " max_completion_length=160,\n", " max_prompt_length=1280,\n", " max_seq_length=2048,\n", " per_device_train_batch_size=4,\n", " gradient_accumulation_steps=4,\n", " learning_rate=5e-6,\n", " num_train_epochs=1.0,\n", " logging_steps=1,\n", " save_steps=25,\n", " trust_remote_code=True,\n", " dry_run=False,\n", " seed=42,\n", ")\n", "\n", "args" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "preview_examples = base.build_prompt_examples(\n", " dataset_episodes=1,\n", " rollout_steps=args.rollout_steps,\n", " collection_policy=args.collection_policy,\n", " scenario_names=[\"cardiac_disease_de\"],\n", " seed=args.seed,\n", " domain_randomise=args.domain_randomise,\n", ")\n", "\n", "print(preview_examples[0][\"prompt\"][:3500])\n", "print(\"\\nReference action:\\n\", preview_examples[0][\"reference_action\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Optional smoke test before a full run.\n", "dry_run_args = make_training_args(**{**vars(args), \"dry_run\": True})\n", "dry_run_result = run_training(dry_run_args)\n", "len(dry_run_result[\"examples\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython.display import Image, display\n", "\n", "train_result = run_training(args)\n", "for name, plot_path in train_result[\"plot_paths\"].items():\n", " print(name, plot_path)\n", " display(Image(filename=plot_path))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 4 }