| |
| import os |
| from typing import List |
|
|
| import fire |
| import torch |
| import transformers |
| from datasets import load_dataset |
| from transformers import BertTokenizerFast |
|
|
| """ |
| Unused imports: |
| import torch.nn as nn |
| import bitsandbytes as bnb |
| """ |
|
|
| from peft import ( |
| LoraConfig, |
| get_peft_model, |
| get_peft_model_state_dict, |
| prepare_model_for_int8_training, |
| set_peft_model_state_dict, |
| ) |
| from transformers import LlamaForCausalLM, LlamaTokenizer |
|
|
| from utils.prompter import Prompter |
|
|
| def train( |
| |
| base_model: str = "", |
| data_path: str = "", |
| output_dir: str = "", |
| |
| batch_size: int = 128, |
| micro_batch_size: int = 4, |
| num_epochs: int = 3, |
| learning_rate: float = 3e-4, |
| cutoff_len: int = 256, |
| val_set_size: int = 2000, |
| |
| lora_r: int = 8, |
| lora_alpha: int = 16, |
| lora_dropout: float = 0.05, |
| lora_target_modules: List[str] = [ |
| "q_proj", |
| "v_proj", |
| ], |
| |
| train_on_inputs: bool = True, |
| add_eos_token: bool = False, |
| group_by_length: bool = False, |
| |
| wandb_project: str = "gama", |
| wandb_run_name: str = "", |
| wandb_watch: str = "false", |
| wandb_log_model: str = "false", |
| resume_from_checkpoint: str = None, |
| prompt_template_name: str = "alpaca_short", |
| save_steps: int = 100, |
| trainable_params = 'all' |
| ): |
| if int(os.environ.get("LOCAL_RANK", 0)) == 0: |
| print( |
| f"Training Alpaca-LoRA model with params:\n" |
| f"base_model: {base_model}\n" |
| f"data_path: {data_path}\n" |
| f"output_dir: {output_dir}\n" |
| f"batch_size: {batch_size}\n" |
| f"micro_batch_size: {micro_batch_size}\n" |
| f"num_epochs: {num_epochs}\n" |
| f"learning_rate: {learning_rate}\n" |
| f"cutoff_len: {cutoff_len}\n" |
| f"val_set_size: {val_set_size}\n" |
| f"lora_r: {lora_r}\n" |
| f"lora_alpha: {lora_alpha}\n" |
| f"lora_dropout: {lora_dropout}\n" |
| f"lora_target_modules: {lora_target_modules}\n" |
| f"train_on_inputs: {train_on_inputs}\n" |
| f"add_eos_token: {add_eos_token}\n" |
| f"group_by_length: {group_by_length}\n" |
| f"wandb_project: {wandb_project}\n" |
| f"wandb_run_name: {wandb_run_name}\n" |
| f"wandb_watch: {wandb_watch}\n" |
| f"wandb_log_model: {wandb_log_model}\n" |
| f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" |
| f"prompt template: {prompt_template_name}\n" |
| ) |
| assert ( |
| base_model |
| ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" |
|
|
| |
| if '/fs/nexus-projects/brain_project/acl_sk_24/GAMA/src/Llama-2-7b-chat-hf-qformer' not in base_model: |
| |
| |
| start_model = base_model |
| |
| base_model = '/fs/nexus-projects/brain_project/acl_sk_24/GAMA/src/Llama-2-7b-chat-hf-qformer' |
| print('Will load from {:s} later, for implementation purpose, first load from {:s}'.format(start_model, base_model)) |
| else: |
| start_model = None |
|
|
| gradient_accumulation_steps = batch_size // micro_batch_size |
| prompter = Prompter(prompt_template_name) |
|
|
| device_map = "auto" |
| world_size = int(torch.cuda.device_count()) |
| ddp = world_size != 1 |
| if ddp: |
| device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} |
| gradient_accumulation_steps = gradient_accumulation_steps // world_size |
|
|
| use_wandb = len(wandb_project) > 0 or ( |
| "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 |
| ) |
| |
| if len(wandb_project) > 0: |
| os.environ["WANDB_PROJECT"] = wandb_project |
| if len(wandb_watch) > 0: |
| os.environ["WANDB_WATCH"] = wandb_watch |
| if len(wandb_log_model) > 0: |
| os.environ["WANDB_LOG_MODEL"] = wandb_log_model |
| |
| |
|
|
| model = LlamaForCausalLM.from_pretrained( |
| base_model, |
| load_in_8bit=False, |
| |
| device_map=device_map, |
| ) |
|
|
| tokenizer = LlamaTokenizer.from_pretrained(base_model) |
| |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| tokenizer.padding_side = "left" |
|
|
| bert_tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-uncased") |
|
|
| def tokenize(prompt, add_eos_token=True): |
| result = tokenizer( |
| prompt, |
| truncation=True, |
| max_length=cutoff_len, |
| padding=False, |
| return_tensors=None, |
| ) |
| if ( |
| result["input_ids"][-1] != tokenizer.eos_token_id |
| and len(result["input_ids"]) < cutoff_len |
| and add_eos_token |
| ): |
| result["input_ids"].append(tokenizer.eos_token_id) |
| result["attention_mask"].append(1) |
|
|
| result["labels"] = result["input_ids"].copy() |
| return result |
| |
| def flatten_c(example): |
| if 'tokenized_full_prompt' in example: |
| example.update(example['tokenized_full_prompt']) |
| del example['tokenized_full_prompt'] |
| return example |
|
|
| def generate_and_tokenize_prompt(data_point): |
| |
| full_prompt = prompter.generate_prompt( |
| data_point["instruction"], |
| data_point["input"], |
| data_point["output"] |
| ) |
| tokenized_full_prompt = tokenize(full_prompt) |
| if not train_on_inputs: |
| user_prompt = prompter.generate_prompt( |
| data_point["instruction"], data_point["input"] |
| ) |
| tokenized_user_prompt = tokenize( |
| user_prompt, add_eos_token=add_eos_token |
| ) |
| user_prompt_len = len(tokenized_user_prompt["input_ids"]) |
|
|
| if add_eos_token: |
| user_prompt_len -= 1 |
|
|
| tokenized_full_prompt["labels"] = [ |
| -100 |
| ] * user_prompt_len + tokenized_full_prompt["labels"][ |
| user_prompt_len: |
| ] |
| tokenizer_input_bert = [] |
| |
| return tokenized_full_prompt |
| |
|
|
|
|
| config = LoraConfig( |
| r=lora_r, |
| lora_alpha=lora_alpha, |
| target_modules=lora_target_modules, |
| lora_dropout=lora_dropout, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| model = get_peft_model(model, config) |
|
|
| |
|
|
| |
| for name, param in model.named_parameters(): |
| if trainable_params == 'all': |
| if "audio" in name: |
| param.requires_grad = True |
| if trainable_params == 'proj': |
| if "audio_proj" in name: |
| param.requires_grad = True |
| if trainable_params == 'qformer': |
| if "audio_aggregator_layer_1" in name or "audio_aggregator_layer_2" in name or "audio_proj_qformer" in name or "audio_proj_audioenc" in name or "audio_proj_norm_qformer" in name or "audio_proj_norm_audioenc" in name: |
| param.requires_grad = True |
| if trainable_params == 'qformer_all': |
| if "audio_aggregator_layer_1" in name or "audio_aggregator_layer_2" in name or "audio_proj_qformer" in name or "audio_proj_audioenc" in name or "audio_proj_norm_qformer" in name or "audio_proj_norm_audioenc" in name or 'audio_encoder' in name or 'Qformer' in name or 'query_tokens' in name or 'qformer_proj_norm' in name: |
| param.requires_grad = True |
|
|
| if data_path.endswith(".json") or data_path.endswith(".jsonl"): |
| data = load_dataset("json", data_files=data_path) |
| else: |
| data = load_dataset(data_path) |
|
|
| if resume_from_checkpoint: |
| |
| checkpoint_name = os.path.join( |
| resume_from_checkpoint, "pytorch_model.bin" |
| ) |
| if not os.path.exists(checkpoint_name): |
| checkpoint_name = os.path.join( |
| resume_from_checkpoint, "adapter_model.bin" |
| ) |
| resume_from_checkpoint = ( |
| False |
| ) |
| |
| if os.path.exists(checkpoint_name): |
| state_dict = torch.load(checkpoint_name, map_location='cpu') |
| msg = model.load_state_dict(state_dict, strict=False) |
| else: |
| print(f"Checkpoint {checkpoint_name} not found") |
|
|
| |
| if start_model != None and (resume_from_checkpoint == None or resume_from_checkpoint == False): |
| state_dict = torch.load(start_model, map_location='cpu') |
| msg = model.load_state_dict(state_dict, strict=False) |
| |
|
|
| model.print_trainable_parameters() |
|
|
| if val_set_size > 0: |
| train_val = data["train"].train_test_split( |
| test_size=val_set_size, shuffle=True, seed=42 |
| ) |
| train_data = ( |
| train_val["train"].shuffle().map(generate_and_tokenize_prompt) |
| ) |
| val_data = ( |
| train_val["test"].shuffle().map(generate_and_tokenize_prompt) |
| ) |
| else: |
| train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) |
| val_data = None |
|
|
| |
| |
| if not ddp and torch.cuda.device_count() > 1: |
| |
| model.is_parallelizable = True |
| model.model_parallel = True |
| |
| from transformers import TrainerCallback |
|
|
| class PrecisionLoggingCallback(TrainerCallback): |
| def on_log(self, args, state, control, logs=None, **kwargs): |
| |
| if logs is not None and 'loss' in logs: |
| |
| high_precision_loss = format(logs['loss'], '.10f') |
| |
|
|
| trainer = transformers.Trainer( |
| model=model, |
| train_dataset=train_data, |
| eval_dataset=val_data, |
| callbacks=[PrecisionLoggingCallback], |
| args=transformers.TrainingArguments( |
| per_device_train_batch_size=micro_batch_size, |
| gradient_accumulation_steps=gradient_accumulation_steps, |
| warmup_steps=100, |
| num_train_epochs=num_epochs, |
| learning_rate=learning_rate, |
| bf16=True, |
| logging_steps=10, |
| optim="adamw_torch", |
| evaluation_strategy="no", |
| save_strategy="steps", |
| eval_steps=None, |
| save_steps=save_steps, |
| dataloader_num_workers=8, |
| output_dir=output_dir, |
| save_total_limit=50, |
| load_best_model_at_end=False, |
| ddp_find_unused_parameters=True, |
| group_by_length=group_by_length, |
| report_to="wandb" if use_wandb else None, |
| run_name=wandb_run_name if use_wandb else None, |
| remove_unused_columns=False ), |
| data_collator=transformers.DataCollatorForSeq2Seq( |
| tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True |
| ), |
| ) |
| model.config.use_cache = False |
| trainer.train(resume_from_checkpoint=resume_from_checkpoint) |
|
|
| model.save_pretrained(output_dir) |
|
|
| if __name__ == "__main__": |
| fire.Fire(train) |