File size: 17,996 Bytes
7140a44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 |
# coding=utf-8
# Copyright 2024 Sourab Mangrulkar. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Continued pre-training/fine-tuning of code LLMs for code autocompletion.
"""
import gc
import os
import random
import sys
from typing import Optional
from dataclasses import dataclass, field
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import IterableDataset
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
HfArgumentParser,
set_seed,
BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, replace_lora_weights_loftq
import fim
# Define and parse arguments.
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={
"help": "Path to pretrained model or model identifier from huggingface.co/models"
}
)
lora_alpha: Optional[int] = field(default=16)
lora_dropout: Optional[float] = field(default=0.1)
lora_r: Optional[int] = field(default=64)
lora_target_modules: Optional[str] = field(
default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
metadata={
"help": "comma separated list of target modules to apply LoRA layers to"
},
)
use_nested_quant: Optional[bool] = field(
default=False,
metadata={"help": "Activate nested quantization for 4bit base models"},
)
bnb_4bit_compute_dtype: Optional[str] = field(
default="float16",
metadata={"help": "Compute dtype for 4bit base models"},
)
bnb_4bit_quant_type: Optional[str] = field(
default="nf4",
metadata={"help": "Quantization type fp4 or nf4"},
)
use_flash_attn: Optional[bool] = field(
default=False,
metadata={"help": "Enables Flash attention for training."},
)
use_peft_lora: Optional[bool] = field(
default=False,
metadata={"help": "Enables PEFT LoRA for training."},
)
use_8bit_qunatization: Optional[bool] = field(
default=False,
metadata={"help": "Enables loading model in 8bit."},
)
use_4bit_quantization: Optional[bool] = field(
default=False,
metadata={"help": "Enables loading model in 4bit."},
)
use_reentrant: Optional[bool] = field(
default=False,
metadata={"help": "Gradient Checkpointing param. Refer the related docs"},
)
use_unsloth: Optional[bool] = field(
default=False,
metadata={"help": "Enables UnSloth for training."},
)
use_loftq: Optional[bool] = field(
default=False,
metadata={"help": "Enables LoftQ init for the LoRA adapters when using QLoRA."},
)
use_loftq_callback: Optional[bool] = field(
default=False,
metadata={"help": "Enables LoftQ callback comparing logits of base model to the ones from LoftQ init. Provides better init."},
)
@dataclass
class DataTrainingArguments:
dataset_name: Optional[str] = field(
default="smangrul/hug_stack",
metadata={"help": "The preference dataset to use."},
)
dataset_text_field: str = field(
default="text", metadata={"help": "Dataset field to use as input text."}
)
max_seq_length: Optional[int] = field(default=4096)
test_size: Optional[float] = field(default=0.1)
fim_rate: Optional[float] = field(default=0.5)
fim_spm_rate: Optional[float] = field(default=0.5)
splits: Optional[str] = field(
default="train",
metadata={"help": "Comma separate list of the splits to use from the dataset."},
)
def chars_token_ratio(dataset, tokenizer, data_column, nb_examples=400):
"""
Estimate the average number of characters per token in the dataset.
"""
total_characters, total_tokens = 0, 0
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
total_characters += len(example[data_column])
total_tokens += len(tokenizer(example[data_column]).tokens())
return total_characters / total_tokens
class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return.
num_of_sequences (int): Number of token sequences to keep in buffer.
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM.
fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM.
seed (int): Seed for random number generator.
"""
def __init__(
self,
tokenizer,
dataset,
infinite=False,
seq_length=1024,
num_of_sequences=1024,
chars_per_token=3.6,
content_field="content",
fim_rate=0.5,
fim_spm_rate=0.5,
seed=0,
shuffle=False,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.dataset = dataset
self.seq_length = seq_length
self.infinite = infinite
self.current_size = 0
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
self.content_field = content_field
self.fim_rate = fim_rate
self.fim_spm_rate = fim_spm_rate
self.seed = seed
self.shuffle = shuffle
(
self.bos_token_id,
self.suffix_tok_id,
self.prefix_tok_id,
self.middle_tok_id,
self.pad_tok_id,
) = fim.get_fim_token_ids(self.tokenizer)
if not self.suffix_tok_id and self.fim_rate > 0:
print("FIM is not supported by tokenizer, disabling FIM")
self.fim_rate = 0
def __iter__(self):
iterator = iter(self.dataset)
more_examples = True
np_rng = np.random.RandomState(seed=self.seed)
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.max_buffer_size:
break
try:
buffer.append(next(iterator)[self.content_field])
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
iterator = iter(self.dataset)
else:
more_examples = False
break
tokenized_inputs = self.tokenizer(
buffer, truncation=False, add_special_tokens=False
)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
# optionally do FIM permutations
if self.fim_rate > 0:
tokenized_input, np_rng = fim.permute(
tokenized_input,
np_rng,
self.suffix_tok_id,
self.prefix_tok_id,
self.middle_tok_id,
self.pad_tok_id,
fim_rate=self.fim_rate,
fim_spm_rate=self.fim_spm_rate,
truncate_or_pad=False,
bos_token_id=self.bos_token_id,
)
all_token_ids.extend(tokenized_input + [self.concat_token_id])
examples = []
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length:
examples.append(input_ids)
if self.shuffle:
random.shuffle(examples)
for example in examples:
self.current_size += 1
yield {
"input_ids": torch.LongTensor(example),
"labels": torch.LongTensor(example),
}
def create_datasets(tokenizer, args, seed):
dataset = load_dataset(args.dataset_name, split=args.splits)
dataset = dataset.train_test_split(
test_size=args.test_size, seed=seed, shuffle=True
)
train_data = dataset["train"]
valid_data = dataset["test"]
print(
f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}"
)
chars_per_token = chars_token_ratio(train_data, tokenizer, args.dataset_text_field)
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
train_dataset = ConstantLengthDataset(
tokenizer,
train_data,
infinite=True,
seq_length=args.max_seq_length,
chars_per_token=chars_per_token,
content_field=args.dataset_text_field,
fim_rate=args.fim_rate,
fim_spm_rate=args.fim_spm_rate,
seed=seed,
shuffle=True,
)
valid_dataset = ConstantLengthDataset(
tokenizer,
valid_data,
infinite=False,
seq_length=args.max_seq_length,
chars_per_token=chars_per_token,
content_field=args.dataset_text_field,
fim_rate=args.fim_rate,
fim_spm_rate=args.fim_spm_rate,
seed=seed,
)
print(f"A sample of valid dataset: {next(iter(valid_dataset))}")
return train_dataset, valid_dataset
def get_mae(x, y):
return (x - y).abs().mean()
def get_mse(x, y):
return torch.pow(x - y, 2).mean()
def error_report(x, y):
mae = get_mae(x, y)
mse = get_mse(x, y)
print(
f"Mean absolute error: {mae:>8.5f}\n"
f"Mean squared error: {mse:>8.5f}"
)
def loftq_init(model, tokenizer, train_dataset, max_seq_length, args):
if args.use_loftq_callback:
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
base_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=compute_dtype)
base_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
random_input_ids = torch.randint(0, len(train_dataset), size=(1,)).numpy().tolist()
random_inputs = [train_dataset[i]['content'] for i in random_input_ids]
random_inputs = tokenizer(random_inputs, return_tensors="pt", padding=True, truncation="max_length", max_length=max_seq_length)
logits_base = base_model(**random_inputs).logits
del base_model
gc.collect()
def loftq_callback(model, module_name):
"""Callable to replace weights with LoFTQ if the mse is lower than the current best one."""
global current_mse
logits = model(**random_inputs).logits
mse = get_mse(logits_base, logits)
if mse < current_mse:
current_mse = mse
print(f"MSE improved for module {module_name}")
return True
print(f"MSE did not improve for module {module_name}")
return False
replace_lora_weights_loftq(model, callback=loftq_callback)
logits_loftq_callback = model(**random_inputs).logits
error_report(logits_base, logits_loftq_callback)
else:
replace_lora_weights_loftq(model)
def create_and_prepare_model(args, data_args, training_args):
device_map = None
bnb_config = None
load_in_8bit = args.use_8bit_qunatization
load_in_4bit = args.use_4bit_quantization
if args.use_unsloth:
from unsloth import FastLanguageModel
if args.use_4bit_quantization:
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
load_in_4bit=args.use_4bit_quantization,
bnb_4bit_quant_type=args.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=args.use_nested_quant,
)
if compute_dtype == torch.float16 and args.use_4bit_quantization:
major, _ = torch.cuda.get_device_capability()
if major >= 8:
print("=" * 80)
print(
"Your GPU supports bfloat16, you can accelerate training with the argument --bf16"
)
print("=" * 80)
if args.use_4bit_quantization or args.use_8bit_qunatization:
device_map = (
int(os.environ.get("LOCAL_RANK", -1))
if torch.distributed.is_available() and torch.distributed.is_initialized()
else "auto"
) # {"": 0}
if args.use_unsloth:
# Load model
model, _ = FastLanguageModel.from_pretrained(
model_name=args.model_name_or_path,
max_seq_length=data_args.max_seq_length,
dtype=None,
load_in_4bit=load_in_4bit,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
load_in_8bit=load_in_8bit,
quantization_config=bnb_config,
device_map=device_map,
trust_remote_code=True,
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
)
if (
(args.use_4bit_quantization or args.use_8bit_qunatization)
and args.use_peft_lora
and not args.use_unsloth
):
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=training_args.gradient_checkpointing,
gradient_checkpointing_kwargs={"use_reentrant": model_args.use_reentrant},
)
if args.use_peft_lora and not args.use_unsloth:
peft_config = LoraConfig(
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=args.lora_target_modules.split(",")
if args.lora_target_modules != "all-linear"
else args.lora_target_modules,
)
model = get_peft_model(model, peft_config)
elif args.use_peft_lora and args.use_unsloth:
# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
model,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
target_modules=args.lora_target_modules.split(",")
if args.lora_target_modules != "all-linear"
else args.lora_target_modules,
use_gradient_checkpointing=training_args.gradient_checkpointing,
random_state=training_args.seed,
max_seq_length=data_args.max_seq_length,
)
return model
def main(model_args, data_args, training_args):
# Set seed for reproducibility
set_seed(training_args.seed)
# load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
# load the datasets
train_dataset, eval_dataset = create_datasets(
tokenizer, data_args, training_args.seed
)
train_dataset.start_iteration = 0
model = create_and_prepare_model(model_args, data_args, training_args)
# gradient ckpt
model.config.use_cache = not training_args.gradient_checkpointing
training_args.gradient_checkpointing = (
training_args.gradient_checkpointing and not model_args.use_unsloth
)
if training_args.gradient_checkpointing:
training_args.gradient_checkpointing_kwargs = {
"use_reentrant": model_args.use_reentrant
}
# trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.accelerator.print(f"{trainer.model}")
if model_args.use_peft_lora:
trainer.model.print_trainable_parameters()
# LoftQ initialization when using QLoRA
if model_args.use_4bit_quantization and model_args.use_loftq:
loftq_init(trainer.model, tokenizer, train_dataset, data_args.max_seq_length ,model_args)
# train
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
trainer.train(resume_from_checkpoint=checkpoint)
# saving final model
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model()
if __name__ == "__main__":
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, TrainingArguments)
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
main(model_args, data_args, training_args)
|