| | |
| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import os |
| | import time |
| | from datetime import timedelta |
| | from collections import defaultdict |
| | import dataclasses |
| |
|
| | import torch |
| | from datasets import interleave_datasets, load_dataset |
| | from torch.distributed.elastic.multiprocessing.errors import record |
| | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
| |
|
| | import fla |
| | from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss |
| | from fla.ops.common.utils import prepare_position_ids |
| | from flame.components.checkpoint import TrainState |
| | from flame.config_manager import JobConfig |
| | from flame.data import build_dataloader, shuffle |
| | from flame.models.parallelize_fla import parallelize_fla |
| | from flame.models.pipeline_fla import pipeline_fla |
| | from flame.tools.utils import get_nparams_and_flops |
| | from flame.utils.checkpoint import cleanup_local_checkpoints |
| | from flame.utils.convert_dcp_to_hf import save_pretrained |
| | from flame.utils.hf_utils import upload_checkpoint_to_hf |
| | from datetime import datetime |
| | from torchtitan.components.checkpoint import CheckpointManager |
| | from torchtitan.components.ft import FTParallelDims, init_ft_manager |
| | from torchtitan.components.loss import build_cross_entropy_loss |
| | from torchtitan.components.lr_scheduler import build_lr_schedulers |
| | from torchtitan.components.metrics import build_device_memory_monitor, build_metrics_processor, ensure_pp_loss_visible |
| | from torchtitan.components.optimizer import build_optimizers |
| | from torchtitan.distributed import ParallelDims |
| | from torchtitan.distributed import utils as dist_utils |
| | from torchtitan.protocols.model_converter import build_model_converters |
| | from torchtitan.protocols.train_spec import TrainSpec, get_train_spec, register_train_spec |
| | from torchtitan.tools import utils |
| | from torchtitan.tools.logging import init_logger, logger |
| | from torchtitan.tools.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling |
| |
|
| | from dotenv import load_dotenv |
| | load_dotenv() |
| |
|
| | import wandb |
| | wandb.login(key=os.environ["WANDB_API_KEY"]) |
| |
|
| | import huggingface_hub |
| | huggingface_hub.login(token=os.environ["HF_TOKEN"]) |
| |
|
| |
|
| | def build_tokenizer(job_config: JobConfig) -> AutoTokenizer: |
| | return AutoTokenizer.from_pretrained(job_config.model.tokenizer_path) |
| |
|
| |
|
| | register_train_spec( |
| | TrainSpec( |
| | name="fla", |
| | cls=AutoModelForCausalLM, |
| | config=AutoConfig, |
| | parallelize_fn=parallelize_fla, |
| | pipelining_fn=pipeline_fla, |
| | build_optimizers_fn=build_optimizers, |
| | build_lr_schedulers_fn=build_lr_schedulers, |
| | build_dataloader_fn=build_dataloader, |
| | build_tokenizer_fn=build_tokenizer, |
| | build_loss_fn=build_cross_entropy_loss, |
| | ) |
| | ) |
| |
|
| |
|
| | |
| | @record |
| | def main(job_config: JobConfig): |
| | logger.info(f"Starting job: {job_config.job.description}") |
| |
|
| | if job_config.experimental.custom_model_path: |
| | utils.import_module_from_path(job_config.experimental.custom_model_path) |
| |
|
| | |
| | color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color |
| |
|
| | if job_config.job.print_args: |
| | logger.info( |
| | f"{color.green}{json.dumps(job_config.to_dict(), indent=2, sort_keys=True)}{color.reset}" |
| | ) |
| |
|
| | |
| | gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) |
| |
|
| | device_module, device_type = utils.device_module, utils.device_type |
| | device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") |
| | |
| | device_module.set_device(device) |
| | ft_manager = init_ft_manager(job_config) |
| |
|
| | run_specific_repo_id = None |
| | if getattr(job_config.checkpoint, "hf_upload_enabled", False): |
| | hf_repo_base = getattr(job_config.checkpoint, "hf_repo_base_name", None) |
| | if hf_repo_base: |
| | |
| | timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") |
| | run_specific_repo_id = f"{hf_repo_base}-{timestamp}" |
| | logger.info(f"Target Hugging Face repository for this run: {run_specific_repo_id}") |
| | else: |
| | logger.warning("HF Hub upload enabled, but 'checkpoint.hf_repo_base_name' is not set.") |
| | |
| | job_config.checkpoint.hf_upload_enabled = False |
| |
|
| | |
| | world_size = int(os.environ["WORLD_SIZE"]) |
| | if not ft_manager.enabled: |
| | parallel_dims = ParallelDims( |
| | dp_shard=job_config.training.data_parallel_shard_degree, |
| | dp_replicate=job_config.training.data_parallel_replicate_degree, |
| | cp=job_config.experimental.context_parallel_degree, |
| | tp=job_config.training.tensor_parallel_degree, |
| | pp=job_config.experimental.pipeline_parallel_degree, |
| | world_size=world_size, |
| | enable_loss_parallel=not job_config.training.disable_loss_parallel, |
| | ) |
| | else: |
| | parallel_dims = FTParallelDims( |
| | dp_shard=job_config.training.data_parallel_shard_degree, |
| | dp_replicate=job_config.training.data_parallel_replicate_degree, |
| | cp=job_config.experimental.context_parallel_degree, |
| | tp=job_config.training.tensor_parallel_degree, |
| | pp=job_config.experimental.pipeline_parallel_degree, |
| | world_size=world_size, |
| | enable_loss_parallel=not job_config.training.disable_loss_parallel, |
| | ft_manager=ft_manager, |
| | ) |
| | dist_utils.init_distributed(job_config) |
| | |
| | device_memory_monitor = build_device_memory_monitor() |
| | gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) |
| | logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") |
| |
|
| | |
| | world_mesh = parallel_dims.build_mesh(device_type=device_type) |
| | if parallel_dims.dp_enabled: |
| | dp_mesh = world_mesh["dp"] |
| | dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() |
| | else: |
| | dp_degree, dp_rank = 1, 0 |
| |
|
| | if parallel_dims.pp_enabled: |
| | raise NotImplementedError( |
| | "Pipeline parallelism is not supported in this version" |
| | ) |
| | """ |
| | ! TODO[flame]: We need to fix the pipeline parallelism for flame |
| | [x] Match the key of models' components with the actual naming |
| | [ ] Fix the post-init and tie-embedding for pipeline parallelism, HF's transformer automatically |
| | forces to tie if head is None, we need to handle this case |
| | [ ] |
| | """ |
| | pp_mesh = world_mesh["pp"] |
| |
|
| | |
| | dist_utils.set_determinism( |
| | world_mesh, device, job_config.training.seed, job_config.training.deterministic |
| | ) |
| | train_spec = get_train_spec(job_config.model.name) |
| |
|
| | logger.info("Loading tokenizer...") |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | job_config.model.tokenizer_path, |
| | trust_remote_code=True, |
| | model_max_length=int(1e10), |
| | ) |
| | logger.info(f"{tokenizer}") |
| | logger.info( |
| | f"Loading dataset {job_config.training.dataset}" |
| | f":{job_config.training.dataset_name}" |
| | if job_config.training.dataset_name is not None |
| | else "" |
| | ) |
| |
|
| | min_num_shards = dp_degree * job_config.training.num_workers |
| | if len(job_config.training.dataset.split(",")) == 1: |
| | dataset = load_dataset( |
| | path=job_config.training.dataset, |
| | name=getattr(job_config.training, "dataset_name", None), |
| | data_dir=getattr(job_config.training, "data_dir", None), |
| | data_files=getattr(job_config.training, "data_files", None), |
| | split=job_config.training.dataset_split or "train", |
| | trust_remote_code=True, |
| | streaming=job_config.training.streaming, |
| | num_proc=( |
| | job_config.training.num_workers |
| | if not job_config.training.streaming |
| | else None |
| | ), |
| | ) |
| | logger.info(f"{dataset}") |
| |
|
| | logger.info(f"Shuffling the dataset with seed {job_config.training.seed}") |
| | if not job_config.training.streaming: |
| | |
| | dataset = dataset.shuffle( |
| | seed=job_config.training.seed |
| | ).to_iterable_dataset(num_shards=min_num_shards) |
| | else: |
| | if dataset.num_shards < min_num_shards: |
| | logger.warning( |
| | f"{color.red}" |
| | f"Dataset {job_config.training.dataset} has insufficient shards ({dataset.num_shards}). " |
| | f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × " |
| | f"{job_config.training.num_workers} dataloader workers. " |
| | f"Disabling the streaming mode and resharding dataset to {min_num_shards} shards." |
| | f"{color.reset}" |
| | ) |
| | dataset = ( |
| | load_dataset( |
| | path=job_config.training.dataset, |
| | name=getattr(job_config.training, "dataset_name", None), |
| | data_dir=getattr(job_config.training, "data_dir", None), |
| | data_files=getattr(job_config.training, "data_files", None), |
| | split=job_config.training.dataset_split or "train", |
| | trust_remote_code=True, |
| | streaming=False, |
| | num_proc=job_config.training.num_workers, |
| | ) |
| | .shuffle(seed=job_config.training.seed) |
| | .to_iterable_dataset(num_shards=min_num_shards) |
| | ) |
| | else: |
| | dataset = shuffle(dataset, seed=job_config.training.seed) |
| | else: |
| | datasets = job_config.training.dataset.split(",") |
| | if job_config.training.dataset_name is not None: |
| | dataset_names = [ |
| | name or None for name in job_config.training.dataset_name.split(",") |
| | ] |
| | assert len(dataset_names) == len(datasets), ( |
| | "The number of dataset names must match the number of datasets" |
| | ) |
| | else: |
| | dataset_names = [None] * len(datasets) |
| | if job_config.training.dataset_split is not None: |
| | dataset_splits = [ |
| | split or "train" |
| | for split in job_config.training.dataset_split.split(",") |
| | ] |
| | assert len(dataset_splits) == len(datasets), ( |
| | "The number of dataset splits must match the number of datasets" |
| | ) |
| | else: |
| | dataset_splits = ["train"] * len(datasets) |
| | if job_config.training.data_dir is not None: |
| | data_dirs = [ |
| | data_dir or None for data_dir in job_config.training.data_dir.split(",") |
| | ] |
| | assert len(data_dirs) == len(datasets), ( |
| | "The number of data dirs must match the number of datasets" |
| | ) |
| | else: |
| | data_dirs = [None] * len(datasets) |
| | if job_config.training.data_files is not None: |
| | data_files = job_config.training.data_files.split(",") |
| | assert len(data_files) == len(datasets), ( |
| | "The number of data files must match the number of datasets" |
| | ) |
| | else: |
| | data_files = [None] * len(datasets) |
| | if job_config.training.data_probs is not None: |
| | data_probs = [float(p) for p in job_config.training.data_probs.split(",")] |
| | assert len(data_probs) == len(datasets), ( |
| | "The number of data probabilities must match the number of datasets" |
| | ) |
| | else: |
| | raise ValueError( |
| | "Data sampling probabilities are required if using multiple datasets" |
| | ) |
| |
|
| | subsets = [] |
| | for i, prob in enumerate(data_probs): |
| | subset = load_dataset( |
| | path=datasets[i], |
| | name=dataset_names[i], |
| | data_dir=data_dirs[i], |
| | data_files=data_files[i], |
| | split=dataset_splits[i], |
| | trust_remote_code=True, |
| | streaming=job_config.training.streaming, |
| | num_proc=( |
| | job_config.training.num_workers |
| | if not job_config.training.streaming |
| | else None |
| | ), |
| | ) |
| | logger.info( |
| | f"Subset {color.cyan}{datasets[i]}" |
| | + (f":{dataset_names[i]} " if dataset_names[i] else " ") |
| | + f"(p = {prob:.3f}){color.reset}:\n" |
| | + f"{subset}" |
| | ) |
| |
|
| | logger.info(f"Shuffling the dataset with seed {job_config.training.seed}") |
| | if not job_config.training.streaming: |
| | |
| | subset = subset.shuffle( |
| | seed=job_config.training.seed |
| | ).to_iterable_dataset(num_shards=min_num_shards) |
| | else: |
| | if subset.num_shards < min_num_shards: |
| | logger.warning( |
| | f"{color.red}" |
| | f"Dataset {datasets[i]} has insufficient shards ({subset.num_shards}). " |
| | f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × " |
| | f"{job_config.training.num_workers} dataloader workers. " |
| | f"Resharding dataset to {min_num_shards} shards and disabling streaming mode." |
| | f"{color.reset}" |
| | ) |
| | |
| | |
| | subset = ( |
| | load_dataset( |
| | path=datasets[i], |
| | name=dataset_names[i], |
| | data_dir=data_dirs[i], |
| | data_files=data_files[i], |
| | split=dataset_splits[i], |
| | trust_remote_code=True, |
| | streaming=False, |
| | num_proc=job_config.training.num_workers, |
| | ) |
| | .shuffle(seed=job_config.training.seed) |
| | .to_iterable_dataset(min_num_shards) |
| | ) |
| | else: |
| | |
| | subset = shuffle( |
| | subset, |
| | seed=job_config.training.seed, |
| | buffer_size=max(128, 1024 // len(datasets)), |
| | ) |
| |
|
| | if "text" in subset.column_names: |
| | subset = subset.select_columns("text") |
| | elif "content" in subset.column_names: |
| | subset = subset.select_columns("content") |
| | else: |
| | raise ValueError( |
| | f"Subset {datasets[i]} has no 'text' or 'content' column" |
| | ) |
| | subsets.append(subset) |
| |
|
| | logger.info( |
| | f"Interleaving {len(subsets)} datasets with probabilities {data_probs}" |
| | ) |
| | dataset = interleave_datasets( |
| | datasets=subsets, |
| | probabilities=data_probs, |
| | stopping_strategy="all_exhausted", |
| | seed=job_config.training.seed, |
| | ) |
| | logger.info(f"{dataset}") |
| |
|
| |
|
| | logger.info(f"Loading model config from {job_config.model.config}") |
| | model_config = AutoConfig.from_pretrained(job_config.model.config) |
| |
|
| | logger.info("Building dataloader...") |
| | dataloader = build_dataloader( |
| | dataset=dataset, |
| | tokenizer=tokenizer, |
| | rank=dp_rank, |
| | world_size=dp_degree, |
| | batch_size=job_config.training.batch_size, |
| | |
| | |
| | seq_len=job_config.training.seq_len * 2, |
| | context_len=job_config.training.context_len, |
| | varlen=job_config.training.varlen, |
| | num_workers=job_config.training.num_workers, |
| | pin_memory=job_config.training.pin_memory, |
| | persistent_workers=job_config.training.persistent_workers, |
| | snapshot_every_n_steps=job_config.checkpoint.interval, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | if parallel_dims.tp_enabled: |
| | if model_config.fuse_norm: |
| | logger.warning( |
| | f"{color.red}" |
| | f"Fused norm is not compatible with tensor parallelism. " |
| | f"Disabling it for now." |
| | f"{color.reset}" |
| | ) |
| | model_config.fuse_norm = False |
| | if parallel_dims.loss_parallel_enabled: |
| | if model_config.fuse_cross_entropy: |
| | logger.warning( |
| | f"{color.red}" |
| | f"Loss parallel enabled. Disabling fused cross entropy for now." |
| | f"{color.reset}" |
| | ) |
| | model_config.fuse_cross_entropy = False |
| | model_config.vocab_size = max(tokenizer.vocab_size, model_config.vocab_size) |
| |
|
| | logger.info( |
| | f"Building model from the config\n{color.green}{model_config}{color.reset}" |
| | ) |
| | with torch.device("meta"): |
| | model = AutoModelForCausalLM.from_config(model_config) |
| | if ( |
| | getattr(model_config, "fuse_cross_entropy", False) |
| | and FusedLinearCrossEntropyLoss is not None |
| | ): |
| | model.criterion = FusedLinearCrossEntropyLoss( |
| | num_chunks=8 // parallel_dims.tp |
| | ) |
| | |
| | model.apply(lambda m: setattr(m, "_is_hf_initialized", False)) |
| | logger.info(f"{color.blue}\n{model}{color.reset}\n") |
| |
|
| | |
| | model_converters = build_model_converters(job_config, parallel_dims) |
| | model_converters.convert(model) |
| |
|
| | |
| | model_param_count, num_flops_per_token = get_nparams_and_flops( |
| | model, model_config, job_config.training.context_len |
| | ) |
| |
|
| | |
| | if job_config.checkpoint.create_seed_checkpoint: |
| | init_device = "cpu" |
| | elif job_config.training.enable_cpu_offload: |
| | init_device = "cpu" |
| | else: |
| | init_device = device_type |
| |
|
| | |
| | if parallel_dims.pp_enabled: |
| | |
| | ( |
| | pp_schedule, |
| | model_parts, |
| | has_first_stage, |
| | has_last_stage, |
| | ) = train_spec.pipelining_fn( |
| | model, |
| | pp_mesh, |
| | parallel_dims, |
| | job_config, |
| | device, |
| | model_config, |
| | train_spec.loss_fn, |
| | ) |
| | |
| | del model |
| |
|
| | |
| | |
| | |
| | for m in model_parts: |
| | |
| | train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config) |
| | m.to_empty(device=init_device) |
| | with torch.no_grad(): |
| | m.post_init() |
| | m.train() |
| |
|
| | |
| | ensure_pp_loss_visible(parallel_dims, job_config, color) |
| | else: |
| | |
| | train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) |
| | model.to_empty(device=init_device) |
| | with torch.no_grad(): |
| | model.post_init() |
| | model.train() |
| |
|
| | model_parts = [model] |
| |
|
| | device_mem_stats = device_memory_monitor.get_peak_stats() |
| | logger.info( |
| | f"{device_type.upper()} memory usage for model: " |
| | f"{device_mem_stats.max_reserved_gib:.2f}GiB" |
| | f"({device_mem_stats.max_reserved_pct:.2f}%)" |
| | ) |
| |
|
| | |
| | optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager) |
| | lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config) |
| | |
| | |
| | |
| | optimizers.register_step_post_hook( |
| | lambda *args, **kwargs: model_converters.post_optimizer_hook(model_parts) |
| | ) |
| |
|
| | train_state = TrainState() |
| |
|
| | |
| | checkpoint = CheckpointManager( |
| | dataloader=dataloader, |
| | model_parts=model_parts, |
| | optimizers=optimizers, |
| | lr_schedulers=lr_schedulers, |
| | states={"train_state": train_state}, |
| | job_config=job_config, |
| | ft_manager=ft_manager, |
| | ) |
| |
|
| | if job_config.checkpoint.create_seed_checkpoint: |
| | assert world_size == 1, ( |
| | "Must create seed checkpoint using a single device, to disable sharding" |
| | ) |
| | assert job_config.checkpoint.enable_checkpoint, ( |
| | "Must enable checkpointing when creating a seed checkpoint" |
| | ) |
| | checkpoint.save(curr_step=0, force=True) |
| | logger.info("Created seed checkpoint") |
| | return |
| |
|
| | checkpoint.load(step=job_config.checkpoint.load_step) |
| | metric_logger = build_metrics_processor(job_config, parallel_dims) |
| | |
| | metric_logger.num_flops_per_token = num_flops_per_token |
| | metric_logger.optimizers = optimizers |
| | metric_logger.lr_schedulers = ( |
| | lr_schedulers |
| | ) |
| |
|
| | |
| | |
| | |
| | if train_state.step > 0 and len(metric_logger.data_loading_times) > 0: |
| | for idx, step in enumerate(train_state.log_steps): |
| | metric_logger.log( |
| | step, |
| | global_avg_loss=train_state.global_avg_losses[idx], |
| | global_max_loss=train_state.global_max_losses[idx], |
| | ) |
| |
|
| | data_iterator = iter(dataloader) |
| |
|
| | train_context = dist_utils.get_train_context( |
| | parallel_dims.loss_parallel_enabled, |
| | job_config.experimental.enable_compiled_autograd, |
| | ) |
| |
|
| | |
| | device_memory_monitor.reset_peak_stats() |
| |
|
| | global_batch_size = ( |
| | job_config.training.batch_size |
| | * dp_degree |
| | * job_config.training.gradient_accumulation_steps |
| | ) |
| | num_tokens_per_step = global_batch_size * job_config.training.seq_len |
| | |
| | logger.info(f"{color.red}***** Running training *****{color.reset}") |
| | logger.info(f"{color.green} Training starts at step {train_state.step + 1}") |
| | logger.info( |
| | f"{color.green} Number of tokens per sequence = {job_config.training.seq_len:,}" |
| | ) |
| | logger.info( |
| | f"{color.green} Gradient Accumulation steps = {job_config.training.gradient_accumulation_steps}" |
| | ) |
| | logger.info( |
| | f"{color.green} Instantaneous batch size (per device) = {job_config.training.batch_size:,}" |
| | ) |
| | logger.info( |
| | f"{color.green} Global batch size (w. parallel, distributed & accumulation) = {global_batch_size:,}" |
| | f" ({num_tokens_per_step:,} tokens)" |
| | ) |
| | logger.info( |
| | f"{color.green} Total optimization steps = {job_config.training.steps:,} " |
| | f"({job_config.training.steps * num_tokens_per_step:,} tokens)" |
| | ) |
| | logger.info( |
| | f"{color.green} Warmup steps = {job_config.lr_scheduler.warmup_steps:,}" |
| | f" ({job_config.lr_scheduler.warmup_steps * num_tokens_per_step:,} tokens)" |
| | ) |
| | logger.info( |
| | f"{color.green} Number of parameters = {model_param_count:,} {color.reset}" |
| | ) |
| |
|
| | with ( |
| | maybe_enable_profiling( |
| | job_config, global_step=train_state.step |
| | ) as torch_profiler, |
| | maybe_enable_memory_snapshot( |
| | job_config, global_step=train_state.step |
| | ) as memory_profiler, |
| | ): |
| | while train_state.step < job_config.training.steps: |
| | train_state.step += 1 |
| | gc_handler.run(train_state.step) |
| |
|
| | optimizers.zero_grad() |
| |
|
| | losses = defaultdict(list) |
| | actual_loss = [] |
| | |
| | for _ in range(job_config.training.gradient_accumulation_steps): |
| | |
| | data_load_start = time.perf_counter() |
| | batch = next(data_iterator) |
| | |
| | |
| | |
| | input_ids, labels = batch["input_ids"][:, :job_config.training.seq_len], batch["labels"] |
| |
|
| | |
| | metric_logger.ntokens_since_last_log += input_ids.numel() |
| | metric_logger.data_loading_times.append( |
| | time.perf_counter() - data_load_start |
| | ) |
| |
|
| | input_ids = input_ids.to(device_type) |
| |
|
| | """ |
| | TODO[flame]: We need to carefully handle the position_ids for TP/CP |
| | Depending on the Models'PE, the position_ids might be different. |
| | |
| | e.g. for TP |
| | For RoPE, all ranks have the same position_ids. [FOR HF model] |
| | For sinusoidal, each rank has the coresponding chunked position_ids. [FOR HF model] |
| | |
| | e.g. for CP, [optional_context_parallel_ctx shoudl automatically distbute the position_ids] |
| | Each rank has the coresponding chunked position_ids. [FOR All model] |
| | |
| | """ |
| | labels = labels.to(device_type) |
| | cu_seqlens = ( |
| | batch["cu_seqlens"].to(device_type) |
| | if "cu_seqlens" in batch |
| | else None |
| | ) |
| | if cu_seqlens is not None: |
| | position_ids = prepare_position_ids(cu_seqlens).to(torch.int32) |
| | else: |
| | position_ids = ( |
| | torch.arange(0, input_ids.shape[1], device=device_type) |
| | .repeat(input_ids.shape[0], 1) |
| | .to(torch.int32) |
| | ) |
| | |
| | |
| | optional_context_parallel_ctx = ( |
| | dist_utils.create_context_parallel_ctx( |
| | cp_mesh=world_mesh["cp"], |
| | cp_buffers=[input_ids, labels, position_ids], |
| | cp_seq_dims=[1, 1, 1], |
| | cp_no_restore_buffers={input_ids, labels, position_ids}, |
| | cp_rotate_method=job_config.experimental.context_parallel_rotate_method, |
| | ) |
| | if parallel_dims.cp_enabled |
| | else None |
| | ) |
| |
|
| | |
| | if parallel_dims.pp_enabled: |
| | raise NotImplementedError( |
| | "Pipeline parallelism is not supported in this version" |
| | ) |
| | |
| | with train_context(optional_context_parallel_ctx): |
| | targets, losses = ( |
| | (labels, []) if has_last_stage else (None, None) |
| | ) |
| |
|
| | if has_first_stage: |
| | pp_schedule.step(input_ids, target=targets, losses=losses) |
| | else: |
| | pp_schedule.step(target=targets, losses=losses) |
| |
|
| | |
| | |
| | loss = ( |
| | torch.mean(torch.stack(losses)).to(device) |
| | if has_last_stage |
| | else torch.tensor([-1.0], device=device) |
| | ) |
| | else: |
| | |
| | with train_context(optional_context_parallel_ctx): |
| | output = model( |
| | input_ids=input_ids, |
| | labels=labels, |
| | position_ids=position_ids, |
| | cu_seqlens=cu_seqlens, |
| | ) |
| | output_attributes = [field.name for field in dataclasses.fields(output)] |
| | losses_atributes = [x for x in output_attributes if "loss" in x and x != "loss"] |
| | loss = ( |
| | output.loss |
| | / job_config.training.gradient_accumulation_steps |
| | ) |
| | loss.backward() |
| |
|
| | actual_loss.append(loss) |
| | for loss_attr in losses_atributes: |
| | custom_loss = getattr(output, loss_attr, None) |
| | if custom_loss is not None: |
| | custom_loss = custom_loss / job_config.training.gradient_accumulation_steps |
| | custom_loss = custom_loss |
| | losses[loss_attr].append(custom_loss) |
| |
|
| | loss = sum(actual_loss) |
| | for loss_attr, loss_values in losses.items(): |
| | losses[loss_attr] = sum(loss_values) |
| |
|
| | |
| | grad_norm = dist_utils.clip_grad_norm_( |
| | [p for m in model_parts for p in m.parameters()], |
| | job_config.training.max_norm, |
| | foreach=True, |
| | pp_mesh=pp_mesh if parallel_dims.pp_enabled else None, |
| | ) |
| |
|
| | |
| | checkpoint.maybe_wait_for_staging() |
| | if job_config.training.skip_nan_inf and ( |
| | grad_norm.isnan() or grad_norm.isinf() |
| | ): |
| | logger.warning( |
| | f"Skipping optimizer step - detected invalid gradient norm: {grad_norm:.4f}" |
| | ) |
| | optimizers.zero_grad() |
| | train_state.skipped_step += 1 |
| | else: |
| | optimizers.step() |
| | lr_schedulers.step() |
| |
|
| | |
| | global_avg_custom_loss = {} |
| | global_max_custom_loss = {} |
| | if metric_logger.should_log(train_state.step): |
| | if ( |
| | parallel_dims.dp_replicate_enabled |
| | or parallel_dims.dp_shard_enabled |
| | or parallel_dims.cp_enabled |
| | ): |
| | loss = loss.detach() |
| | |
| | global_avg_loss, global_max_loss = ( |
| | dist_utils.dist_mean( |
| | loss, |
| | world_mesh["dp_cp"], |
| | ), |
| | dist_utils.dist_max( |
| | loss, |
| | world_mesh["dp_cp"], |
| | ), |
| | ) |
| | for loss_attr, loss_value in losses.items(): |
| | global_avg_custom_loss[loss_attr] = dist_utils.dist_mean( |
| | loss_value, world_mesh["dp_cp"] |
| | ) |
| | global_max_custom_loss[loss_attr] = dist_utils.dist_max( |
| | loss_value, world_mesh["dp_cp"] |
| | ) |
| | else: |
| | |
| | global_avg_loss = global_max_loss = loss.item() |
| | for loss_attr, loss_value in losses.items(): |
| | global_avg_custom_loss[loss_attr] = global_max_custom_loss[ |
| | loss_attr |
| | ] = loss_value.item() |
| |
|
| | |
| | time_now = time.perf_counter() |
| | time_delta = ( |
| | time_now - metric_logger.time_last_log |
| | ) |
| | train_state.token += ( |
| | metric_logger.ntokens_since_last_log |
| | * parallel_dims.world_size |
| | / parallel_dims.non_data_parallel_size |
| | ) |
| | train_state.elapsed += timedelta(seconds=time_delta) |
| | train_state.log_steps.append(train_state.step) |
| | train_state.global_avg_losses.append(global_avg_loss) |
| | train_state.global_max_losses.append(global_max_loss) |
| |
|
| | |
| | last_lr = lr_schedulers.schedulers[0].get_last_lr()[0] |
| | eta = ( |
| | train_state.elapsed |
| | * (job_config.training.steps - train_state.step) |
| | / train_state.step |
| | ) |
| | extra_metrics = { |
| | "optimizer/lr": last_lr, |
| | "optimizer/grad_norm": grad_norm.item(), |
| | "optimizer/skipped_step": train_state.skipped_step, |
| | } |
| | for loss_attr, loss_value in global_avg_custom_loss.items(): |
| | extra_metrics[f"loss_metrics/global_avg_{loss_attr}"] = loss_value.item() if isinstance(loss_value, torch.Tensor) else loss_value |
| | metric_logger.log( |
| | train_state.step, |
| | global_avg_loss, |
| | global_max_loss, |
| | extra_metrics=extra_metrics, |
| | ) |
| |
|
| | logger.info( |
| | f"{color.blue}lr: {last_lr:.4e} gnorm: {grad_norm:5.2f} " |
| | f"{color.magenta}[{str(train_state.elapsed).split('.')[0]:>8}<{str(eta).split('.')[0]:>8}]{color.reset}" |
| | ) |
| |
|
| | checkpoint.save( |
| | train_state.step, force=(train_state.step == job_config.training.steps) |
| | ) |
| | |
| | if torch.distributed.get_rank() == 0: |
| | if job_config.checkpoint.enable_checkpoint: |
| | hf_target_path = None |
| | dcp_save_path = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder, f"step-{train_state.step}") |
| |
|
| | |
| | if getattr(job_config.checkpoint, "convert_to_hf_on_save", False): |
| | try: |
| | |
| | |
| | hf_target_path = f"{dcp_save_path}" |
| |
|
| | logger.info(f"Converting step {train_state.step} DCP checkpoint to HF format at: {hf_target_path}") |
| | save_pretrained( |
| | path=hf_target_path, |
| | step=train_state.step, |
| | config=job_config.model.config, |
| | tokenizer=job_config.model.tokenizer_path |
| | ) |
| | logger.info(f"Successfully converted step {train_state.step} to HF format.") |
| |
|
| | except Exception as e: |
| | logger.error(f"Failed to convert checkpoint step {train_state.step} to HF format: {e}", exc_info=True) |
| |
|
| | base_checkpoint_dir = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder) |
| | if getattr(job_config.checkpoint, "hf_upload_enabled", True): |
| | upload_format = getattr(job_config.checkpoint, "hf_upload_format", "hf") |
| | keep_k_hub = getattr(job_config.checkpoint, "hf_keep_latest_k", 5) |
| |
|
| | local_path_to_upload = None |
| | if upload_format == "hf": |
| | if hf_target_path and os.path.isdir(hf_target_path): |
| | local_path_to_upload = hf_target_path |
| | elif upload_format == "dcp": |
| | if dcp_save_path and os.path.isdir(dcp_save_path): |
| | local_path_to_upload = dcp_save_path |
| |
|
| | if local_path_to_upload: |
| | try: |
| | upload_checkpoint_to_hf( |
| | local_path=local_path_to_upload, |
| | step=train_state.step, |
| | hf_repo_id_for_run=run_specific_repo_id, |
| | upload_format=upload_format, |
| | hf_keep_latest_k=job_config.checkpoint.keep_latest_k, |
| | ) |
| | except Exception as e: |
| | logger.error(f"Failed during HF Hub upload for step {train_state.step}: {e}", exc_info=True) |
| |
|
| | |
| | if torch_profiler: |
| | torch_profiler.step() |
| | if memory_profiler: |
| | memory_profiler.step() |
| |
|
| | |
| | |
| | if train_state.step == 1: |
| | dist_utils.set_pg_timeouts( |
| | timeout=timedelta(seconds=job_config.comm.train_timeout_seconds), |
| | world_mesh=world_mesh, |
| | ) |
| |
|
| | if torch.distributed.get_rank() == 0: |
| | logger.info("Sleeping 2 seconds for other ranks to complete") |
| | time.sleep(2) |
| |
|
| | metric_logger.close() |
| | logger.info("Training completed") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | init_logger() |
| | config = JobConfig() |
| | config.parse_args() |
| | main(config) |
| | torch.distributed.destroy_process_group() |
| |
|