{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "ii5Zkit6eSqU" }, "source": [ "# Teaching Tool Calling with Supervised Fine-Tuning (SFT) using TRL on a Free Colab Notebook\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_tool_calling.ipynb)" ] }, { "cell_type": "markdown", "metadata": { "id": "gJVcVKOteSqV" }, "source": [ "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)" ] }, { "cell_type": "markdown", "metadata": { "id": "hzt0BrvoeSqW" }, "source": [ "Learn how to teach a language model to perform **tool calling** using **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** and the [**TRL**](https://github.com/huggingface/trl) library.\n", "\n", "The model used in this notebook does not have native tool-calling support. We extend its Jinja2 chat template (via `tiny_aya_chat_template.jinja`) to serialize tool schemas into the system preamble and render tool calls as structured `` XML inside the model's native `<|START_RESPONSE|>` / `<|END_RESPONSE|>` delimiters. The modified template is saved with the tokenizer, making inference reproducible: just load the tokenizer from the output directory and call `apply_chat_template` with `tools=TOOLS`.\n", "\n", "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project!\n", "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview)\n", "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)" ] }, { "cell_type": "markdown", "metadata": { "id": "3PfX1aj5eSqW" }, "source": [ "## Key concepts\n", "\n", "- **SFT**: Trains a model on example input-output pairs to align its behavior with a desired task.\n", "- **Tool Calling**: The ability of a model to respond with a structured function call instead of free-form text.\n", "- **LoRA**: Updates only a small set of low-rank parameters, reducing training cost and memory usage.\n", "- **QLoRA**: A quantized variant of LoRA that enables fine-tuning larger models on limited hardware.\n", "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient." ] }, { "cell_type": "markdown", "metadata": { "id": "QDMcKeoEeSqW" }, "source": [ "## Install dependencies\n", "\n", "We'll install **TRL** with the **PEFT** extra, which brings in all main dependencies such as **Transformers** and **PEFT** (parameter-efficient fine-tuning). We also install **trackio** for experiment logging, and **bitsandbytes** for 4-bit quantization," ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ey-TuYPrXTLG", "outputId": "a4fd8cfe-624e-4185-ab59-e6901514cb96" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m17.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.7/60.7 MB\u001b[0m \u001b[31m42.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.2/24.2 MB\u001b[0m \u001b[31m109.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.0/56.0 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.9/9.9 MB\u001b[0m \u001b[31m131.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m540.5/540.5 kB\u001b[0m \u001b[31m44.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h" ] } ], "source": [ "!pip install -Uq \"trl[peft]\" trackio bitsandbytes liger-kernel" ] }, { "cell_type": "markdown", "metadata": { "id": "Aw8_T-Z0eSqW" }, "source": [ "### Log in to Hugging Face\n", "\n", "Log in to your Hugging Face account to push the fine-tuned model to the Hub and access gated models. You can find your access token on your [account settings page](https://huggingface.co/settings/tokens)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_qaeDZwXXTLG" }, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "markdown", "metadata": { "id": "XPnDpJgIeSqX" }, "source": [ "## Load Dataset\n", "\n", "We load the [**bebechien/SimpleToolCalling**](https://huggingface.co/datasets/bebechien/SimpleToolCalling) dataset, which contains user queries paired with the correct tool call to handle each request. Each sample provides a `user_content`, a `tool_name`, and `tool_arguments`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zfJY_8AzXTLG" }, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "dataset_name = \"bebechien/SimpleToolCalling\"\n", "dataset = load_dataset(dataset_name, split=\"train\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ygeMXzKGXTLH", "outputId": "a1ed3a8b-f515-4cda-eeb2-db0355ed2c02" }, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['user_content', 'tool_name', 'tool_arguments'],\n", " num_rows: 40\n", "})" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "O_GkvqtReSqX" }, "source": [ "## Prepare Tool-Calling Data\n", "\n", "We define two tools: `search_knowledge_base` for internal company documents and `search_google` for public information. We then write a custom Jinja2 chat template that extends the model's default template with two additions:\n", "\n", "1. A **Tool Use** section is appended to the system preamble when `tools` is passed to `apply_chat_template`.\n", "2. Assistant turns with `tool_calls` render the call as structured `` inside the model's existing `<|START_RESPONSE|>` / `<|END_RESPONSE|>` delimiters.\n", "\n", "Each training sample uses the standard `tool_calls` message format with a `tools` key — SFTTrainer passes these to `apply_chat_template` automatically." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jaAgXeWtXTLH" }, "outputs": [], "source": [ "import json\n", "\n", "# These are the tool schemas that are used in the dataset\n", "TOOLS = [\n", " {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"search_knowledge_base\",\n", " \"description\": \"Search internal company documents, policies and project data.\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\"query\": {\"type\": \"string\", \"description\": \"query string\"}},\n", " \"required\": [\"query\"],\n", " },\n", " \"return\": {\"type\": \"string\"},\n", " },\n", " },\n", " {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"search_google\",\n", " \"description\": \"Search public information.\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\"query\": {\"type\": \"string\", \"description\": \"query string\"}},\n", " \"required\": [\"query\"],\n", " },\n", " \"return\": {\"type\": \"string\"},\n", " },\n", " },\n", "]\n", "\n", "def create_conversation(sample):\n", " return {\n", " \"prompt\": [{\"role\": \"user\", \"content\": sample[\"user_content\"]}],\n", " \"completion\": [\n", " {\n", " \"role\": \"assistant\",\n", " \"tool_calls\": [\n", " {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": sample[\"tool_name\"],\n", " \"arguments\": json.loads(sample[\"tool_arguments\"]),\n", " },\n", " }\n", " ],\n", " },\n", " ],\n", " \"tools\": TOOLS,\n", " }" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "32p512R2XTLH" }, "outputs": [], "source": [ "dataset = dataset.map(create_conversation, remove_columns=dataset.features)\n", "\n", "# Split dataset into 50% training samples and 50% test samples\n", "dataset = dataset.train_test_split(test_size=0.5, shuffle=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "Plnjef-PeSqX" }, "source": [ "Let's inspect an example from the training set to verify the format:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f4QI6wJjXTLH", "outputId": "2156adb4-7bed-4e29-84c5-54e6d45e5500" }, "outputs": [ { "data": { "text/plain": [ "{'messages': [{'content': 'How do I configure the VPN for the New York office?',\n", " 'role': 'user',\n", " 'tool_calls': None},\n", " {'content': None,\n", " 'role': 'assistant',\n", " 'tool_calls': [{'function': {'arguments': {'query': 'VPN configuration guide New York office'},\n", " 'name': 'search_knowledge_base'},\n", " 'type': 'function'}]}],\n", " 'tools': [{'function': {'description': 'Search internal company documents, policies and project data.',\n", " 'name': 'search_knowledge_base',\n", " 'parameters': {'properties': {'query': {'description': 'query string',\n", " 'type': 'string'}},\n", " 'required': ['query'],\n", " 'type': 'object'},\n", " 'return': {'type': 'string'}},\n", " 'type': 'function'},\n", " {'function': {'description': 'Search public information.',\n", " 'name': 'search_google',\n", " 'parameters': {'properties': {'query': {'description': 'query string',\n", " 'type': 'string'}},\n", " 'required': ['query'],\n", " 'type': 'object'},\n", " 'return': {'type': 'string'}},\n", " 'type': 'function'}]}" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset['train'][0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fBIGKl_UXTLH", "outputId": "edd8e968-c7e4-418d-b9e9-26773aee1366" }, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['messages', 'tools'],\n", " num_rows: 20\n", " })\n", " test: Dataset({\n", " features: ['messages', 'tools'],\n", " num_rows: 20\n", " })\n", "})" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "aud6U3c2eSqX" }, "source": [ "## Load Model and Configure LoRA/QLoRA\n", "\n", "Choose the model you want to fine-tune. This notebook uses [`CohereLabs/tiny-aya-global`](https://huggingface.co/CohereLabs/tiny-aya-global) by default." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_j_LF12IXTLH" }, "outputs": [], "source": [ "model_id, output_dir = \"CohereLabs/tiny-aya-global\", \"tiny-aya-global-SFT\" # ✅ ~9.1 GB VRAM" ] }, { "cell_type": "markdown", "metadata": { "id": "gpTZHjpJeSqX" }, "source": [ "Load the model with 4-bit quantization using `BitsAndBytesConfig` (QLoRA). To use standard LoRA without quantization, comment out the `quantization_config` parameter. We also load the tokenizer separately so we can install the custom chat template before training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "referenced_widgets": [ "680888237b78477ea653adb2ecea7fa8" ] }, "id": "jGpTDV6sXTLH", "outputId": "fc33f7a6-bfd0-4228-80cd-e0aeb67bbd42" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "680888237b78477ea653adb2ecea7fa8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading weights: 0%| | 0/290 [00:00" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "* GPU detected, enabling automatic GPU metrics logging\n", "* Created new run: sergiopaniego-1771428231\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [15/15 00:52, Epoch 3/3]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
13.095131
23.083373
32.951535
42.625918
52.254464
61.939976
71.694891
81.558982
91.430660
101.305176
111.192725
121.120383
131.052859
140.985858
150.970833

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "* Run finished. Uploading logs to Trackio (please wait...)\n" ] } ], "source": [ "trainer_stats = trainer.train()" ] }, { "cell_type": "markdown", "metadata": { "id": "4MGKFi1-eSqY" }, "source": [ "Show memory stats after training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3f68GA6TXTLI", "outputId": "321e90ee-757a-41fc-c6a2-4ba40a6e6b3c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "59.2841 seconds used for training.\n", "0.99 minutes used for training.\n", "Peak reserved memory = 11.928 GB.\n", "Peak reserved memory for training = 7.28 GB.\n", "Peak reserved memory % of max memory = 30.202 %.\n", "Peak reserved memory for training % of max memory = 18.433 %.\n" ] } ], "source": [ "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", "used_percentage = round(used_memory / max_memory * 100, 3)\n", "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", "\n", "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", "print(f\"Peak reserved memory = {used_memory} GB.\")\n", "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" ] }, { "cell_type": "markdown", "metadata": { "id": "ONWy4NOAeSqY" }, "source": [ "## Save the Fine-Tuned Model\n", "\n", "Save the trained LoRA adapter locally and push it to the Hugging Face Hub." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "referenced_widgets": [ "4951424bb90e4dbbaea8c9b88c592872", "1669afd0e52443d090adab0fbe663c66", "5ee5f6b74e7246eea99c0d84c2a27bc0", "15b4a13592c14102af0d3f8a999f3d36", "7216a7d56c364a0d92e079d9848946d3", "83e8f73e00004b718cbba7be0ecc45e1", "05db21352f614288864f88c1ba794ee9", "d7821b8cd21f4fb78237479ff081511b", "00c9f462a4584b22b1e38dfcc5f86af3", "22cacea841ba48c29b7a74ea17a50b4e" ] }, "id": "9qz-fRZyXTLI", "outputId": "9ff41250-0786-4ec6-fe41-dfc6b611d0b5" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4951424bb90e4dbbaea8c9b88c592872", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Processing Files (0 / 0) : | | 0.00B / 0.00B " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1669afd0e52443d090adab0fbe663c66", "version_major": 2, "version_minor": 0 }, "text/plain": [ "New Data Upload : | | 0.00B / 0.00B " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5ee5f6b74e7246eea99c0d84c2a27bc0", "version_major": 2, "version_minor": 0 }, "text/plain": [ " ...bal-SFT/training_args.bin: 100%|##########| 5.58kB / 5.58kB " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "15b4a13592c14102af0d3f8a999f3d36", "version_major": 2, "version_minor": 0 }, "text/plain": [ " ...global-SFT/tokenizer.json: 100%|##########| 21.4MB / 21.4MB " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7216a7d56c364a0d92e079d9848946d3", "version_major": 2, "version_minor": 0 }, "text/plain": [ " ...adapter_model.safetensors: 35%|###4 | 41.9MB / 121MB " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "83e8f73e00004b718cbba7be0ecc45e1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Processing Files (0 / 0) : | | 0.00B / 0.00B " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "05db21352f614288864f88c1ba794ee9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "New Data Upload : | | 0.00B / 0.00B " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d7821b8cd21f4fb78237479ff081511b", "version_major": 2, "version_minor": 0 }, "text/plain": [ " ...bal-SFT/training_args.bin: 100%|##########| 5.58kB / 5.58kB " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "00c9f462a4584b22b1e38dfcc5f86af3", "version_major": 2, "version_minor": 0 }, "text/plain": [ " ...adapter_model.safetensors: 35%|###4 | 41.9MB / 121MB " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "22cacea841ba48c29b7a74ea17a50b4e", "version_major": 2, "version_minor": 0 }, "text/plain": [ " ...global-SFT/tokenizer.json: 100%|##########| 21.4MB / 21.4MB " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "CommitInfo(commit_url='https://huggingface.co/sergiopaniego/tiny-aya-global-SFT/commit/c59baa62c6bb5a3c3be2d33b482522a00783a5b4', commit_message='End of training', commit_description='', oid='c59baa62c6bb5a3c3be2d33b482522a00783a5b4', pr_url=None, repo_url=RepoUrl('https://huggingface.co/sergiopaniego/tiny-aya-global-SFT', endpoint='https://huggingface.co', repo_type='model', repo_id='sergiopaniego/tiny-aya-global-SFT'), pr_revision=None, pr_num=None)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.save_model(output_dir)\n", "trainer.push_to_hub(dataset_name=dataset_name)" ] }, { "cell_type": "markdown", "metadata": { "id": "wNA4AIE4SiUg" }, "source": [ "## Load the Fine-Tuned Model and Run Inference\n", "\n", "Load the trained LoRA adapter on top of the base model and merge it into the weights for efficient inference." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "referenced_widgets": [ "9d6a109e605d440ab2c115d969796859" ] }, "id": "b5CmxYtpXTLI", "outputId": "10ebe012-9ffe-4096-f155-648af855aa80" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9d6a109e605d440ab2c115d969796859", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading weights: 0%| | 0/290 [00:00\n", "\n", "node.js latest version\n", "\n", "\n", "\n" ] } ], "source": [ "sample_test_data = dataset[\"test\"][0] # Get a sample from the test set\n", "\n", "user_content = sample_test_data[\"prompt\"]\n", "\n", "print(f\"User Query: {user_content}\")\n", "\n", "predicted_output = generate_prediction(user_content)\n", "print(f\"Predicted Output: {predicted_output}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "-r85c-aa7C7k" }, "source": [ "You can still use the strong multilingual model capabilities:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UGePqQGVXTLI", "outputId": "adcd21ca-ca45-43d5-a3cc-02a47377e51b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "User Query: [{'role': 'user', 'content': \"Explica en español qué significa la palabra japonesa 'ikigai' y da un ejemplo práctico.\"}]\n", "Predicted Output: \n", "\n", "ikigai significado y ejemplo\n", "\n", "\n", "\n" ] } ], "source": [ "user_content = \"Explica en español qué significa la palabra japonesa 'ikigai' y da un ejemplo práctico.\" # Spanish question\n", "user_content = [{\"role\": \"user\", \"content\": user_content}]\n", "\n", "print(f\"User Query: {user_content}\")\n", "\n", "predicted_output = generate_prediction(user_content)\n", "print(f\"Predicted Output: {predicted_output}\")" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }