Ern You commited on
Commit
131e609
·
1 Parent(s): 6159095

initial commit

Browse files
Files changed (4) hide show
  1. default_config.yaml +25 -0
  2. train.ipynb +742 -0
  3. train.py +392 -0
  4. use_model.py +88 -0
default_config.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: FSDP
4
+ downcast_bf16: 'no'
5
+ fsdp_config:
6
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
7
+ fsdp_backward_prefetch: BACKWARD_PRE
8
+ fsdp_cpu_ram_efficient_loading: true
9
+ fsdp_forward_prefetch: false
10
+ fsdp_offload_params: true
11
+ fsdp_sharding_strategy: FULL_SHARD
12
+ fsdp_state_dict_type: SHARDED_STATE_DICT
13
+ fsdp_sync_module_states: true
14
+ fsdp_use_orig_params: True
15
+ machine_rank: 0
16
+ main_training_function: train
17
+ mixed_precision: 'no'
18
+ num_machines: 1
19
+ num_processes: 8
20
+ rdzv_backend: static
21
+ same_network: true
22
+ tpu_env: []
23
+ tpu_use_cluster: false
24
+ tpu_use_sudo: false
25
+ use_cpu: false
train.ipynb ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "6c066f8c-5b3b-486b-b958-76cc9d380146",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "os.environ[\"TORCH_USE_CUDA_DSA\"] = \"1\" # Enable CUDA Dynamic Shared Allocation\n",
12
+ "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\""
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "c20b0689-1347-47fb-a259-48ab1e1c1420",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "import sys\n",
23
+ "!{sys.executable} -m pip install datasets transformers accelerate==0.30.0 peft flash-attn"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 12,
29
+ "id": "a4b1142a-c51e-4794-94d7-805e70fb308d",
30
+ "metadata": {},
31
+ "outputs": [
32
+ {
33
+ "name": "stdout",
34
+ "output_type": "stream",
35
+ "text": [
36
+ "0.45.5\n"
37
+ ]
38
+ }
39
+ ],
40
+ "source": [
41
+ "import bitsandbytes as bab\n",
42
+ "print(bab.__version__)"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 1,
48
+ "id": "b3900693-64f2-4f48-803a-196de8f616f7",
49
+ "metadata": {},
50
+ "outputs": [
51
+ {
52
+ "name": "stdout",
53
+ "output_type": "stream",
54
+ "text": [
55
+ "NVIDIA L4\n",
56
+ "True\n"
57
+ ]
58
+ }
59
+ ],
60
+ "source": [
61
+ "import torch\n",
62
+ "print(torch.cuda.get_device_name(0))\n",
63
+ "print(torch.cuda.is_bf16_supported())"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 2,
69
+ "id": "de679cdb-4fb6-4bd6-be66-6379c4131312",
70
+ "metadata": {},
71
+ "outputs": [
72
+ {
73
+ "name": "stdout",
74
+ "output_type": "stream",
75
+ "text": [
76
+ "(8, 9)\n"
77
+ ]
78
+ }
79
+ ],
80
+ "source": [
81
+ "print(torch.cuda.get_device_capability())"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": 2,
87
+ "id": "b0b4c7df-aa9e-40ed-9827-e5436e33168c",
88
+ "metadata": {},
89
+ "outputs": [
90
+ {
91
+ "name": "stdout",
92
+ "output_type": "stream",
93
+ "text": [
94
+ "Requirement already satisfied: bitsandbytes in /usr/local/lib/python3.10/dist-packages (0.45.5)\n",
95
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from bitsandbytes) (1.26.4)\n",
96
+ "Requirement already satisfied: torch<3,>=2.0 in /usr/local/lib/python3.10/dist-packages (from bitsandbytes) (2.6.0)\n",
97
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (10.3.5.147)\n",
98
+ "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (4.12.2)\n",
99
+ "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (0.6.2)\n",
100
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.127)\n",
101
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (2024.9.0)\n",
102
+ "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (1.13.1)\n",
103
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (3.4.2)\n",
104
+ "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.5.8)\n",
105
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.127)\n",
106
+ "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (11.2.1.3)\n",
107
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.3.1.170)\n",
108
+ "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (9.1.0.70)\n",
109
+ "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (3.2.0)\n",
110
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (11.6.1.9)\n",
111
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (3.1.5)\n",
112
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.127)\n",
113
+ "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (2.21.5)\n",
114
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (3.17.0)\n",
115
+ "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.127)\n",
116
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.10/dist-packages (from torch<3,>=2.0->bitsandbytes) (12.4.127)\n",
117
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch<3,>=2.0->bitsandbytes) (1.3.0)\n",
118
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch<3,>=2.0->bitsandbytes) (3.0.2)\n",
119
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
120
+ "\u001b[0m"
121
+ ]
122
+ }
123
+ ],
124
+ "source": [
125
+ "import sys\n",
126
+ "!{sys.executable} -m pip install -U bitsandbytes "
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": 13,
132
+ "id": "08edce1c-0326-4688-a557-5f80b33cb077",
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": [
136
+ "device_map = (\n",
137
+ " int(os.environ.get(\"LOCAL_RANK\", -1))\n",
138
+ " if torch.distributed.is_available() and torch.distributed.is_initialized()\n",
139
+ " else \"auto\"\n",
140
+ ") # {\"\": 0}"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": 14,
146
+ "id": "ca502d8f-a0e2-421c-b615-5bd232236fb2",
147
+ "metadata": {},
148
+ "outputs": [
149
+ {
150
+ "name": "stdout",
151
+ "output_type": "stream",
152
+ "text": [
153
+ "auto\n"
154
+ ]
155
+ }
156
+ ],
157
+ "source": [
158
+ "print(device_map)"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": 1,
164
+ "id": "6d4c31bf-9087-40e2-8cf1-50d793a50cc4",
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": [
168
+ "MODEL = \"bigcode/starcoderbase-1b\" # Model checkpoint on the Hugging Face Hub\n",
169
+ "DATASET = \"smangrul/hf-stack-v1\" # Dataset on the Hugging Face Hub\n",
170
+ "DATA_COLUMN = \"content\" # Column name containing the code content\n",
171
+ "\n",
172
+ "SEQ_LENGTH = 2048 # Sequence length\n",
173
+ "\n",
174
+ "MAX_STEPS = 2000 # max_steps\n",
175
+ "BATCH_SIZE = 1 # batch_size\n",
176
+ "GR_ACC_STEPS = 1 # gradient_accumulation_steps\n",
177
+ "LR = 5e-4 # learning_rate\n",
178
+ "LR_SCHEDULER_TYPE = \"cosine\" # lr_scheduler_type\n",
179
+ "WEIGHT_DECAY = 0.01 # weight_decay\n",
180
+ "NUM_WARMUP_STEPS = 30 # num_warmup_steps\n",
181
+ "EVAL_FREQ = 100 # eval_freq\n",
182
+ "SAVE_FREQ = 100 # save_freq\n",
183
+ "LOG_FREQ = 25 # log_freq\n",
184
+ "OUTPUT_DIR = \"peft-starcoder-lora-a100\" # output_dir\n",
185
+ "BF16 = True # bf16\n",
186
+ "FP16 = False # no_fp16\n",
187
+ "\n",
188
+ "# FIM trasformations arguments\n",
189
+ "FIM_RATE = 0.5 # fim_rate\n",
190
+ "FIM_SPM_RATE = 0.5 # fim_spm_rate\n",
191
+ "\n",
192
+ "# LORA\n",
193
+ "LORA_R = 8 # lora_r\n",
194
+ "LORA_ALPHA = 32 # lora_alpha\n",
195
+ "LORA_DROPOUT = 0.0 # lora_dropout\n",
196
+ "LORA_TARGET_MODULES = \"c_proj,c_attn,q_attn,c_fc,c_proj\" # lora_target_modules\n",
197
+ "\n",
198
+ "# bitsandbytes config\n",
199
+ "USE_NESTED_QUANT = True # use_nested_quant\n",
200
+ "BNB_4BIT_COMPUTE_DTYPE = \"bfloat16\" # bnb_4bit_compute_dtype\n",
201
+ "\n",
202
+ "SEED = 0"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": 2,
208
+ "id": "2e60f3f7-c90f-41ec-91d3-98b6532e9446",
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "from huggingface_hub import login\n",
213
+ "from transformers import (\n",
214
+ " AutoModelForCausalLM,\n",
215
+ " AutoTokenizer,\n",
216
+ " Trainer,\n",
217
+ " TrainingArguments,\n",
218
+ " logging,\n",
219
+ " set_seed,\n",
220
+ " BitsAndBytesConfig,\n",
221
+ ")\n",
222
+ "\n",
223
+ "from datasets import load_dataset\n",
224
+ "import torch\n",
225
+ "from tqdm import tqdm\n",
226
+ "\n",
227
+ "#Prepare Data\n",
228
+ "dataset = load_dataset(\n",
229
+ " DATASET,\n",
230
+ " data_dir=\"data\",\n",
231
+ " split=\"train\",\n",
232
+ " streaming=True,\n",
233
+ ")\n",
234
+ "\n",
235
+ "valid_data = dataset.take(4000)\n",
236
+ "train_data = dataset.skip(4000)\n",
237
+ "train_data = train_data.shuffle(buffer_size=5000, seed=SEED)\n",
238
+ "\n",
239
+ "set_seed(SEED)"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": 5,
245
+ "id": "88201294-50c4-44b0-9209-1873feba0dae",
246
+ "metadata": {},
247
+ "outputs": [
248
+ {
249
+ "name": "stderr",
250
+ "output_type": "stream",
251
+ "text": [
252
+ "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
253
+ " warnings.warn(\n",
254
+ "100%|██████████| 400/400 [00:03<00:00, 109.96it/s]"
255
+ ]
256
+ },
257
+ {
258
+ "name": "stdout",
259
+ "output_type": "stream",
260
+ "text": [
261
+ "The character to token ratio of the dataset is: 2.43\n"
262
+ ]
263
+ },
264
+ {
265
+ "name": "stderr",
266
+ "output_type": "stream",
267
+ "text": [
268
+ "\n"
269
+ ]
270
+ }
271
+ ],
272
+ "source": [
273
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)\n",
274
+ "\n",
275
+ "\n",
276
+ "def chars_token_ratio(dataset, tokenizer, data_column, nb_examples=400):\n",
277
+ " \"\"\"\n",
278
+ " Estimate the average number of characters per token in the dataset.\n",
279
+ " \"\"\"\n",
280
+ "\n",
281
+ " total_characters, total_tokens = 0, 0\n",
282
+ " for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):\n",
283
+ " total_characters += len(example[data_column])\n",
284
+ " total_tokens += len(tokenizer(example[data_column]).tokens())\n",
285
+ "\n",
286
+ " return total_characters / total_tokens\n",
287
+ "\n",
288
+ "\n",
289
+ "chars_per_token = chars_token_ratio(train_data, tokenizer, DATA_COLUMN)\n",
290
+ "print(f\"The character to token ratio of the dataset is: {chars_per_token:.2f}\")"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": 6,
296
+ "id": "eee2c81e-f0df-435e-b4b9-e2a1d4f8c853",
297
+ "metadata": {},
298
+ "outputs": [],
299
+ "source": [
300
+ "import functools\n",
301
+ "import numpy as np\n",
302
+ "\n",
303
+ "\n",
304
+ "# Helper function to get token ids of the special tokens for prefix, suffix and middle for FIM transformations.\n",
305
+ "@functools.lru_cache(maxsize=None)\n",
306
+ "def get_fim_token_ids(tokenizer):\n",
307
+ " try:\n",
308
+ " FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD = tokenizer.special_tokens_map[\"additional_special_tokens\"][1:5]\n",
309
+ " suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = (\n",
310
+ " tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD]\n",
311
+ " )\n",
312
+ " except KeyError:\n",
313
+ " suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = None, None, None, None\n",
314
+ " return suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id\n",
315
+ "\n",
316
+ "\n",
317
+ "## Adapted from https://github.com/bigcode-project/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py\n",
318
+ "def permute(\n",
319
+ " sample,\n",
320
+ " np_rng,\n",
321
+ " suffix_tok_id,\n",
322
+ " prefix_tok_id,\n",
323
+ " middle_tok_id,\n",
324
+ " pad_tok_id,\n",
325
+ " fim_rate=0.5,\n",
326
+ " fim_spm_rate=0.5,\n",
327
+ " truncate_or_pad=False,\n",
328
+ "):\n",
329
+ " \"\"\"\n",
330
+ " Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes:\n",
331
+ " PSM and SPM (with a probability of fim_spm_rate).\n",
332
+ " \"\"\"\n",
333
+ "\n",
334
+ " # The if condition will trigger with the probability of fim_rate\n",
335
+ " # This means FIM transformations will apply to samples with a probability of fim_rate\n",
336
+ " if np_rng.binomial(1, fim_rate):\n",
337
+ "\n",
338
+ " # Split the sample into prefix, middle, and suffix, based on randomly generated indices stored in the boundaries list.\n",
339
+ " boundaries = list(np_rng.randint(low=0, high=len(sample) + 1, size=2))\n",
340
+ " boundaries.sort()\n",
341
+ "\n",
342
+ " prefix = np.array(sample[: boundaries[0]], dtype=np.int64)\n",
343
+ " middle = np.array(sample[boundaries[0] : boundaries[1]], dtype=np.int64)\n",
344
+ " suffix = np.array(sample[boundaries[1] :], dtype=np.int64)\n",
345
+ "\n",
346
+ " if truncate_or_pad:\n",
347
+ " # calculate the new total length of the sample, taking into account tokens indicating prefix, middle, and suffix\n",
348
+ " new_length = suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3\n",
349
+ " diff = new_length - len(sample)\n",
350
+ "\n",
351
+ " # trancate or pad if there's a difference in length between the new length and the original\n",
352
+ " if diff > 0:\n",
353
+ " if suffix.shape[0] <= diff:\n",
354
+ " return sample, np_rng\n",
355
+ " suffix = suffix[: suffix.shape[0] - diff]\n",
356
+ " elif diff < 0:\n",
357
+ " suffix = np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)])\n",
358
+ "\n",
359
+ " # With the probability of fim_spm_rateapply SPM variant of FIM transformations\n",
360
+ " # SPM: suffix, prefix, middle\n",
361
+ " if np_rng.binomial(1, fim_spm_rate):\n",
362
+ " new_sample = np.concatenate(\n",
363
+ " [\n",
364
+ " [prefix_tok_id, suffix_tok_id],\n",
365
+ " suffix,\n",
366
+ " [middle_tok_id],\n",
367
+ " prefix,\n",
368
+ " middle,\n",
369
+ " ]\n",
370
+ " )\n",
371
+ " # Otherwise, apply the PSM variant of FIM transformations\n",
372
+ " # PSM: prefix, suffix, middle\n",
373
+ " else:\n",
374
+ "\n",
375
+ " new_sample = np.concatenate(\n",
376
+ " [\n",
377
+ " [prefix_tok_id],\n",
378
+ " prefix,\n",
379
+ " [suffix_tok_id],\n",
380
+ " suffix,\n",
381
+ " [middle_tok_id],\n",
382
+ " middle,\n",
383
+ " ]\n",
384
+ " )\n",
385
+ " else:\n",
386
+ " # don't apply FIM transformations\n",
387
+ " new_sample = sample\n",
388
+ "\n",
389
+ " return list(new_sample), np_rng"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": 7,
395
+ "id": "3a9ebdef-4178-44af-9b35-12c8189c27f7",
396
+ "metadata": {},
397
+ "outputs": [],
398
+ "source": [
399
+ "from torch.utils.data import IterableDataset\n",
400
+ "from torch.utils.data.dataloader import DataLoader\n",
401
+ "import random\n",
402
+ "\n",
403
+ "# Create an Iterable dataset that returns constant-length chunks of tokens from a stream of text files.\n",
404
+ "\n",
405
+ "\n",
406
+ "class ConstantLengthDataset(IterableDataset):\n",
407
+ " \"\"\"\n",
408
+ " Iterable dataset that returns constant length chunks of tokens from stream of text files.\n",
409
+ " Args:\n",
410
+ " tokenizer (Tokenizer): The processor used for proccessing the data.\n",
411
+ " dataset (dataset.Dataset): Dataset with text files.\n",
412
+ " infinite (bool): If True the iterator is reset after dataset reaches end else stops.\n",
413
+ " seq_length (int): Length of token sequences to return.\n",
414
+ " num_of_sequences (int): Number of token sequences to keep in buffer.\n",
415
+ " chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.\n",
416
+ " fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM.\n",
417
+ " fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM.\n",
418
+ " seed (int): Seed for random number generator.\n",
419
+ " \"\"\"\n",
420
+ "\n",
421
+ " def __init__(\n",
422
+ " self,\n",
423
+ " tokenizer,\n",
424
+ " dataset,\n",
425
+ " infinite=False,\n",
426
+ " seq_length=1024,\n",
427
+ " num_of_sequences=1024,\n",
428
+ " chars_per_token=3.6,\n",
429
+ " content_field=\"content\",\n",
430
+ " fim_rate=0.5,\n",
431
+ " fim_spm_rate=0.5,\n",
432
+ " seed=0,\n",
433
+ " ):\n",
434
+ " self.tokenizer = tokenizer\n",
435
+ " self.concat_token_id = tokenizer.eos_token_id\n",
436
+ " self.dataset = dataset\n",
437
+ " self.seq_length = seq_length\n",
438
+ " self.infinite = infinite\n",
439
+ " self.current_size = 0\n",
440
+ " self.max_buffer_size = seq_length * chars_per_token * num_of_sequences\n",
441
+ " self.content_field = content_field\n",
442
+ " self.fim_rate = fim_rate\n",
443
+ " self.fim_spm_rate = fim_spm_rate\n",
444
+ " self.seed = seed\n",
445
+ "\n",
446
+ " (\n",
447
+ " self.suffix_tok_id,\n",
448
+ " self.prefix_tok_id,\n",
449
+ " self.middle_tok_id,\n",
450
+ " self.pad_tok_id,\n",
451
+ " ) = get_fim_token_ids(self.tokenizer)\n",
452
+ " if not self.suffix_tok_id and self.fim_rate > 0:\n",
453
+ " print(\"FIM is not supported by tokenizer, disabling FIM\")\n",
454
+ " self.fim_rate = 0\n",
455
+ "\n",
456
+ " def __iter__(self):\n",
457
+ " iterator = iter(self.dataset)\n",
458
+ " more_examples = True\n",
459
+ " np_rng = np.random.RandomState(seed=self.seed)\n",
460
+ " while more_examples:\n",
461
+ " buffer, buffer_len = [], 0\n",
462
+ " while True:\n",
463
+ " if buffer_len >= self.max_buffer_size:\n",
464
+ " break\n",
465
+ " try:\n",
466
+ " buffer.append(next(iterator)[self.content_field])\n",
467
+ " buffer_len += len(buffer[-1])\n",
468
+ " except StopIteration:\n",
469
+ " if self.infinite:\n",
470
+ " iterator = iter(self.dataset)\n",
471
+ " else:\n",
472
+ " more_examples = False\n",
473
+ " break\n",
474
+ " tokenized_inputs = self.tokenizer(buffer, truncation=False)[\"input_ids\"]\n",
475
+ " all_token_ids = []\n",
476
+ "\n",
477
+ " for tokenized_input in tokenized_inputs:\n",
478
+ " # optionally do FIM permutations\n",
479
+ " if self.fim_rate > 0:\n",
480
+ " tokenized_input, np_rng = permute(\n",
481
+ " tokenized_input,\n",
482
+ " np_rng,\n",
483
+ " self.suffix_tok_id,\n",
484
+ " self.prefix_tok_id,\n",
485
+ " self.middle_tok_id,\n",
486
+ " self.pad_tok_id,\n",
487
+ " fim_rate=self.fim_rate,\n",
488
+ " fim_spm_rate=self.fim_spm_rate,\n",
489
+ " truncate_or_pad=False,\n",
490
+ " )\n",
491
+ "\n",
492
+ " all_token_ids.extend(tokenized_input + [self.concat_token_id])\n",
493
+ " examples = []\n",
494
+ " for i in range(0, len(all_token_ids), self.seq_length):\n",
495
+ " input_ids = all_token_ids[i : i + self.seq_length]\n",
496
+ " if len(input_ids) == self.seq_length:\n",
497
+ " examples.append(input_ids)\n",
498
+ " random.shuffle(examples)\n",
499
+ " for example in examples:\n",
500
+ " self.current_size += 1\n",
501
+ " yield {\n",
502
+ " \"input_ids\": torch.LongTensor(example),\n",
503
+ " \"labels\": torch.LongTensor(example),\n",
504
+ " }\n",
505
+ "\n",
506
+ "\n",
507
+ "train_dataset = ConstantLengthDataset(\n",
508
+ " tokenizer,\n",
509
+ " train_data,\n",
510
+ " infinite=True,\n",
511
+ " seq_length=SEQ_LENGTH,\n",
512
+ " chars_per_token=chars_per_token,\n",
513
+ " content_field=DATA_COLUMN,\n",
514
+ " fim_rate=FIM_RATE,\n",
515
+ " fim_spm_rate=FIM_SPM_RATE,\n",
516
+ " seed=SEED,\n",
517
+ ")\n",
518
+ "eval_dataset = ConstantLengthDataset(\n",
519
+ " tokenizer,\n",
520
+ " valid_data,\n",
521
+ " infinite=False,\n",
522
+ " seq_length=SEQ_LENGTH,\n",
523
+ " chars_per_token=chars_per_token,\n",
524
+ " content_field=DATA_COLUMN,\n",
525
+ " fim_rate=FIM_RATE,\n",
526
+ " fim_spm_rate=FIM_SPM_RATE,\n",
527
+ " seed=SEED,\n",
528
+ ")"
529
+ ]
530
+ },
531
+ {
532
+ "cell_type": "code",
533
+ "execution_count": 8,
534
+ "id": "5021e686-2e1c-4608-9477-1f07adf2de35",
535
+ "metadata": {},
536
+ "outputs": [
537
+ {
538
+ "data": {
539
+ "application/vnd.jupyter.widget-view+json": {
540
+ "model_id": "e4ea2b63e63448938020858c06143ca5",
541
+ "version_major": 2,
542
+ "version_minor": 0
543
+ },
544
+ "text/plain": [
545
+ "model.safetensors: 0%| | 0.00/4.55G [00:00<?, ?B/s]"
546
+ ]
547
+ },
548
+ "metadata": {},
549
+ "output_type": "display_data"
550
+ },
551
+ {
552
+ "ename": "ImportError",
553
+ "evalue": "FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.",
554
+ "output_type": "error",
555
+ "traceback": [
556
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
557
+ "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
558
+ "Cell \u001b[0;32mIn[8], line 26\u001b[0m\n\u001b[1;32m 17\u001b[0m bnb_config \u001b[38;5;241m=\u001b[39m BitsAndBytesConfig(\n\u001b[1;32m 18\u001b[0m load_in_4bit\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 19\u001b[0m bnb_4bit_quant_type\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnf4\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 20\u001b[0m bnb_4bit_compute_dtype\u001b[38;5;241m=\u001b[39mcompute_dtype,\n\u001b[1;32m 21\u001b[0m bnb_4bit_use_double_quant\u001b[38;5;241m=\u001b[39mUSE_NESTED_QUANT,\n\u001b[1;32m 22\u001b[0m )\n\u001b[1;32m 24\u001b[0m device_map \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m0\u001b[39m}\n\u001b[0;32m---> 26\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mAutoModelForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 27\u001b[0m \u001b[43m \u001b[49m\u001b[43mMODEL\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 28\u001b[0m \u001b[43m \u001b[49m\u001b[43mquantization_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbnb_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 29\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 30\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# We will be using gradient checkpointing\u001b[39;49;00m\n\u001b[1;32m 31\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 32\u001b[0m \u001b[43m \u001b[49m\u001b[43mattn_implementation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mflash_attention_2\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 33\u001b[0m \u001b[43m)\u001b[49m\n",
559
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py:564\u001b[0m, in \u001b[0;36m_BaseAutoModelClass.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 562\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(config) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_model_mapping\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[1;32m 563\u001b[0m model_class \u001b[38;5;241m=\u001b[39m _get_model_class(config, \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_model_mapping)\n\u001b[0;32m--> 564\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodel_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 565\u001b[0m \u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhub_kwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 566\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 568\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnrecognized configuration class \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m for this kind of AutoModel: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 569\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModel type should be one of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(c\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mfor\u001b[39;00m\u001b[38;5;250m \u001b[39mc\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01min\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_model_mapping\u001b[38;5;241m.\u001b[39mkeys())\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 570\u001b[0m )\n",
560
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:3804\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3801\u001b[0m init_contexts\u001b[38;5;241m.\u001b[39mappend(init_empty_weights())\n\u001b[1;32m 3803\u001b[0m config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(config) \u001b[38;5;66;03m# We do not want to modify the config inplace in from_pretrained.\u001b[39;00m\n\u001b[0;32m-> 3804\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_autoset_attn_implementation\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3805\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muse_flash_attention_2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_flash_attention_2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch_dtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\n\u001b[1;32m 3806\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3808\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ContextManagers(init_contexts):\n\u001b[1;32m 3809\u001b[0m \u001b[38;5;66;03m# Let's make sure we don't run the init function of buffer modules\u001b[39;00m\n\u001b[1;32m 3810\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m(config, \u001b[38;5;241m*\u001b[39mmodel_args, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs)\n",
561
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:1534\u001b[0m, in \u001b[0;36mPreTrainedModel._autoset_attn_implementation\u001b[0;34m(cls, config, use_flash_attention_2, torch_dtype, device_map, check_device_map)\u001b[0m\n\u001b[1;32m 1531\u001b[0m config\u001b[38;5;241m.\u001b[39m_attn_implementation \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attention_2\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1533\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config\u001b[38;5;241m.\u001b[39m_attn_implementation \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attention_2\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m-> 1534\u001b[0m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_check_and_enable_flash_attn_2\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1535\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1536\u001b[0m \u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch_dtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1537\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1538\u001b[0m \u001b[43m \u001b[49m\u001b[43mhard_check_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1539\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_device_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheck_device_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1540\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1541\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m requested_attn_implementation \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msdpa\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available():\n\u001b[1;32m 1542\u001b[0m \u001b[38;5;66;03m# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.\u001b[39;00m\n\u001b[1;32m 1543\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_check_and_enable_sdpa(\n\u001b[1;32m 1544\u001b[0m config,\n\u001b[1;32m 1545\u001b[0m hard_check_only\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m requested_attn_implementation \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 1546\u001b[0m )\n",
562
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:1636\u001b[0m, in \u001b[0;36mPreTrainedModel._check_and_enable_flash_attn_2\u001b[0;34m(cls, config, torch_dtype, device_map, check_device_map, hard_check_only)\u001b[0m\n\u001b[1;32m 1633\u001b[0m install_message \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1635\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m importlib\u001b[38;5;241m.\u001b[39mutil\u001b[38;5;241m.\u001b[39mfind_spec(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attn\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1636\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpreface\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m the package flash_attn seems to be not installed. \u001b[39m\u001b[38;5;132;01m{\u001b[39;00minstall_message\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1638\u001b[0m flash_attention_version \u001b[38;5;241m=\u001b[39m version\u001b[38;5;241m.\u001b[39mparse(importlib\u001b[38;5;241m.\u001b[39mmetadata\u001b[38;5;241m.\u001b[39mversion(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflash_attn\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 1639\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mversion\u001b[38;5;241m.\u001b[39mcuda:\n",
563
+ "\u001b[0;31mImportError\u001b[0m: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
564
+ ]
565
+ }
566
+ ],
567
+ "source": [
568
+ "import torch\n",
569
+ "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
570
+ "from peft.tuners.lora import LoraLayer\n",
571
+ "\n",
572
+ "load_in_8bit = False\n",
573
+ "\n",
574
+ "# 4-bit quantization\n",
575
+ "compute_dtype = getattr(torch, BNB_4BIT_COMPUTE_DTYPE)\n",
576
+ "\n",
577
+ "bnb_config2 = BitsAndBytesConfig(\n",
578
+ " load_in_4bit=True,\n",
579
+ " bnb_4bit_use_double_quant=True,\n",
580
+ " bnb_4bit_quant_type=\"nf4\",\n",
581
+ " bnb_4bit_compute_dtype=torch.bfloat16\n",
582
+ ")\n",
583
+ "\n",
584
+ "bnb_config = BitsAndBytesConfig(\n",
585
+ " load_in_4bit=True,\n",
586
+ " bnb_4bit_quant_type=\"nf4\",\n",
587
+ " bnb_4bit_compute_dtype=compute_dtype,\n",
588
+ " bnb_4bit_use_double_quant=USE_NESTED_QUANT,\n",
589
+ ")\n",
590
+ "\n",
591
+ "device_map = {\"\": 0}\n",
592
+ "\n",
593
+ "model = AutoModelForCausalLM.from_pretrained(\n",
594
+ " MODEL,\n",
595
+ " quantization_config=bnb_config,\n",
596
+ " device_map=device_map,\n",
597
+ " use_cache=False, # We will be using gradient checkpointing\n",
598
+ " trust_remote_code=True,\n",
599
+ " attn_implementation=\"flash_attention_2\",\n",
600
+ ")"
601
+ ]
602
+ },
603
+ {
604
+ "cell_type": "code",
605
+ "execution_count": null,
606
+ "id": "7c0a9fdb-4087-4ec5-aefa-fc5f413252e4",
607
+ "metadata": {},
608
+ "outputs": [],
609
+ "source": [
610
+ "model = prepare_model_for_kbit_training(model)"
611
+ ]
612
+ },
613
+ {
614
+ "cell_type": "code",
615
+ "execution_count": null,
616
+ "id": "b5fd3293-76ef-48eb-8241-478e311ec947",
617
+ "metadata": {},
618
+ "outputs": [],
619
+ "source": [
620
+ "# Set up lora\n",
621
+ "peft_config = LoraConfig(\n",
622
+ " lora_alpha=LORA_ALPHA,\n",
623
+ " lora_dropout=LORA_DROPOUT,\n",
624
+ " r=LORA_R,\n",
625
+ " bias=\"none\",\n",
626
+ " task_type=\"CAUSAL_LM\",\n",
627
+ " target_modules=LORA_TARGET_MODULES.split(\",\"),\n",
628
+ ")\n",
629
+ "\n",
630
+ "model = get_peft_model(model, peft_config)\n",
631
+ "model.print_trainable_parameters()"
632
+ ]
633
+ },
634
+ {
635
+ "cell_type": "code",
636
+ "execution_count": null,
637
+ "id": "082c6a7b-db61-4800-8a94-419331b1fd22",
638
+ "metadata": {},
639
+ "outputs": [],
640
+ "source": [
641
+ "train_data.start_iteration = 0\n",
642
+ "\n",
643
+ "\n",
644
+ "training_args = TrainingArguments(\n",
645
+ " output_dir=f\"ernyou/{OUTPUT_DIR}\",\n",
646
+ " dataloader_drop_last=True,\n",
647
+ " eval_strategy=\"steps\",\n",
648
+ " save_strategy=\"steps\",\n",
649
+ " max_steps=MAX_STEPS,\n",
650
+ " eval_steps=EVAL_FREQ,\n",
651
+ " save_steps=SAVE_FREQ,\n",
652
+ " logging_steps=LOG_FREQ,\n",
653
+ " per_device_train_batch_size=BATCH_SIZE,\n",
654
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
655
+ " learning_rate=LR,\n",
656
+ " lr_scheduler_type=LR_SCHEDULER_TYPE,\n",
657
+ " warmup_steps=NUM_WARMUP_STEPS,\n",
658
+ " gradient_accumulation_steps=GR_ACC_STEPS,\n",
659
+ " gradient_checkpointing_kwargs={\"use_reentrant\": True},\n",
660
+ " gradient_checkpointing=True,\n",
661
+ " fp16=FP16,\n",
662
+ " bf16=BF16,\n",
663
+ " weight_decay=WEIGHT_DECAY,\n",
664
+ " push_to_hub=True,\n",
665
+ " include_tokens_per_second=True,\n",
666
+ ")"
667
+ ]
668
+ },
669
+ {
670
+ "cell_type": "code",
671
+ "execution_count": 10,
672
+ "id": "2c302ded-f017-433c-9622-55ecb45141bd",
673
+ "metadata": {},
674
+ "outputs": [
675
+ {
676
+ "name": "stdout",
677
+ "output_type": "stream",
678
+ "text": [
679
+ "1.3.0\n"
680
+ ]
681
+ }
682
+ ],
683
+ "source": [
684
+ "import accelerate as ac\n",
685
+ "print(ac.__version__)"
686
+ ]
687
+ },
688
+ {
689
+ "cell_type": "code",
690
+ "execution_count": null,
691
+ "id": "5318efa4-83da-41fa-9123-50b505e9a615",
692
+ "metadata": {},
693
+ "outputs": [],
694
+ "source": [
695
+ "import torch\n",
696
+ "torch.cuda.empty_cache()"
697
+ ]
698
+ },
699
+ {
700
+ "cell_type": "code",
701
+ "execution_count": null,
702
+ "id": "ceddceb0-8e1e-493f-99a1-b77c6b0c40b6",
703
+ "metadata": {},
704
+ "outputs": [],
705
+ "source": [
706
+ "trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset)\n",
707
+ "\n",
708
+ "print(\"Training...\")\n",
709
+ "trainer.train()"
710
+ ]
711
+ },
712
+ {
713
+ "cell_type": "code",
714
+ "execution_count": null,
715
+ "id": "76fb1530-5e16-4b4b-a30f-b689df3483f3",
716
+ "metadata": {},
717
+ "outputs": [],
718
+ "source": []
719
+ }
720
+ ],
721
+ "metadata": {
722
+ "kernelspec": {
723
+ "display_name": "Python 3 (ipykernel)",
724
+ "language": "python",
725
+ "name": "python3"
726
+ },
727
+ "language_info": {
728
+ "codemirror_mode": {
729
+ "name": "ipython",
730
+ "version": 3
731
+ },
732
+ "file_extension": ".py",
733
+ "mimetype": "text/x-python",
734
+ "name": "python",
735
+ "nbconvert_exporter": "python",
736
+ "pygments_lexer": "ipython3",
737
+ "version": "3.10.12"
738
+ }
739
+ },
740
+ "nbformat": 4,
741
+ "nbformat_minor": 5
742
+ }
train.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL = "bigcode/starcoderbase-1b" # Model checkpoint on the Hugging Face Hub
2
+ DATASET = "smangrul/hf-stack-v1" # Dataset on the Hugging Face Hub
3
+ DATA_COLUMN = "content" # Column name containing the code content
4
+
5
+ SEQ_LENGTH = 2048 # Sequence length
6
+
7
+ MAX_STEPS = 2000 # max_steps
8
+ BATCH_SIZE = 8 # batch_size
9
+ GR_ACC_STEPS = 1 # gradient_accumulation_steps
10
+ LR = 5e-4 # learning_rate
11
+ LR_SCHEDULER_TYPE = "cosine" # lr_scheduler_type
12
+ WEIGHT_DECAY = 0.01 # weight_decay
13
+ NUM_WARMUP_STEPS = 30 # num_warmup_steps
14
+ EVAL_FREQ = 100 # eval_freq
15
+ SAVE_FREQ = 100 # save_freq
16
+ LOG_FREQ = 25 # log_freq
17
+ OUTPUT_DIR = "peft-starcoder-lora-a100" # output_dir
18
+ BF16 = False # bf16
19
+ FP16 = False # no_fp16
20
+
21
+ # FIM trasformations arguments
22
+ FIM_RATE = 0.5 # fim_rate
23
+ FIM_SPM_RATE = 0.5 # fim_spm_rate
24
+
25
+ # LORA
26
+ LORA_R = 8 # lora_r
27
+ LORA_ALPHA = 32 # lora_alpha
28
+ LORA_DROPOUT = 0.0 # lora_dropout
29
+ LORA_TARGET_MODULES = "c_proj,c_attn,q_attn,c_fc,c_proj" # lora_target_modules
30
+
31
+ # bitsandbytes config
32
+ #USE_NESTED_QUANT = True # use_nested_quant
33
+ #BNB_4BIT_COMPUTE_DTYPE = "bfloat16" # bnb_4bit_compute_dtype
34
+
35
+ SEED = 0
36
+
37
+ from huggingface_hub import login
38
+ from transformers import (
39
+ AutoModelForCausalLM,
40
+ AutoTokenizer,
41
+ Trainer,
42
+ TrainingArguments,
43
+ logging,
44
+ set_seed,
45
+ BitsAndBytesConfig,
46
+ )
47
+
48
+ from datasets import load_dataset
49
+ import torch
50
+ from tqdm import tqdm
51
+
52
+ #Prepare Data
53
+ dataset = load_dataset(
54
+ DATASET,
55
+ data_dir="data",
56
+ split="train",
57
+ streaming=True,
58
+ )
59
+
60
+ valid_data = dataset.take(4000)
61
+ train_data = dataset.skip(4000)
62
+ train_data = train_data.shuffle(buffer_size=5000, seed=SEED)
63
+
64
+ set_seed(SEED)
65
+
66
+ tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
67
+
68
+
69
+ def chars_token_ratio(dataset, tokenizer, data_column, nb_examples=400):
70
+ """
71
+ Estimate the average number of characters per token in the dataset.
72
+ """
73
+
74
+ total_characters, total_tokens = 0, 0
75
+ for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
76
+ total_characters += len(example[data_column])
77
+ total_tokens += len(tokenizer(example[data_column]).tokens())
78
+
79
+ return total_characters / total_tokens
80
+
81
+
82
+ chars_per_token = chars_token_ratio(train_data, tokenizer, DATA_COLUMN)
83
+ print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
84
+
85
+ import functools
86
+ import numpy as np
87
+
88
+
89
+ # Helper function to get token ids of the special tokens for prefix, suffix and middle for FIM transformations.
90
+ @functools.lru_cache(maxsize=None)
91
+ def get_fim_token_ids(tokenizer):
92
+ try:
93
+ FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD = tokenizer.special_tokens_map["additional_special_tokens"][1:5]
94
+ suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = (
95
+ tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD]
96
+ )
97
+ except KeyError:
98
+ suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = None, None, None, None
99
+ return suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id
100
+
101
+
102
+ ## Adapted from https://github.com/bigcode-project/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py
103
+ def permute(
104
+ sample,
105
+ np_rng,
106
+ suffix_tok_id,
107
+ prefix_tok_id,
108
+ middle_tok_id,
109
+ pad_tok_id,
110
+ fim_rate=0.5,
111
+ fim_spm_rate=0.5,
112
+ truncate_or_pad=False,
113
+ ):
114
+ """
115
+ Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes:
116
+ PSM and SPM (with a probability of fim_spm_rate).
117
+ """
118
+
119
+ # The if condition will trigger with the probability of fim_rate
120
+ # This means FIM transformations will apply to samples with a probability of fim_rate
121
+ if np_rng.binomial(1, fim_rate):
122
+
123
+ # Split the sample into prefix, middle, and suffix, based on randomly generated indices stored in the boundaries list.
124
+ boundaries = list(np_rng.randint(low=0, high=len(sample) + 1, size=2))
125
+ boundaries.sort()
126
+
127
+ prefix = np.array(sample[: boundaries[0]], dtype=np.int64)
128
+ middle = np.array(sample[boundaries[0] : boundaries[1]], dtype=np.int64)
129
+ suffix = np.array(sample[boundaries[1] :], dtype=np.int64)
130
+
131
+ if truncate_or_pad:
132
+ # calculate the new total length of the sample, taking into account tokens indicating prefix, middle, and suffix
133
+ new_length = suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3
134
+ diff = new_length - len(sample)
135
+
136
+ # trancate or pad if there's a difference in length between the new length and the original
137
+ if diff > 0:
138
+ if suffix.shape[0] <= diff:
139
+ return sample, np_rng
140
+ suffix = suffix[: suffix.shape[0] - diff]
141
+ elif diff < 0:
142
+ suffix = np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)])
143
+
144
+ # With the probability of fim_spm_rateapply SPM variant of FIM transformations
145
+ # SPM: suffix, prefix, middle
146
+ if np_rng.binomial(1, fim_spm_rate):
147
+ new_sample = np.concatenate(
148
+ [
149
+ [prefix_tok_id, suffix_tok_id],
150
+ suffix,
151
+ [middle_tok_id],
152
+ prefix,
153
+ middle,
154
+ ]
155
+ )
156
+ # Otherwise, apply the PSM variant of FIM transformations
157
+ # PSM: prefix, suffix, middle
158
+ else:
159
+
160
+ new_sample = np.concatenate(
161
+ [
162
+ [prefix_tok_id],
163
+ prefix,
164
+ [suffix_tok_id],
165
+ suffix,
166
+ [middle_tok_id],
167
+ middle,
168
+ ]
169
+ )
170
+ else:
171
+ # don't apply FIM transformations
172
+ new_sample = sample
173
+
174
+ return list(new_sample), np_rng
175
+
176
+ from torch.utils.data import IterableDataset
177
+ from torch.utils.data.dataloader import DataLoader
178
+ import random
179
+
180
+ # Create an Iterable dataset that returns constant-length chunks of tokens from a stream of text files.
181
+
182
+
183
+ class ConstantLengthDataset(IterableDataset):
184
+ """
185
+ Iterable dataset that returns constant length chunks of tokens from stream of text files.
186
+ Args:
187
+ tokenizer (Tokenizer): The processor used for proccessing the data.
188
+ dataset (dataset.Dataset): Dataset with text files.
189
+ infinite (bool): If True the iterator is reset after dataset reaches end else stops.
190
+ seq_length (int): Length of token sequences to return.
191
+ num_of_sequences (int): Number of token sequences to keep in buffer.
192
+ chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
193
+ fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM.
194
+ fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM.
195
+ seed (int): Seed for random number generator.
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ tokenizer,
201
+ dataset,
202
+ infinite=False,
203
+ seq_length=1024,
204
+ num_of_sequences=1024,
205
+ chars_per_token=3.6,
206
+ content_field="content",
207
+ fim_rate=0.5,
208
+ fim_spm_rate=0.5,
209
+ seed=0,
210
+ ):
211
+ self.tokenizer = tokenizer
212
+ self.concat_token_id = tokenizer.eos_token_id
213
+ self.dataset = dataset
214
+ self.seq_length = seq_length
215
+ self.infinite = infinite
216
+ self.current_size = 0
217
+ self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
218
+ self.content_field = content_field
219
+ self.fim_rate = fim_rate
220
+ self.fim_spm_rate = fim_spm_rate
221
+ self.seed = seed
222
+
223
+ (
224
+ self.suffix_tok_id,
225
+ self.prefix_tok_id,
226
+ self.middle_tok_id,
227
+ self.pad_tok_id,
228
+ ) = get_fim_token_ids(self.tokenizer)
229
+ if not self.suffix_tok_id and self.fim_rate > 0:
230
+ print("FIM is not supported by tokenizer, disabling FIM")
231
+ self.fim_rate = 0
232
+
233
+ def __iter__(self):
234
+ iterator = iter(self.dataset)
235
+ more_examples = True
236
+ np_rng = np.random.RandomState(seed=self.seed)
237
+ while more_examples:
238
+ buffer, buffer_len = [], 0
239
+ while True:
240
+ if buffer_len >= self.max_buffer_size:
241
+ break
242
+ try:
243
+ buffer.append(next(iterator)[self.content_field])
244
+ buffer_len += len(buffer[-1])
245
+ except StopIteration:
246
+ if self.infinite:
247
+ iterator = iter(self.dataset)
248
+ else:
249
+ more_examples = False
250
+ break
251
+ tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
252
+ all_token_ids = []
253
+
254
+ for tokenized_input in tokenized_inputs:
255
+ # optionally do FIM permutations
256
+ if self.fim_rate > 0:
257
+ tokenized_input, np_rng = permute(
258
+ tokenized_input,
259
+ np_rng,
260
+ self.suffix_tok_id,
261
+ self.prefix_tok_id,
262
+ self.middle_tok_id,
263
+ self.pad_tok_id,
264
+ fim_rate=self.fim_rate,
265
+ fim_spm_rate=self.fim_spm_rate,
266
+ truncate_or_pad=False,
267
+ )
268
+
269
+ all_token_ids.extend(tokenized_input + [self.concat_token_id])
270
+ examples = []
271
+ for i in range(0, len(all_token_ids), self.seq_length):
272
+ input_ids = all_token_ids[i : i + self.seq_length]
273
+ if len(input_ids) == self.seq_length:
274
+ examples.append(input_ids)
275
+ random.shuffle(examples)
276
+ for example in examples:
277
+ self.current_size += 1
278
+ yield {
279
+ "input_ids": torch.LongTensor(example),
280
+ "labels": torch.LongTensor(example),
281
+ }
282
+
283
+
284
+ train_dataset = ConstantLengthDataset(
285
+ tokenizer,
286
+ train_data,
287
+ infinite=True,
288
+ seq_length=SEQ_LENGTH,
289
+ chars_per_token=chars_per_token,
290
+ content_field=DATA_COLUMN,
291
+ fim_rate=FIM_RATE,
292
+ fim_spm_rate=FIM_SPM_RATE,
293
+ seed=SEED,
294
+ )
295
+ eval_dataset = ConstantLengthDataset(
296
+ tokenizer,
297
+ valid_data,
298
+ infinite=False,
299
+ seq_length=SEQ_LENGTH,
300
+ chars_per_token=chars_per_token,
301
+ content_field=DATA_COLUMN,
302
+ fim_rate=FIM_RATE,
303
+ fim_spm_rate=FIM_SPM_RATE,
304
+ seed=SEED,
305
+ )
306
+
307
+ import torch
308
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
309
+ from peft.tuners.lora import LoraLayer
310
+
311
+ #load_in_8bit = False
312
+
313
+ # 4-bit quantization
314
+ #compute_dtype = getattr(torch, BNB_4BIT_COMPUTE_DTYPE)
315
+ #compute_float32 = torch.float32
316
+
317
+ #bnb_config = BitsAndBytesConfig(
318
+ # load_in_4bit=True,
319
+ # bnb_4bit_quant_type="nf4",
320
+ # bnb_4bit_compute_dtype=compute_float32,
321
+ # bnb_4bit_use_double_quant=USE_NESTED_QUANT,
322
+ # bnb_4bit_quant_storage=compute_float32
323
+ #)
324
+
325
+ #import os
326
+ #device_map = int(os.environ.get("LOCAL_RANK", -1))
327
+
328
+ model = AutoModelForCausalLM.from_pretrained(
329
+ MODEL,
330
+ #quantization_config=bnb_config,
331
+ device_map=None,
332
+ use_cache=False, # We will be using gradient checkpointing
333
+ trust_remote_code=True,
334
+ torch_dtype = torch.float32,
335
+ )
336
+
337
+ #from collections import Counter
338
+ #print(Counter(p.dtype for p in model.parameters()))
339
+
340
+ #model = prepare_model_for_kbit_training(model)
341
+
342
+ #from collections import Counter
343
+ #print("after prepare_model_for_kbit_training ", Counter(p.dtype for p in model.parameters()))
344
+
345
+ peft_config = LoraConfig(
346
+ lora_alpha=LORA_ALPHA,
347
+ lora_dropout=LORA_DROPOUT,
348
+ r=LORA_R,
349
+ bias="none",
350
+ task_type="CAUSAL_LM",
351
+ target_modules=LORA_TARGET_MODULES.split(","),
352
+ )
353
+
354
+ model = get_peft_model(model, peft_config)
355
+ model.print_trainable_parameters()
356
+ #from collections import Counter
357
+ #print("after get_peft_model ", Counter(p.dtype for p in model.parameters()))
358
+
359
+ train_data.start_iteration = 0
360
+
361
+ training_args = TrainingArguments(
362
+ output_dir=f"limernyou/{OUTPUT_DIR}",
363
+ dataloader_drop_last=True,
364
+ eval_strategy="steps",
365
+ save_strategy="steps",
366
+ max_steps=MAX_STEPS,
367
+ eval_steps=EVAL_FREQ,
368
+ save_steps=SAVE_FREQ,
369
+ logging_steps=LOG_FREQ,
370
+ per_device_train_batch_size=BATCH_SIZE,
371
+ per_device_eval_batch_size=BATCH_SIZE,
372
+ learning_rate=LR,
373
+ lr_scheduler_type=LR_SCHEDULER_TYPE,
374
+ warmup_steps=NUM_WARMUP_STEPS,
375
+ gradient_accumulation_steps=GR_ACC_STEPS,
376
+ gradient_checkpointing_kwargs={"use_reentrant": False},
377
+ gradient_checkpointing=True,
378
+ fp16=FP16,
379
+ bf16=BF16,
380
+ weight_decay=WEIGHT_DECAY,
381
+ push_to_hub=True,
382
+ include_tokens_per_second=True,
383
+ )
384
+
385
+ #from trl import SFTConfig, SFTTrainer
386
+ trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset)
387
+
388
+ print("Training...")
389
+ trainer.train()
390
+ trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
391
+ trainer.save_model()
392
+ trainer.push_to_hub()
use_model.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from peft import PeftModel
2
+ from huggingface_hub import login
3
+ import torch
4
+
5
+ from transformers import (
6
+ AutoModelForCausalLM,
7
+ AutoTokenizer,
8
+ Trainer,
9
+ TrainingArguments,
10
+ logging,
11
+ set_seed,
12
+ BitsAndBytesConfig,
13
+ )
14
+
15
+ MODEL = "bigcode/starcoderbase-1b" # Model checkpoint on the Hugging Face Hub
16
+
17
+ # load the original model first
18
+ print("Load Tokenizer")
19
+ tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
20
+ tokenizer.pad_token = tokenizer.eos_token
21
+ tokenizer.padding_side = "left"
22
+ print("Load Model")
23
+ base_model = AutoModelForCausalLM.from_pretrained(
24
+ MODEL,
25
+ quantization_config=None,
26
+ device_map=None,
27
+ trust_remote_code=True,
28
+ torch_dtype=torch.float32,
29
+ ).cuda()
30
+
31
+ # merge fine-tuned weights with the base model
32
+ peft_model_id = f"limernyou/starcoder-peft-conti"
33
+ model = PeftModel.from_pretrained(base_model, peft_model_id, adapter_name="personal_copilot")
34
+ #model.add_weighted_adapter(["personal_copilot"], [0.8], "best_personal_copilot")
35
+ #model.set_adapter("best_personal_copilot")
36
+
37
+ model.merge_and_unload()
38
+
39
+
40
+ #if not hasattr(model, "hf_device_map"):
41
+ # model.cuda()
42
+
43
+ def get_code_completion(prefix, suffix):
44
+ text = prompt = f"""<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>"""
45
+ base_model.eval()
46
+ outputs = base_model.generate(
47
+ input_ids=tokenizer(text, return_tensors="pt").input_ids.cuda(),
48
+ #attention_mask=tokenizer(prompt, return_tensors="pt").to("cuda")["attention_mask"],
49
+ max_new_tokens=128,
50
+ temperature=0.2,
51
+ top_k=50,
52
+ top_p=0.95,
53
+ do_sample=True,
54
+ repetition_penalty=1.0,
55
+ )
56
+ return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
57
+
58
+ def get_code_completion1(prefix, suffix):
59
+ prompt = prefix + suffix
60
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
61
+ model.eval()
62
+ with torch.no_grad():
63
+ outputs = model.generate(
64
+ input_ids=inputs["input_ids"],
65
+ attention_mask=inputs["attention_mask"],
66
+ max_new_tokens=128,
67
+ temperature=0.2,
68
+ top_k=50,
69
+ top_p=0.95,
70
+ do_sample=True,
71
+ repetition_penalty=1.0,
72
+ pad_token_id=tokenizer.eos_token_id,
73
+ )
74
+ output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
75
+ completion = output_text.split("<fim_middle>")[-1].strip()
76
+ return completion
77
+
78
+ prefix = """from peft import LoraConfig, TaskType, get_peft_model
79
+ from transformers import AutoModelForCausalLM
80
+ peft_config = LoraConfig(
81
+ """
82
+ suffix = """"""
83
+
84
+ print("Starcoder generating response")
85
+ #print(tokenizer.special_tokens_map)
86
+ print(get_code_completion(prefix, suffix))
87
+
88
+ print("Successful")