{ "cells": [ { "cell_type": "markdown", "id": "d36e1e93-ae93-4a4e-93c6-68fd868d2882", "metadata": {}, "source": [ "# Using VB-LoRA for sequence classification" ] }, { "cell_type": "markdown", "id": "ddfc0610-55f6-4343-a950-125ccf0f45ac", "metadata": {}, "source": [ "In this example, we fine-tune Roberta on a sequence classification task using VB-LoRA.\n", "\n", "This notebook is adapted from `examples/sequence_classification/VeRA.ipynb`." ] }, { "cell_type": "markdown", "id": "45addd81-d4f3-4dfd-960d-3920d347f0a6", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "a9935ae2", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader\n", "from peft import (\n", " get_peft_model,\n", " VBLoRAConfig,\n", " PeftType,\n", ")\n", "\n", "import evaluate\n", "from datasets import load_dataset\n", "from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup\n", "from tqdm import tqdm" ] }, { "cell_type": "markdown", "id": "62c959bf-7cc2-49e0-b97e-4c10ec3b9bf3", "metadata": {}, "source": [ "## Parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "e3b13308", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch_size = 32\n", "model_name_or_path = \"roberta-large\"\n", "task = \"mrpc\"\n", "peft_type = PeftType.VBLORA\n", "device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n", "num_epochs = 20\n", "rank = 4\n", "max_length = 128\n", "num_vectors = 90\n", "vector_length = 256\n", "torch.manual_seed(0)" ] }, { "cell_type": "code", "execution_count": 3, "id": "0526f571", "metadata": {}, "outputs": [], "source": [ "peft_config = VBLoRAConfig(\n", " task_type=\"SEQ_CLS\", \n", " r=rank,\n", " topk=2,\n", " target_modules=['key', 'value', 'query', 'output.dense', 'intermediate.dense'],\n", " num_vectors=num_vectors,\n", " vector_length=vector_length,\n", " save_only_topk_weights=True, # Set to True to reduce storage space. Note that the saved parameters cannot be used to resume training from checkpoints.\n", " vblora_dropout=0.,\n", ")\n", "head_lr = 4e-3\n", "vector_bank_lr = 1e-3\n", "logits_lr = 1e-2" ] }, { "cell_type": "markdown", "id": "c075c5d2-a457-4f37-a7f1-94fd0d277972", "metadata": {}, "source": [ "## Loading data" ] }, { "cell_type": "code", "execution_count": 4, "id": "7bb52cb4-d1c3-4b04-8bf0-f39ca88af139", "metadata": {}, "outputs": [], "source": [ "if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n", " padding_side = \"left\"\n", "else:\n", " padding_side = \"right\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n", "if getattr(tokenizer, \"pad_token_id\") is None:\n", " tokenizer.pad_token_id = tokenizer.eos_token_id" ] }, { "cell_type": "code", "execution_count": 5, "id": "e69c5e1f-d27b-4264-a41e-fc9b99d025e6", "metadata": {}, "outputs": [], "source": [ "datasets = load_dataset(\"glue\", task)\n", "metric = evaluate.load(\"glue\", task)" ] }, { "cell_type": "code", "execution_count": 6, "id": "0209f778-c93b-40eb-a4e0-24c25db03980", "metadata": {}, "outputs": [], "source": [ "def tokenize_function(examples):\n", " # max_length=None => use the model max length (it's actually the default)\n", " outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=max_length)\n", " return outputs\n", "\n", "\n", "tokenized_datasets = datasets.map(\n", " tokenize_function,\n", " batched=True,\n", " remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n", ")\n", "\n", "# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n", "# transformers library\n", "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "7453954e-982c-46f0-b09c-589776e6d6cb", "metadata": {}, "outputs": [], "source": [ "def collate_fn(examples):\n", " return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n", "\n", "\n", "# Instantiate dataloaders.\n", "train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n", "eval_dataloader = DataLoader(\n", " tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n", ")" ] }, { "cell_type": "markdown", "id": "f3b9b2e8-f415-4d0f-9fb4-436f1a3585ea", "metadata": {}, "source": [ "## Preparing the VB-LoRA model" ] }, { "cell_type": "code", "execution_count": 8, "id": "2ed5ac74", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "trainable params: 1,696,770 || all params: 357,058,564 || trainable%: 0.4752\n", "VB-LoRA params to-be-saved (float32-equivalent): 33,408 || total params to-be-saved: 1,085,058\n" ] } ], "source": [ "model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True, max_length=None)\n", "model = get_peft_model(model, peft_config)\n", "model.print_trainable_parameters()\n", "model.print_savable_parameters()" ] }, { "cell_type": "code", "execution_count": 9, "id": "0d2d0381", "metadata": {}, "outputs": [], "source": [ "\n", "from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS\n", "from transformers.trainer_pt_utils import get_parameter_names\n", "\n", "decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)\n", "decay_parameters = [name for name in decay_parameters if \"bias\" not in name]\n", "vector_bank_parameters = [name for name, _ in model.named_parameters() if \"vector_bank\" in name]\n", "logits_parameters = [name for name, _ in model.named_parameters() if \"logits\" in name ]\n", "\n", "optimizer_grouped_parameters = [\n", " {\n", " \"params\": [p for n, p in model.named_parameters() if n in decay_parameters and \\\n", " n not in logits_parameters and n not in vector_bank_parameters],\n", " \"weight_decay\": 0.1,\n", " \"lr\": head_lr,\n", " },\n", " {\n", " \"params\": [p for n, p in model.named_parameters() if n not in decay_parameters and \\\n", " n not in logits_parameters and n not in vector_bank_parameters],\n", " \"weight_decay\": 0.0,\n", " \"lr\": head_lr,\n", " },\n", " {\n", " \"params\": [p for n, p in model.named_parameters() if n in vector_bank_parameters],\n", " \"lr\": vector_bank_lr,\n", " \"weight_decay\": 0.0,\n", " },\n", " {\n", " \"params\": [p for n, p in model.named_parameters() if n in logits_parameters],\n", " \"lr\": logits_lr,\n", " \"weight_decay\": 0.0,\n", " },\n", "]\n", "\n", "optimizer = AdamW(optimizer_grouped_parameters)\n", "lr_scheduler = get_linear_schedule_with_warmup(\n", " optimizer=optimizer,\n", " num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),\n", " num_training_steps=(len(train_dataloader) * num_epochs),\n", ")" ] }, { "cell_type": "markdown", "id": "c0dd5aa8-977b-4ac0-8b96-884b17bcdd00", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 10, "id": "fa0e73be", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/115 [00:00