| import os |
| os.environ["WANDB_DISABLED"] = "true" |
| import sys |
| from typing import List |
| import argparse, logging |
|
|
| import fire |
| import torch |
| import torch.nn as nn |
| import bitsandbytes as bnb |
| from datasets import load_dataset, Dataset |
| import transformers |
| import json |
|
|
| assert ( |
| "LlamaTokenizer" in transformers._import_structure["models.llama"] |
| ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" |
| from transformers import LlamaForCausalLM, LlamaTokenizer |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import ( |
| prepare_model_for_int8_training, |
| LoraConfig, |
| get_peft_model, |
| get_peft_model_state_dict, |
| set_peft_model_state_dict, |
| ) |
|
|
| def get_logger(logger_name,output_dir): |
| logger = logging.getLogger(logger_name) |
| logger.setLevel(logging.DEBUG) |
| os.makedirs(output_dir, exist_ok=True) |
| file_handler = logging.FileHandler(os.path.join(output_dir,'log.txt'),mode='w') |
| file_handler.setLevel(logging.INFO) |
| file_handler.setFormatter( |
| logging.Formatter( |
| fmt='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S') |
| ) |
| logger.addHandler(file_handler) |
| console_handler = logging.StreamHandler() |
| console_handler.setLevel(logging.INFO) |
| console_handler.setFormatter( |
| logging.Formatter( |
| fmt='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S') |
| ) |
| logger.addHandler(console_handler) |
| return logger |
|
|
|
|
| def train( |
| train_on_inputs: bool = False, |
| group_by_length: bool = True, |
| resume_from_checkpoint: str = None, |
| ): |
|
|
| model_config = json.load(open(args.model_config_file)) |
| model_type = model_config['model_type'] |
| model_name_or_path = model_config['model_name_or_path'] |
| data_path = model_config['data_path'] |
| output_dir = model_config['output_dir'] |
| cutoff_len = model_config['cutoff_len'] |
|
|
| logger = get_logger("train", model_config['output_dir']) |
| logger.info("args.__dict__ : {}".format(args.__dict__)) |
| for key, value in model_config.items(): |
| logger.info("{} : {}".format(key, value)) |
| assert ( |
| model_name_or_path |
| ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" |
|
|
| gradient_accumulation_steps = model_config['batch_size'] // model_config['per_device_train_batch_size'] if "gradient_accumulation_steps" not in model_config else model_config['gradient_accumulation_steps'] |
|
|
| device_map = "auto" |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) |
| ddp = world_size != 1 |
| if ddp: |
| device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} |
| gradient_accumulation_steps = max(gradient_accumulation_steps // world_size, 1) |
|
|
| load_in_8bit = True if args.use_lora else False |
| if model_type.lower() == "llama": |
| model = LlamaForCausalLM.from_pretrained( |
| model_name_or_path, |
| load_in_8bit = load_in_8bit, |
| device_map=device_map, |
| ) |
| tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path) |
| else: |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name_or_path, |
| load_in_8bit = load_in_8bit, |
| device_map=device_map, |
| trust_remote_code=True |
| ) |
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) |
|
|
| tokenizer.pad_token_id = 0 |
| tokenizer.padding_side = "left" |
| |
| def tokenize(prompt): |
| result = tokenizer( |
| prompt, |
| truncation=True, |
| max_length=cutoff_len + 1, |
| padding=False, |
| ) |
| return { |
| "input_ids": result["input_ids"][:-1], |
| "attention_mask": result["attention_mask"][:-1], |
| } |
|
|
| def generate_and_tokenize_prompt(data_point): |
| return tokenize(data_point["input"]) |
|
|
| if args.use_lora: |
| model = prepare_model_for_int8_training(model) |
| lora_hyperparams = json.load(open(args.lora_hyperparams_file)) |
| for key, value in lora_hyperparams.items(): |
| logger.info("{} : {}".format(key, value)) |
| config = LoraConfig( |
| r=lora_hyperparams['lora_r'], |
| lora_alpha=lora_hyperparams['lora_alpha'], |
| target_modules=lora_hyperparams['lora_target_modules'] if model_config['model_type']=="Llama" else ["query_key_value"], |
| lora_dropout=lora_hyperparams['lora_dropout'], |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| print(config) |
| model = get_peft_model(model, config) |
|
|
| data = load_dataset("json", data_files=data_path) |
| print(data) |
|
|
| val_set_size = model_config['val_set_size'] |
| if val_set_size > 0: |
| val_set_size = min(val_set_size, int(len(data['train'])*model_config['val_set_rate'])) |
| train_val = data["train"].train_test_split( |
| test_size=val_set_size, shuffle=True, seed=42 |
| ) |
| train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt) |
| val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt) |
| else: |
| train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) |
| val_data = None |
|
|
| print("start train...") |
| trainer = transformers.Trainer( |
| model=model, |
| train_dataset=train_data, |
| eval_dataset=val_data, |
| args=transformers.TrainingArguments( |
| per_device_train_batch_size=model_config['per_device_train_batch_size'], |
| gradient_accumulation_steps=gradient_accumulation_steps, |
| warmup_steps=model_config['warmup_steps'], |
| num_train_epochs=model_config['num_epochs'], |
| learning_rate=model_config['learning_rate'], |
| fp16=True, |
| logging_steps=model_config['logging_steps'], |
| evaluation_strategy="steps" if val_set_size > 0 else "no", |
| save_strategy="no", |
| eval_steps=model_config["eval_steps"] if val_set_size > 0 else None, |
| save_steps=model_config["save_steps"], |
| output_dir=output_dir, |
| save_total_limit=3, |
| load_best_model_at_end=False, |
| ddp_find_unused_parameters=False if ddp else None, |
| deepspeed=args.deepspeed if not args.use_lora else None, |
| group_by_length=group_by_length |
| ), |
| data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), |
| ) |
|
|
| model.config.use_cache = False |
| if args.use_lora: |
| old_state_dict = model.state_dict |
| model.state_dict = ( |
| lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) |
| ).__get__(model, type(model)) |
|
|
| if torch.__version__ >= "2" and sys.platform != "win32": |
| model = torch.compile(model) |
| print("trainer.train") |
| trainer.train(resume_from_checkpoint = args.resume_from_checkpoint) |
| logger.info("Save checkpointing...") |
|
|
| model.save_pretrained(output_dir) |
| tokenizer.save_pretrained(output_dir) |
|
|
| print("\n If there's a warning about missing keys above when using lora to train, please disregard :)") |
| logger.info("Training succeeded") |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_config_file", type=str, required=True) |
| parser.add_argument("--deepspeed", type=str, help="deepspeed config") |
| parser.add_argument("--resume_from_checkpoint", action="store_true", default=False) |
| parser.add_argument("--lora_hyperparams_file", default="", type=str, help="Provide it when use_lora=True") |
| parser.add_argument("--use_lora", action="store_true", default=False, help="Use lora") |
| parser.add_argument("--local_rank", type=int) |
| args = parser.parse_args() |
| fire.Fire(train) |