| | from ast import mod |
| | from calendar import c |
| | import os |
| | from turtle import up |
| | import torch |
| | |
| | |
| | from torch.utils.data import DataLoader, DistributedSampler |
| | import deepspeed |
| | import datasets |
| | import wandb |
| | from transformers import HfArgumentParser, AutoTokenizer,AutoModelForCausalLM |
| | from dataclasses import dataclass, field |
| | import logging |
| | import json |
| | from typing import Optional |
| | from functools import partial |
| | import time |
| | import regex as re |
| | from data.utils.llm_dataset import load_jsonl_dataset, collate_fn |
| | from model.llm.llm import RWKV7LM |
| | from train_scripts.train_functions import train_step,alter_emb_and_head |
| | logger = logging.getLogger(__name__) |
| | @dataclass |
| | class ScriptArguments: |
| | """Command line arguments for training script""" |
| | data_file: str = field( |
| | default=None, |
| | metadata={"help": "Path to training data file (JSONL format)"} |
| | ) |
| | model_name: str = field( |
| | default=None, |
| | metadata={"help": "Path or name of pretrained model"} |
| | ) |
| | output_dir: str = field( |
| | default=None, |
| | metadata={"help": "Directory to save trained model"} |
| | ) |
| | deepspeed_config: Optional[str] = field( |
| | default=None, |
| | metadata={"help": "Path to DeepSpeed config file"} |
| | ) |
| | num_epochs: int = field( |
| | default=3, |
| | metadata={"help": "Number of training epochs"} |
| | ) |
| | per_device_train_batch_size: int = field( |
| | default=1, |
| | metadata={"help": "Training batch size per device"} |
| | ) |
| | learning_rate: float = field( |
| | default=1e-5, |
| | metadata={"help": "Learning rate"} |
| | ) |
| | learning_rate_final: float = field( |
| | default=1e-6, |
| | metadata={"help": "Final learning rate at the end of training"} |
| | ) |
| | weight_decay: float = field( |
| | default=0.01, |
| | metadata={"help": "Weight decay"} |
| | ) |
| | warmup_steps: int = field( |
| | default=100, |
| | metadata={"help": "Number of warmup steps"} |
| | ) |
| | max_length: int = field( |
| | default=2048, |
| | metadata={"help": "Maximum length of input sequence"} |
| | ) |
| | logging_steps: int = field( |
| | default=10, |
| | metadata={"help": "Number of steps between logging"} |
| | ) |
| | save_steps: int = field( |
| | default=500, |
| | metadata={"help": "Number of steps between saving checkpoints"} |
| | ) |
| | local_rank: int = field( |
| | default=-1, |
| | metadata={"help": "Local rank for distributed training"} |
| | ) |
| | seed: int = field( |
| | default=42, |
| | metadata={"help": "Random seed"} |
| | ) |
| | wandb_project: str = field( |
| | default="grpo-training", |
| | metadata={"help": "Name of W&B project"} |
| | ) |
| | wandb_run_name: str = field( |
| | default=None, |
| | metadata={"help": "Name of W&B run"} |
| | ) |
| | gradient_checkpointing: bool = field( |
| | default=False, |
| | metadata={"help": "Use gradient checkpointing"} |
| | ) |
| | |
| | chunk_size : int = field( |
| | default=1024, |
| | metadata={"help": "chunk size"} |
| | ) |
| | |
| | batch_chunk_size: int = field( |
| | default=2, |
| | metadata={"help": "batch chunk size"} |
| | ) |
| | |
| | ds_stage: int = field( |
| | default=3, |
| | metadata={"help": "DeepSpeed stage"} |
| | ) |
| |
|
| | ds_param_offload : bool = field( |
| | default=True, |
| | metadata={"help": "DeepSpeed parameter offload"} |
| | ) |
| | |
| | ds_optimizer_offload : bool = field( |
| | default=True, |
| | metadata={"help": "DeepSpeed optimizer offload"} |
| | ) |
| | speech_token_size: int = field( |
| | default=6561, |
| | metadata={"help": "speech token size"} |
| | ) |
| | |
| | drop_out: float = field( |
| | default=0.02, |
| | metadata={"help": "drop out"} |
| | ) |
| | |
| | drop_prompt_ratio : float = field( |
| | default=0.5, |
| | metadata={"help": "drop prompt ratio"} |
| | ) |
| | |
| | ckpt_file: Optional[str] = field( |
| | default=None, |
| | metadata={"help": "Path to model checkpoint file"} |
| | ) |
| | |
| | def setup_logging(local_rank): |
| | """Configure logging""" |
| | if local_rank <= 0: |
| | logging.basicConfig( |
| | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| | datefmt="%m/%d/%Y %H:%M:%S", |
| | level=logging.INFO if 'LOG_LEVEL' not in os.environ else os.environ['LOG_LEVEL'], |
| | ) |
| |
|
| | def configure_optimizer(model, args): |
| | lr_1x = set() |
| | for n, p in model.named_parameters(): |
| | if not p.requires_grad: |
| | continue |
| | lr_1x.add(n) |
| |
|
| | lr_1x = sorted(list(lr_1x)) |
| | param_dict = {n: p for n, p in model.named_parameters()} |
| | |
| | optim_groups = [{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}] |
| |
|
| | if args.ds_optimizer_offload: |
| | from deepspeed.ops.adam import DeepSpeedCPUAdam |
| | optimizer = DeepSpeedCPUAdam(optim_groups, lr=args.learning_rate, betas=(0.9, 0.95), eps=1e-18, bias_correction=True, adamw_mode=True, amsgrad=False,weight_decay=args.weight_decay) |
| | else: |
| | from deepspeed.ops.adam import FusedAdam |
| | optimizer = FusedAdam(optim_groups, lr=args.learning_rate, betas=(0.9, 0.95), eps=1e-18, bias_correction=True, adam_w_mode=True, amsgrad=False, weight_decay=args.weight_decay) |
| | |
| | return optimizer |
| |
|
| | def save_checkpoint(model_engine, output_dir, epoch, step,logger): |
| | """Save model checkpoint""" |
| | if os.path.exists(output_dir): |
| | if model_engine.local_rank == 0: |
| | checkpoints = os.listdir(output_dir) |
| | |
| | checkpoints = [f for f in checkpoints if os.path.isdir(os.path.join(output_dir, f))] |
| | |
| | checkpoints.sort(key=lambda x: os.path.getctime(os.path.join(output_dir, x))) |
| | if len(checkpoints) > 2: |
| | print(f'deleting older checkpoints {checkpoints[0]}') |
| | import shutil |
| | shutil.rmtree(os.path.join(output_dir, checkpoints[0])) |
| | output_dir = f"{output_dir}/epoch_{epoch}_step_{step}" |
| | print(f'saving checkpoint to {output_dir}') |
| | if model_engine.local_rank == 0 and not os.path.exists(output_dir): |
| | os.makedirs(output_dir) |
| | |
| | model_engine.save_checkpoint(output_dir) |
| | def get_lr_scheduler(optimizer, total_steps, warmup_steps, learning_rate, learning_rate_final): |
| | """Create a linear learning rate scheduler that goes from learning_rate to learning_rate_final""" |
| | from transformers import get_linear_schedule_with_warmup |
| | |
| | def lr_lambda(current_step): |
| | if current_step < warmup_steps: |
| | |
| | return float(current_step) / float(max(1, warmup_steps)) |
| | else: |
| | |
| | progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps)) |
| | return max(learning_rate_final / learning_rate, 1.0 - progress * (1.0 - learning_rate_final / learning_rate)) |
| | |
| | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
| | def main(): |
| | |
| | parser = HfArgumentParser(ScriptArguments) |
| | args = parser.parse_args_into_dataclasses()[0] |
| | |
| | |
| | local_rank = int(os.getenv('LOCAL_RANK', '0')) |
| | world_size = int(os.getenv('WORLD_SIZE', '1')) |
| | is_main_process = local_rank == 0 |
| | device = torch.device(f'cuda:{local_rank}') |
| | |
| | |
| | setup_logging(local_rank) |
| | logger = logging.getLogger(__name__) |
| | |
| | if is_main_process: |
| | logger.info(f"Arguments: {args}") |
| |
|
| | |
| | torch.manual_seed(args.seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(args.seed) |
| | |
| | |
| | if is_main_process: |
| | logger.info(f"Loading tokenizer from {args.model_name}") |
| | tokenizer = AutoTokenizer.from_pretrained(args.model_name,trust_remote_code=True) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | special_tokens = { |
| | 'pad_token': '<|rwkv_tokenizer_end_of_text|>', |
| | 'additional_special_tokens': [ |
| | '<|endofprompt|>', |
| | '[breath]', '<strong>', '</strong>', '[noise]', |
| | '[laughter]', '[cough]', '[clucking]', '[accent]', |
| | '[quick_breath]', |
| | "<laughter>", "</laughter>", |
| | "[hissing]", "[sigh]", "[vocalized-noise]", |
| | "[lipsmack]", "[mn]" |
| | ] |
| | } |
| | tokenizer.add_special_tokens(special_tokens) |
| | vocab_size = tokenizer.vocab_size |
| | |
| | if is_main_process: |
| | logger.info(f"Loading dataset from {args.data_file}") |
| | dataset = load_jsonl_dataset(args.data_file,tokenizer) |
| | |
| | |
| | if is_main_process: |
| | logger.info(f"Creating DataLoader with batch size {args.per_device_train_batch_size}, world size {world_size}") |
| | sampler = DistributedSampler( |
| | dataset, |
| | num_replicas=world_size, |
| | rank=local_rank, |
| | shuffle=True, |
| | seed=args.seed |
| | ) |
| | |
| | data_collator = partial(collate_fn,tokenizer=tokenizer,max_length=args.max_length,pad_to_max_length=False,drop_prompt_audio_rate=0.5) |
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=args.per_device_train_batch_size, |
| | sampler=sampler, |
| | num_workers=4, |
| | pin_memory=True, |
| | drop_last=True, |
| | collate_fn=data_collator |
| | ) |
| | |
| | |
| | if args.deepspeed_config: |
| | if is_main_process: |
| | logger.info(f"Loading DeepSpeed config from {args.deepspeed_config}") |
| | with open(args.deepspeed_config, 'r') as f: |
| | ds_config = json.load(f) |
| | else: |
| | |
| | if is_main_process: |
| | logger.info("Using default DeepSpeed config") |
| | train_batch_size = args.per_device_train_batch_size * world_size* 1 |
| | ds_config = { |
| | "distributed_backend": "nccl", |
| | "train_batch_size": train_batch_size, |
| | "bf16": { |
| | "enabled": True |
| | }, |
| | "zero_optimization": { |
| | "stage": args.ds_stage, |
| | "stage3_max_live_parameters": 1e9, |
| | "stage3_max_reuse_distance": 1e9, |
| | "stage3_prefetch_bucket_size": 5e6, |
| | "memory_efficient_linear": True, |
| | "stage3_param_persistence_threshold": 1e4, |
| | "offload_param": { |
| | "device": "cpu", |
| | "pin_memory": True, |
| | "buffer_count": 4, |
| | "buffer_size": 1e8 |
| | }, |
| | "offload_optimizer": { |
| | "device": "cpu", |
| | "pin_memory": True, |
| | "buffer_count": 4 |
| | }, |
| | "allgather_partitions": True, |
| | "reduce_scatter": True, |
| | "reduce_bucket_size": 5e6, |
| | "overlap_comm": True, |
| | "contiguous_gradients": True |
| | }, |
| | "zero_force_ds_cpu_initialization": True, |
| | "gradient_checkpointing": args.gradient_checkpointing, |
| | "dump_state": True |
| | } |
| | |
| | |
| | if is_main_process: |
| | logger.info(f"Initializing model with DeepSpeed config") |
| | model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.bfloat16,trust_remote_code=True) |
| | model = alter_emb_and_head(model,vocab_size,args.speech_token_size) |
| | if args.gradient_checkpointing: |
| | model.gradient_checkpointing_enable() |
| | model.train() |
| | llm_input_size = model.config.hidden_size |
| | llm_output_size = model.config.hidden_size |
| | model = RWKV7LM(llm_input_size,llm_output_size,args.speech_token_size,model,None,drop_ratio=args.drop_out) |
| | if args.ckpt_file is not None: |
| | if is_main_process: |
| | logger.info(f"Loading checkpoint from {args.ckpt_file}") |
| | info = model.load_state_dict(torch.load(args.ckpt_file)) |
| | if is_main_process: |
| | logger.info(f"Loaded checkpoint info: {info}") |
| | model.train() |
| | if is_main_process: |
| | logger.info(f'Enable gradient checkpointing: {args.gradient_checkpointing}') |
| | for n,p in model.named_parameters(): |
| | p.requires_grad = True |
| | if is_main_process: |
| | for n,p in model.named_parameters(): |
| | print(f'{n} requires grad: {p.requires_grad}') |
| | logger.info(f'start configuring optimizer') |
| | optimizer = configure_optimizer(model, args) |
| | |
| | model_ds_config = ds_config.copy() |
| | if not args.ds_param_offload: |
| | del model_ds_config["zero_optimization"]["offload_param"] |
| | if not args.ds_optimizer_offload: |
| | del model_ds_config["zero_optimization"]["offload_optimizer"] |
| | |
| | |
| | total_steps = len(dataloader) * args.num_epochs |
| | |
| | |
| | lr_scheduler = get_lr_scheduler( |
| | optimizer, |
| | total_steps, |
| | args.warmup_steps, |
| | args.learning_rate, |
| | args.learning_rate_final |
| | ) |
| | model_engine, optimizer, _, scheduler = deepspeed.initialize( |
| | model=model, |
| | config=model_ds_config, |
| | model_parameters=model.parameters(), |
| | optimizer=optimizer, |
| | lr_scheduler=lr_scheduler |
| | ) |
| | if is_main_process: |
| | logger.info("Model initialized") |
| | del model |
| | if is_main_process: |
| | from tqdm import tqdm |
| | pbar = tqdm(total=len(dataloader)) |
| | wandb.init( |
| | project=args.wandb_project, |
| | name=args.wandb_run_name, |
| | config=vars(args) |
| | ) |
| | |
| | if os.path.exists(args.output_dir) and model_engine.local_rank == 0: |
| | import shutil |
| | shutil.rmtree(args.output_dir) |
| | total_loss = 0.0 |
| | total_steps = 0 |
| | total_acc = 0.0 |
| | all_tokens = 0 |
| | for epoch in range(args.num_epochs): |
| | if is_main_process: |
| | update_time = time.time() |
| | logger.info(f"Epoch {epoch} starts training") |
| | |
| | time_seed = int(time.time() * 1000) & 0xffffffff |
| | sampler.set_epoch(time_seed) |
| | |
| | for batch_idx,batch in enumerate(dataloader): |
| | if is_main_process: |
| | speech_token_shape = batch['speech_token'].shape |
| | text_token_shape = batch['text_token'].shape |
| | logger.debug(f'speech_token_shape: {speech_token_shape} text_token_shape: {text_token_shape} at batch_idx: {batch_idx}') |
| | skip = batch['skip'] |
| | if skip: |
| | all_length = batch['text_token'].shape[1] + batch['speech_token'].shape[1] |
| | if all_length > args.max_length: |
| | |
| | truncated_length = args.max_length - batch['speech_token'].shape[1] |
| | speech_token = batch['speech_token'] |
| | batch['speech_token'] = speech_token[:,:truncated_length] |
| | batch.pop('skip') |
| | output = train_step(model_engine,batch) |
| | loss = output['loss'] |
| | acc = output['acc'] |
| | if is_main_process: |
| | logger.debug(f'loss: {loss} acc: {acc}') |
| | |
| | |
| | is_nan_loss = torch.isnan(loss) or torch.isinf(loss) |
| | |
| | is_nan_loss_tensor = torch.tensor([1.0 if is_nan_loss else 0.0], device=model_engine.device) |
| | |
| | torch.distributed.all_reduce(is_nan_loss_tensor, op=torch.distributed.ReduceOp.MAX) |
| | is_nan_loss = bool(is_nan_loss_tensor.item()) |
| |
|
| | if is_nan_loss: |
| | |
| | |
| | logger.info(f"NaN loss detected at batch {batch_idx}, using safe zero loss instead") |
| | logger.info(f'batch data is {batch}') |
| | safe_loss = loss * 0.0 |
| | if is_main_process: |
| | logger.warning(f"NaN loss detected at batch {batch_idx}, using safe zero loss instead") |
| | wandb.log({ |
| | "nan_detected": 1, |
| | "epoch": epoch, |
| | "step": total_steps |
| | }) |
| | |
| | |
| | model_engine.backward(safe_loss) |
| | model_engine.step() |
| | else: |
| | |
| | model_engine.backward(loss) |
| | model_engine.step() |
| | |
| | if batch_idx % args.save_steps == 0 and batch_idx > 0: |
| | if args.ds_stage == 3 or args.ds_stage == 2: |
| | save_checkpoint(model_engine, args.output_dir, epoch, batch_idx,logger) |
| | |
| | if is_main_process: |
| | elapsed_time = time.time()-update_time |
| | total_loss += loss.item() |
| | total_acc += acc.item() |
| | total_steps += 1 |
| | |
| | |
| | avg_loss = total_loss / total_steps |
| | avg_acc = total_acc / total_steps |
| | tokens = (batch['speech_token'].shape[1]+batch['text_token'].shape[1])*args.per_device_train_batch_size*world_size |
| | all_tokens += tokens |
| | kts = tokens / elapsed_time / 1e3 |
| | |
| | current_lr = optimizer.param_groups[0]['lr'] |
| | wandb.log({ |
| | "loss": loss.item(), |
| | "avg_loss": avg_loss, |
| | "epoch": epoch, |
| | "step": total_steps, |
| | "acc": acc.item(), |
| | "avg_acc": avg_acc, |
| | "KT/s": kts, |
| | "Gtokens": all_tokens/1e9, |
| | "learning_rate": current_lr |
| | }) |
| | |
| | pbar.update(1) |
| | pbar.set_postfix({ |
| | 'loss': loss.item(), |
| | 'avg_loss': avg_loss, |
| | 'acc': acc.item(), |
| | 'avg_acc': avg_acc, |
| | 'lr': current_lr |
| | }) |
| | |
| | |
| | if args.ds_stage == 3 or args.ds_stage == 2: |
| | epoch_checkpoint_dir = f"{args.output_dir}/epoch_{epoch}" |
| | if not os.path.exists(epoch_checkpoint_dir): |
| | os.makedirs(epoch_checkpoint_dir) |
| | print(f'saving checkpoint to {epoch_checkpoint_dir}') |
| | model_engine.save_checkpoint(epoch_checkpoint_dir) |
| | |
| | if is_main_process: |
| | wandb.finish() |
| |
|
| | if __name__ == "__main__": |
| | main() |