| import torch |
| import os |
| import typing as tp |
| import numpy as np |
| import pandas as pd |
| from tqdm import tqdm |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| AutoModelForSeq2SeqLM, |
| Seq2SeqTrainingArguments, |
| Seq2SeqTrainer, |
| EarlyStoppingCallback, |
| TrainerCallback, |
| TrainerControl, |
| TrainerState, |
| ) |
| from transformers.trainer_utils import PredictionOutput |
| from datasets import Dataset, load_dataset |
| from torch.utils.data import DataLoader |
| from transformers import AdamW, get_linear_schedule_with_warmup |
| from lora_plus import LoraPlusTrainingArguments, LoraPlusTrainer |
| from logTrainer import LogTrainer |
| import logging |
| import wandb |
| from peft import PeftModel |
| from data import load_alpaca |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| def causalLMEncode(example, tokenizer, max_length=-1, ignore_masked_token=True): |
| is_list_input = isinstance(example["x"], list) |
| |
| combined_text = ( |
| [ |
| x + " " + y + tokenizer.eos_token |
| for (x, y) in zip(example["x"], example["y"]) |
| ] |
| if is_list_input |
| else example["x"] + " " + example["y"] + tokenizer.eos_token |
| ) |
| |
| encodings = tokenizer( |
| combined_text, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=max_length if max_length != -1 else None, |
| ) |
| |
| input_text_length = ( |
| [ |
| len(tokenizer(example["x"][i], return_tensors="pt")["input_ids"][0]) |
| for i in range(len(example["x"])) |
| ] |
| if is_list_input |
| else len(tokenizer(example["x"], return_tensors="pt")["input_ids"][0]) |
| ) |
| if input_text_length[0] >= max_length: |
| log.warning( |
| f"Input text length >= max_length: {input_text_length} >= {max_length}. " |
| "Consider increasing max_length to avoid truncation." |
| ) |
| |
| labels = encodings["input_ids"].clone() |
| if is_list_input: |
| for i, l in enumerate(input_text_length): |
| labels[i, :l] = -100 |
| else: |
| labels[0, :input_text_length] = -100 |
| if ignore_masked_token: |
| labels[encodings["attention_mask"] == 0] = -100 |
| |
| results = { |
| "input_ids": encodings["input_ids"], |
| "attention_mask": encodings["attention_mask"], |
| "labels": labels, |
| |
| } |
|
|
| return results |
|
|
|
|
| def SeqToSeqEncode(example, tokenizer, max_length=None, ignore_masked_token=False): |
| inputs = tokenizer( |
| example["x"], |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| ) |
| outputs = tokenizer( |
| example["y"], |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| ) |
|
|
| results = { |
| "input_ids": inputs["input_ids"], |
| "attention_mask": inputs["attention_mask"], |
| "labels": outputs["input_ids"], |
| "decoder_attention_mask": outputs["attention_mask"], |
| } |
|
|
| if ignore_masked_token: |
| results["labels"][outputs["attention_mask"] == 0] = -100 |
|
|
| return results |
|
|
|
|
| def preprocess_dataset( |
| dataset: tp.Union[Dataset, tp.List[tp.Tuple[str, str]], tp.List[tp.Dict[str, str]]] |
| ) -> Dataset: |
| if isinstance(dataset, list) and isinstance(dataset[0], tuple): |
| dataset = Dataset.from_pandas(pd.DataFrame(dataset, columns=["x", "y"])) |
| elif isinstance(dataset, list) and isinstance(dataset[0], dict): |
| dataset = Dataset.from_dict( |
| {k: [dic[k] for dic in dataset] for k in dataset[0]} |
| ) |
| elif isinstance(dataset, dict): |
| dataset = Dataset.from_dict(dataset) |
| elif isinstance(dataset, Dataset): |
| pass |
| else: |
| raise ValueError("Wrong format") |
| return dataset |
|
|
|
|
| def initialize_text_to_text_model( |
| model_name: str, |
| model_type: str, |
| bf16: bool, |
| use_peft: bool = True, |
| tokenizer: str = None, |
| flash_attention: bool = False, |
| ): |
| if model_type == "CausalLM": |
| if flash_attention: |
| log.info("Using flash attention 2") |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16 if bf16 else torch.float32, |
| device_map="auto" if use_peft else None, |
| attn_implementation="flash_attention_2", |
| ) |
| else: |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16 if bf16 else torch.float32, |
| device_map="auto" if use_peft else None, |
| ) |
| elif model_type == "ConditionalGeneration": |
| model = AutoModelForSeq2SeqLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.bfloat16 if bf16 else torch.float32, |
| device_map="auto" if use_peft else None, |
| ) |
| if tokenizer: |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer) |
| else: |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| if tokenizer.eos_token is None: |
| tokenizer.add_special_tokens({"eos_token": "<|endoftext|>"}) |
| model.resize_token_embeddings(len(tokenizer)) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| return model, tokenizer |
|
|
|
|
| def compute_metrics(p: PredictionOutput): |
| predictions = p.predictions |
| label_ids = p.label_ids |
| if False: |
| |
| |
| pred = np.argmax(predictions[0], axis=-1) |
| num_correct = sum([np.array_equal(pred[i], label_ids[i]) for i in range(len(pred))]) |
| accuracy = num_correct / len(pred) |
| else: |
| |
| |
| |
| |
| label_ids = label_ids[:, 0] |
| unique_labels = np.unique(label_ids) |
| flipped_labels = np.ones_like(label_ids) * unique_labels.sum() - label_ids |
| predictions = predictions[0][:, 0, :] |
| label_prob = predictions[np.arange(len(predictions)), label_ids] |
| flipped_label_prob = predictions[np.arange(len(predictions)), flipped_labels] |
| num_correct = sum(label_prob > flipped_label_prob) |
| accuracy = num_correct / len(label_prob) |
|
|
| return {"accuracy": accuracy} |
|
|
|
|
| def transform_dataset(model_type, tokenizer, dataset, max_length): |
| if model_type == "CausalLM": |
| dataset.set_transform(lambda x: causalLMEncode(x, tokenizer, max_length)) |
| elif model_type == "ConditionalGeneration": |
| dataset.set_transform(lambda x: SeqToSeqEncode(x, tokenizer, max_length)) |
| else: |
| raise ValueError("Wrong model type") |
| return dataset |
|
|
|
|
| def train_text_to_text_model( |
| run_name: str, |
| train_dataset: Dataset, |
| valid_dataset: Dataset, |
| model: torch.nn.Module, |
| tokenizer: AutoTokenizer, |
| model_type: str, |
| per_device_batch_size: int = 1, |
| real_batch_size: int = 32, |
| max_length: int = None, |
| **kwargs, |
| ) -> torch.nn.Module: |
| |
| train_dataset = preprocess_dataset(train_dataset) |
| valid_dataset = preprocess_dataset(valid_dataset) |
|
|
| assert ( |
| real_batch_size % per_device_batch_size == 0 |
| ), "real_batch_size must be divisible by per_device_batch_size" |
| accu_step = real_batch_size // per_device_batch_size |
|
|
| train_dataset, valid_dataset = transform_dataset( |
| model_type, tokenizer, train_dataset, max_length |
| ), transform_dataset(model_type, tokenizer, valid_dataset, max_length) |
|
|
| eval_steps = ( |
| int(len(train_dataset) * kwargs.get("eval_epochs", 1)) // real_batch_size |
| ) |
| |
| use_loraplus = kwargs.get("use_loraplus", False) |
| TrainingArgumentsClass = ( |
| LoraPlusTrainingArguments if use_loraplus else Seq2SeqTrainingArguments |
| ) |
| TrainerClass = LoraPlusTrainer if use_loraplus else LogTrainer |
| if use_loraplus: |
| additional_kwargs = { |
| "loraplus_lr_ratio": kwargs.get("loraplus_lr_ratio", 1.0), |
| } |
| log.info( |
| f"Begin training using LoraPlusTrainer with additional kwargs: {additional_kwargs}" |
| ) |
| else: |
| additional_kwargs = {} |
| log.info("Begin training using Seq2SeqTrainer") |
|
|
| |
| output_dir = f"./results/{run_name}/{kwargs.get('seed')}" |
| training_args = TrainingArgumentsClass( |
| output_dir=output_dir, |
| num_train_epochs=kwargs.get( |
| "num_train_epochs", 3 |
| ), |
| per_device_train_batch_size=per_device_batch_size, |
| per_device_eval_batch_size=per_device_batch_size, |
| gradient_accumulation_steps=accu_step, |
| logging_dir="./logs", |
| logging_steps=kwargs.get("logging_steps", 10), |
| bf16=kwargs.get("bf16", False), |
| gradient_checkpointing=kwargs.get("gradient_checkpointing", False), |
| optim=kwargs.get("optim", "adamw_torch"), |
| evaluation_strategy="no", |
| eval_steps=eval_steps, |
| save_steps=eval_steps, |
| save_strategy="steps", |
| save_total_limit=1, |
| load_best_model_at_end=False, |
| metric_for_best_model=kwargs.get("metric_for_best_model", "eval_loss"), |
| greater_is_better=kwargs.get("greater_is_better", False), |
| do_eval=False, |
| learning_rate=kwargs.get("learning_rate", 5e-5), |
| remove_unused_columns=False, |
| eval_accumulation_steps=kwargs.get("eval_accumulation_steps", real_batch_size), |
| label_names=[ |
| "labels" |
| ], |
| |
| |
| weight_decay = 5e-4, |
| warmup_ratio = 0.03, |
| lr_scheduler_type = "cosine", |
| seed = kwargs.get("seed", 42), |
| **additional_kwargs, |
| ) |
|
|
| trainer = TrainerClass( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=valid_dataset, |
| compute_metrics=compute_metrics if "llama" not in run_name else None, |
| |
| |
| |
| |
| |
| ) |
|
|
| trainer.train() |
| |
| |
| |
| return model |
|
|
|
|
| def model_inference( |
| model: torch.nn.Module, |
| tokenizer: AutoTokenizer, |
| input_text: str, |
| model_type: str, |
| max_source_length: str = 768, |
| max_target_length: str = 256, |
| ): |
| if model_type == "CausalLM": |
| inputs = tokenizer( |
| input_text + " ", |
| return_tensors="pt", |
| max_length=max_source_length, |
| truncation=True, |
| return_token_type_ids=False, |
| ) |
| inputs = {k: v.cuda() for k, v in inputs.items()} |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| return_dict_in_generate=True, |
| output_scores=False, |
| max_new_tokens=max_target_length, |
| eos_token_id=tokenizer.eos_token_id, |
| top_p=0.95, |
| temperature=0.8, |
| ) |
| pred_text = tokenizer.decode( |
| outputs.sequences[0][len(inputs["input_ids"][0]) :], |
| skip_special_tokens=True, |
| ) |
| elif model_type == "ConditionalGeneration": |
| inputs = tokenizer(input_text, return_tensors="pt").to("cuda") |
| with torch.no_grad(): |
| outputs = model.generate(**inputs, max_new_tokens=max_target_length) |
| pred_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| return pred_text |
|
|
|
|
| def load_peft_model(model, peft_path: str): |
| peft_paths = [f"{peft_path}/{i}" for i in os.listdir(peft_path) if "merge" not in i] |
| for peft_path in peft_paths: |
| print(f"loading and merging from {peft_path}") |
| model: PeftModel = PeftModel.from_pretrained(model, peft_path) |
| model = model.merge_and_unload() |
| return model |
|
|
|
|
| def test_train(): |
| |
| dataset = load_dataset("emo") |
| label_map = {0: "others", 1: "happy", 2: "sad", 3: "angry"} |
| dataset = dataset.map(lambda e: {"x": e["text"], "y": label_map[e["label"]]}) |
| train_set = dataset["train"] |
| test_set = dataset["test"] |
|
|
| model_name = "t5-small" |
| model_type = "ConditionalGeneration" |
| model, tokenizer = initialize_text_to_text_model(model_name, model_type) |
|
|
| model = train_text_to_text_model( |
| train_set, |
| test_set, |
| model, |
| tokenizer, |
| model_type, |
| num_train_epochs=1, |
| per_device_batch_size=64, |
| real_batch_size=64, |
| ) |
| |
| for i in range(10): |
| print("Input:", test_set[i]["x"]) |
| print("Target:", test_set[i]["y"]) |
| print( |
| "Prediction:", |
| model_inference(model, tokenizer, test_set[i]["x"], model_type), |
| ) |
| print() |
|
|
|
|
| def test_llama_alpaca(): |
| model_name = "meta-llama/Llama-2-7b-hf" |
| model_type = "CausalLM" |
| peft_path = "results/llama-alpaca_alpaca/gradient-ArB2r-adam/0" |
| model, tokenizer = initialize_text_to_text_model(model_name, model_type, True) |
| model = load_peft_model(model, peft_path) |
| _, _, test_set = load_alpaca() |
| for i in range(10): |
| print("Input:", test_set[i]["x"]) |
| |
| print( |
| "Prediction:", |
| model_inference(model, tokenizer, test_set[i]["x"], model_type), |
| ) |
| print() |
|
|
|
|
| def merge_llama(peft_path): |
| model_name = "meta-llama/Llama-2-7b-hf" |
| model_type = "CausalLM" |
| model, tokenizer = initialize_text_to_text_model(model_name, model_type, True) |
| model = load_peft_model(model, peft_path) |
| print("Save model to ", os.path.join(peft_path, "merged_checkpoint")) |
| model.save_pretrained(os.path.join(peft_path, "merged_checkpoint")) |
| tokenizer.save_pretrained(os.path.join(peft_path, "merged_checkpoint")) |
| del model, tokenizer |
|
|
|
|
| if __name__ == "__main__": |
| merge_llama("results/llama-alpaca_alpaca/default/0") |
| |
|
|