{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "6c066f8c-5b3b-486b-b958-76cc9d380146", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ[\"TORCH_USE_CUDA_DSA\"] = \"1\" # Enable CUDA Dynamic Shared Allocation\n", "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"" ] }, { "cell_type": "code", "execution_count": null, "id": "c20b0689-1347-47fb-a259-48ab1e1c1420", "metadata": {}, "outputs": [], "source": [ "import sys\n", "!{sys.executable} -m pip install datasets transformers accelerate==0.30.0 peft flash-attn" ] }, { "cell_type": "code", "execution_count": 12, "id": "a4b1142a-c51e-4794-94d7-805e70fb308d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.45.5\n" ] } ], "source": [ "import bitsandbytes as bab\n", "print(bab.__version__)" ] }, { "cell_type": "code", "execution_count": 1, "id": "b3900693-64f2-4f48-803a-196de8f616f7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NVIDIA L4\n", "True\n" ] } ], "source": [ "import torch\n", "print(torch.cuda.get_device_name(0))\n", "print(torch.cuda.is_bf16_supported())" ] }, { "cell_type": "code", "execution_count": 2, "id": "de679cdb-4fb6-4bd6-be66-6379c4131312", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(8, 9)\n" ] } ], "source": [ "print(torch.cuda.get_device_capability())" ] }, { "cell_type": "code", "execution_count": 2, "id": "b0b4c7df-aa9e-40ed-9827-e5436e33168c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: bitsandbytes in /usr/local/lib/python3.10/dist-packages (0.45.5)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from bitsandbytes) (1.26.4)\n", "Requirement already satisfied: torch<3,>=2.0 in /usr/local/lib/python3.10/dist-packages (from bitsandbytes) (2.6.0)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (10.3.5.147)\n", "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (4.12.2)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (0.6.2)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.127)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (2024.9.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (1.13.1)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (3.4.2)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.127)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (11.2.1.3)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.3.1.170)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (9.1.0.70)\n", "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (3.2.0)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (11.6.1.9)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (3.1.5)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.127)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (2.21.5)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (3.17.0)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.127)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch<3,>=2.0->bitsandbytes) (1.3.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch<3,>=2.0->bitsandbytes) (3.0.2)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "import sys\n", "!{sys.executable} -m pip install -U bitsandbytes " ] }, { "cell_type": "code", "execution_count": 13, "id": "08edce1c-0326-4688-a557-5f80b33cb077", "metadata": {}, "outputs": [], "source": [ "device_map = (\n", " int(os.environ.get(\"LOCAL_RANK\", -1))\n", " if torch.distributed.is_available() and torch.distributed.is_initialized()\n", " else \"auto\"\n", ") # {\"\": 0}" ] }, { "cell_type": "code", "execution_count": 14, "id": "ca502d8f-a0e2-421c-b615-5bd232236fb2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "auto\n" ] } ], "source": [ "print(device_map)" ] }, { "cell_type": "code", "execution_count": 1, "id": "6d4c31bf-9087-40e2-8cf1-50d793a50cc4", "metadata": {}, "outputs": [], "source": [ "MODEL = \"bigcode/starcoderbase-1b\" # Model checkpoint on the Hugging Face Hub\n", "DATASET = \"smangrul/hf-stack-v1\" # Dataset on the Hugging Face Hub\n", "DATA_COLUMN = \"content\" # Column name containing the code content\n", "\n", "SEQ_LENGTH = 2048 # Sequence length\n", "\n", "MAX_STEPS = 2000 # max_steps\n", "BATCH_SIZE = 1 # batch_size\n", "GR_ACC_STEPS = 1 # gradient_accumulation_steps\n", "LR = 5e-4 # learning_rate\n", "LR_SCHEDULER_TYPE = \"cosine\" # lr_scheduler_type\n", "WEIGHT_DECAY = 0.01 # weight_decay\n", "NUM_WARMUP_STEPS = 30 # num_warmup_steps\n", "EVAL_FREQ = 100 # eval_freq\n", "SAVE_FREQ = 100 # save_freq\n", "LOG_FREQ = 25 # log_freq\n", "OUTPUT_DIR = \"peft-starcoder-lora-a100\" # output_dir\n", "BF16 = True # bf16\n", "FP16 = False # no_fp16\n", "\n", "# FIM trasformations arguments\n", "FIM_RATE = 0.5 # fim_rate\n", "FIM_SPM_RATE = 0.5 # fim_spm_rate\n", "\n", "# LORA\n", "LORA_R = 8 # lora_r\n", "LORA_ALPHA = 32 # lora_alpha\n", "LORA_DROPOUT = 0.0 # lora_dropout\n", "LORA_TARGET_MODULES = \"c_proj,c_attn,q_attn,c_fc,c_proj\" # lora_target_modules\n", "\n", "# bitsandbytes config\n", "USE_NESTED_QUANT = True # use_nested_quant\n", "BNB_4BIT_COMPUTE_DTYPE = \"bfloat16\" # bnb_4bit_compute_dtype\n", "\n", "SEED = 0" ] }, { "cell_type": "code", "execution_count": 2, "id": "2e60f3f7-c90f-41ec-91d3-98b6532e9446", "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import login\n", "from transformers import (\n", " AutoModelForCausalLM,\n", " AutoTokenizer,\n", " Trainer,\n", " TrainingArguments,\n", " logging,\n", " set_seed,\n", " BitsAndBytesConfig,\n", ")\n", "\n", "from datasets import load_dataset\n", "import torch\n", "from tqdm import tqdm\n", "\n", "#Prepare Data\n", "dataset = load_dataset(\n", " DATASET,\n", " data_dir=\"data\",\n", " split=\"train\",\n", " streaming=True,\n", ")\n", "\n", "valid_data = dataset.take(4000)\n", "train_data = dataset.skip(4000)\n", "train_data = train_data.shuffle(buffer_size=5000, seed=SEED)\n", "\n", "set_seed(SEED)" ] }, { "cell_type": "code", "execution_count": 5, "id": "88201294-50c4-44b0-9209-1873feba0dae", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", " warnings.warn(\n", "100%|██████████| 400/400 [00:03<00:00, 109.96it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "The character to token ratio of the dataset is: 2.43\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)\n", "\n", "\n", "def chars_token_ratio(dataset, tokenizer, data_column, nb_examples=400):\n", " \"\"\"\n", " Estimate the average number of characters per token in the dataset.\n", " \"\"\"\n", "\n", " total_characters, total_tokens = 0, 0\n", " for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):\n", " total_characters += len(example[data_column])\n", " total_tokens += len(tokenizer(example[data_column]).tokens())\n", "\n", " return total_characters / total_tokens\n", "\n", "\n", "chars_per_token = chars_token_ratio(train_data, tokenizer, DATA_COLUMN)\n", "print(f\"The character to token ratio of the dataset is: {chars_per_token:.2f}\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "eee2c81e-f0df-435e-b4b9-e2a1d4f8c853", "metadata": {}, "outputs": [], "source": [ "import functools\n", "import numpy as np\n", "\n", "\n", "# Helper function to get token ids of the special tokens for prefix, suffix and middle for FIM transformations.\n", "@functools.lru_cache(maxsize=None)\n", "def get_fim_token_ids(tokenizer):\n", " try:\n", " FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD = tokenizer.special_tokens_map[\"additional_special_tokens\"][1:5]\n", " suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = (\n", " tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD]\n", " )\n", " except KeyError:\n", " suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = None, None, None, None\n", " return suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id\n", "\n", "\n", "## Adapted from https://github.com/bigcode-project/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py\n", "def permute(\n", " sample,\n", " np_rng,\n", " suffix_tok_id,\n", " prefix_tok_id,\n", " middle_tok_id,\n", " pad_tok_id,\n", " fim_rate=0.5,\n", " fim_spm_rate=0.5,\n", " truncate_or_pad=False,\n", "):\n", " \"\"\"\n", " Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes:\n", " PSM and SPM (with a probability of fim_spm_rate).\n", " \"\"\"\n", "\n", " # The if condition will trigger with the probability of fim_rate\n", " # This means FIM transformations will apply to samples with a probability of fim_rate\n", " if np_rng.binomial(1, fim_rate):\n", "\n", " # Split the sample into prefix, middle, and suffix, based on randomly generated indices stored in the boundaries list.\n", " boundaries = list(np_rng.randint(low=0, high=len(sample) + 1, size=2))\n", " boundaries.sort()\n", "\n", " prefix = np.array(sample[: boundaries[0]], dtype=np.int64)\n", " middle = np.array(sample[boundaries[0] : boundaries[1]], dtype=np.int64)\n", " suffix = np.array(sample[boundaries[1] :], dtype=np.int64)\n", "\n", " if truncate_or_pad:\n", " # calculate the new total length of the sample, taking into account tokens indicating prefix, middle, and suffix\n", " new_length = suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3\n", " diff = new_length - len(sample)\n", "\n", " # trancate or pad if there's a difference in length between the new length and the original\n", " if diff > 0:\n", " if suffix.shape[0] <= diff:\n", " return sample, np_rng\n", " suffix = suffix[: suffix.shape[0] - diff]\n", " elif diff < 0:\n", " suffix = np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)])\n", "\n", " # With the probability of fim_spm_rateapply SPM variant of FIM transformations\n", " # SPM: suffix, prefix, middle\n", " if np_rng.binomial(1, fim_spm_rate):\n", " new_sample = np.concatenate(\n", " [\n", " [prefix_tok_id, suffix_tok_id],\n", " suffix,\n", " [middle_tok_id],\n", " prefix,\n", " middle,\n", " ]\n", " )\n", " # Otherwise, apply the PSM variant of FIM transformations\n", " # PSM: prefix, suffix, middle\n", " else:\n", "\n", " new_sample = np.concatenate(\n", " [\n", " [prefix_tok_id],\n", " prefix,\n", " [suffix_tok_id],\n", " suffix,\n", " [middle_tok_id],\n", " middle,\n", " ]\n", " )\n", " else:\n", " # don't apply FIM transformations\n", " new_sample = sample\n", "\n", " return list(new_sample), np_rng" ] }, { "cell_type": "code", "execution_count": 7, "id": "3a9ebdef-4178-44af-9b35-12c8189c27f7", "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import IterableDataset\n", "from torch.utils.data.dataloader import DataLoader\n", "import random\n", "\n", "# Create an Iterable dataset that returns constant-length chunks of tokens from a stream of text files.\n", "\n", "\n", "class ConstantLengthDataset(IterableDataset):\n", " \"\"\"\n", " Iterable dataset that returns constant length chunks of tokens from stream of text files.\n", " Args:\n", " tokenizer (Tokenizer): The processor used for proccessing the data.\n", " dataset (dataset.Dataset): Dataset with text files.\n", " infinite (bool): If True the iterator is reset after dataset reaches end else stops.\n", " seq_length (int): Length of token sequences to return.\n", " num_of_sequences (int): Number of token sequences to keep in buffer.\n", " chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.\n", " fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM.\n", " fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM.\n", " seed (int): Seed for random number generator.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " tokenizer,\n", " dataset,\n", " infinite=False,\n", " seq_length=1024,\n", " num_of_sequences=1024,\n", " chars_per_token=3.6,\n", " content_field=\"content\",\n", " fim_rate=0.5,\n", " fim_spm_rate=0.5,\n", " seed=0,\n", " ):\n", " self.tokenizer = tokenizer\n", " self.concat_token_id = tokenizer.eos_token_id\n", " self.dataset = dataset\n", " self.seq_length = seq_length\n", " self.infinite = infinite\n", " self.current_size = 0\n", " self.max_buffer_size = seq_length * chars_per_token * num_of_sequences\n", " self.content_field = content_field\n", " self.fim_rate = fim_rate\n", " self.fim_spm_rate = fim_spm_rate\n", " self.seed = seed\n", "\n", " (\n", " self.suffix_tok_id,\n", " self.prefix_tok_id,\n", " self.middle_tok_id,\n", " self.pad_tok_id,\n", " ) = get_fim_token_ids(self.tokenizer)\n", " if not self.suffix_tok_id and self.fim_rate > 0:\n", " print(\"FIM is not supported by tokenizer, disabling FIM\")\n", " self.fim_rate = 0\n", "\n", " def __iter__(self):\n", " iterator = iter(self.dataset)\n", " more_examples = True\n", " np_rng = np.random.RandomState(seed=self.seed)\n", " while more_examples:\n", " buffer, buffer_len = [], 0\n", " while True:\n", " if buffer_len >= self.max_buffer_size:\n", " break\n", " try:\n", " buffer.append(next(iterator)[self.content_field])\n", " buffer_len += len(buffer[-1])\n", " except StopIteration:\n", " if self.infinite:\n", " iterator = iter(self.dataset)\n", " else:\n", " more_examples = False\n", " break\n", " tokenized_inputs = self.tokenizer(buffer, truncation=False)[\"input_ids\"]\n", " all_token_ids = []\n", "\n", " for tokenized_input in tokenized_inputs:\n", " # optionally do FIM permutations\n", " if self.fim_rate > 0:\n", " tokenized_input, np_rng = permute(\n", " tokenized_input,\n", " np_rng,\n", " self.suffix_tok_id,\n", " self.prefix_tok_id,\n", " self.middle_tok_id,\n", " self.pad_tok_id,\n", " fim_rate=self.fim_rate,\n", " fim_spm_rate=self.fim_spm_rate,\n", " truncate_or_pad=False,\n", " )\n", "\n", " all_token_ids.extend(tokenized_input + [self.concat_token_id])\n", " examples = []\n", " for i in range(0, len(all_token_ids), self.seq_length):\n", " input_ids = all_token_ids[i : i + self.seq_length]\n", " if len(input_ids) == self.seq_length:\n", " examples.append(input_ids)\n", " random.shuffle(examples)\n", " for example in examples:\n", " self.current_size += 1\n", " yield {\n", " \"input_ids\": torch.LongTensor(example),\n", " \"labels\": torch.LongTensor(example),\n", " }\n", "\n", "\n", "train_dataset = ConstantLengthDataset(\n", " tokenizer,\n", " train_data,\n", " infinite=True,\n", " seq_length=SEQ_LENGTH,\n", " chars_per_token=chars_per_token,\n", " content_field=DATA_COLUMN,\n", " fim_rate=FIM_RATE,\n", " fim_spm_rate=FIM_SPM_RATE,\n", " seed=SEED,\n", ")\n", "eval_dataset = ConstantLengthDataset(\n", " tokenizer,\n", " valid_data,\n", " infinite=False,\n", " seq_length=SEQ_LENGTH,\n", " chars_per_token=chars_per_token,\n", " content_field=DATA_COLUMN,\n", " fim_rate=FIM_RATE,\n", " fim_spm_rate=FIM_SPM_RATE,\n", " seed=SEED,\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "id": "5021e686-2e1c-4608-9477-1f07adf2de35", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e4ea2b63e63448938020858c06143ca5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model.safetensors: 0%| | 0.00/4.55G [00:00 26\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mAutoModelForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 27\u001b[0m \u001b[43m \u001b[49m\u001b[43mMODEL\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 28\u001b[0m \u001b[43m \u001b[49m\u001b[43mquantization_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbnb_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 29\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 30\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# We will be using gradient checkpointing\u001b[39;49;00m\n\u001b[1;32m 31\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 32\u001b[0m \u001b[43m \u001b[49m\u001b[43mattn_implementation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mflash_attention_2\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 33\u001b[0m \u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py:564\u001b[0m, in \u001b[0;36m_BaseAutoModelClass.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 562\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(config) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_model_mapping\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[1;32m 563\u001b[0m model_class \u001b[38;5;241m=\u001b[39m _get_model_class(config, \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_model_mapping)\n\u001b[0;32m--> 564\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodel_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 565\u001b[0m \u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhub_kwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 566\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 568\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnrecognized configuration class \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m for this kind of AutoModel: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 569\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModel type should be one of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(c\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mfor\u001b[39;00m\u001b[38;5;250m \u001b[39mc\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01min\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_model_mapping\u001b[38;5;241m.\u001b[39mkeys())\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 570\u001b[0m )\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:3804\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3801\u001b[0m init_contexts\u001b[38;5;241m.\u001b[39mappend(init_empty_weights())\n\u001b[1;32m 3803\u001b[0m config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(config) \u001b[38;5;66;03m# We do not want to modify the config inplace in from_pretrained.\u001b[39;00m\n\u001b[0;32m-> 3804\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_autoset_attn_implementation\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3805\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muse_flash_attention_2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_flash_attention_2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch_dtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\n\u001b[1;32m 3806\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3808\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ContextManagers(init_contexts):\n\u001b[1;32m 3809\u001b[0m \u001b[38;5;66;03m# Let's make sure we don't run the init function of buffer modules\u001b[39;00m\n\u001b[1;32m 3810\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m(config, \u001b[38;5;241m*\u001b[39mmodel_args, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs)\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:1534\u001b[0m, in \u001b[0;36mPreTrainedModel._autoset_attn_implementation\u001b[0;34m(cls, config, use_flash_attention_2, torch_dtype, device_map, check_device_map)\u001b[0m\n\u001b[1;32m 1531\u001b[0m config\u001b[38;5;241m.\u001b[39m_attn_implementation \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attention_2\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1533\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config\u001b[38;5;241m.\u001b[39m_attn_implementation \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attention_2\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m-> 1534\u001b[0m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_check_and_enable_flash_attn_2\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1535\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1536\u001b[0m \u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch_dtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1537\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1538\u001b[0m \u001b[43m \u001b[49m\u001b[43mhard_check_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1539\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_device_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheck_device_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1540\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1541\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m requested_attn_implementation \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msdpa\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available():\n\u001b[1;32m 1542\u001b[0m \u001b[38;5;66;03m# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.\u001b[39;00m\n\u001b[1;32m 1543\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_check_and_enable_sdpa(\n\u001b[1;32m 1544\u001b[0m config,\n\u001b[1;32m 1545\u001b[0m hard_check_only\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m requested_attn_implementation \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 1546\u001b[0m )\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:1636\u001b[0m, in \u001b[0;36mPreTrainedModel._check_and_enable_flash_attn_2\u001b[0;34m(cls, config, torch_dtype, device_map, check_device_map, hard_check_only)\u001b[0m\n\u001b[1;32m 1633\u001b[0m install_message \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1635\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m importlib\u001b[38;5;241m.\u001b[39mutil\u001b[38;5;241m.\u001b[39mfind_spec(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attn\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1636\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpreface\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m the package flash_attn seems to be not installed. \u001b[39m\u001b[38;5;132;01m{\u001b[39;00minstall_message\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1638\u001b[0m flash_attention_version \u001b[38;5;241m=\u001b[39m version\u001b[38;5;241m.\u001b[39mparse(importlib\u001b[38;5;241m.\u001b[39mmetadata\u001b[38;5;241m.\u001b[39mversion(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attn\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 1639\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mversion\u001b[38;5;241m.\u001b[39mcuda:\n", "\u001b[0;31mImportError\u001b[0m: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2." ] } ], "source": [ "import torch\n", "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n", "from peft.tuners.lora import LoraLayer\n", "\n", "load_in_8bit = False\n", "\n", "# 4-bit quantization\n", "compute_dtype = getattr(torch, BNB_4BIT_COMPUTE_DTYPE)\n", "\n", "bnb_config2 = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_use_double_quant=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_compute_dtype=torch.bfloat16\n", ")\n", "\n", "bnb_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_compute_dtype=compute_dtype,\n", " bnb_4bit_use_double_quant=USE_NESTED_QUANT,\n", ")\n", "\n", "device_map = {\"\": 0}\n", "\n", "model = AutoModelForCausalLM.from_pretrained(\n", " MODEL,\n", " quantization_config=bnb_config,\n", " device_map=device_map,\n", " use_cache=False, # We will be using gradient checkpointing\n", " trust_remote_code=True,\n", " attn_implementation=\"flash_attention_2\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "7c0a9fdb-4087-4ec5-aefa-fc5f413252e4", "metadata": {}, "outputs": [], "source": [ "model = prepare_model_for_kbit_training(model)" ] }, { "cell_type": "code", "execution_count": null, "id": "b5fd3293-76ef-48eb-8241-478e311ec947", "metadata": {}, "outputs": [], "source": [ "# Set up lora\n", "peft_config = LoraConfig(\n", " lora_alpha=LORA_ALPHA,\n", " lora_dropout=LORA_DROPOUT,\n", " r=LORA_R,\n", " bias=\"none\",\n", " task_type=\"CAUSAL_LM\",\n", " target_modules=LORA_TARGET_MODULES.split(\",\"),\n", ")\n", "\n", "model = get_peft_model(model, peft_config)\n", "model.print_trainable_parameters()" ] }, { "cell_type": "code", "execution_count": null, "id": "082c6a7b-db61-4800-8a94-419331b1fd22", "metadata": {}, "outputs": [], "source": [ "train_data.start_iteration = 0\n", "\n", "\n", "training_args = TrainingArguments(\n", " output_dir=f\"ernyou/{OUTPUT_DIR}\",\n", " dataloader_drop_last=True,\n", " eval_strategy=\"steps\",\n", " save_strategy=\"steps\",\n", " max_steps=MAX_STEPS,\n", " eval_steps=EVAL_FREQ,\n", " save_steps=SAVE_FREQ,\n", " logging_steps=LOG_FREQ,\n", " per_device_train_batch_size=BATCH_SIZE,\n", " per_device_eval_batch_size=BATCH_SIZE,\n", " learning_rate=LR,\n", " lr_scheduler_type=LR_SCHEDULER_TYPE,\n", " warmup_steps=NUM_WARMUP_STEPS,\n", " gradient_accumulation_steps=GR_ACC_STEPS,\n", " gradient_checkpointing_kwargs={\"use_reentrant\": True},\n", " gradient_checkpointing=True,\n", " fp16=FP16,\n", " bf16=BF16,\n", " weight_decay=WEIGHT_DECAY,\n", " push_to_hub=True,\n", " include_tokens_per_second=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "id": "2c302ded-f017-433c-9622-55ecb45141bd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.3.0\n" ] } ], "source": [ "import accelerate as ac\n", "print(ac.__version__)" ] }, { "cell_type": "code", "execution_count": null, "id": "5318efa4-83da-41fa-9123-50b505e9a615", "metadata": {}, "outputs": [], "source": [ "import torch\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": null, "id": "ceddceb0-8e1e-493f-99a1-b77c6b0c40b6", "metadata": {}, "outputs": [], "source": [ "trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset)\n", "\n", "print(\"Training...\")\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "id": "76fb1530-5e16-4b4b-a30f-b689df3483f3", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }