{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "XIyP_0r6zuVc" }, "source": [ "# Training Large Language Models in 2bit with `aqlm`, `transformers` and `PEFT`\n", "\n", "\n", " \"Open\n", "\n", "\n", "Welcome to this notebook that goes through the recent `aqlm` integration that introduces minimal performance degradation 2bit quantization techniques.\n", "\n", "In this notebook, we will learn how to load a large model in 2bit (`Mixtral-8x7b`) and train it using Google Colab and PEFT library from Hugging Face 🤗.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "A_VgSpl4Dsr3" }, "source": [ "**Install the `aqlm` library**\n", "- It's the only extra dependency to run AQLM models.\n", "- Add `[gpu]` to install the required CUDA specific dependencies.\n", "- Install the latest `accelerate` and `transformers` releases to properly support it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FuXIFTFapAMI" }, "outputs": [], "source": [ "%%capture\n", "!pip install aqlm[gpu]>=1.1.0\n", "!pip install git+https://github.com/huggingface/peft.git@main\n", "!pip install accelerate>=0.27.0\n", "!pip install git+https://github.com/huggingface/transformers.git@main\n", "!pip install datasets\n", "!pip install bitsandbytes\n", "# for 8-bit optimizer only" ] }, { "cell_type": "markdown", "metadata": { "id": "MJ-5idQwzvg-" }, "source": [ "First let's load the model we are going to use - `Mixtral-8x7b`! Note that the model itself is around 50GB in half precision" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "E0Nl5mWL0k2T" }, "outputs": [], "source": [ "import torch\n", "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n", "\n", "model_id = \"ISTA-DASLab/Meta-Llama-3-8B-Instruct-AQLM-2Bit-1x16\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "model = AutoModelForCausalLM.from_pretrained(model_id, device_map=\"auto\", torch_dtype=\"bfloat16\", low_cpu_mem_usage=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "Mp2gMi1ZzGET" }, "source": [ "**Add LoRA**\n", "\n", "To alter model's behavior, we have to make it trainable. We can do that by addind a small set of trainable parameters on top of the untrainable quantized ones." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Ybeyl20n3dYH", "outputId": "0efda156-4886-4718-9877-e93a17dc02d2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trainable params: 41,943,040 || all params: 2,084,114,432 || trainable%: 2.0125\n" ] } ], "source": [ "from peft import LoraConfig, get_peft_model\n", "\n", "config = LoraConfig(\n", " r=16,\n", " lora_alpha=32,\n", " target_modules=['q_proj','k_proj','v_proj','o_proj','gate_proj','down_proj','up_proj', ],\n", " lora_dropout=0.05,\n", " bias=\"none\",\n", " task_type=\"CAUSAL_LM\"\n", ")\n", "\n", "model = get_peft_model(model, config)\n", "model.print_trainable_parameters()\n", "model.enable_input_require_grads() # it's needed for gradient checkpointing" ] }, { "cell_type": "markdown", "metadata": { "id": "4xSPH1D_Wv9x" }, "source": [ "Here we add a trainable adapter ontop of every `q_prok`, `k_proj` and `o_proj` linear layer." ] }, { "cell_type": "markdown", "metadata": { "id": "FCc64bfnmd3j" }, "source": [ "**Loading a dataset**\n", "\n", "Let's load a common dataset, english quotes, to fine tune our model on famous quotes." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "s6f4z8EYmcJ6" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9ef07f1bc62e4887817a81d4a3e15da1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Resolving data files: 0%| | 0/114 [00:00 \n", " \n", " \n", " [ 1100/10000 8:32:29 < 69:14:04, 0.04 it/s, Epoch 0.35/4]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
15.558500
254.310400
501.984600
751.548100
1001.286000
1251.133400
1501.040200
1750.977800
2000.913900
2250.909900
2500.854600
2750.851700
3000.832200
3250.810900
3500.816500
3750.796300
4000.810300
4250.767200
4500.767100
4750.772500
5000.788000
5250.741900
5500.757600
5750.732800
6000.741600
6250.749000
6500.723700
6750.735200
7000.731500
7250.711800
7500.702200
7750.714100
8000.705400
8250.711800
8500.687200
8750.708400
9000.690700
9250.697200
9500.698000
9750.681700
10000.685100
10250.684400
10500.683500
10750.698300

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "No files have been modified since last commit. Skipping to prevent empty commit.\n", "No files have been modified since last commit. Skipping to prevent empty commit.\n", "No files have been modified since last commit. Skipping to prevent empty commit.\n", "No files have been modified since last commit. Skipping to prevent empty commit.\n", "No files have been modified since last commit. Skipping to prevent empty commit.\n", "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:600: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " return fn(*args, **kwargs)\n", "No files have been modified since last commit. Skipping to prevent empty commit.\n", "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n", " with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]\n", "No files have been modified since last commit. Skipping to prevent empty commit.\n", "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:600: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " return fn(*args, **kwargs)\n", "No files have been modified since last commit. Skipping to prevent empty commit.\n", "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n", " with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]\n" ] } ], "source": [ "hub_model_id = \"davisrbr/math-lora\"\n", "tokenizer.pad_token = tokenizer.eos_token\n", "torch.cuda.empty_cache()\n", "trainer = transformers.Trainer(\n", " model=model,\n", " train_dataset=processed_dataset,\n", " args=TrainingArguments(\n", " per_device_train_batch_size=4,\n", " gradient_accumulation_steps=8,\n", " gradient_checkpointing=True,\n", " warmup_steps=200,\n", " max_steps=10000,\n", " learning_rate=2e-4,\n", " bf16=True,\n", " logging_steps=25,\n", " output_dir=\".\",\n", " optim=\"adamw_bnb_8bit\",\n", " logging_first_step=True,\n", " push_to_hub=True,\n", " hub_model_id=hub_model_id,\n", " ),\n", " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),\n", ")\n", "model.config.use_cache = False\n", "\n", "push_frequency = 100\n", "trainer.add_callback(PushToHubCallback(trainer, push_frequency,))\n", "\n", "trainer.train()\n", "\n", "final_commit_hash = trainer.push_to_hub(\"Training complete\")\n", "print(f\"Training complete. Final commit hash: {final_commit_hash}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "_0MOtwf3zdZp" }, "source": [ "Run the cell below to run the training! For the sake of the demo, we just ran it for few steps just to showcase how to use this integration with existing tools on the HF ecosystem." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 481 }, "id": "jq0nX33BmfaC", "outputId": "7f470980-c49e-4230-b947-ad43510f1bee" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [10/10 13:02, Epoch 0/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
12.042200
21.293400
31.447500
41.433600
51.725900
61.506400
71.549600
81.038300
91.603300
101.676400

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=10, training_loss=1.531658697128296, metrics={'train_runtime': 861.2678, 'train_samples_per_second': 0.046, 'train_steps_per_second': 0.012, 'total_flos': 56809829376000.0, 'train_loss': 1.531658697128296, 'epoch': 0.02})" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import transformers\n", "\n", "tokenizer.pad_token = tokenizer.eos_token\n", "\n", "trainer = transformers.Trainer(\n", " model=model,\n", " train_dataset=data[\"train\"],\n", " args=transformers.TrainingArguments(\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=8,\n", " gradient_checkpointing=True,\n", " warmup_steps=2,\n", " max_steps=10,\n", " learning_rate=2e-4,\n", " fp16=True,\n", " logging_steps=1,\n", " output_dir=\"outputs\",\n", " optim=\"adamw_bnb_8bit\",\n", " logging_first_step=True,\n", " ),\n", " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),\n", ")\n", "model.config.use_cache = False # silence the warnings. Please re-enable for inference!\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "05iBmtP6X3Mq" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }