CDKA / utils.py
rainstonee's picture
Upload 33 files
c5d3e8d verified
Raw
History Blame Contribute Delete
15.1 kB
import torch
import os
import typing as tp
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
EarlyStoppingCallback,
TrainerCallback,
TrainerControl,
TrainerState,
)
from transformers.trainer_utils import PredictionOutput
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from lora_plus import LoraPlusTrainingArguments, LoraPlusTrainer
from logTrainer import LogTrainer
import logging
import wandb
from peft import PeftModel
from data import load_alpaca
log = logging.getLogger(__name__)
def causalLMEncode(example, tokenizer, max_length=-1, ignore_masked_token=True):
is_list_input = isinstance(example["x"], list)
# Combine text and add EOS token
combined_text = (
[
x + " " + y + tokenizer.eos_token
for (x, y) in zip(example["x"], example["y"])
]
if is_list_input
else example["x"] + " " + example["y"] + tokenizer.eos_token
)
# Tokenize combined text
encodings = tokenizer(
combined_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length if max_length != -1 else None,
)
# Calculate input text length in tokens
input_text_length = (
[
len(tokenizer(example["x"][i], return_tensors="pt")["input_ids"][0])
for i in range(len(example["x"]))
]
if is_list_input
else len(tokenizer(example["x"], return_tensors="pt")["input_ids"][0])
)
if input_text_length[0] >= max_length:
log.warning(
f"Input text length >= max_length: {input_text_length} >= {max_length}. "
"Consider increasing max_length to avoid truncation."
)
# Create labels
labels = encodings["input_ids"].clone()
if is_list_input:
for i, l in enumerate(input_text_length):
labels[i, :l] = -100
else:
labels[0, :input_text_length] = -100
if ignore_masked_token:
labels[encodings["attention_mask"] == 0] = -100
# Update example dictionary
results = {
"input_ids": encodings["input_ids"],
"attention_mask": encodings["attention_mask"],
"labels": labels,
# "input_text_length": input_text_length,
}
return results
def SeqToSeqEncode(example, tokenizer, max_length=None, ignore_masked_token=False):
inputs = tokenizer(
example["x"],
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
outputs = tokenizer(
example["y"],
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
results = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"labels": outputs["input_ids"],
"decoder_attention_mask": outputs["attention_mask"],
}
if ignore_masked_token:
results["labels"][outputs["attention_mask"] == 0] = -100
return results
def preprocess_dataset(
dataset: tp.Union[Dataset, tp.List[tp.Tuple[str, str]], tp.List[tp.Dict[str, str]]]
) -> Dataset:
if isinstance(dataset, list) and isinstance(dataset[0], tuple):
dataset = Dataset.from_pandas(pd.DataFrame(dataset, columns=["x", "y"]))
elif isinstance(dataset, list) and isinstance(dataset[0], dict):
dataset = Dataset.from_dict(
{k: [dic[k] for dic in dataset] for k in dataset[0]}
)
elif isinstance(dataset, dict):
dataset = Dataset.from_dict(dataset)
elif isinstance(dataset, Dataset):
pass
else:
raise ValueError("Wrong format")
return dataset
def initialize_text_to_text_model(
model_name: str,
model_type: str,
bf16: bool,
use_peft: bool = True,
tokenizer: str = None,
flash_attention: bool = False,
):
if model_type == "CausalLM":
if flash_attention:
log.info("Using flash attention 2")
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.bfloat16 if bf16 else torch.float32,
device_map="auto" if use_peft else None,
attn_implementation="flash_attention_2",
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.bfloat16 if bf16 else torch.float32,
device_map="auto" if use_peft else None,
)
elif model_type == "ConditionalGeneration":
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16 if bf16 else torch.float32,
device_map="auto" if use_peft else None,
)
if tokenizer:
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
else:
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.eos_token is None:
tokenizer.add_special_tokens({"eos_token": "<|endoftext|>"})
model.resize_token_embeddings(len(tokenizer))
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
def compute_metrics(p: PredictionOutput):
predictions = p.predictions
label_ids = p.label_ids # shape (batch_size, seq_len)
if False:
# Hard metric: the model must output exactly the same as the target
# This should be the default evaluation metric for most tasks
pred = np.argmax(predictions[0], axis=-1)
num_correct = sum([np.array_equal(pred[i], label_ids[i]) for i in range(len(pred))])
accuracy = num_correct / len(pred)
else:
# Soft metric: we limit the output space to the target space
# i.e. the model classify the one with higher prob in positive and negative
# **Use it in cola and mrpc, because it's too hard for vanilla lora**
# Only suit for the binary classification with each label of 1 token
label_ids = label_ids[:, 0] # remove the eos token
unique_labels = np.unique(label_ids)
flipped_labels = np.ones_like(label_ids) * unique_labels.sum() - label_ids
predictions = predictions[0][:, 0, :] # remove the eos token # seq_len, tokens
label_prob = predictions[np.arange(len(predictions)), label_ids]
flipped_label_prob = predictions[np.arange(len(predictions)), flipped_labels]
num_correct = sum(label_prob > flipped_label_prob)
accuracy = num_correct / len(label_prob)
return {"accuracy": accuracy}
def transform_dataset(model_type, tokenizer, dataset, max_length):
if model_type == "CausalLM":
dataset.set_transform(lambda x: causalLMEncode(x, tokenizer, max_length))
elif model_type == "ConditionalGeneration":
dataset.set_transform(lambda x: SeqToSeqEncode(x, tokenizer, max_length))
else:
raise ValueError("Wrong model type")
return dataset
def train_text_to_text_model(
run_name: str,
train_dataset: Dataset,
valid_dataset: Dataset,
model: torch.nn.Module,
tokenizer: AutoTokenizer,
model_type: str,
per_device_batch_size: int = 1,
real_batch_size: int = 32,
max_length: int = None,
**kwargs,
) -> torch.nn.Module:
# Preprocess the dataset
train_dataset = preprocess_dataset(train_dataset)
valid_dataset = preprocess_dataset(valid_dataset)
assert (
real_batch_size % per_device_batch_size == 0
), "real_batch_size must be divisible by per_device_batch_size"
accu_step = real_batch_size // per_device_batch_size
train_dataset, valid_dataset = transform_dataset(
model_type, tokenizer, train_dataset, max_length
), transform_dataset(model_type, tokenizer, valid_dataset, max_length)
eval_steps = (
int(len(train_dataset) * kwargs.get("eval_epochs", 1)) // real_batch_size
)
# Special for lorqplus
use_loraplus = kwargs.get("use_loraplus", False)
TrainingArgumentsClass = (
LoraPlusTrainingArguments if use_loraplus else Seq2SeqTrainingArguments
)
TrainerClass = LoraPlusTrainer if use_loraplus else LogTrainer
if use_loraplus:
additional_kwargs = {
"loraplus_lr_ratio": kwargs.get("loraplus_lr_ratio", 1.0),
}
log.info(
f"Begin training using LoraPlusTrainer with additional kwargs: {additional_kwargs}"
)
else:
additional_kwargs = {}
log.info("Begin training using Seq2SeqTrainer")
# Training arguments
output_dir = f"./results/{run_name}/{kwargs.get('seed')}"
training_args = TrainingArgumentsClass(
output_dir=output_dir, # output directory
num_train_epochs=kwargs.get(
"num_train_epochs", 3
), # total number of training epochs
per_device_train_batch_size=per_device_batch_size,
per_device_eval_batch_size=per_device_batch_size,
gradient_accumulation_steps=accu_step,
logging_dir="./logs", # directory for storing logs
logging_steps=kwargs.get("logging_steps", 10), # when to print log
bf16=kwargs.get("bf16", False),
gradient_checkpointing=kwargs.get("gradient_checkpointing", False),
optim=kwargs.get("optim", "adamw_torch"),
evaluation_strategy="no",
eval_steps=eval_steps,
save_steps=eval_steps,
save_strategy="steps",
save_total_limit=1, # No need for saving
load_best_model_at_end=False,
metric_for_best_model=kwargs.get("metric_for_best_model", "eval_loss"),
greater_is_better=kwargs.get("greater_is_better", False),
do_eval=False,
learning_rate=kwargs.get("learning_rate", 5e-5),
remove_unused_columns=False, # We tokenize the dataset on the fly
eval_accumulation_steps=kwargs.get("eval_accumulation_steps", real_batch_size),
label_names=[
"labels"
], # Peft are not compatible with HF's default label names yet
# Ref: https://discuss.huggingface.co/t/eval-with-trainer-not-running-with-peft-lora-model/53286
# weight_decay = 0, # No weight decay
weight_decay = 5e-4,
warmup_ratio = 0.03,
lr_scheduler_type = "cosine",
seed = kwargs.get("seed", 42),
**additional_kwargs,
)
trainer = TrainerClass(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
compute_metrics=compute_metrics if "llama" not in run_name else None,
# callbacks=[
# EarlyStoppingCallback(
# early_stopping_patience=kwargs.get("early_stopping_patience", 1)
# ),
# ],
)
trainer.train()
# eval_results = trainer.evaluate()
# eval_accuracy = eval_results.get("eval_accuracy", 0)
# print(f"FINAL_EVAL_ACCURACY: {eval_accuracy:.4f}")
return model
def model_inference(
model: torch.nn.Module,
tokenizer: AutoTokenizer,
input_text: str,
model_type: str,
max_source_length: str = 768,
max_target_length: str = 256,
):
if model_type == "CausalLM":
inputs = tokenizer(
input_text + " ",
return_tensors="pt",
max_length=max_source_length,
truncation=True,
return_token_type_ids=False,
)
inputs = {k: v.cuda() for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
return_dict_in_generate=True,
output_scores=False,
max_new_tokens=max_target_length,
eos_token_id=tokenizer.eos_token_id,
top_p=0.95,
temperature=0.8,
)
pred_text = tokenizer.decode(
outputs.sequences[0][len(inputs["input_ids"][0]) :],
skip_special_tokens=True,
)
elif model_type == "ConditionalGeneration":
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=max_target_length)
pred_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return pred_text
def load_peft_model(model, peft_path: str):
peft_paths = [f"{peft_path}/{i}" for i in os.listdir(peft_path) if "merge" not in i]
for peft_path in peft_paths:
print(f"loading and merging from {peft_path}")
model: PeftModel = PeftModel.from_pretrained(model, peft_path)
model = model.merge_and_unload()
return model
def test_train():
# Example usage using emo dataset
dataset = load_dataset("emo")
label_map = {0: "others", 1: "happy", 2: "sad", 3: "angry"}
dataset = dataset.map(lambda e: {"x": e["text"], "y": label_map[e["label"]]})
train_set = dataset["train"]
test_set = dataset["test"]
model_name = "t5-small"
model_type = "ConditionalGeneration"
model, tokenizer = initialize_text_to_text_model(model_name, model_type)
model = train_text_to_text_model(
train_set,
test_set,
model,
tokenizer,
model_type,
num_train_epochs=1,
per_device_batch_size=64,
real_batch_size=64,
)
# Use the model for inference in the testset, print the first 10 examples
for i in range(10):
print("Input:", test_set[i]["x"])
print("Target:", test_set[i]["y"])
print(
"Prediction:",
model_inference(model, tokenizer, test_set[i]["x"], model_type),
)
print()
def test_llama_alpaca():
model_name = "meta-llama/Llama-2-7b-hf"
model_type = "CausalLM"
peft_path = "results/llama-alpaca_alpaca/gradient-ArB2r-adam/0"
model, tokenizer = initialize_text_to_text_model(model_name, model_type, True)
model = load_peft_model(model, peft_path)
_, _, test_set = load_alpaca()
for i in range(10):
print("Input:", test_set[i]["x"])
# print("Target:", test_set[i]["y"])
print(
"Prediction:",
model_inference(model, tokenizer, test_set[i]["x"], model_type),
)
print()
def merge_llama(peft_path):
model_name = "meta-llama/Llama-2-7b-hf"
model_type = "CausalLM"
model, tokenizer = initialize_text_to_text_model(model_name, model_type, True)
model = load_peft_model(model, peft_path)
print("Save model to ", os.path.join(peft_path, "merged_checkpoint"))
model.save_pretrained(os.path.join(peft_path, "merged_checkpoint"))
tokenizer.save_pretrained(os.path.join(peft_path, "merged_checkpoint"))
del model, tokenizer
if __name__ == "__main__":
merge_llama("results/llama-alpaca_alpaca/default/0")
# merge_llama("results/llama-alpaca_alpaca/gradient-ArB2r-adam/0")