Ern You
commited on
Commit
·
131e609
1
Parent(s):
6159095
initial commit
Browse files- default_config.yaml +25 -0
- train.ipynb +742 -0
- train.py +392 -0
- use_model.py +88 -0
default_config.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: FSDP
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
fsdp_config:
|
| 6 |
+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
| 7 |
+
fsdp_backward_prefetch: BACKWARD_PRE
|
| 8 |
+
fsdp_cpu_ram_efficient_loading: true
|
| 9 |
+
fsdp_forward_prefetch: false
|
| 10 |
+
fsdp_offload_params: true
|
| 11 |
+
fsdp_sharding_strategy: FULL_SHARD
|
| 12 |
+
fsdp_state_dict_type: SHARDED_STATE_DICT
|
| 13 |
+
fsdp_sync_module_states: true
|
| 14 |
+
fsdp_use_orig_params: True
|
| 15 |
+
machine_rank: 0
|
| 16 |
+
main_training_function: train
|
| 17 |
+
mixed_precision: 'no'
|
| 18 |
+
num_machines: 1
|
| 19 |
+
num_processes: 8
|
| 20 |
+
rdzv_backend: static
|
| 21 |
+
same_network: true
|
| 22 |
+
tpu_env: []
|
| 23 |
+
tpu_use_cluster: false
|
| 24 |
+
tpu_use_sudo: false
|
| 25 |
+
use_cpu: false
|
train.ipynb
ADDED
|
@@ -0,0 +1,742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "6c066f8c-5b3b-486b-b958-76cc9d380146",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import os\n",
|
| 11 |
+
"os.environ[\"TORCH_USE_CUDA_DSA\"] = \"1\" # Enable CUDA Dynamic Shared Allocation\n",
|
| 12 |
+
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\""
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "code",
|
| 17 |
+
"execution_count": null,
|
| 18 |
+
"id": "c20b0689-1347-47fb-a259-48ab1e1c1420",
|
| 19 |
+
"metadata": {},
|
| 20 |
+
"outputs": [],
|
| 21 |
+
"source": [
|
| 22 |
+
"import sys\n",
|
| 23 |
+
"!{sys.executable} -m pip install datasets transformers accelerate==0.30.0 peft flash-attn"
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"cell_type": "code",
|
| 28 |
+
"execution_count": 12,
|
| 29 |
+
"id": "a4b1142a-c51e-4794-94d7-805e70fb308d",
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"outputs": [
|
| 32 |
+
{
|
| 33 |
+
"name": "stdout",
|
| 34 |
+
"output_type": "stream",
|
| 35 |
+
"text": [
|
| 36 |
+
"0.45.5\n"
|
| 37 |
+
]
|
| 38 |
+
}
|
| 39 |
+
],
|
| 40 |
+
"source": [
|
| 41 |
+
"import bitsandbytes as bab\n",
|
| 42 |
+
"print(bab.__version__)"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"execution_count": 1,
|
| 48 |
+
"id": "b3900693-64f2-4f48-803a-196de8f616f7",
|
| 49 |
+
"metadata": {},
|
| 50 |
+
"outputs": [
|
| 51 |
+
{
|
| 52 |
+
"name": "stdout",
|
| 53 |
+
"output_type": "stream",
|
| 54 |
+
"text": [
|
| 55 |
+
"NVIDIA L4\n",
|
| 56 |
+
"True\n"
|
| 57 |
+
]
|
| 58 |
+
}
|
| 59 |
+
],
|
| 60 |
+
"source": [
|
| 61 |
+
"import torch\n",
|
| 62 |
+
"print(torch.cuda.get_device_name(0))\n",
|
| 63 |
+
"print(torch.cuda.is_bf16_supported())"
|
| 64 |
+
]
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"cell_type": "code",
|
| 68 |
+
"execution_count": 2,
|
| 69 |
+
"id": "de679cdb-4fb6-4bd6-be66-6379c4131312",
|
| 70 |
+
"metadata": {},
|
| 71 |
+
"outputs": [
|
| 72 |
+
{
|
| 73 |
+
"name": "stdout",
|
| 74 |
+
"output_type": "stream",
|
| 75 |
+
"text": [
|
| 76 |
+
"(8, 9)\n"
|
| 77 |
+
]
|
| 78 |
+
}
|
| 79 |
+
],
|
| 80 |
+
"source": [
|
| 81 |
+
"print(torch.cuda.get_device_capability())"
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"cell_type": "code",
|
| 86 |
+
"execution_count": 2,
|
| 87 |
+
"id": "b0b4c7df-aa9e-40ed-9827-e5436e33168c",
|
| 88 |
+
"metadata": {},
|
| 89 |
+
"outputs": [
|
| 90 |
+
{
|
| 91 |
+
"name": "stdout",
|
| 92 |
+
"output_type": "stream",
|
| 93 |
+
"text": [
|
| 94 |
+
"Requirement already satisfied: bitsandbytes in /usr/local/lib/python3.10/dist-packages (0.45.5)\n",
|
| 95 |
+
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from bitsandbytes) (1.26.4)\n",
|
| 96 |
+
"Requirement already satisfied: torch<3,>=2.0 in /usr/local/lib/python3.10/dist-packages (from bitsandbytes) (2.6.0)\n",
|
| 97 |
+
"Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (10.3.5.147)\n",
|
| 98 |
+
"Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (4.12.2)\n",
|
| 99 |
+
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (0.6.2)\n",
|
| 100 |
+
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.127)\n",
|
| 101 |
+
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (2024.9.0)\n",
|
| 102 |
+
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (1.13.1)\n",
|
| 103 |
+
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (3.4.2)\n",
|
| 104 |
+
"Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.5.8)\n",
|
| 105 |
+
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.127)\n",
|
| 106 |
+
"Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (11.2.1.3)\n",
|
| 107 |
+
"Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.3.1.170)\n",
|
| 108 |
+
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (9.1.0.70)\n",
|
| 109 |
+
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (3.2.0)\n",
|
| 110 |
+
"Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (11.6.1.9)\n",
|
| 111 |
+
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (3.1.5)\n",
|
| 112 |
+
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.127)\n",
|
| 113 |
+
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (2.21.5)\n",
|
| 114 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (3.17.0)\n",
|
| 115 |
+
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.127)\n",
|
| 116 |
+
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.127)\n",
|
| 117 |
+
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch<3,>=2.0->bitsandbytes) (1.3.0)\n",
|
| 118 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch<3,>=2.0->bitsandbytes) (3.0.2)\n",
|
| 119 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
| 120 |
+
"\u001b[0m"
|
| 121 |
+
]
|
| 122 |
+
}
|
| 123 |
+
],
|
| 124 |
+
"source": [
|
| 125 |
+
"import sys\n",
|
| 126 |
+
"!{sys.executable} -m pip install -U bitsandbytes "
|
| 127 |
+
]
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"cell_type": "code",
|
| 131 |
+
"execution_count": 13,
|
| 132 |
+
"id": "08edce1c-0326-4688-a557-5f80b33cb077",
|
| 133 |
+
"metadata": {},
|
| 134 |
+
"outputs": [],
|
| 135 |
+
"source": [
|
| 136 |
+
"device_map = (\n",
|
| 137 |
+
" int(os.environ.get(\"LOCAL_RANK\", -1))\n",
|
| 138 |
+
" if torch.distributed.is_available() and torch.distributed.is_initialized()\n",
|
| 139 |
+
" else \"auto\"\n",
|
| 140 |
+
") # {\"\": 0}"
|
| 141 |
+
]
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"cell_type": "code",
|
| 145 |
+
"execution_count": 14,
|
| 146 |
+
"id": "ca502d8f-a0e2-421c-b615-5bd232236fb2",
|
| 147 |
+
"metadata": {},
|
| 148 |
+
"outputs": [
|
| 149 |
+
{
|
| 150 |
+
"name": "stdout",
|
| 151 |
+
"output_type": "stream",
|
| 152 |
+
"text": [
|
| 153 |
+
"auto\n"
|
| 154 |
+
]
|
| 155 |
+
}
|
| 156 |
+
],
|
| 157 |
+
"source": [
|
| 158 |
+
"print(device_map)"
|
| 159 |
+
]
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"cell_type": "code",
|
| 163 |
+
"execution_count": 1,
|
| 164 |
+
"id": "6d4c31bf-9087-40e2-8cf1-50d793a50cc4",
|
| 165 |
+
"metadata": {},
|
| 166 |
+
"outputs": [],
|
| 167 |
+
"source": [
|
| 168 |
+
"MODEL = \"bigcode/starcoderbase-1b\" # Model checkpoint on the Hugging Face Hub\n",
|
| 169 |
+
"DATASET = \"smangrul/hf-stack-v1\" # Dataset on the Hugging Face Hub\n",
|
| 170 |
+
"DATA_COLUMN = \"content\" # Column name containing the code content\n",
|
| 171 |
+
"\n",
|
| 172 |
+
"SEQ_LENGTH = 2048 # Sequence length\n",
|
| 173 |
+
"\n",
|
| 174 |
+
"MAX_STEPS = 2000 # max_steps\n",
|
| 175 |
+
"BATCH_SIZE = 1 # batch_size\n",
|
| 176 |
+
"GR_ACC_STEPS = 1 # gradient_accumulation_steps\n",
|
| 177 |
+
"LR = 5e-4 # learning_rate\n",
|
| 178 |
+
"LR_SCHEDULER_TYPE = \"cosine\" # lr_scheduler_type\n",
|
| 179 |
+
"WEIGHT_DECAY = 0.01 # weight_decay\n",
|
| 180 |
+
"NUM_WARMUP_STEPS = 30 # num_warmup_steps\n",
|
| 181 |
+
"EVAL_FREQ = 100 # eval_freq\n",
|
| 182 |
+
"SAVE_FREQ = 100 # save_freq\n",
|
| 183 |
+
"LOG_FREQ = 25 # log_freq\n",
|
| 184 |
+
"OUTPUT_DIR = \"peft-starcoder-lora-a100\" # output_dir\n",
|
| 185 |
+
"BF16 = True # bf16\n",
|
| 186 |
+
"FP16 = False # no_fp16\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"# FIM trasformations arguments\n",
|
| 189 |
+
"FIM_RATE = 0.5 # fim_rate\n",
|
| 190 |
+
"FIM_SPM_RATE = 0.5 # fim_spm_rate\n",
|
| 191 |
+
"\n",
|
| 192 |
+
"# LORA\n",
|
| 193 |
+
"LORA_R = 8 # lora_r\n",
|
| 194 |
+
"LORA_ALPHA = 32 # lora_alpha\n",
|
| 195 |
+
"LORA_DROPOUT = 0.0 # lora_dropout\n",
|
| 196 |
+
"LORA_TARGET_MODULES = \"c_proj,c_attn,q_attn,c_fc,c_proj\" # lora_target_modules\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"# bitsandbytes config\n",
|
| 199 |
+
"USE_NESTED_QUANT = True # use_nested_quant\n",
|
| 200 |
+
"BNB_4BIT_COMPUTE_DTYPE = \"bfloat16\" # bnb_4bit_compute_dtype\n",
|
| 201 |
+
"\n",
|
| 202 |
+
"SEED = 0"
|
| 203 |
+
]
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"cell_type": "code",
|
| 207 |
+
"execution_count": 2,
|
| 208 |
+
"id": "2e60f3f7-c90f-41ec-91d3-98b6532e9446",
|
| 209 |
+
"metadata": {},
|
| 210 |
+
"outputs": [],
|
| 211 |
+
"source": [
|
| 212 |
+
"from huggingface_hub import login\n",
|
| 213 |
+
"from transformers import (\n",
|
| 214 |
+
" AutoModelForCausalLM,\n",
|
| 215 |
+
" AutoTokenizer,\n",
|
| 216 |
+
" Trainer,\n",
|
| 217 |
+
" TrainingArguments,\n",
|
| 218 |
+
" logging,\n",
|
| 219 |
+
" set_seed,\n",
|
| 220 |
+
" BitsAndBytesConfig,\n",
|
| 221 |
+
")\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"from datasets import load_dataset\n",
|
| 224 |
+
"import torch\n",
|
| 225 |
+
"from tqdm import tqdm\n",
|
| 226 |
+
"\n",
|
| 227 |
+
"#Prepare Data\n",
|
| 228 |
+
"dataset = load_dataset(\n",
|
| 229 |
+
" DATASET,\n",
|
| 230 |
+
" data_dir=\"data\",\n",
|
| 231 |
+
" split=\"train\",\n",
|
| 232 |
+
" streaming=True,\n",
|
| 233 |
+
")\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"valid_data = dataset.take(4000)\n",
|
| 236 |
+
"train_data = dataset.skip(4000)\n",
|
| 237 |
+
"train_data = train_data.shuffle(buffer_size=5000, seed=SEED)\n",
|
| 238 |
+
"\n",
|
| 239 |
+
"set_seed(SEED)"
|
| 240 |
+
]
|
| 241 |
+
},
|
| 242 |
+
{
|
| 243 |
+
"cell_type": "code",
|
| 244 |
+
"execution_count": 5,
|
| 245 |
+
"id": "88201294-50c4-44b0-9209-1873feba0dae",
|
| 246 |
+
"metadata": {},
|
| 247 |
+
"outputs": [
|
| 248 |
+
{
|
| 249 |
+
"name": "stderr",
|
| 250 |
+
"output_type": "stream",
|
| 251 |
+
"text": [
|
| 252 |
+
"/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
|
| 253 |
+
" warnings.warn(\n",
|
| 254 |
+
"100%|██████████| 400/400 [00:03<00:00, 109.96it/s]"
|
| 255 |
+
]
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"name": "stdout",
|
| 259 |
+
"output_type": "stream",
|
| 260 |
+
"text": [
|
| 261 |
+
"The character to token ratio of the dataset is: 2.43\n"
|
| 262 |
+
]
|
| 263 |
+
},
|
| 264 |
+
{
|
| 265 |
+
"name": "stderr",
|
| 266 |
+
"output_type": "stream",
|
| 267 |
+
"text": [
|
| 268 |
+
"\n"
|
| 269 |
+
]
|
| 270 |
+
}
|
| 271 |
+
],
|
| 272 |
+
"source": [
|
| 273 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)\n",
|
| 274 |
+
"\n",
|
| 275 |
+
"\n",
|
| 276 |
+
"def chars_token_ratio(dataset, tokenizer, data_column, nb_examples=400):\n",
|
| 277 |
+
" \"\"\"\n",
|
| 278 |
+
" Estimate the average number of characters per token in the dataset.\n",
|
| 279 |
+
" \"\"\"\n",
|
| 280 |
+
"\n",
|
| 281 |
+
" total_characters, total_tokens = 0, 0\n",
|
| 282 |
+
" for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):\n",
|
| 283 |
+
" total_characters += len(example[data_column])\n",
|
| 284 |
+
" total_tokens += len(tokenizer(example[data_column]).tokens())\n",
|
| 285 |
+
"\n",
|
| 286 |
+
" return total_characters / total_tokens\n",
|
| 287 |
+
"\n",
|
| 288 |
+
"\n",
|
| 289 |
+
"chars_per_token = chars_token_ratio(train_data, tokenizer, DATA_COLUMN)\n",
|
| 290 |
+
"print(f\"The character to token ratio of the dataset is: {chars_per_token:.2f}\")"
|
| 291 |
+
]
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
"cell_type": "code",
|
| 295 |
+
"execution_count": 6,
|
| 296 |
+
"id": "eee2c81e-f0df-435e-b4b9-e2a1d4f8c853",
|
| 297 |
+
"metadata": {},
|
| 298 |
+
"outputs": [],
|
| 299 |
+
"source": [
|
| 300 |
+
"import functools\n",
|
| 301 |
+
"import numpy as np\n",
|
| 302 |
+
"\n",
|
| 303 |
+
"\n",
|
| 304 |
+
"# Helper function to get token ids of the special tokens for prefix, suffix and middle for FIM transformations.\n",
|
| 305 |
+
"@functools.lru_cache(maxsize=None)\n",
|
| 306 |
+
"def get_fim_token_ids(tokenizer):\n",
|
| 307 |
+
" try:\n",
|
| 308 |
+
" FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD = tokenizer.special_tokens_map[\"additional_special_tokens\"][1:5]\n",
|
| 309 |
+
" suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = (\n",
|
| 310 |
+
" tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD]\n",
|
| 311 |
+
" )\n",
|
| 312 |
+
" except KeyError:\n",
|
| 313 |
+
" suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = None, None, None, None\n",
|
| 314 |
+
" return suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id\n",
|
| 315 |
+
"\n",
|
| 316 |
+
"\n",
|
| 317 |
+
"## Adapted from https://github.com/bigcode-project/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py\n",
|
| 318 |
+
"def permute(\n",
|
| 319 |
+
" sample,\n",
|
| 320 |
+
" np_rng,\n",
|
| 321 |
+
" suffix_tok_id,\n",
|
| 322 |
+
" prefix_tok_id,\n",
|
| 323 |
+
" middle_tok_id,\n",
|
| 324 |
+
" pad_tok_id,\n",
|
| 325 |
+
" fim_rate=0.5,\n",
|
| 326 |
+
" fim_spm_rate=0.5,\n",
|
| 327 |
+
" truncate_or_pad=False,\n",
|
| 328 |
+
"):\n",
|
| 329 |
+
" \"\"\"\n",
|
| 330 |
+
" Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes:\n",
|
| 331 |
+
" PSM and SPM (with a probability of fim_spm_rate).\n",
|
| 332 |
+
" \"\"\"\n",
|
| 333 |
+
"\n",
|
| 334 |
+
" # The if condition will trigger with the probability of fim_rate\n",
|
| 335 |
+
" # This means FIM transformations will apply to samples with a probability of fim_rate\n",
|
| 336 |
+
" if np_rng.binomial(1, fim_rate):\n",
|
| 337 |
+
"\n",
|
| 338 |
+
" # Split the sample into prefix, middle, and suffix, based on randomly generated indices stored in the boundaries list.\n",
|
| 339 |
+
" boundaries = list(np_rng.randint(low=0, high=len(sample) + 1, size=2))\n",
|
| 340 |
+
" boundaries.sort()\n",
|
| 341 |
+
"\n",
|
| 342 |
+
" prefix = np.array(sample[: boundaries[0]], dtype=np.int64)\n",
|
| 343 |
+
" middle = np.array(sample[boundaries[0] : boundaries[1]], dtype=np.int64)\n",
|
| 344 |
+
" suffix = np.array(sample[boundaries[1] :], dtype=np.int64)\n",
|
| 345 |
+
"\n",
|
| 346 |
+
" if truncate_or_pad:\n",
|
| 347 |
+
" # calculate the new total length of the sample, taking into account tokens indicating prefix, middle, and suffix\n",
|
| 348 |
+
" new_length = suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3\n",
|
| 349 |
+
" diff = new_length - len(sample)\n",
|
| 350 |
+
"\n",
|
| 351 |
+
" # trancate or pad if there's a difference in length between the new length and the original\n",
|
| 352 |
+
" if diff > 0:\n",
|
| 353 |
+
" if suffix.shape[0] <= diff:\n",
|
| 354 |
+
" return sample, np_rng\n",
|
| 355 |
+
" suffix = suffix[: suffix.shape[0] - diff]\n",
|
| 356 |
+
" elif diff < 0:\n",
|
| 357 |
+
" suffix = np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)])\n",
|
| 358 |
+
"\n",
|
| 359 |
+
" # With the probability of fim_spm_rateapply SPM variant of FIM transformations\n",
|
| 360 |
+
" # SPM: suffix, prefix, middle\n",
|
| 361 |
+
" if np_rng.binomial(1, fim_spm_rate):\n",
|
| 362 |
+
" new_sample = np.concatenate(\n",
|
| 363 |
+
" [\n",
|
| 364 |
+
" [prefix_tok_id, suffix_tok_id],\n",
|
| 365 |
+
" suffix,\n",
|
| 366 |
+
" [middle_tok_id],\n",
|
| 367 |
+
" prefix,\n",
|
| 368 |
+
" middle,\n",
|
| 369 |
+
" ]\n",
|
| 370 |
+
" )\n",
|
| 371 |
+
" # Otherwise, apply the PSM variant of FIM transformations\n",
|
| 372 |
+
" # PSM: prefix, suffix, middle\n",
|
| 373 |
+
" else:\n",
|
| 374 |
+
"\n",
|
| 375 |
+
" new_sample = np.concatenate(\n",
|
| 376 |
+
" [\n",
|
| 377 |
+
" [prefix_tok_id],\n",
|
| 378 |
+
" prefix,\n",
|
| 379 |
+
" [suffix_tok_id],\n",
|
| 380 |
+
" suffix,\n",
|
| 381 |
+
" [middle_tok_id],\n",
|
| 382 |
+
" middle,\n",
|
| 383 |
+
" ]\n",
|
| 384 |
+
" )\n",
|
| 385 |
+
" else:\n",
|
| 386 |
+
" # don't apply FIM transformations\n",
|
| 387 |
+
" new_sample = sample\n",
|
| 388 |
+
"\n",
|
| 389 |
+
" return list(new_sample), np_rng"
|
| 390 |
+
]
|
| 391 |
+
},
|
| 392 |
+
{
|
| 393 |
+
"cell_type": "code",
|
| 394 |
+
"execution_count": 7,
|
| 395 |
+
"id": "3a9ebdef-4178-44af-9b35-12c8189c27f7",
|
| 396 |
+
"metadata": {},
|
| 397 |
+
"outputs": [],
|
| 398 |
+
"source": [
|
| 399 |
+
"from torch.utils.data import IterableDataset\n",
|
| 400 |
+
"from torch.utils.data.dataloader import DataLoader\n",
|
| 401 |
+
"import random\n",
|
| 402 |
+
"\n",
|
| 403 |
+
"# Create an Iterable dataset that returns constant-length chunks of tokens from a stream of text files.\n",
|
| 404 |
+
"\n",
|
| 405 |
+
"\n",
|
| 406 |
+
"class ConstantLengthDataset(IterableDataset):\n",
|
| 407 |
+
" \"\"\"\n",
|
| 408 |
+
" Iterable dataset that returns constant length chunks of tokens from stream of text files.\n",
|
| 409 |
+
" Args:\n",
|
| 410 |
+
" tokenizer (Tokenizer): The processor used for proccessing the data.\n",
|
| 411 |
+
" dataset (dataset.Dataset): Dataset with text files.\n",
|
| 412 |
+
" infinite (bool): If True the iterator is reset after dataset reaches end else stops.\n",
|
| 413 |
+
" seq_length (int): Length of token sequences to return.\n",
|
| 414 |
+
" num_of_sequences (int): Number of token sequences to keep in buffer.\n",
|
| 415 |
+
" chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.\n",
|
| 416 |
+
" fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM.\n",
|
| 417 |
+
" fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM.\n",
|
| 418 |
+
" seed (int): Seed for random number generator.\n",
|
| 419 |
+
" \"\"\"\n",
|
| 420 |
+
"\n",
|
| 421 |
+
" def __init__(\n",
|
| 422 |
+
" self,\n",
|
| 423 |
+
" tokenizer,\n",
|
| 424 |
+
" dataset,\n",
|
| 425 |
+
" infinite=False,\n",
|
| 426 |
+
" seq_length=1024,\n",
|
| 427 |
+
" num_of_sequences=1024,\n",
|
| 428 |
+
" chars_per_token=3.6,\n",
|
| 429 |
+
" content_field=\"content\",\n",
|
| 430 |
+
" fim_rate=0.5,\n",
|
| 431 |
+
" fim_spm_rate=0.5,\n",
|
| 432 |
+
" seed=0,\n",
|
| 433 |
+
" ):\n",
|
| 434 |
+
" self.tokenizer = tokenizer\n",
|
| 435 |
+
" self.concat_token_id = tokenizer.eos_token_id\n",
|
| 436 |
+
" self.dataset = dataset\n",
|
| 437 |
+
" self.seq_length = seq_length\n",
|
| 438 |
+
" self.infinite = infinite\n",
|
| 439 |
+
" self.current_size = 0\n",
|
| 440 |
+
" self.max_buffer_size = seq_length * chars_per_token * num_of_sequences\n",
|
| 441 |
+
" self.content_field = content_field\n",
|
| 442 |
+
" self.fim_rate = fim_rate\n",
|
| 443 |
+
" self.fim_spm_rate = fim_spm_rate\n",
|
| 444 |
+
" self.seed = seed\n",
|
| 445 |
+
"\n",
|
| 446 |
+
" (\n",
|
| 447 |
+
" self.suffix_tok_id,\n",
|
| 448 |
+
" self.prefix_tok_id,\n",
|
| 449 |
+
" self.middle_tok_id,\n",
|
| 450 |
+
" self.pad_tok_id,\n",
|
| 451 |
+
" ) = get_fim_token_ids(self.tokenizer)\n",
|
| 452 |
+
" if not self.suffix_tok_id and self.fim_rate > 0:\n",
|
| 453 |
+
" print(\"FIM is not supported by tokenizer, disabling FIM\")\n",
|
| 454 |
+
" self.fim_rate = 0\n",
|
| 455 |
+
"\n",
|
| 456 |
+
" def __iter__(self):\n",
|
| 457 |
+
" iterator = iter(self.dataset)\n",
|
| 458 |
+
" more_examples = True\n",
|
| 459 |
+
" np_rng = np.random.RandomState(seed=self.seed)\n",
|
| 460 |
+
" while more_examples:\n",
|
| 461 |
+
" buffer, buffer_len = [], 0\n",
|
| 462 |
+
" while True:\n",
|
| 463 |
+
" if buffer_len >= self.max_buffer_size:\n",
|
| 464 |
+
" break\n",
|
| 465 |
+
" try:\n",
|
| 466 |
+
" buffer.append(next(iterator)[self.content_field])\n",
|
| 467 |
+
" buffer_len += len(buffer[-1])\n",
|
| 468 |
+
" except StopIteration:\n",
|
| 469 |
+
" if self.infinite:\n",
|
| 470 |
+
" iterator = iter(self.dataset)\n",
|
| 471 |
+
" else:\n",
|
| 472 |
+
" more_examples = False\n",
|
| 473 |
+
" break\n",
|
| 474 |
+
" tokenized_inputs = self.tokenizer(buffer, truncation=False)[\"input_ids\"]\n",
|
| 475 |
+
" all_token_ids = []\n",
|
| 476 |
+
"\n",
|
| 477 |
+
" for tokenized_input in tokenized_inputs:\n",
|
| 478 |
+
" # optionally do FIM permutations\n",
|
| 479 |
+
" if self.fim_rate > 0:\n",
|
| 480 |
+
" tokenized_input, np_rng = permute(\n",
|
| 481 |
+
" tokenized_input,\n",
|
| 482 |
+
" np_rng,\n",
|
| 483 |
+
" self.suffix_tok_id,\n",
|
| 484 |
+
" self.prefix_tok_id,\n",
|
| 485 |
+
" self.middle_tok_id,\n",
|
| 486 |
+
" self.pad_tok_id,\n",
|
| 487 |
+
" fim_rate=self.fim_rate,\n",
|
| 488 |
+
" fim_spm_rate=self.fim_spm_rate,\n",
|
| 489 |
+
" truncate_or_pad=False,\n",
|
| 490 |
+
" )\n",
|
| 491 |
+
"\n",
|
| 492 |
+
" all_token_ids.extend(tokenized_input + [self.concat_token_id])\n",
|
| 493 |
+
" examples = []\n",
|
| 494 |
+
" for i in range(0, len(all_token_ids), self.seq_length):\n",
|
| 495 |
+
" input_ids = all_token_ids[i : i + self.seq_length]\n",
|
| 496 |
+
" if len(input_ids) == self.seq_length:\n",
|
| 497 |
+
" examples.append(input_ids)\n",
|
| 498 |
+
" random.shuffle(examples)\n",
|
| 499 |
+
" for example in examples:\n",
|
| 500 |
+
" self.current_size += 1\n",
|
| 501 |
+
" yield {\n",
|
| 502 |
+
" \"input_ids\": torch.LongTensor(example),\n",
|
| 503 |
+
" \"labels\": torch.LongTensor(example),\n",
|
| 504 |
+
" }\n",
|
| 505 |
+
"\n",
|
| 506 |
+
"\n",
|
| 507 |
+
"train_dataset = ConstantLengthDataset(\n",
|
| 508 |
+
" tokenizer,\n",
|
| 509 |
+
" train_data,\n",
|
| 510 |
+
" infinite=True,\n",
|
| 511 |
+
" seq_length=SEQ_LENGTH,\n",
|
| 512 |
+
" chars_per_token=chars_per_token,\n",
|
| 513 |
+
" content_field=DATA_COLUMN,\n",
|
| 514 |
+
" fim_rate=FIM_RATE,\n",
|
| 515 |
+
" fim_spm_rate=FIM_SPM_RATE,\n",
|
| 516 |
+
" seed=SEED,\n",
|
| 517 |
+
")\n",
|
| 518 |
+
"eval_dataset = ConstantLengthDataset(\n",
|
| 519 |
+
" tokenizer,\n",
|
| 520 |
+
" valid_data,\n",
|
| 521 |
+
" infinite=False,\n",
|
| 522 |
+
" seq_length=SEQ_LENGTH,\n",
|
| 523 |
+
" chars_per_token=chars_per_token,\n",
|
| 524 |
+
" content_field=DATA_COLUMN,\n",
|
| 525 |
+
" fim_rate=FIM_RATE,\n",
|
| 526 |
+
" fim_spm_rate=FIM_SPM_RATE,\n",
|
| 527 |
+
" seed=SEED,\n",
|
| 528 |
+
")"
|
| 529 |
+
]
|
| 530 |
+
},
|
| 531 |
+
{
|
| 532 |
+
"cell_type": "code",
|
| 533 |
+
"execution_count": 8,
|
| 534 |
+
"id": "5021e686-2e1c-4608-9477-1f07adf2de35",
|
| 535 |
+
"metadata": {},
|
| 536 |
+
"outputs": [
|
| 537 |
+
{
|
| 538 |
+
"data": {
|
| 539 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 540 |
+
"model_id": "e4ea2b63e63448938020858c06143ca5",
|
| 541 |
+
"version_major": 2,
|
| 542 |
+
"version_minor": 0
|
| 543 |
+
},
|
| 544 |
+
"text/plain": [
|
| 545 |
+
"model.safetensors: 0%| | 0.00/4.55G [00:00<?, ?B/s]"
|
| 546 |
+
]
|
| 547 |
+
},
|
| 548 |
+
"metadata": {},
|
| 549 |
+
"output_type": "display_data"
|
| 550 |
+
},
|
| 551 |
+
{
|
| 552 |
+
"ename": "ImportError",
|
| 553 |
+
"evalue": "FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.",
|
| 554 |
+
"output_type": "error",
|
| 555 |
+
"traceback": [
|
| 556 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 557 |
+
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
|
| 558 |
+
"Cell \u001b[0;32mIn[8], line 26\u001b[0m\n\u001b[1;32m 17\u001b[0m bnb_config \u001b[38;5;241m=\u001b[39m BitsAndBytesConfig(\n\u001b[1;32m 18\u001b[0m load_in_4bit\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 19\u001b[0m bnb_4bit_quant_type\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnf4\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 20\u001b[0m bnb_4bit_compute_dtype\u001b[38;5;241m=\u001b[39mcompute_dtype,\n\u001b[1;32m 21\u001b[0m bnb_4bit_use_double_quant\u001b[38;5;241m=\u001b[39mUSE_NESTED_QUANT,\n\u001b[1;32m 22\u001b[0m )\n\u001b[1;32m 24\u001b[0m device_map \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m0\u001b[39m}\n\u001b[0;32m---> 26\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mAutoModelForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 27\u001b[0m \u001b[43m \u001b[49m\u001b[43mMODEL\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 28\u001b[0m \u001b[43m \u001b[49m\u001b[43mquantization_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbnb_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 29\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 30\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# We will be using gradient checkpointing\u001b[39;49;00m\n\u001b[1;32m 31\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 32\u001b[0m \u001b[43m \u001b[49m\u001b[43mattn_implementation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mflash_attention_2\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 33\u001b[0m \u001b[43m)\u001b[49m\n",
|
| 559 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py:564\u001b[0m, in \u001b[0;36m_BaseAutoModelClass.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 562\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(config) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_model_mapping\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[1;32m 563\u001b[0m model_class \u001b[38;5;241m=\u001b[39m _get_model_class(config, \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_model_mapping)\n\u001b[0;32m--> 564\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodel_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 565\u001b[0m \u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhub_kwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 566\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 568\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnrecognized configuration class \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m for this kind of AutoModel: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 569\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModel type should be one of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(c\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mfor\u001b[39;00m\u001b[38;5;250m \u001b[39mc\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01min\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_model_mapping\u001b[38;5;241m.\u001b[39mkeys())\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 570\u001b[0m )\n",
|
| 560 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:3804\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3801\u001b[0m init_contexts\u001b[38;5;241m.\u001b[39mappend(init_empty_weights())\n\u001b[1;32m 3803\u001b[0m config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(config) \u001b[38;5;66;03m# We do not want to modify the config inplace in from_pretrained.\u001b[39;00m\n\u001b[0;32m-> 3804\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_autoset_attn_implementation\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3805\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muse_flash_attention_2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_flash_attention_2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch_dtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\n\u001b[1;32m 3806\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3808\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ContextManagers(init_contexts):\n\u001b[1;32m 3809\u001b[0m \u001b[38;5;66;03m# Let's make sure we don't run the init function of buffer modules\u001b[39;00m\n\u001b[1;32m 3810\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m(config, \u001b[38;5;241m*\u001b[39mmodel_args, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs)\n",
|
| 561 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:1534\u001b[0m, in \u001b[0;36mPreTrainedModel._autoset_attn_implementation\u001b[0;34m(cls, config, use_flash_attention_2, torch_dtype, device_map, check_device_map)\u001b[0m\n\u001b[1;32m 1531\u001b[0m config\u001b[38;5;241m.\u001b[39m_attn_implementation \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attention_2\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1533\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config\u001b[38;5;241m.\u001b[39m_attn_implementation \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attention_2\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m-> 1534\u001b[0m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_check_and_enable_flash_attn_2\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1535\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1536\u001b[0m \u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch_dtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1537\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1538\u001b[0m \u001b[43m \u001b[49m\u001b[43mhard_check_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1539\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_device_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheck_device_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1540\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1541\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m requested_attn_implementation \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msdpa\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available():\n\u001b[1;32m 1542\u001b[0m \u001b[38;5;66;03m# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.\u001b[39;00m\n\u001b[1;32m 1543\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_check_and_enable_sdpa(\n\u001b[1;32m 1544\u001b[0m config,\n\u001b[1;32m 1545\u001b[0m hard_check_only\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m requested_attn_implementation \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 1546\u001b[0m )\n",
|
| 562 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:1636\u001b[0m, in \u001b[0;36mPreTrainedModel._check_and_enable_flash_attn_2\u001b[0;34m(cls, config, torch_dtype, device_map, check_device_map, hard_check_only)\u001b[0m\n\u001b[1;32m 1633\u001b[0m install_message \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1635\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m importlib\u001b[38;5;241m.\u001b[39mutil\u001b[38;5;241m.\u001b[39mfind_spec(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attn\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1636\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpreface\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m the package flash_attn seems to be not installed. \u001b[39m\u001b[38;5;132;01m{\u001b[39;00minstall_message\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1638\u001b[0m flash_attention_version \u001b[38;5;241m=\u001b[39m version\u001b[38;5;241m.\u001b[39mparse(importlib\u001b[38;5;241m.\u001b[39mmetadata\u001b[38;5;241m.\u001b[39mversion(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attn\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 1639\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mversion\u001b[38;5;241m.\u001b[39mcuda:\n",
|
| 563 |
+
"\u001b[0;31mImportError\u001b[0m: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
|
| 564 |
+
]
|
| 565 |
+
}
|
| 566 |
+
],
|
| 567 |
+
"source": [
|
| 568 |
+
"import torch\n",
|
| 569 |
+
"from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
|
| 570 |
+
"from peft.tuners.lora import LoraLayer\n",
|
| 571 |
+
"\n",
|
| 572 |
+
"load_in_8bit = False\n",
|
| 573 |
+
"\n",
|
| 574 |
+
"# 4-bit quantization\n",
|
| 575 |
+
"compute_dtype = getattr(torch, BNB_4BIT_COMPUTE_DTYPE)\n",
|
| 576 |
+
"\n",
|
| 577 |
+
"bnb_config2 = BitsAndBytesConfig(\n",
|
| 578 |
+
" load_in_4bit=True,\n",
|
| 579 |
+
" bnb_4bit_use_double_quant=True,\n",
|
| 580 |
+
" bnb_4bit_quant_type=\"nf4\",\n",
|
| 581 |
+
" bnb_4bit_compute_dtype=torch.bfloat16\n",
|
| 582 |
+
")\n",
|
| 583 |
+
"\n",
|
| 584 |
+
"bnb_config = BitsAndBytesConfig(\n",
|
| 585 |
+
" load_in_4bit=True,\n",
|
| 586 |
+
" bnb_4bit_quant_type=\"nf4\",\n",
|
| 587 |
+
" bnb_4bit_compute_dtype=compute_dtype,\n",
|
| 588 |
+
" bnb_4bit_use_double_quant=USE_NESTED_QUANT,\n",
|
| 589 |
+
")\n",
|
| 590 |
+
"\n",
|
| 591 |
+
"device_map = {\"\": 0}\n",
|
| 592 |
+
"\n",
|
| 593 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 594 |
+
" MODEL,\n",
|
| 595 |
+
" quantization_config=bnb_config,\n",
|
| 596 |
+
" device_map=device_map,\n",
|
| 597 |
+
" use_cache=False, # We will be using gradient checkpointing\n",
|
| 598 |
+
" trust_remote_code=True,\n",
|
| 599 |
+
" attn_implementation=\"flash_attention_2\",\n",
|
| 600 |
+
")"
|
| 601 |
+
]
|
| 602 |
+
},
|
| 603 |
+
{
|
| 604 |
+
"cell_type": "code",
|
| 605 |
+
"execution_count": null,
|
| 606 |
+
"id": "7c0a9fdb-4087-4ec5-aefa-fc5f413252e4",
|
| 607 |
+
"metadata": {},
|
| 608 |
+
"outputs": [],
|
| 609 |
+
"source": [
|
| 610 |
+
"model = prepare_model_for_kbit_training(model)"
|
| 611 |
+
]
|
| 612 |
+
},
|
| 613 |
+
{
|
| 614 |
+
"cell_type": "code",
|
| 615 |
+
"execution_count": null,
|
| 616 |
+
"id": "b5fd3293-76ef-48eb-8241-478e311ec947",
|
| 617 |
+
"metadata": {},
|
| 618 |
+
"outputs": [],
|
| 619 |
+
"source": [
|
| 620 |
+
"# Set up lora\n",
|
| 621 |
+
"peft_config = LoraConfig(\n",
|
| 622 |
+
" lora_alpha=LORA_ALPHA,\n",
|
| 623 |
+
" lora_dropout=LORA_DROPOUT,\n",
|
| 624 |
+
" r=LORA_R,\n",
|
| 625 |
+
" bias=\"none\",\n",
|
| 626 |
+
" task_type=\"CAUSAL_LM\",\n",
|
| 627 |
+
" target_modules=LORA_TARGET_MODULES.split(\",\"),\n",
|
| 628 |
+
")\n",
|
| 629 |
+
"\n",
|
| 630 |
+
"model = get_peft_model(model, peft_config)\n",
|
| 631 |
+
"model.print_trainable_parameters()"
|
| 632 |
+
]
|
| 633 |
+
},
|
| 634 |
+
{
|
| 635 |
+
"cell_type": "code",
|
| 636 |
+
"execution_count": null,
|
| 637 |
+
"id": "082c6a7b-db61-4800-8a94-419331b1fd22",
|
| 638 |
+
"metadata": {},
|
| 639 |
+
"outputs": [],
|
| 640 |
+
"source": [
|
| 641 |
+
"train_data.start_iteration = 0\n",
|
| 642 |
+
"\n",
|
| 643 |
+
"\n",
|
| 644 |
+
"training_args = TrainingArguments(\n",
|
| 645 |
+
" output_dir=f\"ernyou/{OUTPUT_DIR}\",\n",
|
| 646 |
+
" dataloader_drop_last=True,\n",
|
| 647 |
+
" eval_strategy=\"steps\",\n",
|
| 648 |
+
" save_strategy=\"steps\",\n",
|
| 649 |
+
" max_steps=MAX_STEPS,\n",
|
| 650 |
+
" eval_steps=EVAL_FREQ,\n",
|
| 651 |
+
" save_steps=SAVE_FREQ,\n",
|
| 652 |
+
" logging_steps=LOG_FREQ,\n",
|
| 653 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
| 654 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
| 655 |
+
" learning_rate=LR,\n",
|
| 656 |
+
" lr_scheduler_type=LR_SCHEDULER_TYPE,\n",
|
| 657 |
+
" warmup_steps=NUM_WARMUP_STEPS,\n",
|
| 658 |
+
" gradient_accumulation_steps=GR_ACC_STEPS,\n",
|
| 659 |
+
" gradient_checkpointing_kwargs={\"use_reentrant\": True},\n",
|
| 660 |
+
" gradient_checkpointing=True,\n",
|
| 661 |
+
" fp16=FP16,\n",
|
| 662 |
+
" bf16=BF16,\n",
|
| 663 |
+
" weight_decay=WEIGHT_DECAY,\n",
|
| 664 |
+
" push_to_hub=True,\n",
|
| 665 |
+
" include_tokens_per_second=True,\n",
|
| 666 |
+
")"
|
| 667 |
+
]
|
| 668 |
+
},
|
| 669 |
+
{
|
| 670 |
+
"cell_type": "code",
|
| 671 |
+
"execution_count": 10,
|
| 672 |
+
"id": "2c302ded-f017-433c-9622-55ecb45141bd",
|
| 673 |
+
"metadata": {},
|
| 674 |
+
"outputs": [
|
| 675 |
+
{
|
| 676 |
+
"name": "stdout",
|
| 677 |
+
"output_type": "stream",
|
| 678 |
+
"text": [
|
| 679 |
+
"1.3.0\n"
|
| 680 |
+
]
|
| 681 |
+
}
|
| 682 |
+
],
|
| 683 |
+
"source": [
|
| 684 |
+
"import accelerate as ac\n",
|
| 685 |
+
"print(ac.__version__)"
|
| 686 |
+
]
|
| 687 |
+
},
|
| 688 |
+
{
|
| 689 |
+
"cell_type": "code",
|
| 690 |
+
"execution_count": null,
|
| 691 |
+
"id": "5318efa4-83da-41fa-9123-50b505e9a615",
|
| 692 |
+
"metadata": {},
|
| 693 |
+
"outputs": [],
|
| 694 |
+
"source": [
|
| 695 |
+
"import torch\n",
|
| 696 |
+
"torch.cuda.empty_cache()"
|
| 697 |
+
]
|
| 698 |
+
},
|
| 699 |
+
{
|
| 700 |
+
"cell_type": "code",
|
| 701 |
+
"execution_count": null,
|
| 702 |
+
"id": "ceddceb0-8e1e-493f-99a1-b77c6b0c40b6",
|
| 703 |
+
"metadata": {},
|
| 704 |
+
"outputs": [],
|
| 705 |
+
"source": [
|
| 706 |
+
"trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset)\n",
|
| 707 |
+
"\n",
|
| 708 |
+
"print(\"Training...\")\n",
|
| 709 |
+
"trainer.train()"
|
| 710 |
+
]
|
| 711 |
+
},
|
| 712 |
+
{
|
| 713 |
+
"cell_type": "code",
|
| 714 |
+
"execution_count": null,
|
| 715 |
+
"id": "76fb1530-5e16-4b4b-a30f-b689df3483f3",
|
| 716 |
+
"metadata": {},
|
| 717 |
+
"outputs": [],
|
| 718 |
+
"source": []
|
| 719 |
+
}
|
| 720 |
+
],
|
| 721 |
+
"metadata": {
|
| 722 |
+
"kernelspec": {
|
| 723 |
+
"display_name": "Python 3 (ipykernel)",
|
| 724 |
+
"language": "python",
|
| 725 |
+
"name": "python3"
|
| 726 |
+
},
|
| 727 |
+
"language_info": {
|
| 728 |
+
"codemirror_mode": {
|
| 729 |
+
"name": "ipython",
|
| 730 |
+
"version": 3
|
| 731 |
+
},
|
| 732 |
+
"file_extension": ".py",
|
| 733 |
+
"mimetype": "text/x-python",
|
| 734 |
+
"name": "python",
|
| 735 |
+
"nbconvert_exporter": "python",
|
| 736 |
+
"pygments_lexer": "ipython3",
|
| 737 |
+
"version": "3.10.12"
|
| 738 |
+
}
|
| 739 |
+
},
|
| 740 |
+
"nbformat": 4,
|
| 741 |
+
"nbformat_minor": 5
|
| 742 |
+
}
|
train.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL = "bigcode/starcoderbase-1b" # Model checkpoint on the Hugging Face Hub
|
| 2 |
+
DATASET = "smangrul/hf-stack-v1" # Dataset on the Hugging Face Hub
|
| 3 |
+
DATA_COLUMN = "content" # Column name containing the code content
|
| 4 |
+
|
| 5 |
+
SEQ_LENGTH = 2048 # Sequence length
|
| 6 |
+
|
| 7 |
+
MAX_STEPS = 2000 # max_steps
|
| 8 |
+
BATCH_SIZE = 8 # batch_size
|
| 9 |
+
GR_ACC_STEPS = 1 # gradient_accumulation_steps
|
| 10 |
+
LR = 5e-4 # learning_rate
|
| 11 |
+
LR_SCHEDULER_TYPE = "cosine" # lr_scheduler_type
|
| 12 |
+
WEIGHT_DECAY = 0.01 # weight_decay
|
| 13 |
+
NUM_WARMUP_STEPS = 30 # num_warmup_steps
|
| 14 |
+
EVAL_FREQ = 100 # eval_freq
|
| 15 |
+
SAVE_FREQ = 100 # save_freq
|
| 16 |
+
LOG_FREQ = 25 # log_freq
|
| 17 |
+
OUTPUT_DIR = "peft-starcoder-lora-a100" # output_dir
|
| 18 |
+
BF16 = False # bf16
|
| 19 |
+
FP16 = False # no_fp16
|
| 20 |
+
|
| 21 |
+
# FIM trasformations arguments
|
| 22 |
+
FIM_RATE = 0.5 # fim_rate
|
| 23 |
+
FIM_SPM_RATE = 0.5 # fim_spm_rate
|
| 24 |
+
|
| 25 |
+
# LORA
|
| 26 |
+
LORA_R = 8 # lora_r
|
| 27 |
+
LORA_ALPHA = 32 # lora_alpha
|
| 28 |
+
LORA_DROPOUT = 0.0 # lora_dropout
|
| 29 |
+
LORA_TARGET_MODULES = "c_proj,c_attn,q_attn,c_fc,c_proj" # lora_target_modules
|
| 30 |
+
|
| 31 |
+
# bitsandbytes config
|
| 32 |
+
#USE_NESTED_QUANT = True # use_nested_quant
|
| 33 |
+
#BNB_4BIT_COMPUTE_DTYPE = "bfloat16" # bnb_4bit_compute_dtype
|
| 34 |
+
|
| 35 |
+
SEED = 0
|
| 36 |
+
|
| 37 |
+
from huggingface_hub import login
|
| 38 |
+
from transformers import (
|
| 39 |
+
AutoModelForCausalLM,
|
| 40 |
+
AutoTokenizer,
|
| 41 |
+
Trainer,
|
| 42 |
+
TrainingArguments,
|
| 43 |
+
logging,
|
| 44 |
+
set_seed,
|
| 45 |
+
BitsAndBytesConfig,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
from datasets import load_dataset
|
| 49 |
+
import torch
|
| 50 |
+
from tqdm import tqdm
|
| 51 |
+
|
| 52 |
+
#Prepare Data
|
| 53 |
+
dataset = load_dataset(
|
| 54 |
+
DATASET,
|
| 55 |
+
data_dir="data",
|
| 56 |
+
split="train",
|
| 57 |
+
streaming=True,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
valid_data = dataset.take(4000)
|
| 61 |
+
train_data = dataset.skip(4000)
|
| 62 |
+
train_data = train_data.shuffle(buffer_size=5000, seed=SEED)
|
| 63 |
+
|
| 64 |
+
set_seed(SEED)
|
| 65 |
+
|
| 66 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def chars_token_ratio(dataset, tokenizer, data_column, nb_examples=400):
|
| 70 |
+
"""
|
| 71 |
+
Estimate the average number of characters per token in the dataset.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
total_characters, total_tokens = 0, 0
|
| 75 |
+
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
|
| 76 |
+
total_characters += len(example[data_column])
|
| 77 |
+
total_tokens += len(tokenizer(example[data_column]).tokens())
|
| 78 |
+
|
| 79 |
+
return total_characters / total_tokens
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
chars_per_token = chars_token_ratio(train_data, tokenizer, DATA_COLUMN)
|
| 83 |
+
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
|
| 84 |
+
|
| 85 |
+
import functools
|
| 86 |
+
import numpy as np
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Helper function to get token ids of the special tokens for prefix, suffix and middle for FIM transformations.
|
| 90 |
+
@functools.lru_cache(maxsize=None)
|
| 91 |
+
def get_fim_token_ids(tokenizer):
|
| 92 |
+
try:
|
| 93 |
+
FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD = tokenizer.special_tokens_map["additional_special_tokens"][1:5]
|
| 94 |
+
suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = (
|
| 95 |
+
tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD]
|
| 96 |
+
)
|
| 97 |
+
except KeyError:
|
| 98 |
+
suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = None, None, None, None
|
| 99 |
+
return suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
## Adapted from https://github.com/bigcode-project/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py
|
| 103 |
+
def permute(
|
| 104 |
+
sample,
|
| 105 |
+
np_rng,
|
| 106 |
+
suffix_tok_id,
|
| 107 |
+
prefix_tok_id,
|
| 108 |
+
middle_tok_id,
|
| 109 |
+
pad_tok_id,
|
| 110 |
+
fim_rate=0.5,
|
| 111 |
+
fim_spm_rate=0.5,
|
| 112 |
+
truncate_or_pad=False,
|
| 113 |
+
):
|
| 114 |
+
"""
|
| 115 |
+
Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes:
|
| 116 |
+
PSM and SPM (with a probability of fim_spm_rate).
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
# The if condition will trigger with the probability of fim_rate
|
| 120 |
+
# This means FIM transformations will apply to samples with a probability of fim_rate
|
| 121 |
+
if np_rng.binomial(1, fim_rate):
|
| 122 |
+
|
| 123 |
+
# Split the sample into prefix, middle, and suffix, based on randomly generated indices stored in the boundaries list.
|
| 124 |
+
boundaries = list(np_rng.randint(low=0, high=len(sample) + 1, size=2))
|
| 125 |
+
boundaries.sort()
|
| 126 |
+
|
| 127 |
+
prefix = np.array(sample[: boundaries[0]], dtype=np.int64)
|
| 128 |
+
middle = np.array(sample[boundaries[0] : boundaries[1]], dtype=np.int64)
|
| 129 |
+
suffix = np.array(sample[boundaries[1] :], dtype=np.int64)
|
| 130 |
+
|
| 131 |
+
if truncate_or_pad:
|
| 132 |
+
# calculate the new total length of the sample, taking into account tokens indicating prefix, middle, and suffix
|
| 133 |
+
new_length = suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3
|
| 134 |
+
diff = new_length - len(sample)
|
| 135 |
+
|
| 136 |
+
# trancate or pad if there's a difference in length between the new length and the original
|
| 137 |
+
if diff > 0:
|
| 138 |
+
if suffix.shape[0] <= diff:
|
| 139 |
+
return sample, np_rng
|
| 140 |
+
suffix = suffix[: suffix.shape[0] - diff]
|
| 141 |
+
elif diff < 0:
|
| 142 |
+
suffix = np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)])
|
| 143 |
+
|
| 144 |
+
# With the probability of fim_spm_rateapply SPM variant of FIM transformations
|
| 145 |
+
# SPM: suffix, prefix, middle
|
| 146 |
+
if np_rng.binomial(1, fim_spm_rate):
|
| 147 |
+
new_sample = np.concatenate(
|
| 148 |
+
[
|
| 149 |
+
[prefix_tok_id, suffix_tok_id],
|
| 150 |
+
suffix,
|
| 151 |
+
[middle_tok_id],
|
| 152 |
+
prefix,
|
| 153 |
+
middle,
|
| 154 |
+
]
|
| 155 |
+
)
|
| 156 |
+
# Otherwise, apply the PSM variant of FIM transformations
|
| 157 |
+
# PSM: prefix, suffix, middle
|
| 158 |
+
else:
|
| 159 |
+
|
| 160 |
+
new_sample = np.concatenate(
|
| 161 |
+
[
|
| 162 |
+
[prefix_tok_id],
|
| 163 |
+
prefix,
|
| 164 |
+
[suffix_tok_id],
|
| 165 |
+
suffix,
|
| 166 |
+
[middle_tok_id],
|
| 167 |
+
middle,
|
| 168 |
+
]
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
# don't apply FIM transformations
|
| 172 |
+
new_sample = sample
|
| 173 |
+
|
| 174 |
+
return list(new_sample), np_rng
|
| 175 |
+
|
| 176 |
+
from torch.utils.data import IterableDataset
|
| 177 |
+
from torch.utils.data.dataloader import DataLoader
|
| 178 |
+
import random
|
| 179 |
+
|
| 180 |
+
# Create an Iterable dataset that returns constant-length chunks of tokens from a stream of text files.
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class ConstantLengthDataset(IterableDataset):
|
| 184 |
+
"""
|
| 185 |
+
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
| 186 |
+
Args:
|
| 187 |
+
tokenizer (Tokenizer): The processor used for proccessing the data.
|
| 188 |
+
dataset (dataset.Dataset): Dataset with text files.
|
| 189 |
+
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
|
| 190 |
+
seq_length (int): Length of token sequences to return.
|
| 191 |
+
num_of_sequences (int): Number of token sequences to keep in buffer.
|
| 192 |
+
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
|
| 193 |
+
fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM.
|
| 194 |
+
fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM.
|
| 195 |
+
seed (int): Seed for random number generator.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
def __init__(
|
| 199 |
+
self,
|
| 200 |
+
tokenizer,
|
| 201 |
+
dataset,
|
| 202 |
+
infinite=False,
|
| 203 |
+
seq_length=1024,
|
| 204 |
+
num_of_sequences=1024,
|
| 205 |
+
chars_per_token=3.6,
|
| 206 |
+
content_field="content",
|
| 207 |
+
fim_rate=0.5,
|
| 208 |
+
fim_spm_rate=0.5,
|
| 209 |
+
seed=0,
|
| 210 |
+
):
|
| 211 |
+
self.tokenizer = tokenizer
|
| 212 |
+
self.concat_token_id = tokenizer.eos_token_id
|
| 213 |
+
self.dataset = dataset
|
| 214 |
+
self.seq_length = seq_length
|
| 215 |
+
self.infinite = infinite
|
| 216 |
+
self.current_size = 0
|
| 217 |
+
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
|
| 218 |
+
self.content_field = content_field
|
| 219 |
+
self.fim_rate = fim_rate
|
| 220 |
+
self.fim_spm_rate = fim_spm_rate
|
| 221 |
+
self.seed = seed
|
| 222 |
+
|
| 223 |
+
(
|
| 224 |
+
self.suffix_tok_id,
|
| 225 |
+
self.prefix_tok_id,
|
| 226 |
+
self.middle_tok_id,
|
| 227 |
+
self.pad_tok_id,
|
| 228 |
+
) = get_fim_token_ids(self.tokenizer)
|
| 229 |
+
if not self.suffix_tok_id and self.fim_rate > 0:
|
| 230 |
+
print("FIM is not supported by tokenizer, disabling FIM")
|
| 231 |
+
self.fim_rate = 0
|
| 232 |
+
|
| 233 |
+
def __iter__(self):
|
| 234 |
+
iterator = iter(self.dataset)
|
| 235 |
+
more_examples = True
|
| 236 |
+
np_rng = np.random.RandomState(seed=self.seed)
|
| 237 |
+
while more_examples:
|
| 238 |
+
buffer, buffer_len = [], 0
|
| 239 |
+
while True:
|
| 240 |
+
if buffer_len >= self.max_buffer_size:
|
| 241 |
+
break
|
| 242 |
+
try:
|
| 243 |
+
buffer.append(next(iterator)[self.content_field])
|
| 244 |
+
buffer_len += len(buffer[-1])
|
| 245 |
+
except StopIteration:
|
| 246 |
+
if self.infinite:
|
| 247 |
+
iterator = iter(self.dataset)
|
| 248 |
+
else:
|
| 249 |
+
more_examples = False
|
| 250 |
+
break
|
| 251 |
+
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
|
| 252 |
+
all_token_ids = []
|
| 253 |
+
|
| 254 |
+
for tokenized_input in tokenized_inputs:
|
| 255 |
+
# optionally do FIM permutations
|
| 256 |
+
if self.fim_rate > 0:
|
| 257 |
+
tokenized_input, np_rng = permute(
|
| 258 |
+
tokenized_input,
|
| 259 |
+
np_rng,
|
| 260 |
+
self.suffix_tok_id,
|
| 261 |
+
self.prefix_tok_id,
|
| 262 |
+
self.middle_tok_id,
|
| 263 |
+
self.pad_tok_id,
|
| 264 |
+
fim_rate=self.fim_rate,
|
| 265 |
+
fim_spm_rate=self.fim_spm_rate,
|
| 266 |
+
truncate_or_pad=False,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
all_token_ids.extend(tokenized_input + [self.concat_token_id])
|
| 270 |
+
examples = []
|
| 271 |
+
for i in range(0, len(all_token_ids), self.seq_length):
|
| 272 |
+
input_ids = all_token_ids[i : i + self.seq_length]
|
| 273 |
+
if len(input_ids) == self.seq_length:
|
| 274 |
+
examples.append(input_ids)
|
| 275 |
+
random.shuffle(examples)
|
| 276 |
+
for example in examples:
|
| 277 |
+
self.current_size += 1
|
| 278 |
+
yield {
|
| 279 |
+
"input_ids": torch.LongTensor(example),
|
| 280 |
+
"labels": torch.LongTensor(example),
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
train_dataset = ConstantLengthDataset(
|
| 285 |
+
tokenizer,
|
| 286 |
+
train_data,
|
| 287 |
+
infinite=True,
|
| 288 |
+
seq_length=SEQ_LENGTH,
|
| 289 |
+
chars_per_token=chars_per_token,
|
| 290 |
+
content_field=DATA_COLUMN,
|
| 291 |
+
fim_rate=FIM_RATE,
|
| 292 |
+
fim_spm_rate=FIM_SPM_RATE,
|
| 293 |
+
seed=SEED,
|
| 294 |
+
)
|
| 295 |
+
eval_dataset = ConstantLengthDataset(
|
| 296 |
+
tokenizer,
|
| 297 |
+
valid_data,
|
| 298 |
+
infinite=False,
|
| 299 |
+
seq_length=SEQ_LENGTH,
|
| 300 |
+
chars_per_token=chars_per_token,
|
| 301 |
+
content_field=DATA_COLUMN,
|
| 302 |
+
fim_rate=FIM_RATE,
|
| 303 |
+
fim_spm_rate=FIM_SPM_RATE,
|
| 304 |
+
seed=SEED,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
import torch
|
| 308 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 309 |
+
from peft.tuners.lora import LoraLayer
|
| 310 |
+
|
| 311 |
+
#load_in_8bit = False
|
| 312 |
+
|
| 313 |
+
# 4-bit quantization
|
| 314 |
+
#compute_dtype = getattr(torch, BNB_4BIT_COMPUTE_DTYPE)
|
| 315 |
+
#compute_float32 = torch.float32
|
| 316 |
+
|
| 317 |
+
#bnb_config = BitsAndBytesConfig(
|
| 318 |
+
# load_in_4bit=True,
|
| 319 |
+
# bnb_4bit_quant_type="nf4",
|
| 320 |
+
# bnb_4bit_compute_dtype=compute_float32,
|
| 321 |
+
# bnb_4bit_use_double_quant=USE_NESTED_QUANT,
|
| 322 |
+
# bnb_4bit_quant_storage=compute_float32
|
| 323 |
+
#)
|
| 324 |
+
|
| 325 |
+
#import os
|
| 326 |
+
#device_map = int(os.environ.get("LOCAL_RANK", -1))
|
| 327 |
+
|
| 328 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 329 |
+
MODEL,
|
| 330 |
+
#quantization_config=bnb_config,
|
| 331 |
+
device_map=None,
|
| 332 |
+
use_cache=False, # We will be using gradient checkpointing
|
| 333 |
+
trust_remote_code=True,
|
| 334 |
+
torch_dtype = torch.float32,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
#from collections import Counter
|
| 338 |
+
#print(Counter(p.dtype for p in model.parameters()))
|
| 339 |
+
|
| 340 |
+
#model = prepare_model_for_kbit_training(model)
|
| 341 |
+
|
| 342 |
+
#from collections import Counter
|
| 343 |
+
#print("after prepare_model_for_kbit_training ", Counter(p.dtype for p in model.parameters()))
|
| 344 |
+
|
| 345 |
+
peft_config = LoraConfig(
|
| 346 |
+
lora_alpha=LORA_ALPHA,
|
| 347 |
+
lora_dropout=LORA_DROPOUT,
|
| 348 |
+
r=LORA_R,
|
| 349 |
+
bias="none",
|
| 350 |
+
task_type="CAUSAL_LM",
|
| 351 |
+
target_modules=LORA_TARGET_MODULES.split(","),
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
model = get_peft_model(model, peft_config)
|
| 355 |
+
model.print_trainable_parameters()
|
| 356 |
+
#from collections import Counter
|
| 357 |
+
#print("after get_peft_model ", Counter(p.dtype for p in model.parameters()))
|
| 358 |
+
|
| 359 |
+
train_data.start_iteration = 0
|
| 360 |
+
|
| 361 |
+
training_args = TrainingArguments(
|
| 362 |
+
output_dir=f"limernyou/{OUTPUT_DIR}",
|
| 363 |
+
dataloader_drop_last=True,
|
| 364 |
+
eval_strategy="steps",
|
| 365 |
+
save_strategy="steps",
|
| 366 |
+
max_steps=MAX_STEPS,
|
| 367 |
+
eval_steps=EVAL_FREQ,
|
| 368 |
+
save_steps=SAVE_FREQ,
|
| 369 |
+
logging_steps=LOG_FREQ,
|
| 370 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 371 |
+
per_device_eval_batch_size=BATCH_SIZE,
|
| 372 |
+
learning_rate=LR,
|
| 373 |
+
lr_scheduler_type=LR_SCHEDULER_TYPE,
|
| 374 |
+
warmup_steps=NUM_WARMUP_STEPS,
|
| 375 |
+
gradient_accumulation_steps=GR_ACC_STEPS,
|
| 376 |
+
gradient_checkpointing_kwargs={"use_reentrant": False},
|
| 377 |
+
gradient_checkpointing=True,
|
| 378 |
+
fp16=FP16,
|
| 379 |
+
bf16=BF16,
|
| 380 |
+
weight_decay=WEIGHT_DECAY,
|
| 381 |
+
push_to_hub=True,
|
| 382 |
+
include_tokens_per_second=True,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
#from trl import SFTConfig, SFTTrainer
|
| 386 |
+
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
| 387 |
+
|
| 388 |
+
print("Training...")
|
| 389 |
+
trainer.train()
|
| 390 |
+
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
| 391 |
+
trainer.save_model()
|
| 392 |
+
trainer.push_to_hub()
|
use_model.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from peft import PeftModel
|
| 2 |
+
from huggingface_hub import login
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from transformers import (
|
| 6 |
+
AutoModelForCausalLM,
|
| 7 |
+
AutoTokenizer,
|
| 8 |
+
Trainer,
|
| 9 |
+
TrainingArguments,
|
| 10 |
+
logging,
|
| 11 |
+
set_seed,
|
| 12 |
+
BitsAndBytesConfig,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
MODEL = "bigcode/starcoderbase-1b" # Model checkpoint on the Hugging Face Hub
|
| 16 |
+
|
| 17 |
+
# load the original model first
|
| 18 |
+
print("Load Tokenizer")
|
| 19 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
|
| 20 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 21 |
+
tokenizer.padding_side = "left"
|
| 22 |
+
print("Load Model")
|
| 23 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 24 |
+
MODEL,
|
| 25 |
+
quantization_config=None,
|
| 26 |
+
device_map=None,
|
| 27 |
+
trust_remote_code=True,
|
| 28 |
+
torch_dtype=torch.float32,
|
| 29 |
+
).cuda()
|
| 30 |
+
|
| 31 |
+
# merge fine-tuned weights with the base model
|
| 32 |
+
peft_model_id = f"limernyou/starcoder-peft-conti"
|
| 33 |
+
model = PeftModel.from_pretrained(base_model, peft_model_id, adapter_name="personal_copilot")
|
| 34 |
+
#model.add_weighted_adapter(["personal_copilot"], [0.8], "best_personal_copilot")
|
| 35 |
+
#model.set_adapter("best_personal_copilot")
|
| 36 |
+
|
| 37 |
+
model.merge_and_unload()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
#if not hasattr(model, "hf_device_map"):
|
| 41 |
+
# model.cuda()
|
| 42 |
+
|
| 43 |
+
def get_code_completion(prefix, suffix):
|
| 44 |
+
text = prompt = f"""<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>"""
|
| 45 |
+
base_model.eval()
|
| 46 |
+
outputs = base_model.generate(
|
| 47 |
+
input_ids=tokenizer(text, return_tensors="pt").input_ids.cuda(),
|
| 48 |
+
#attention_mask=tokenizer(prompt, return_tensors="pt").to("cuda")["attention_mask"],
|
| 49 |
+
max_new_tokens=128,
|
| 50 |
+
temperature=0.2,
|
| 51 |
+
top_k=50,
|
| 52 |
+
top_p=0.95,
|
| 53 |
+
do_sample=True,
|
| 54 |
+
repetition_penalty=1.0,
|
| 55 |
+
)
|
| 56 |
+
return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 57 |
+
|
| 58 |
+
def get_code_completion1(prefix, suffix):
|
| 59 |
+
prompt = prefix + suffix
|
| 60 |
+
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
| 61 |
+
model.eval()
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
outputs = model.generate(
|
| 64 |
+
input_ids=inputs["input_ids"],
|
| 65 |
+
attention_mask=inputs["attention_mask"],
|
| 66 |
+
max_new_tokens=128,
|
| 67 |
+
temperature=0.2,
|
| 68 |
+
top_k=50,
|
| 69 |
+
top_p=0.95,
|
| 70 |
+
do_sample=True,
|
| 71 |
+
repetition_penalty=1.0,
|
| 72 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 73 |
+
)
|
| 74 |
+
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 75 |
+
completion = output_text.split("<fim_middle>")[-1].strip()
|
| 76 |
+
return completion
|
| 77 |
+
|
| 78 |
+
prefix = """from peft import LoraConfig, TaskType, get_peft_model
|
| 79 |
+
from transformers import AutoModelForCausalLM
|
| 80 |
+
peft_config = LoraConfig(
|
| 81 |
+
"""
|
| 82 |
+
suffix = """"""
|
| 83 |
+
|
| 84 |
+
print("Starcoder generating response")
|
| 85 |
+
#print(tokenizer.special_tokens_map)
|
| 86 |
+
print(get_code_completion(prefix, suffix))
|
| 87 |
+
|
| 88 |
+
print("Successful")
|