| |
| |
|
|
| from copy import deepcopy |
| import gc |
| import logging |
| import os |
| import sys |
| import time |
| from contextlib import ExitStack |
| from dataclasses import asdict, dataclass, field |
| from pathlib import Path |
| from timeit import default_timer as timer |
| from typing import Any, Dict, List, Optional |
|
|
| import numpy as np |
| from omegaconf import OmegaConf |
| import torch |
| import torch.distributed |
| import torch.nn.functional as F |
| import xformers.profiler |
| from torch.optim import lr_scheduler |
| from torch.distributed.checkpoint.stateful import Stateful |
| from torch.distributed._tensor import DTensor |
|
|
| from lingua.args import dataclass_from_dict, dump_config, flatten_dict |
| from lingua.checkpoint import CheckpointArgs, CheckpointManager, load_from_checkpoint |
| from lingua.data import ( |
| DataArgs, |
| PackTokensState, |
| build_dataloader_from_args, |
| init_dataloader_state_from_args, |
| ) |
| from lingua.distributed import ( |
| DistributedArgs, |
| EnvironmentArgs, |
| init_signal_handler, |
| dist_mean_dict, |
| get_device_mesh, |
| get_is_master, |
| get_world_size, |
| parallelize_model, |
| setup_env, |
| setup_torch_distributed, |
| clean_env, |
| requeue_slurm_job, |
| check_model_value_range, |
| ) |
| from lingua.logger import init_logger |
| from lingua.metrics import ( |
| GPUMemoryMonitor, |
| LoggingArgs, |
| MetricLogger, |
| get_num_params, |
| ) |
| from lingua.optim import OptimArgs, build_optimizer |
| from lingua.profiling import ProfilerArgs, maybe_run_profiler |
| from lingua.tokenizer import build_tokenizer |
|
|
| from apps.fastRNN.minGRU.mingru import ( |
| LMMinGRU, |
| LMMinGRUArgs, |
| ) |
| from apps.fastRNN.minLSTM.minlstm import ( |
| LMMinLSTM, |
| LMMinLSTMArgs, |
| ) |
| from apps.fastRNN.hawk.hawk import ( |
| LMHawk, |
| LMHawkArgs, |
| ) |
|
|
| from lingua.probe import AutoProbeD |
| from lingua.stool import StoolArgs, launch_job |
|
|
| import wandb |
|
|
| logger = logging.getLogger() |
|
|
|
|
| @dataclass |
| class TrainArgs: |
| name: str = "lingua" |
| dump_dir: str = "" |
|
|
| seed: int = 42 |
|
|
| |
| |
| grad_acc_steps: int = 1 |
|
|
| gc_collect_freq: int = 1000 |
| probe_freq: Optional[int] = None |
|
|
| |
| steps: int = 1000 |
|
|
| data: DataArgs = field(default_factory=DataArgs) |
| optim: OptimArgs = field(default_factory=OptimArgs) |
| |
| model_type: str = "minGRU" |
| model: dict = field(default_factory=lambda: asdict(LMMinGRUArgs())) |
| distributed: DistributedArgs = field(default_factory=DistributedArgs) |
| env: EnvironmentArgs = field(default_factory=EnvironmentArgs) |
|
|
| checkpoint: CheckpointArgs = field(default_factory=CheckpointArgs) |
| profiling: ProfilerArgs = field(default_factory=ProfilerArgs) |
| logging: LoggingArgs = field(default_factory=LoggingArgs) |
|
|
| |
| async_eval_gpus: Optional[int] = None |
| eval: Optional[Any] = None |
|
|
|
|
| @dataclass |
| class TrainState(Stateful): |
| step: int |
| acc_step: int |
| scheduler: lr_scheduler.LambdaLR |
| data_loader_state: PackTokensState |
|
|
| def state_dict(self) -> Dict[str, Any]: |
| return { |
| "step": self.step, |
| "acc_step": self.acc_step, |
| "data_loader_state": self.data_loader_state, |
| "scheduler": self.scheduler.state_dict(), |
| } |
|
|
| def load_state_dict(self, state_dict): |
| self.step = state_dict["step"] |
| self.acc_step = state_dict["acc_step"] |
| self.data_loader_state = PackTokensState(**state_dict["data_loader_state"]) |
| self.scheduler.load_state_dict(state_dict["scheduler"]) |
|
|
|
|
| def validate_train_args(args: TrainArgs, output_size: int): |
| if args.model.vocab_size < 0: |
| logger.info(f"Setting model output size to {output_size}") |
| args.model.vocab_size = output_size |
| assert ( |
| args.model.vocab_size == output_size |
| ), "Vocab size should be the same as output size" |
|
|
| assert args.dump_dir, "Dump dir not set" |
|
|
| if args.checkpoint.path is None: |
| logger.info(f"Setting checkpoint path to {str(Path(args.dump_dir) / 'checkpoints')}") |
| args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints") |
|
|
| for source in args.data.sources: |
| data_path = os.path.join(args.data.root_dir, source) |
| assert os.path.exists(data_path), f"{data_path} doesn't exist" |
|
|
| if ( |
| args.distributed.dp_replicate |
| * args.distributed.dp_shard |
| * args.distributed.tp_size |
| != get_world_size() |
| ): |
| assert get_world_size() % args.distributed.dp_shard == 0 |
| args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard |
|
|
| assert args.distributed.dp_replicate % args.distributed.tp_size == 0 |
| args.distributed.dp_replicate = ( |
| args.distributed.dp_replicate // args.distributed.tp_size |
| ) |
|
|
| logger.warning( |
| f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}" |
| ) |
| assert ( |
| args.distributed.dp_replicate |
| * args.distributed.dp_shard |
| * args.distributed.tp_size |
| == get_world_size() |
| ) |
|
|
| if args.distributed.fsdp_type == "no_shard": |
| assert ( |
| args.distributed.dp_shard == 1 |
| and args.distributed.dp_replicate == get_world_size() |
| ) |
|
|
| args.model.max_seqlen = args.data.seq_len |
|
|
| if args.distributed.tp_size == 1: |
| logger.warning( |
| "Tensor parallelism has not been tested for a while, use at your own risk" |
| ) |
|
|
| assert ( |
| args.probe_freq != args.profiling.mem_steps |
| ), "Don't profile during probe step" |
| assert ( |
| args.probe_freq != args.profiling.profile_steps |
| ), "Don't profile during probe step" |
| if args.logging.wandb is not None: |
| args.logging.wandb.name = args.name |
|
|
| if args.probe_freq is not None: |
| assert ( |
| args.distributed.tp_size == 1 |
| ), "Probing not supported with tensor parallelism" |
| assert ( |
| args.distributed.selective_activation_checkpointing is False |
| ), "Probing not supported with selective activation checkpointing" |
|
|
|
|
| preemption_flag = dict(flag=False) |
|
|
|
|
| def set_preemption_flag(signum, frame): |
| logger.warning("Signal handler called with signal " + str(signum)) |
| logger.warning("Preemption ! checkpointing asap and exiting.") |
| preemption_flag["flag"] = True |
|
|
|
|
| def every_n_steps(train_state, freq, acc_step=None, acc_freq=None): |
| test = train_state.step % freq == 0 |
| if acc_step is not None: |
| test = test and (train_state.acc_step == acc_step) |
| elif acc_freq is not None: |
| test = test and ((train_state.acc_step % acc_freq) == 0) |
| return test |
|
|
|
|
| def train(args: TrainArgs): |
| with ExitStack() as context_stack: |
| tokenizer = build_tokenizer(args.data.tokenizer.name, args.data.tokenizer.path) |
| validate_train_args( |
| args, |
| tokenizer.n_words, |
| ) |
| dump_args = deepcopy(args) |
| dump_args.model = asdict(dump_args.model) |
| if get_is_master(): |
| os.makedirs(args.dump_dir, exist_ok=True) |
| dump_config(dump_args, Path(args.dump_dir) / "config.yaml") |
| init_logger(Path(args.dump_dir) / "train.log") |
| init_signal_handler(set_preemption_flag) |
| setup_env(args.env) |
| setup_torch_distributed(args.distributed) |
| world_mesh = get_device_mesh(args.distributed) |
| logger.info(f"Starting job: {args.name}") |
|
|
| |
| |
| dp_mesh = world_mesh["dp_replicate"] |
| dp_degree = dp_mesh.size() |
| dp_rank = dp_mesh.get_local_rank() |
| if args.distributed.dp_shard > 1: |
| dp_rank = dp_rank * world_mesh["dp_shard"].size() + world_mesh["dp_shard"].get_local_rank() |
| dp_degree *= world_mesh["dp_shard"].size() |
|
|
| logger.info(f"Running on dp rank : {dp_rank}") |
| logger.info(f"Running on dp size : {dp_degree}") |
|
|
| torch.manual_seed(args.seed) |
| logger.info(f"Building model") |
|
|
| |
| with torch.device("meta"): |
| if args.model_type == "minGRU": |
| model = LMMinGRU(args.model) |
| elif args.model_type == "minLSTM": |
| model = LMMinLSTM(args.model) |
| elif args.model_type == "hawk": |
| model = LMHawk(args.model) |
| else: |
| raise ValueError(f"Model type {args.model_type} not recognized") |
|
|
| logger.info(f"Model is built !") |
|
|
| model_param_count = get_num_params(model) |
|
|
| model = parallelize_model( |
| model, |
| world_mesh, |
| args.model, |
| args.distributed, |
| fsdp_grouping_plan=None, |
| tp_parallelize=None, |
| no_recompute_ops=model._get_no_recompute_ops(), |
| ) |
|
|
| |
| |
| model = model.to_empty(device="cuda") |
| |
| |
| |
|
|
| if args.checkpoint.init_ckpt_path: |
| logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") |
| load_from_checkpoint(args.checkpoint.init_ckpt_path, model, model_key="model") |
| else: |
| with torch.random.fork_rng(devices=[torch.cuda.current_device()]): |
| torch.manual_seed(args.model.seed) |
| model.init_weights() |
| check_model_value_range(model, range=10.0, std=1.0) |
|
|
| |
|
|
| logger.info(f"Model size: {model_param_count:,} total parameters") |
|
|
| gpu_memory_monitor = GPUMemoryMonitor("cuda") |
| logger.info( |
| f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) " |
| f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory" |
| ) |
| logger.info(f"GPU memory usage: {gpu_memory_monitor}") |
|
|
| |
| optimizer, scheduler = build_optimizer(model, args.optim, args.steps) |
| data_loader_state = init_dataloader_state_from_args( |
| args.data, dp_rank, dp_degree |
| ) |
|
|
| train_state = TrainState( |
| step=0, |
| acc_step=0, |
| data_loader_state=data_loader_state, |
| scheduler=scheduler, |
| ) |
|
|
| checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint) |
| checkpoint.load(model, optimizer, train_state, world_mesh) |
| |
| if args.probe_freq is not None: |
| if get_is_master(): |
| os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True) |
| torch.distributed.barrier() |
| probe = AutoProbeD( |
| model, |
| ( |
| Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl" |
| if (dp_rank % 128 == 0) |
| else None |
| ), |
| ) |
|
|
| gc.disable() |
|
|
| |
| model.train() |
| metric_logger = context_stack.enter_context( |
| MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args) |
| ) |
| data_loader = context_stack.enter_context( |
| build_dataloader_from_args( |
| args.data, |
| state=train_state.data_loader_state, |
| ) |
| ) |
| torch_profiler = context_stack.enter_context( |
| maybe_run_profiler(args.dump_dir, model, args.profiling) |
| ) |
|
|
| nwords_since_last_log = 0 |
| time_last_log = timer() |
| gc.collect() |
| while train_state.step < args.steps: |
| |
| train_state.acc_step += 1 |
| train_state.acc_step = train_state.acc_step % args.grad_acc_steps |
|
|
| |
| curr_lr = float(optimizer.param_groups[0]["lr"]) |
| data_load_start = timer() |
| batch, train_state.data_loader_state = next(data_loader) |
| batch = torch.tensor( |
| batch, |
| dtype=torch.long, |
| ) |
|
|
| if every_n_steps(train_state, args.gc_collect_freq, acc_step=0): |
| logger.info("garbage collection") |
| |
| |
| gc.collect() |
|
|
| input_ids = batch[:, :, 0].cuda() |
| labels = batch[:, :, 1].cuda() |
| data_load_time = round(timer() - data_load_start, 4) |
| nwords_since_last_log += input_ids.numel() |
|
|
| bsz, seqlen = labels.shape |
|
|
| |
| start_timer = torch.cuda.Event(enable_timing=True) |
| end_timer = torch.cuda.Event(enable_timing=True) |
| start_timer.record() |
|
|
| |
| |
| |
| |
| if (args.probe_freq is not None) and every_n_steps( |
| train_state, args.probe_freq, acc_step=1 % args.grad_acc_steps |
| ): |
| |
| |
| |
| assert ( |
| next(model.parameters()).grad is None |
| ), "Can't probe model if grads are not reset" |
|
|
| with probe: |
| probe.metadata = { |
| "it": train_state.step, |
| "global_step": train_state.step, |
| "loop": "lingua", |
| } |
| |
| |
| probe_bsz = max(1, bsz // 2) |
| probe_seq = seqlen if (bsz // 2 >= 1) else (seqlen // 2) |
| probe_loss = model( |
| input_ids[:probe_bsz, :probe_seq], |
| labels[:probe_bsz, :probe_seq], |
| ) |
| probe_loss.backward() |
| |
| optimizer.zero_grad() |
|
|
| assert ( |
| next(model.parameters()).grad is None |
| ), "Probe model shouldn't have grads at this point" |
|
|
| loss = model(input_ids, labels) |
|
|
| if args.grad_acc_steps > 1: |
| model.set_requires_gradient_sync(train_state.acc_step == 0) |
|
|
| |
| |
| loss = loss / args.grad_acc_steps |
| |
| loss.backward() |
| |
| loss = loss.detach() * args.grad_acc_steps |
|
|
| |
| grad_norm = -1.0 |
| if train_state.acc_step == 0: |
| |
| grad_norm = torch.nn.utils.clip_grad_norm_( |
| model.parameters(), max_norm=args.optim.clip, foreach=True |
| ) |
|
|
| grad_norm = ( |
| grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm |
| ).item() |
|
|
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| train_state.step += 1 |
|
|
| |
| |
| end_timer.record() |
|
|
| torch.cuda.synchronize() |
|
|
| curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4) |
|
|
| |
| if torch_profiler: |
| xformers.profiler.step() |
|
|
| |
| if every_n_steps( |
| train_state, |
| args.logging.freq, |
| acc_step=None if args.logging.acc_freq else 0, |
| acc_freq=args.logging.acc_freq, |
| ): |
| time_delta = timer() - time_last_log |
| wps = nwords_since_last_log / (time_delta * args.distributed.tp_size) |
|
|
| gpu_mem_stats = gpu_memory_monitor.get_peak_stats() |
|
|
| total_acc_steps = ( |
| args.grad_acc_steps * train_state.step + train_state.acc_step |
| ) |
| tokens_per_gpu = ( |
| total_acc_steps * args.data.batch_size * args.data.seq_len |
| ) |
| total_tokens = dp_degree * tokens_per_gpu |
| |
| |
| |
| FLOPS = 0 |
| |
| metrics = flatten_dict( |
| { |
| "global_step": train_state.step, |
| "acc_step": train_state.acc_step, |
| "speed": { |
| "wps": wps, |
| "FLOPS": FLOPS, |
| "curr_iter_time": curr_iter_time, |
| "data_load_time": data_load_time, |
| }, |
| "optim": { |
| "grad_norm": grad_norm, |
| "lr": curr_lr, |
| "total_tokens": total_tokens, |
| }, |
| "memory": gpu_mem_stats._asdict(), |
| }, |
| sep="/", |
| ) |
|
|
| to_sync = {} |
| to_sync["loss/out"] = loss.item() |
| metrics.update(dist_mean_dict(to_sync)) |
|
|
| if get_is_master(): |
| metric_logger.log(metrics) |
|
|
| gpu_memory_monitor.reset_peak_stats() |
| nwords_since_last_log = 0 |
| time_last_log = timer() |
| logger.info( |
| f"step: {train_state.step}" |
| f" acc: {train_state.acc_step}" |
| f" loss: {round(loss.item(),4):>7}" |
| f" grad: {grad_norm:.2e}" |
| f" flops: {FLOPS:.2e}" |
| f" wps: {wps:.2e}" |
| f" iter: {curr_iter_time:>7}" |
| f" data: {data_load_time:>5}" |
| f" lr: {curr_lr:.2e}" |
| f" mem: {gpu_mem_stats.max_active_pct:.0f}%" |
| f" pow: {gpu_mem_stats.power_draw/1000} W" |
| ) |
|
|
| saved = False |
| if every_n_steps( |
| train_state, args.checkpoint.dump.every, acc_step=0 |
| ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): |
| saved = checkpoint.save( |
| model, |
| optimizer, |
| train_state, |
| dump_args, |
| device_mesh=world_mesh, |
| ) |
|
|
| if args.eval is not None and every_n_steps( |
| train_state, args.checkpoint.eval.every, acc_step=0 |
| ): |
| from apps.fastRNN.eval import ( |
| launch_eval, |
| EVAL_FOLDER_NAME, |
| EvalArgs, |
| ) |
|
|
| eval_args = dataclass_from_dict(EvalArgs, args.eval) |
|
|
| eval_args.global_step = train_state.step |
| eval_args.ckpt_dir = str(checkpoint.existing_saves[-1]) |
| eval_args.dump_dir = str( |
| os.path.join( |
| args.dump_dir, |
| "evals", |
| EVAL_FOLDER_NAME.format(train_state.step), |
| ) |
| ) |
| eval_args.metric_log_dir = args.dump_dir |
| if args.async_eval_gpus is None: |
| launch_eval(eval_args) |
| elif get_is_master(): |
| if wandb.run is not None and args.logging.wandb is not None: |
| eval_args.wandb = deepcopy(args.logging.wandb) |
| assert args.async_eval_gpus > 0 |
| logger.info(f"Launching evals on {args.async_eval_gpus} gpus") |
| with clean_env(): |
| launch_job( |
| StoolArgs( |
| asdict(eval_args), |
| script="apps.fastRNN.eval", |
| copy_code=False, |
| nodes=args.async_eval_gpus // 8, |
| qos="lowest", |
| ) |
| ) |
|
|
| if preemption_flag["flag"]: |
| if not saved: |
| checkpoint.save( |
| model, |
| optimizer, |
| train_state, |
| dump_args, |
| device_mesh=world_mesh, |
| ) |
| requeue_slurm_job() |
| sys.exit(0) |
|
|
| if not saved: |
| checkpoint.save( |
| model, |
| optimizer, |
| train_state, |
| dump_args, |
| device_mesh=world_mesh, |
| ) |
| gc.collect() |
|
|
|
|
| def main(): |
| """ |
| The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments |
| This accepts arguments as a dot list |
| So if the dataclass looks like |
| |
| @dataclass |
| class DummyArgs: |
| name: str |
| mode: LMMRNNArgs |
| |
| @dataclass |
| class LMRNNArgs: |
| dim: int |
| |
| Then you can pass model.dim=32 to change values in LMRNNArgs |
| or just name=tictac for top level attributes. |
| |
| The behavior here is as follows: |
| 1. We instantiate TrainArgs with its default values |
| 2. We override those default values with the ones in the provided config file |
| 3. We override the result with the additional arguments provided through command line |
| |
| For example, if the config is the following |
| |
| model: |
| dim: 128 |
| n_layers: 4 |
| |
| and you call train.py with train.py model.dim=64 |
| |
| Then the final TrainArgs will have |
| |
| model: |
| dim: 64 |
| n_layers: 4 |
| |
| Plus all the default values in TrainArgs dataclass. |
| """ |
| cli_args = OmegaConf.from_cli() |
| file_cfg = OmegaConf.load(cli_args.config) |
| |
| del cli_args.config |
|
|
| default_cfg = OmegaConf.structured(TrainArgs()) |
| cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) |
| cfg = OmegaConf.to_object(cfg) |
|
|
| if cfg.model_type == "minGRU": |
| cfg.model = dataclass_from_dict(LMMinGRUArgs, cfg.model) |
| elif cfg.model_type == "minLSTM": |
| cfg.model = dataclass_from_dict(LMMinLSTMArgs, cfg.model) |
| elif cfg.model_type == "hawk": |
| cfg.model = dataclass_from_dict(LMHawkArgs, cfg.model) |
| else: |
| raise ValueError(f"Model type {cfg.model_type} not recognized") |
|
|
| train(cfg) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|