Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Practical Exercise: GRPO with Unsloth","local":"practical-exercise-grpo-with-unsloth","sections":[{"title":"Install dependencies","local":"install-dependencies","sections":[],"depth":2},{"title":"Setting up Unsloth","local":"setting-up-unsloth","sections":[],"depth":2},{"title":"Data Preparation","local":"data-preparation","sections":[],"depth":2},{"title":"Defining Reward Functions","local":"defining-reward-functions","sections":[],"depth":2},{"title":"Training with GRPO","local":"training-with-grpo","sections":[],"depth":2},{"title":"Testing the Model","local":"testing-the-model","sections":[],"depth":2},{"title":"Saving the Model","local":"saving-the-model","sections":[],"depth":2},{"title":"Pushing to Hugging Face Hub","local":"pushing-to-hugging-face-hub","sections":[],"depth":2},{"title":"Conclusion","local":"conclusion","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/course/pr_1069/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/entry/start.c5306bb2.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/scheduler.37c15a92.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/singletons.bc78d867.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/index.18351ede.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/paths.76894643.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/entry/app.4264f5f8.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/index.7cb9c9b8.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/nodes/0.f5347c47.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/nodes/34.f41cea0d.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/Tip.d10b3fc9.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/CodeBlock.abae2786.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/CourseFloatingBanner.df82c153.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/getInferenceSnippets.f9350a3f.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Practical Exercise: GRPO with Unsloth","local":"practical-exercise-grpo-with-unsloth","sections":[{"title":"Install dependencies","local":"install-dependencies","sections":[],"depth":2},{"title":"Setting up Unsloth","local":"setting-up-unsloth","sections":[],"depth":2},{"title":"Data Preparation","local":"data-preparation","sections":[],"depth":2},{"title":"Defining Reward Functions","local":"defining-reward-functions","sections":[],"depth":2},{"title":"Training with GRPO","local":"training-with-grpo","sections":[],"depth":2},{"title":"Testing the Model","local":"testing-the-model","sections":[],"depth":2},{"title":"Saving the Model","local":"saving-the-model","sections":[],"depth":2},{"title":"Pushing to Hugging Face Hub","local":"pushing-to-hugging-face-hub","sections":[],"depth":2},{"title":"Conclusion","local":"conclusion","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="flex space-x-1 absolute z-10 right-0 top-0"><a href="https://discuss.huggingface.co/t/chapter-2-questions" target="_blank"><img alt="Ask a Question" class="!m-0" src="https://img.shields.io/badge/Ask%20a%20question-ffcb4c.svg?logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgLTEgMTA0IDEwNiI+PGRlZnM+PHN0eWxlPi5jbHMtMXtmaWxsOiMyMzFmMjA7fS5jbHMtMntmaWxsOiNmZmY5YWU7fS5jbHMtM3tmaWxsOiMwMGFlZWY7fS5jbHMtNHtmaWxsOiMwMGE5NGY7fS5jbHMtNXtmaWxsOiNmMTVkMjI7fS5jbHMtNntmaWxsOiNlMzFiMjM7fTwvc3R5bGU+PC9kZWZzPjx0aXRsZT5EaXNjb3Vyc2VfbG9nbzwvdGl0bGU+PGcgaWQ9IkxheWVyXzIiPjxnIGlkPSJMYXllcl8zIj48cGF0aCBjbGFzcz0iY2xzLTEiIGQ9Ik01MS44NywwQzIzLjcxLDAsMCwyMi44MywwLDUxYzAsLjkxLDAsNTIuODEsMCw1Mi44MWw1MS44Ni0uMDVjMjguMTYsMCw1MS0yMy43MSw1MS01MS44N1M4MCwwLDUxLjg3LDBaIi8+PHBhdGggY2xhc3M9ImNscy0yIiBkPSJNNTIuMzcsMTkuNzRBMzEuNjIsMzEuNjIsMCwwLDAsMjQuNTgsNjYuNDFsLTUuNzIsMTguNEwzOS40LDgwLjE3YTMxLjYxLDMxLjYxLDAsMSwwLDEzLTYwLjQzWiIvPjxwYXRoIGNsYXNzPSJjbHMtMyIgZD0iTTc3LjQ1LDMyLjEyYTMxLjYsMzEuNiwwLDAsMS0zOC4wNSw0OEwxOC44Niw4NC44MmwyMC45MS0yLjQ3QTMxLjYsMzEuNiwwLDAsMCw3Ny40NSwzMi4xMloiLz48cGF0aCBjbGFzcz0iY2xzLTQiIGQ9Ik03MS42MywyNi4yOUEzMS42LDMxLjYsMCwwLDEsMzguOCw3OEwxOC44Niw4NC44MiwzOS40LDgwLjE3QTMxLjYsMzEuNiwwLDAsMCw3MS42MywyNi4yOVoiLz48cGF0aCBjbGFzcz0iY2xzLTUiIGQ9Ik0yNi40Nyw2Ny4xMWEzMS42MSwzMS42MSwwLDAsMSw1MS0zNUEzMS42MSwzMS42MSwwLDAsMCwyNC41OCw2Ni40MWwtNS43MiwxOC40WiIvPjxwYXRoIGNsYXNzPSJjbHMtNiIgZD0iTTI0LjU4LDY2LjQxQTMxLjYxLDMxLjYxLDAsMCwxLDcxLjYzLDI2LjI5YTMxLjYxLDMxLjYxLDAsMCwwLTQ5LDM5LjYzbC0zLjc2LDE4LjlaIi8+PC9nPjwvZz48L3N2Zz4="></a> <a href="https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/HuggingFace%20Course-Gemma3_(1B)-GRPO.ipynb" target="_blank"><img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"></a> </div> <h1 class="relative group"><a id="practical-exercise-grpo-with-unsloth" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#practical-exercise-grpo-with-unsloth"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Practical Exercise: GRPO with Unsloth</span></h1> <p data-svelte-h="svelte-3dwsf2">In this exercise, you’ll fine-tune a model with GRPO (Group Relative Policy Optimization) using Unsloth, to improve a model’s reasoning capabilities. We covered GRPO in <a href="/course/chapter3/3">Chapter 3</a>.</p> <p data-svelte-h="svelte-v44mtq">Unsloth is a library that accelerates LLM fine-tuning, making it possible to train models faster and with less computational resources. Unsloth is plugs into TRL, so we’ll build on what we learned in the previous sections, and adapt it for Unsloth specifics.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1tto42k">This exercise can be run on a free Google Colab T4 GPU. For the best experience, follow along with the notebook linked above and try it out yourself.</p></div> <h2 class="relative group"><a id="install-dependencies" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#install-dependencies"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Install dependencies</span></h2> <p data-svelte-h="svelte-7apc0w">First, let’s install the necessary libraries. We’ll need Unsloth for the accelerated fine-tuning and vLLM for fast inference.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install unsloth vllm | |
| pip install --upgrade pillow<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="setting-up-unsloth" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#setting-up-unsloth"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Setting up Unsloth</span></h2> <p data-svelte-h="svelte-uhz6q0">Unsloth provides a class (<code>FastLanguageModel</code>) that integrates transformers with Unsloth optimizations. Let’s import it:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> unsloth <span class="hljs-keyword">import</span> FastLanguageModel<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1o52er5">Now, let’s load Google’s Gemma 3 1B Instruct model and configure it for fine-tuning:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> unsloth <span class="hljs-keyword">import</span> FastLanguageModel | |
| <span class="hljs-keyword">import</span> torch | |
| max_seq_length = <span class="hljs-number">1024</span> <span class="hljs-comment"># Can increase for longer reasoning traces</span> | |
| lora_rank = <span class="hljs-number">32</span> <span class="hljs-comment"># Larger rank = smarter, but slower</span> | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=<span class="hljs-string">"google/gemma-3-1b-it"</span>, | |
| max_seq_length=max_seq_length, | |
| load_in_4bit=<span class="hljs-literal">True</span>, <span class="hljs-comment"># False for LoRA 16bit</span> | |
| fast_inference=<span class="hljs-literal">True</span>, <span class="hljs-comment"># Enable vLLM fast inference</span> | |
| max_lora_rank=lora_rank, | |
| gpu_memory_utilization=<span class="hljs-number">0.6</span>, <span class="hljs-comment"># Reduce if out of memory</span> | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=lora_rank, <span class="hljs-comment"># Choose any number > 0 ! Suggested 8, 16, 32, 64, 128</span> | |
| target_modules=[ | |
| <span class="hljs-string">"q_proj"</span>, | |
| <span class="hljs-string">"k_proj"</span>, | |
| <span class="hljs-string">"v_proj"</span>, | |
| <span class="hljs-string">"o_proj"</span>, | |
| <span class="hljs-string">"gate_proj"</span>, | |
| <span class="hljs-string">"up_proj"</span>, | |
| <span class="hljs-string">"down_proj"</span>, | |
| ], <span class="hljs-comment"># Remove QKVO if out of memory</span> | |
| lora_alpha=lora_rank, | |
| use_gradient_checkpointing=<span class="hljs-string">"unsloth"</span>, <span class="hljs-comment"># Enable long context finetuning</span> | |
| random_state=<span class="hljs-number">3407</span>, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ihivum">This code loads the model in 4-bit quantization to save memory and applies LoRA (Low-Rank Adaptation) for efficient fine-tuning. The <code>target_modules</code> parameter specifies which layers of the model to fine-tune, and <code>use_gradient_checkpointing</code> enables training with longer contexts.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1qfqkww">We won’t cover the details of LoRA in this chapter, but you can learn more in <a href="/course/chapter11/3">Chapter 11</a>.</p></div> <h2 class="relative group"><a id="data-preparation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#data-preparation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Data Preparation</span></h2> <p data-svelte-h="svelte-1jqnhvo">For this exercise, we’ll use the GSM8K dataset, which contains grade school math problems. We’ll format the data to encourage the model to show its reasoning before providing an answer.</p> <p data-svelte-h="svelte-222xwa">First, we will define the format of the prompts and answers:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Define the system prompt that instructs the model to use a specific format</span> | |
| SYSTEM_PROMPT = <span class="hljs-string">""" | |
| Respond in the following format: | |
| <reasoning> | |
| ... | |
| </reasoning> | |
| <answer> | |
| ... | |
| </answer> | |
| """</span> | |
| XML_COT_FORMAT = <span class="hljs-string">"""\ | |
| <reasoning> | |
| {reasoning} | |
| </reasoning> | |
| <answer> | |
| {answer} | |
| </answer> | |
| """</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-uuip11">Now, let’s prepare the dataset:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> re | |
| <span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset, Dataset | |
| <span class="hljs-comment"># Helper functions to extract answers from different formats</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">extract_xml_answer</span>(<span class="hljs-params">text: <span class="hljs-built_in">str</span></span>) -> <span class="hljs-built_in">str</span>: | |
| answer = text.split(<span class="hljs-string">"<answer>"</span>)[-<span class="hljs-number">1</span>] | |
| answer = answer.split(<span class="hljs-string">"</answer>"</span>)[<span class="hljs-number">0</span>] | |
| <span class="hljs-keyword">return</span> answer.strip() | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">extract_hash_answer</span>(<span class="hljs-params">text: <span class="hljs-built_in">str</span></span>) -> <span class="hljs-built_in">str</span> | <span class="hljs-literal">None</span>: | |
| <span class="hljs-keyword">if</span> <span class="hljs-string">"####"</span> <span class="hljs-keyword">not</span> <span class="hljs-keyword">in</span> text: | |
| <span class="hljs-keyword">return</span> <span class="hljs-literal">None</span> | |
| <span class="hljs-keyword">return</span> text.split(<span class="hljs-string">"####"</span>)[<span class="hljs-number">1</span>].strip() | |
| <span class="hljs-comment"># Function to prepare the GSM8K dataset</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">get_gsm8k_questions</span>(<span class="hljs-params">split=<span class="hljs-string">"train"</span></span>) -> Dataset: | |
| data = load_dataset(<span class="hljs-string">"openai/gsm8k"</span>, <span class="hljs-string">"main"</span>)[split] | |
| data = data.<span class="hljs-built_in">map</span>( | |
| <span class="hljs-keyword">lambda</span> x: { | |
| <span class="hljs-string">"prompt"</span>: [ | |
| {<span class="hljs-string">"role"</span>: <span class="hljs-string">"system"</span>, <span class="hljs-string">"content"</span>: SYSTEM_PROMPT}, | |
| {<span class="hljs-string">"role"</span>: <span class="hljs-string">"user"</span>, <span class="hljs-string">"content"</span>: x[<span class="hljs-string">"question"</span>]}, | |
| ], | |
| <span class="hljs-string">"answer"</span>: extract_hash_answer(x[<span class="hljs-string">"answer"</span>]), | |
| } | |
| ) | |
| <span class="hljs-keyword">return</span> data | |
| dataset = get_gsm8k_questions()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-yuvu2n">The dataset is prepared by extracting the answer from the dataset and formatting it as a string.</p> <h2 class="relative group"><a id="defining-reward-functions" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#defining-reward-functions"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Defining Reward Functions</span></h2> <p data-svelte-h="svelte-tq7f7j">As we discussed in <a href="/course/chapter13/4">an earlier page</a>, GRPO can use reward functions to guide the model’s learning based on verifiable criteria like length and formatting.</p> <p data-svelte-h="svelte-1bbmyq2">In this exercise, we’ll define several reward functions that encourage different aspects of good reasoning. For example, we’ll reward the model for providing an integer answer, and for following the strict format.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Reward function that checks if the answer is correct</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">correctness_reward_func</span>(<span class="hljs-params">prompts, completions, answer, **kwargs</span>) -> <span class="hljs-built_in">list</span>[<span class="hljs-built_in">float</span>]: | |
| responses = [completion[<span class="hljs-number">0</span>][<span class="hljs-string">"content"</span>] <span class="hljs-keyword">for</span> completion <span class="hljs-keyword">in</span> completions] | |
| q = prompts[<span class="hljs-number">0</span>][-<span class="hljs-number">1</span>][<span class="hljs-string">"content"</span>] | |
| extracted_responses = [extract_xml_answer(r) <span class="hljs-keyword">for</span> r <span class="hljs-keyword">in</span> responses] | |
| <span class="hljs-built_in">print</span>( | |
| <span class="hljs-string">"-"</span> * <span class="hljs-number">20</span>, | |
| <span class="hljs-string">f"Question:\n<span class="hljs-subst">{q}</span>"</span>, | |
| <span class="hljs-string">f"\nAnswer:\n<span class="hljs-subst">{answer[<span class="hljs-number">0</span>]}</span>"</span>, | |
| <span class="hljs-string">f"\nResponse:\n<span class="hljs-subst">{responses[<span class="hljs-number">0</span>]}</span>"</span>, | |
| <span class="hljs-string">f"\nExtracted:\n<span class="hljs-subst">{extracted_responses[<span class="hljs-number">0</span>]}</span>"</span>, | |
| ) | |
| <span class="hljs-keyword">return</span> [<span class="hljs-number">2.0</span> <span class="hljs-keyword">if</span> r == a <span class="hljs-keyword">else</span> <span class="hljs-number">0.0</span> <span class="hljs-keyword">for</span> r, a <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(extracted_responses, answer)] | |
| <span class="hljs-comment"># Reward function that checks if the answer is an integer</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">int_reward_func</span>(<span class="hljs-params">completions, **kwargs</span>) -> <span class="hljs-built_in">list</span>[<span class="hljs-built_in">float</span>]: | |
| responses = [completion[<span class="hljs-number">0</span>][<span class="hljs-string">"content"</span>] <span class="hljs-keyword">for</span> completion <span class="hljs-keyword">in</span> completions] | |
| extracted_responses = [extract_xml_answer(r) <span class="hljs-keyword">for</span> r <span class="hljs-keyword">in</span> responses] | |
| <span class="hljs-keyword">return</span> [<span class="hljs-number">0.5</span> <span class="hljs-keyword">if</span> r.isdigit() <span class="hljs-keyword">else</span> <span class="hljs-number">0.0</span> <span class="hljs-keyword">for</span> r <span class="hljs-keyword">in</span> extracted_responses] | |
| <span class="hljs-comment"># Reward function that checks if the completion follows the strict format</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">strict_format_reward_func</span>(<span class="hljs-params">completions, **kwargs</span>) -> <span class="hljs-built_in">list</span>[<span class="hljs-built_in">float</span>]: | |
| pattern = <span class="hljs-string">r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"</span> | |
| responses = [completion[<span class="hljs-number">0</span>][<span class="hljs-string">"content"</span>] <span class="hljs-keyword">for</span> completion <span class="hljs-keyword">in</span> completions] | |
| matches = [re.<span class="hljs-keyword">match</span>(pattern, r) <span class="hljs-keyword">for</span> r <span class="hljs-keyword">in</span> responses] | |
| <span class="hljs-keyword">return</span> [<span class="hljs-number">0.5</span> <span class="hljs-keyword">if</span> <span class="hljs-keyword">match</span> <span class="hljs-keyword">else</span> <span class="hljs-number">0.0</span> <span class="hljs-keyword">for</span> <span class="hljs-keyword">match</span> <span class="hljs-keyword">in</span> matches] | |
| <span class="hljs-comment"># Reward function that checks if the completion follows a more relaxed format</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">soft_format_reward_func</span>(<span class="hljs-params">completions, **kwargs</span>) -> <span class="hljs-built_in">list</span>[<span class="hljs-built_in">float</span>]: | |
| pattern = <span class="hljs-string">r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"</span> | |
| responses = [completion[<span class="hljs-number">0</span>][<span class="hljs-string">"content"</span>] <span class="hljs-keyword">for</span> completion <span class="hljs-keyword">in</span> completions] | |
| matches = [re.<span class="hljs-keyword">match</span>(pattern, r) <span class="hljs-keyword">for</span> r <span class="hljs-keyword">in</span> responses] | |
| <span class="hljs-keyword">return</span> [<span class="hljs-number">0.5</span> <span class="hljs-keyword">if</span> <span class="hljs-keyword">match</span> <span class="hljs-keyword">else</span> <span class="hljs-number">0.0</span> <span class="hljs-keyword">for</span> <span class="hljs-keyword">match</span> <span class="hljs-keyword">in</span> matches] | |
| <span class="hljs-comment"># Reward function that counts XML tags and penalizes extra content</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">count_xml</span>(<span class="hljs-params">text</span>) -> <span class="hljs-built_in">float</span>: | |
| count = <span class="hljs-number">0.0</span> | |
| <span class="hljs-keyword">if</span> text.count(<span class="hljs-string">"<reasoning>\n"</span>) == <span class="hljs-number">1</span>: | |
| count += <span class="hljs-number">0.125</span> | |
| <span class="hljs-keyword">if</span> text.count(<span class="hljs-string">"\n</reasoning>\n"</span>) == <span class="hljs-number">1</span>: | |
| count += <span class="hljs-number">0.125</span> | |
| <span class="hljs-keyword">if</span> text.count(<span class="hljs-string">"\n<answer>\n"</span>) == <span class="hljs-number">1</span>: | |
| count += <span class="hljs-number">0.125</span> | |
| count -= <span class="hljs-built_in">len</span>(text.split(<span class="hljs-string">"\n</answer>\n"</span>)[-<span class="hljs-number">1</span>]) * <span class="hljs-number">0.001</span> | |
| <span class="hljs-keyword">if</span> text.count(<span class="hljs-string">"\n</answer>"</span>) == <span class="hljs-number">1</span>: | |
| count += <span class="hljs-number">0.125</span> | |
| count -= (<span class="hljs-built_in">len</span>(text.split(<span class="hljs-string">"\n</answer>"</span>)[-<span class="hljs-number">1</span>]) - <span class="hljs-number">1</span>) * <span class="hljs-number">0.001</span> | |
| <span class="hljs-keyword">return</span> count | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">xmlcount_reward_func</span>(<span class="hljs-params">completions, **kwargs</span>) -> <span class="hljs-built_in">list</span>[<span class="hljs-built_in">float</span>]: | |
| contents = [completion[<span class="hljs-number">0</span>][<span class="hljs-string">"content"</span>] <span class="hljs-keyword">for</span> completion <span class="hljs-keyword">in</span> completions] | |
| <span class="hljs-keyword">return</span> [count_xml(c) <span class="hljs-keyword">for</span> c <span class="hljs-keyword">in</span> contents]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-10lwv8u">These reward functions serve different purposes:</p> <table data-svelte-h="svelte-beg5py"><thead><tr><th>Reward Function</th> <th>Purpose</th></tr></thead> <tbody><tr><td><code>correctness_reward_func</code></td> <td>Rewards the model when its answer matches the correct answer</td></tr> <tr><td><code>int_reward_func</code></td> <td>Rewards the model for providing a numeric answer</td></tr> <tr><td><code>strict_format_reward_func</code> and <code>soft_format_reward_func</code></td> <td>Reward the model for following the specified format</td></tr> <tr><td><code>xmlcount_reward_func</code></td> <td>Rewards proper XML tag usage and penalizes extra content after the closing tags</td></tr></tbody></table> <h2 class="relative group"><a id="training-with-grpo" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#training-with-grpo"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Training with GRPO</span></h2> <p data-svelte-h="svelte-9dkvch">Now we’ll set up the GRPO trainer with our model, tokenizer, and reward functions. This part follows the same approach as the <a href="/course/chapter12/5">previous exercise</a>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> GRPOConfig, GRPOTrainer | |
| max_prompt_length = <span class="hljs-number">256</span> | |
| training_args = GRPOConfig( | |
| learning_rate=<span class="hljs-number">5e-6</span>, | |
| adam_beta1=<span class="hljs-number">0.9</span>, | |
| adam_beta2=<span class="hljs-number">0.99</span>, | |
| weight_decay=<span class="hljs-number">0.1</span>, | |
| warmup_ratio=<span class="hljs-number">0.1</span>, | |
| lr_scheduler_type=<span class="hljs-string">"cosine"</span>, | |
| optim=<span class="hljs-string">"paged_adamw_8bit"</span>, | |
| logging_steps=<span class="hljs-number">1</span>, | |
| per_device_train_batch_size=<span class="hljs-number">1</span>, | |
| gradient_accumulation_steps=<span class="hljs-number">1</span>, <span class="hljs-comment"># Increase to 4 for smoother training</span> | |
| num_generations=<span class="hljs-number">6</span>, <span class="hljs-comment"># Decrease if out of memory</span> | |
| max_prompt_length=max_prompt_length, | |
| max_completion_length=max_seq_length - max_prompt_length, | |
| <span class="hljs-comment"># num_train_epochs = 1, # Set to 1 for a full training run</span> | |
| max_steps=<span class="hljs-number">250</span>, | |
| save_steps=<span class="hljs-number">250</span>, | |
| max_grad_norm=<span class="hljs-number">0.1</span>, | |
| report_to=<span class="hljs-string">"none"</span>, <span class="hljs-comment"># Can use Weights & Biases</span> | |
| output_dir=<span class="hljs-string">"outputs"</span>, | |
| ) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| reward_funcs=[ | |
| xmlcount_reward_func, | |
| soft_format_reward_func, | |
| strict_format_reward_func, | |
| int_reward_func, | |
| correctness_reward_func, | |
| ], | |
| args=training_args, | |
| train_dataset=dataset, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1jo17wb">The <code>GRPOConfig</code> sets various hyperparameters for training:</p> <ul data-svelte-h="svelte-eo3csz"><li><code>use_vllm</code>: Enables fast inference with vLLM</li> <li><code>learning_rate</code>: Controls how quickly the model learns</li> <li><code>num_generations</code>: Number of completions to generate for each prompt</li> <li><code>max_steps</code>: Total number of training steps to perform</li></ul> <p data-svelte-h="svelte-c65a8i">Now let’s start the training:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.train()<!-- HTML_TAG_END --></pre></div> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-1q4mcgt">Training may take some time. You might not see rewards increase immediately - it can take 150-200 steps before you start seeing improvements. Be patient!</p></div> <h2 class="relative group"><a id="testing-the-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#testing-the-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Testing the Model</span></h2> <p data-svelte-h="svelte-dwjtn">After training, let’s test our model to see how it performs. First, we’ll save the LoRA weights:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model.save_lora(<span class="hljs-string">"grpo_saved_lora"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-6mnmyk">Now, let’s test the model with a new question:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> vllm <span class="hljs-keyword">import</span> SamplingParams | |
| text = tokenizer.apply_chat_template( | |
| [ | |
| {<span class="hljs-string">"role"</span>: <span class="hljs-string">"system"</span>, <span class="hljs-string">"content"</span>: SYSTEM_PROMPT}, | |
| {<span class="hljs-string">"role"</span>: <span class="hljs-string">"user"</span>, <span class="hljs-string">"content"</span>: <span class="hljs-string">"Calculate pi."</span>}, | |
| ], | |
| tokenize=<span class="hljs-literal">False</span>, | |
| add_generation_prompt=<span class="hljs-literal">True</span>, | |
| ) | |
| sampling_params = SamplingParams( | |
| temperature=<span class="hljs-number">0.8</span>, | |
| top_p=<span class="hljs-number">0.95</span>, | |
| max_tokens=<span class="hljs-number">1024</span>, | |
| ) | |
| output = ( | |
| model.fast_generate( | |
| text, | |
| sampling_params=sampling_params, | |
| lora_request=model.load_lora(<span class="hljs-string">"grpo_saved_lora"</span>), | |
| )[<span class="hljs-number">0</span>] | |
| .outputs[<span class="hljs-number">0</span>] | |
| .text | |
| ) | |
| <span class="hljs-built_in">print</span>(output)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-n6ify3">You should see that the model now follows the specified format, showing its reasoning before providing an answer.</p> <h2 class="relative group"><a id="saving-the-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#saving-the-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Saving the Model</span></h2> <p data-svelte-h="svelte-8mgtgx">Unsloth provides several options for saving your fine-tuned model, but we’ll focus on the most common.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Save to 16-bit precision</span> | |
| model.save_pretrained_merged(<span class="hljs-string">"model"</span>, tokenizer, save_method=<span class="hljs-string">"merged_16bit"</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="pushing-to-hugging-face-hub" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#pushing-to-hugging-face-hub"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Pushing to Hugging Face Hub</span></h2> <p data-svelte-h="svelte-5e9stk">We’ll push the model to the Hugging Face Hub using the <code>push_to_hub_merged</code> method. This method allows us to push the model in multiple quantization formats.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Push to Hugging Face Hub (requires a token)</span> | |
| model.push_to_hub_merged( | |
| <span class="hljs-string">"your-username/model-name"</span>, tokenizer, save_method=<span class="hljs-string">"merged_16bit"</span>, token=<span class="hljs-string">"your-token"</span> | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-hwdwrb">Unsloth also supports saving to GGUF format for use with llama.cpp:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model.push_to_hub_gguf( | |
| <span class="hljs-string">"your-username/model-name"</span>, | |
| tokenizer, | |
| quantization_method=[<span class="hljs-string">"q4_k_m"</span>, <span class="hljs-string">"q8_0"</span>, <span class="hljs-string">"q5_k_m"</span>], | |
| token=<span class="hljs-string">"your-token"</span>, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-102ov0f">The GGUF files can be used with llama.cpp or UI-based systems like Jan or Open WebUI.</p> <h2 class="relative group"><a id="conclusion" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#conclusion"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Conclusion</span></h2> <p data-svelte-h="svelte-onfwxy">In this exercise, you’ve learned how to:</p> <ol data-svelte-h="svelte-j3tqlf"><li>Set up Unsloth for accelerated fine-tuning</li> <li>Prepare data for GRPO training</li> <li>Define custom reward functions to guide the model’s learning</li> <li>Train a model using GRPO</li> <li>Test the fine-tuned model</li> <li>Save the model in various formats</li></ol> <p data-svelte-h="svelte-ntlp2y">GRPO is a powerful technique for aligning language models with specific behaviors, and Unsloth makes it accessible even on limited hardware. By combining multiple reward functions, you can guide the model to follow a specific format while also improving its reasoning capabilities.</p> <p data-svelte-h="svelte-1r9yves">For more information and resources, check out:</p> <ul data-svelte-h="svelte-1fz95oy"><li><a href="https://docs.unsloth.ai/" rel="nofollow">Unsloth Documentation</a></li> <li><a href="https://discord.gg/unsloth" rel="nofollow">Unsloth Discord</a></li> <li><a href="https://github.com/unslothai/unsloth" rel="nofollow">Unsloth GitHub</a></li></ul> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/course/blob/main/chapters/en/chapter12/6.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_1y0degu = { | |
| assets: "/docs/course/pr_1069/en", | |
| base: "/docs/course/pr_1069/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/course/pr_1069/en/_app/immutable/entry/start.c5306bb2.js"), | |
| import("/docs/course/pr_1069/en/_app/immutable/entry/app.4264f5f8.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 34], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 58.6 kB
- Xet hash:
- f6182960bd5445c4d631b00c72ae656eb70282b60301604c17359dd0272e3e37
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.