Spaces:
Running
Running
File size: 7,733 Bytes
15389e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import os
import time
import logging
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
Trainer,
TrainingArguments,
HfArgumentParser,
default_data_collator,
)
import wandb
from peft import LoraConfig, get_peft_model
from core.arguments import parse_args
from core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from core.datasets.gpt_dataset import GPTDatasetConfig, GPTDataset
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
_GLOBAL_TOKENIZER = None
def is_dataset_built_on_rank():
# return (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and mpu.get_tensor_model_parallel_rank() == 0
return True
def core_gpt_dataset_config_from_args(args):
return GPTDatasetConfig(
is_built_on_rank=is_dataset_built_on_rank,
random_seed=args.seed,
sequence_length=args.seq_length,
blend=args.data_path,
blend_per_split=[args.train_data_path, args.valid_data_path, args.test_data_path],
split=args.split,
path_to_cache=args.data_cache_path,
return_document_ids=args.retro_return_doc_ids,
reset_position_ids=args.reset_position_ids,
reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss,
eod_id=_GLOBAL_TOKENIZER.vocab['<EOD>'],
enable_shuffle=args.enable_shuffle,
)
def _build_tokenizer(args):
"""Initialize tokenizer."""
global _GLOBAL_TOKENIZER
logger.info(f"Loading tokenizer from {args.model_name_or_path}")
_GLOBAL_TOKENIZER = AutoTokenizer.from_pretrained(
args.model_name_or_path,
model_max_length=args.model_max_length,
padding_side="right")
return _GLOBAL_TOKENIZER
def build_train_valid_test_datasets(args):
"""Build the train, validation, and test datasets."""
# Number of train/valid/test samples.
if args.train_samples:
train_samples = args.train_samples
else:
train_samples = args.train_iters * args.global_batch_size
eval_iters = (args.train_iters // args.eval_interval + 1) * args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_samples,
eval_iters * args.global_batch_size,
test_iters * args.global_batch_size]
logger.info("> Building train, validation, and test datasets...")
try:
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
GPTDataset,
train_val_test_num_samples,
core_gpt_dataset_config_from_args(args)
).build()
logger.info("> Finished creating datasets")
return train_ds, valid_ds, test_ds
except Exception as e:
logger.error(f"Failed to build datasets: {e}")
raise
def _compile_dependencies():
"""Compile dataset C++ code."""
if torch.distributed.get_rank() == 0:
start_time = time.time()
logger.info("> Compiling dataset index builder...")
try:
from core.datasets.utils import compile_helpers
compile_helpers()
logger.info(
f">>> Done with dataset index builder. Compilation time: {time.time() - start_time:.3f} seconds"
)
except Exception as e:
logger.error(f"Failed to compile helpers: {e}")
raise
def setup_distributed_training():
"""Setup distributed training environment."""
try:
# Initialize process group for distributed training
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
if world_size > 1:
# Multi-GPU setup
torch.cuda.set_device(local_rank)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
logger.info(f"Distributed training initialized with world size: {world_size}, local rank: {local_rank}")
else:
# Single GPU setup
logger.info(f"Running on a single GPU (device {local_rank})")
torch.cuda.set_device(local_rank)
return local_rank
except Exception as e:
logger.error(f"Failed to setup distributed training: {e}")
raise
def create_and_configure_model(args):
"""Create and configure the model with LoRA."""
try:
if args.fp16:
assert not args.bf16
args.params_dtype = torch.half
if args.bf16:
assert not args.fp16
args.params_dtype = torch.bfloat16
logger.info(f"Loading base model from {args.model_name_or_path}")
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=args.params_dtype,
cache_dir=args.cache_dir
)
logger.info(f"Configuring LoRA with r={args.lora_r}, alpha={args.lora_alpha}")
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=args.lora_target_modules,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Number of trainable parameters: {trainable_params:,}")
return model
except Exception as e:
logger.error(f"Failed to create and configure model: {e}")
raise
def main():
# Setup distributed training
local_rank = setup_distributed_training()
# Compile dependencies after initializing distributed group
_compile_dependencies()
# Parse arguments
args = parse_args()
# Build tokenizer
_build_tokenizer(args)
# Build datasets
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(args)
# Create and configure model
model = create_and_configure_model(args)
# Setup training arguments
parser = HfArgumentParser(TrainingArguments)
training_args = parser.parse_dict(args.__dict__, allow_extra_keys=True)[0]
# Initialize wandb if specified
is_main_process = torch.distributed.get_rank() == 0
if args.report_to == "wandb" and is_main_process:
try:
wandb.init(
project=args.wandb_project or "YuE-finetune",
config=vars(args),
name=args.run_name
)
except Exception as e:
logger.warning(f"Failed to initialize wandb: {e}. Continuing without wandb.")
# Create trainer
trainer = Trainer(
model=model,
tokenizer=_GLOBAL_TOKENIZER,
args=training_args,
train_dataset=train_ds,
eval_dataset=valid_ds,
data_collator=default_data_collator,
)
# Start training
logger.info("Starting training...")
trainer.train()
# Save model and tokenizer
if is_main_process:
logger.info(f"Saving model to {training_args.output_dir}")
trainer.save_model(training_args.output_dir)
_GLOBAL_TOKENIZER.save_pretrained(training_args.output_dir)
logger.info("Training completed successfully")
if __name__ == "__main__":
try:
main()
except Exception as e:
logger.error(f"Training failed with error: {e}")
raise |