| |
| |
| |
| |
| |
|
|
| import json |
| import os |
| import time |
| from datetime import timedelta |
| from collections import defaultdict |
| import dataclasses |
|
|
| import torch |
| from datasets import interleave_datasets, load_dataset |
| from torch.distributed.elastic.multiprocessing.errors import record |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
|
|
| import fla |
| from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss |
| from fla.ops.common.utils import prepare_position_ids |
| from flame.components.checkpoint import TrainState |
| from flame.config_manager import JobConfig |
| from flame.data import build_dataloader, shuffle |
| from flame.models.parallelize_fla import parallelize_fla |
| from flame.models.pipeline_fla import pipeline_fla |
| from flame.tools.utils import get_nparams_and_flops |
| from flame.utils.checkpoint import cleanup_local_checkpoints |
| from flame.utils.convert_dcp_to_hf import save_pretrained |
| from flame.utils.hf_utils import upload_checkpoint_to_hf |
| from datetime import datetime |
| from torchtitan.components.checkpoint import CheckpointManager |
| from torchtitan.components.ft import FTParallelDims, init_ft_manager |
| from torchtitan.components.loss import build_cross_entropy_loss |
| from torchtitan.components.lr_scheduler import build_lr_schedulers |
| from torchtitan.components.metrics import build_device_memory_monitor, build_metrics_processor, ensure_pp_loss_visible |
| from torchtitan.components.optimizer import build_optimizers |
| from torchtitan.distributed import ParallelDims |
| from torchtitan.distributed import utils as dist_utils |
| from torchtitan.protocols.model_converter import build_model_converters |
| from torchtitan.protocols.train_spec import TrainSpec, get_train_spec, register_train_spec |
| from torchtitan.tools import utils |
| from torchtitan.tools.logging import init_logger, logger |
| from torchtitan.tools.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling |
|
|
| from dotenv import load_dotenv |
| load_dotenv() |
|
|
| import wandb |
| wandb.login(key=os.environ["WANDB_API_KEY"]) |
|
|
| import huggingface_hub |
| huggingface_hub.login(token=os.environ["HF_TOKEN"]) |
|
|
|
|
| def build_tokenizer(job_config: JobConfig) -> AutoTokenizer: |
| return AutoTokenizer.from_pretrained(job_config.model.tokenizer_path) |
|
|
|
|
| register_train_spec( |
| TrainSpec( |
| name="fla", |
| cls=AutoModelForCausalLM, |
| config=AutoConfig, |
| parallelize_fn=parallelize_fla, |
| pipelining_fn=pipeline_fla, |
| build_optimizers_fn=build_optimizers, |
| build_lr_schedulers_fn=build_lr_schedulers, |
| build_dataloader_fn=build_dataloader, |
| build_tokenizer_fn=build_tokenizer, |
| build_loss_fn=build_cross_entropy_loss, |
| ) |
| ) |
|
|
|
|
| |
| @record |
| def main(job_config: JobConfig): |
| logger.info(f"Starting job: {job_config.job.description}") |
|
|
| if job_config.experimental.custom_model_path: |
| utils.import_module_from_path(job_config.experimental.custom_model_path) |
|
|
| |
| color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color |
|
|
| if job_config.job.print_args: |
| logger.info( |
| f"{color.green}{json.dumps(job_config.to_dict(), indent=2, sort_keys=True)}{color.reset}" |
| ) |
|
|
| |
| gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) |
|
|
| device_module, device_type = utils.device_module, utils.device_type |
| device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") |
| |
| device_module.set_device(device) |
| ft_manager = init_ft_manager(job_config) |
|
|
| run_specific_repo_id = None |
| if getattr(job_config.checkpoint, "hf_upload_enabled", False): |
| hf_repo_base = getattr(job_config.checkpoint, "hf_repo_base_name", None) |
| if hf_repo_base: |
| |
| timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") |
| run_specific_repo_id = f"{hf_repo_base}-{timestamp}" |
| logger.info(f"Target Hugging Face repository for this run: {run_specific_repo_id}") |
| else: |
| logger.warning("HF Hub upload enabled, but 'checkpoint.hf_repo_base_name' is not set.") |
| |
| job_config.checkpoint.hf_upload_enabled = False |
|
|
| |
| world_size = int(os.environ["WORLD_SIZE"]) |
| if not ft_manager.enabled: |
| parallel_dims = ParallelDims( |
| dp_shard=job_config.training.data_parallel_shard_degree, |
| dp_replicate=job_config.training.data_parallel_replicate_degree, |
| cp=job_config.experimental.context_parallel_degree, |
| tp=job_config.training.tensor_parallel_degree, |
| pp=job_config.experimental.pipeline_parallel_degree, |
| world_size=world_size, |
| enable_loss_parallel=not job_config.training.disable_loss_parallel, |
| ) |
| else: |
| parallel_dims = FTParallelDims( |
| dp_shard=job_config.training.data_parallel_shard_degree, |
| dp_replicate=job_config.training.data_parallel_replicate_degree, |
| cp=job_config.experimental.context_parallel_degree, |
| tp=job_config.training.tensor_parallel_degree, |
| pp=job_config.experimental.pipeline_parallel_degree, |
| world_size=world_size, |
| enable_loss_parallel=not job_config.training.disable_loss_parallel, |
| ft_manager=ft_manager, |
| ) |
| dist_utils.init_distributed(job_config) |
| |
| device_memory_monitor = build_device_memory_monitor() |
| gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) |
| logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") |
|
|
| |
| world_mesh = parallel_dims.build_mesh(device_type=device_type) |
| if parallel_dims.dp_enabled: |
| dp_mesh = world_mesh["dp"] |
| dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() |
| else: |
| dp_degree, dp_rank = 1, 0 |
|
|
| if parallel_dims.pp_enabled: |
| raise NotImplementedError( |
| "Pipeline parallelism is not supported in this version" |
| ) |
| """ |
| ! TODO[flame]: We need to fix the pipeline parallelism for flame |
| [x] Match the key of models' components with the actual naming |
| [ ] Fix the post-init and tie-embedding for pipeline parallelism, HF's transformer automatically |
| forces to tie if head is None, we need to handle this case |
| [ ] |
| """ |
| pp_mesh = world_mesh["pp"] |
|
|
| |
| dist_utils.set_determinism( |
| world_mesh, device, job_config.training.seed, job_config.training.deterministic |
| ) |
| train_spec = get_train_spec(job_config.model.name) |
|
|
| logger.info("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained( |
| job_config.model.tokenizer_path, |
| trust_remote_code=True, |
| model_max_length=int(1e10), |
| ) |
| tokenizer.pad_token = "<unk>" |
| |
| |
| |
| |
| |
| |
| |
| logger.info(f"{tokenizer}") |
| logger.info( |
| f"Loading dataset {job_config.training.dataset}" |
| f":{job_config.training.dataset_name}" |
| if job_config.training.dataset_name is not None |
| else "" |
| ) |
|
|
| min_num_shards = dp_degree * job_config.training.num_workers |
| if len(job_config.training.dataset.split(",")) == 1: |
| dataset = load_dataset( |
| path=job_config.training.dataset, |
| name=getattr(job_config.training, "dataset_name", None), |
| data_dir=getattr(job_config.training, "data_dir", None), |
| data_files=getattr(job_config.training, "data_files", None), |
| split=job_config.training.dataset_split or "train", |
| trust_remote_code=True, |
| streaming=job_config.training.streaming, |
| num_proc=( |
| job_config.training.num_workers |
| if not job_config.training.streaming |
| else None |
| ), |
| ) |
| logger.info(f"{dataset}") |
|
|
| logger.info(f"Shuffling the dataset with seed {job_config.training.seed}") |
| if not job_config.training.streaming: |
| |
| dataset = dataset.shuffle( |
| seed=job_config.training.seed |
| ).to_iterable_dataset(num_shards=min_num_shards) |
| else: |
| if dataset.num_shards < min_num_shards: |
| logger.warning( |
| f"{color.red}" |
| f"Dataset {job_config.training.dataset} has insufficient shards ({dataset.num_shards}). " |
| f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × " |
| f"{job_config.training.num_workers} dataloader workers. " |
| f"Disabling the streaming mode and resharding dataset to {min_num_shards} shards." |
| f"{color.reset}" |
| ) |
| dataset = ( |
| load_dataset( |
| path=job_config.training.dataset, |
| name=getattr(job_config.training, "dataset_name", None), |
| data_dir=getattr(job_config.training, "data_dir", None), |
| data_files=getattr(job_config.training, "data_files", None), |
| split=job_config.training.dataset_split or "train", |
| trust_remote_code=True, |
| streaming=False, |
| num_proc=job_config.training.num_workers, |
| ) |
| .shuffle(seed=job_config.training.seed) |
| .to_iterable_dataset(num_shards=min_num_shards) |
| ) |
| else: |
| dataset = shuffle(dataset, seed=job_config.training.seed) |
| else: |
| datasets = job_config.training.dataset.split(",") |
| if job_config.training.dataset_name is not None: |
| dataset_names = [ |
| name or None for name in job_config.training.dataset_name.split(",") |
| ] |
| assert len(dataset_names) == len(datasets), ( |
| "The number of dataset names must match the number of datasets" |
| ) |
| else: |
| dataset_names = [None] * len(datasets) |
| if job_config.training.dataset_split is not None: |
| dataset_splits = [ |
| split or "train" |
| for split in job_config.training.dataset_split.split(",") |
| ] |
| assert len(dataset_splits) == len(datasets), ( |
| "The number of dataset splits must match the number of datasets" |
| ) |
| else: |
| dataset_splits = ["train"] * len(datasets) |
| if job_config.training.data_dir is not None: |
| data_dirs = [ |
| data_dir or None for data_dir in job_config.training.data_dir.split(",") |
| ] |
| assert len(data_dirs) == len(datasets), ( |
| "The number of data dirs must match the number of datasets" |
| ) |
| else: |
| data_dirs = [None] * len(datasets) |
| if job_config.training.data_files is not None: |
| data_files = job_config.training.data_files.split(",") |
| assert len(data_files) == len(datasets), ( |
| "The number of data files must match the number of datasets" |
| ) |
| else: |
| data_files = [None] * len(datasets) |
| if job_config.training.data_probs is not None: |
| data_probs = [float(p) for p in job_config.training.data_probs.split(",")] |
| assert len(data_probs) == len(datasets), ( |
| "The number of data probabilities must match the number of datasets" |
| ) |
| else: |
| raise ValueError( |
| "Data sampling probabilities are required if using multiple datasets" |
| ) |
|
|
| subsets = [] |
| for i, prob in enumerate(data_probs): |
| subset = load_dataset( |
| path=datasets[i], |
| name=dataset_names[i], |
| data_dir=data_dirs[i], |
| data_files=data_files[i], |
| split=dataset_splits[i], |
| trust_remote_code=True, |
| streaming=job_config.training.streaming, |
| num_proc=( |
| job_config.training.num_workers |
| if not job_config.training.streaming |
| else None |
| ), |
| ) |
| logger.info( |
| f"Subset {color.cyan}{datasets[i]}" |
| + (f":{dataset_names[i]} " if dataset_names[i] else " ") |
| + f"(p = {prob:.3f}){color.reset}:\n" |
| + f"{subset}" |
| ) |
|
|
| logger.info(f"Shuffling the dataset with seed {job_config.training.seed}") |
| if not job_config.training.streaming: |
| |
| subset = subset.shuffle( |
| seed=job_config.training.seed |
| ).to_iterable_dataset(num_shards=min_num_shards) |
| else: |
| if subset.num_shards < min_num_shards: |
| logger.warning( |
| f"{color.red}" |
| f"Dataset {datasets[i]} has insufficient shards ({subset.num_shards}). " |
| f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × " |
| f"{job_config.training.num_workers} dataloader workers. " |
| f"Resharding dataset to {min_num_shards} shards and disabling streaming mode." |
| f"{color.reset}" |
| ) |
| |
| |
| subset = ( |
| load_dataset( |
| path=datasets[i], |
| name=dataset_names[i], |
| data_dir=data_dirs[i], |
| data_files=data_files[i], |
| split=dataset_splits[i], |
| trust_remote_code=True, |
| streaming=False, |
| num_proc=job_config.training.num_workers, |
| ) |
| .shuffle(seed=job_config.training.seed) |
| .to_iterable_dataset(min_num_shards) |
| ) |
| else: |
| |
| subset = shuffle( |
| subset, |
| seed=job_config.training.seed, |
| buffer_size=max(128, 1024 // len(datasets)), |
| ) |
|
|
| if "text" in subset.column_names: |
| subset = subset.select_columns("text") |
| elif "content" in subset.column_names: |
| subset = subset.select_columns("content") |
| else: |
| raise ValueError( |
| f"Subset {datasets[i]} has no 'text' or 'content' column" |
| ) |
| subsets.append(subset) |
|
|
| logger.info( |
| f"Interleaving {len(subsets)} datasets with probabilities {data_probs}" |
| ) |
| dataset = interleave_datasets( |
| datasets=subsets, |
| probabilities=data_probs, |
| stopping_strategy="all_exhausted", |
| seed=job_config.training.seed, |
| ) |
| logger.info(f"{dataset}") |
|
|
|
|
| logger.info(f"Loading model config from {job_config.model.config}") |
| model_config = AutoConfig.from_pretrained(job_config.model.config) |
|
|
| logger.info("Building dataloader...") |
| if job_config.training.dataset_mode.lower() == "pretrain": |
| dataloader_seq_len = job_config.training.seq_len * 2 |
| else: |
| dataloader_seq_len = job_config.training.context_len or job_config.training.seq_len |
|
|
| dataloader = build_dataloader( |
| dataset=dataset, |
| tokenizer=tokenizer, |
| rank=dp_rank, |
| world_size=dp_degree, |
| batch_size=job_config.training.batch_size, |
| |
| |
| seq_len=dataloader_seq_len, |
| context_len=job_config.training.context_len, |
| varlen=job_config.training.varlen, |
| num_workers=job_config.training.num_workers, |
| pin_memory=job_config.training.pin_memory, |
| persistent_workers=job_config.training.persistent_workers, |
| snapshot_every_n_steps=job_config.checkpoint.interval, |
| dataset_mode=job_config.training.dataset_mode, |
| ) |
|
|
| |
| |
| |
| |
| |
| if parallel_dims.tp_enabled: |
| if model_config.fuse_norm: |
| logger.warning( |
| f"{color.red}" |
| f"Fused norm is not compatible with tensor parallelism. " |
| f"Disabling it for now." |
| f"{color.reset}" |
| ) |
| model_config.fuse_norm = False |
| if parallel_dims.loss_parallel_enabled: |
| if model_config.fuse_cross_entropy: |
| logger.warning( |
| f"{color.red}" |
| f"Loss parallel enabled. Disabling fused cross entropy for now." |
| f"{color.reset}" |
| ) |
| model_config.fuse_cross_entropy = False |
| model_config.vocab_size = max(tokenizer.vocab_size, model_config.vocab_size) |
|
|
| logger.info( |
| f"Building model from the config\n{color.green}{model_config}{color.reset}" |
| ) |
| with torch.device("meta"): |
| model = AutoModelForCausalLM.from_config(model_config) |
| if ( |
| getattr(model_config, "fuse_cross_entropy", False) |
| and FusedLinearCrossEntropyLoss is not None |
| ): |
| model.criterion = FusedLinearCrossEntropyLoss( |
| num_chunks=8 // parallel_dims.tp |
| ) |
| |
| model.apply(lambda m: setattr(m, "_is_hf_initialized", False)) |
| logger.info(f"{color.blue}\n{model}{color.reset}\n") |
|
|
| |
| model_converters = build_model_converters(job_config, parallel_dims) |
| model_converters.convert(model) |
|
|
| |
| model_param_count, num_flops_per_token = get_nparams_and_flops( |
| model, model_config, job_config.training.context_len |
| ) |
|
|
| |
| if job_config.checkpoint.create_seed_checkpoint: |
| init_device = "cpu" |
| elif job_config.training.enable_cpu_offload: |
| init_device = "cpu" |
| else: |
| init_device = device_type |
|
|
| |
| if parallel_dims.pp_enabled: |
| |
| ( |
| pp_schedule, |
| model_parts, |
| has_first_stage, |
| has_last_stage, |
| ) = train_spec.pipelining_fn( |
| model, |
| pp_mesh, |
| parallel_dims, |
| job_config, |
| device, |
| model_config, |
| train_spec.loss_fn, |
| ) |
| |
| del model |
|
|
| |
| |
| |
| for m in model_parts: |
| |
| train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config) |
| m.to_empty(device=init_device) |
| with torch.no_grad(): |
| m.post_init() |
| m.train() |
|
|
| |
| ensure_pp_loss_visible(parallel_dims, job_config, color) |
| else: |
| |
| train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) |
| model.to_empty(device=init_device) |
| with torch.no_grad(): |
| model.post_init() |
| model.train() |
|
|
| model_parts = [model] |
|
|
| device_mem_stats = device_memory_monitor.get_peak_stats() |
| logger.info( |
| f"{device_type.upper()} memory usage for model: " |
| f"{device_mem_stats.max_reserved_gib:.2f}GiB" |
| f"({device_mem_stats.max_reserved_pct:.2f}%)" |
| ) |
|
|
| |
| optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager) |
| lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config) |
| |
| |
| |
| optimizers.register_step_post_hook( |
| lambda *args, **kwargs: model_converters.post_optimizer_hook(model_parts) |
| ) |
|
|
| train_state = TrainState() |
|
|
| |
| checkpoint = CheckpointManager( |
| dataloader=dataloader, |
| model_parts=model_parts, |
| optimizers=optimizers, |
| lr_schedulers=lr_schedulers, |
| states={"train_state": train_state}, |
| job_config=job_config, |
| ft_manager=ft_manager, |
| ) |
|
|
| if job_config.checkpoint.create_seed_checkpoint: |
| assert world_size == 1, ( |
| "Must create seed checkpoint using a single device, to disable sharding" |
| ) |
| assert job_config.checkpoint.enable_checkpoint, ( |
| "Must enable checkpointing when creating a seed checkpoint" |
| ) |
| checkpoint.save(curr_step=0, force=True) |
| logger.info("Created seed checkpoint") |
| return |
|
|
| checkpoint.load(step=job_config.checkpoint.load_step) |
| metric_logger = build_metrics_processor(job_config, parallel_dims) |
| |
| metric_logger.num_flops_per_token = num_flops_per_token |
| metric_logger.optimizers = optimizers |
| metric_logger.lr_schedulers = ( |
| lr_schedulers |
| ) |
|
|
| |
| |
| |
| if train_state.step > 0 and len(metric_logger.data_loading_times) > 0: |
| for idx, step in enumerate(train_state.log_steps): |
| metric_logger.log( |
| step, |
| global_avg_loss=train_state.global_avg_losses[idx], |
| global_max_loss=train_state.global_max_losses[idx], |
| ) |
|
|
| data_iterator = iter(dataloader) |
|
|
| train_context = dist_utils.get_train_context( |
| parallel_dims.loss_parallel_enabled, |
| job_config.experimental.enable_compiled_autograd, |
| ) |
|
|
| |
| device_memory_monitor.reset_peak_stats() |
|
|
| global_batch_size = ( |
| job_config.training.batch_size |
| * dp_degree |
| * job_config.training.gradient_accumulation_steps |
| ) |
| num_tokens_per_step = global_batch_size * job_config.training.seq_len |
| |
| logger.info(f"{color.red}***** Running training *****{color.reset}") |
| logger.info(f"{color.green} Training starts at step {train_state.step + 1}") |
| logger.info( |
| f"{color.green} Number of tokens per sequence = {job_config.training.seq_len:,}" |
| ) |
| logger.info( |
| f"{color.green} Gradient Accumulation steps = {job_config.training.gradient_accumulation_steps}" |
| ) |
| logger.info( |
| f"{color.green} Instantaneous batch size (per device) = {job_config.training.batch_size:,}" |
| ) |
| logger.info( |
| f"{color.green} Global batch size (w. parallel, distributed & accumulation) = {global_batch_size:,}" |
| f" ({num_tokens_per_step:,} tokens)" |
| ) |
| logger.info( |
| f"{color.green} Total optimization steps = {job_config.training.steps:,} " |
| f"({job_config.training.steps * num_tokens_per_step:,} tokens)" |
| ) |
| logger.info( |
| f"{color.green} Warmup steps = {job_config.lr_scheduler.warmup_steps:,}" |
| f" ({job_config.lr_scheduler.warmup_steps * num_tokens_per_step:,} tokens)" |
| ) |
| logger.info( |
| f"{color.green} Number of parameters = {model_param_count:,} {color.reset}" |
| ) |
|
|
| with ( |
| maybe_enable_profiling( |
| job_config, global_step=train_state.step |
| ) as torch_profiler, |
| maybe_enable_memory_snapshot( |
| job_config, global_step=train_state.step |
| ) as memory_profiler, |
| ): |
| while train_state.step < job_config.training.steps: |
| train_state.step += 1 |
| gc_handler.run(train_state.step) |
|
|
| optimizers.zero_grad() |
|
|
| losses = defaultdict(list) |
| actual_loss = [] |
| |
| for _ in range(job_config.training.gradient_accumulation_steps): |
| |
| data_load_start = time.perf_counter() |
| batch = next(data_iterator) |
| |
| |
| |
| input_ids, labels = batch["input_ids"][:, :job_config.training.seq_len], batch["labels"] |
|
|
| |
| metric_logger.ntokens_since_last_log += input_ids.numel() |
| metric_logger.data_loading_times.append( |
| time.perf_counter() - data_load_start |
| ) |
|
|
| input_ids = input_ids.to(device_type) |
|
|
| """ |
| TODO[flame]: We need to carefully handle the position_ids for TP/CP |
| Depending on the Models'PE, the position_ids might be different. |
| |
| e.g. for TP |
| For RoPE, all ranks have the same position_ids. [FOR HF model] |
| For sinusoidal, each rank has the coresponding chunked position_ids. [FOR HF model] |
| |
| e.g. for CP, [optional_context_parallel_ctx shoudl automatically distbute the position_ids] |
| Each rank has the coresponding chunked position_ids. [FOR All model] |
| |
| """ |
| labels = labels.to(device_type) |
| cu_seqlens = ( |
| batch["cu_seqlens"].to(device_type) |
| if "cu_seqlens" in batch |
| else None |
| ) |
| if cu_seqlens is not None: |
| position_ids = prepare_position_ids(cu_seqlens).to(torch.int32) |
| else: |
| position_ids = ( |
| torch.arange(0, input_ids.shape[1], device=device_type) |
| .repeat(input_ids.shape[0], 1) |
| .to(torch.int32) |
| ) |
| |
| |
| optional_context_parallel_ctx = ( |
| dist_utils.create_context_parallel_ctx( |
| cp_mesh=world_mesh["cp"], |
| cp_buffers=[input_ids, labels, position_ids], |
| cp_seq_dims=[1, 1, 1], |
| cp_no_restore_buffers={input_ids, labels, position_ids}, |
| cp_rotate_method=job_config.experimental.context_parallel_rotate_method, |
| ) |
| if parallel_dims.cp_enabled |
| else None |
| ) |
|
|
| |
| if parallel_dims.pp_enabled: |
| raise NotImplementedError( |
| "Pipeline parallelism is not supported in this version" |
| ) |
| |
| with train_context(optional_context_parallel_ctx): |
| targets, losses = ( |
| (labels, []) if has_last_stage else (None, None) |
| ) |
|
|
| if has_first_stage: |
| pp_schedule.step(input_ids, target=targets, losses=losses) |
| else: |
| pp_schedule.step(target=targets, losses=losses) |
|
|
| |
| |
| loss = ( |
| torch.mean(torch.stack(losses)).to(device) |
| if has_last_stage |
| else torch.tensor([-1.0], device=device) |
| ) |
| else: |
| |
| with train_context(optional_context_parallel_ctx): |
| output = model( |
| input_ids=input_ids, |
| labels=labels, |
| position_ids=position_ids, |
| cu_seqlens=cu_seqlens, |
| ) |
| output_attributes = [field.name for field in dataclasses.fields(output)] |
| losses_atributes = [x for x in output_attributes if "loss" in x and x != "loss"] |
| loss = ( |
| output.loss |
| / job_config.training.gradient_accumulation_steps |
| ) |
| loss.backward() |
|
|
| actual_loss.append(loss) |
| for loss_attr in losses_atributes: |
| custom_loss = getattr(output, loss_attr, None) |
| if custom_loss is not None: |
| custom_loss = custom_loss / job_config.training.gradient_accumulation_steps |
| custom_loss = custom_loss |
| losses[loss_attr].append(custom_loss) |
|
|
| loss = sum(actual_loss) |
| for loss_attr, loss_values in losses.items(): |
| losses[loss_attr] = sum(loss_values) |
|
|
| |
| grad_norm = dist_utils.clip_grad_norm_( |
| [p for m in model_parts for p in m.parameters()], |
| job_config.training.max_norm, |
| foreach=True, |
| pp_mesh=pp_mesh if parallel_dims.pp_enabled else None, |
| ) |
|
|
| |
| checkpoint.maybe_wait_for_staging() |
| if job_config.training.skip_nan_inf and ( |
| grad_norm.isnan() or grad_norm.isinf() |
| ): |
| logger.warning( |
| f"Skipping optimizer step - detected invalid gradient norm: {grad_norm:.4f}" |
| ) |
| optimizers.zero_grad() |
| train_state.skipped_step += 1 |
| else: |
| optimizers.step() |
| lr_schedulers.step() |
|
|
| |
| global_avg_custom_loss = {} |
| global_max_custom_loss = {} |
| if metric_logger.should_log(train_state.step): |
| if ( |
| parallel_dims.dp_replicate_enabled |
| or parallel_dims.dp_shard_enabled |
| or parallel_dims.cp_enabled |
| ): |
| loss = loss.detach() |
| |
| global_avg_loss, global_max_loss = ( |
| dist_utils.dist_mean( |
| loss, |
| world_mesh["dp_cp"], |
| ), |
| dist_utils.dist_max( |
| loss, |
| world_mesh["dp_cp"], |
| ), |
| ) |
| for loss_attr, loss_value in losses.items(): |
| global_avg_custom_loss[loss_attr] = dist_utils.dist_mean( |
| loss_value, world_mesh["dp_cp"] |
| ) |
| global_max_custom_loss[loss_attr] = dist_utils.dist_max( |
| loss_value, world_mesh["dp_cp"] |
| ) |
| else: |
| |
| global_avg_loss = global_max_loss = loss.item() |
| for loss_attr, loss_value in losses.items(): |
| global_avg_custom_loss[loss_attr] = global_max_custom_loss[ |
| loss_attr |
| ] = loss_value.item() |
|
|
| |
| time_now = time.perf_counter() |
| time_delta = ( |
| time_now - metric_logger.time_last_log |
| ) |
| train_state.token += ( |
| metric_logger.ntokens_since_last_log |
| * parallel_dims.world_size |
| / parallel_dims.non_data_parallel_size |
| ) |
| train_state.elapsed += timedelta(seconds=time_delta) |
| train_state.log_steps.append(train_state.step) |
| train_state.global_avg_losses.append(global_avg_loss) |
| train_state.global_max_losses.append(global_max_loss) |
|
|
| |
| last_lr = lr_schedulers.schedulers[0].get_last_lr()[0] |
| eta = ( |
| train_state.elapsed |
| * (job_config.training.steps - train_state.step) |
| / train_state.step |
| ) |
| extra_metrics = { |
| "optimizer/lr": last_lr, |
| "optimizer/grad_norm": grad_norm.item(), |
| "optimizer/skipped_step": train_state.skipped_step, |
| } |
| for loss_attr, loss_value in global_avg_custom_loss.items(): |
| extra_metrics[f"loss_metrics/global_avg_{loss_attr}"] = loss_value.item() if isinstance(loss_value, torch.Tensor) else loss_value |
| metric_logger.log( |
| train_state.step, |
| global_avg_loss, |
| global_max_loss, |
| extra_metrics=extra_metrics, |
| ) |
|
|
| logger.info( |
| f"{color.blue}lr: {last_lr:.4e} gnorm: {grad_norm:5.2f} " |
| f"{color.magenta}[{str(train_state.elapsed).split('.')[0]:>8}<{str(eta).split('.')[0]:>8}]{color.reset}" |
| ) |
|
|
| checkpoint.save( |
| train_state.step, force=(train_state.step == job_config.training.steps) |
| ) |
| |
| if torch.distributed.get_rank() == 0: |
| if job_config.checkpoint.enable_checkpoint: |
| hf_target_path = None |
| dcp_save_path = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder, f"step-{train_state.step}") |
|
|
| |
| if getattr(job_config.checkpoint, "convert_to_hf_on_save", False): |
| try: |
| |
| |
| hf_target_path = f"{dcp_save_path}" |
|
|
| logger.info(f"Converting step {train_state.step} DCP checkpoint to HF format at: {hf_target_path}") |
| save_pretrained( |
| path=hf_target_path, |
| step=train_state.step, |
| config=job_config.model.config, |
| tokenizer=job_config.model.tokenizer_path |
| ) |
| logger.info(f"Successfully converted step {train_state.step} to HF format.") |
|
|
| except Exception as e: |
| logger.error(f"Failed to convert checkpoint step {train_state.step} to HF format: {e}", exc_info=True) |
|
|
| base_checkpoint_dir = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder) |
| if getattr(job_config.checkpoint, "hf_upload_enabled", True): |
| upload_format = getattr(job_config.checkpoint, "hf_upload_format", "hf") |
| keep_k_hub = getattr(job_config.checkpoint, "hf_keep_latest_k", 5) |
|
|
| local_path_to_upload = None |
| if upload_format == "hf": |
| if hf_target_path and os.path.isdir(hf_target_path): |
| local_path_to_upload = hf_target_path |
| elif upload_format == "dcp": |
| if dcp_save_path and os.path.isdir(dcp_save_path): |
| local_path_to_upload = dcp_save_path |
|
|
| if local_path_to_upload: |
| try: |
| upload_checkpoint_to_hf( |
| local_path=local_path_to_upload, |
| step=train_state.step, |
| hf_repo_id_for_run=run_specific_repo_id, |
| upload_format=upload_format, |
| hf_keep_latest_k=job_config.checkpoint.keep_latest_k, |
| ) |
| except Exception as e: |
| logger.error(f"Failed during HF Hub upload for step {train_state.step}: {e}", exc_info=True) |
|
|
| |
| if torch_profiler: |
| torch_profiler.step() |
| if memory_profiler: |
| memory_profiler.step() |
|
|
| |
| |
| if train_state.step == 1: |
| dist_utils.set_pg_timeouts( |
| timeout=timedelta(seconds=job_config.comm.train_timeout_seconds), |
| world_mesh=world_mesh, |
| ) |
|
|
| if torch.distributed.get_rank() == 0: |
| logger.info("Sleeping 2 seconds for other ranks to complete") |
| time.sleep(2) |
|
|
| metric_logger.close() |
| logger.info("Training completed") |
|
|
|
|
| if __name__ == "__main__": |
| init_logger() |
| config = JobConfig() |
| config.parse_args() |
| main(config) |
| torch.distributed.destroy_process_group() |
|
|