| import glob |
| import os |
| import shutil |
| from typing import Optional |
|
|
| import torch.multiprocessing as mp |
| from nemo.core.config import hydra_runner |
| from nemo.utils import logging |
| from nemo.utils.exp_manager import exp_manager |
| from nemo.utils.get_rank import is_global_rank_zero |
| from nemo_aligner.algorithms.supervised import SupervisedTrainer |
| from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel |
| from nemo_aligner.utils.distributed import Timer |
| from nemo_aligner.utils.train_script_utils import ( |
| CustomLoggerWrapper, |
| add_custom_checkpoint_callback, |
| extract_optimizer_scheduler_from_ptl_model, |
| init_distributed, |
| init_peft, |
| init_using_ptl, |
| resolve_and_create_trainer, |
| retrieve_custom_trainer_state_dict, |
| ) |
| from nemo_aligner.utils.utils import load_from_nemo |
| from omegaconf.omegaconf import OmegaConf, open_dict |
| from pytorch_lightning import Trainer |
|
|
| from dataset import LLMJPSFTDataset, build_dataloader, load_datasets |
|
|
| OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) |
| OmegaConf.register_new_resolver("int_div", lambda x, y: x // y, replace=True) |
|
|
| mp.set_start_method("spawn", force=True) |
|
|
|
|
| def get_latest_checkpoint(checkpoints_dir: str) -> Optional[str]: |
| if not os.path.exists(checkpoints_dir): |
| return None |
| checkpoint_dirs: list[str] = [ |
| d for d in os.listdir(checkpoints_dir) if d.startswith("step=") |
| ] |
| if not checkpoint_dirs: |
| return None |
| latest_checkpoint = max(checkpoint_dirs, key=lambda d: int(d.split("=")[1])) |
| return os.path.join(checkpoints_dir, latest_checkpoint) |
|
|
|
|
| def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): |
| """ |
| This function modifies the original gpt pre-training config (gpt_cfg) with attributes from the finetuning config (cfg). |
| The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`. |
| """ |
| OmegaConf.set_struct(gpt_cfg, True) |
| OmegaConf.resolve(cfg) |
| with open_dict(gpt_cfg): |
| gpt_cfg.megatron_amp_O2 = cfg.model.get("megatron_amp_O2", False) |
| gpt_cfg.micro_batch_size = cfg.mbs |
| gpt_cfg.global_batch_size = cfg.gbs |
| gpt_cfg.data = cfg.data |
| gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False) |
| gpt_cfg.activations_checkpoint_granularity = cfg.model.get( |
| "activations_checkpoint_granularity", None |
| ) |
| gpt_cfg.activations_checkpoint_num_layers = cfg.model.get( |
| "activations_checkpoint_num_layers", None |
| ) |
| gpt_cfg.activations_checkpoint_method = cfg.model.get( |
| "activations_checkpoint_method", None |
| ) |
| gpt_cfg.activations_checkpoint_layers_per_pipeline = cfg.model.get( |
| "activations_checkpoint_layers_per_pipeline", None |
| ) |
| gpt_cfg.peft = cfg.model.peft |
| gpt_cfg.optim = cfg.model.optim |
| gpt_cfg.precision = cfg.trainer.precision |
| gpt_cfg.restore_from_path = cfg.model.restore_from_path |
| gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint |
| gpt_cfg.save_nemo_on_validation_end = cfg.model.save_nemo_on_validation_end |
| gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view |
| gpt_cfg.hidden_dropout = cfg.model.get("hidden_dropout", 0.0) |
| gpt_cfg.attention_dropout = cfg.model.get("attention_dropout", 0.0) |
| gpt_cfg.ffn_dropout = cfg.model.ffn_dropout |
| gpt_cfg.use_flash_attention = cfg.model.get("use_flash_attention", False) |
| |
| if cfg.model.get("tensor_model_parallel_size", 1) > 0: |
| gpt_cfg.tensor_model_parallel_size = cfg.model.get( |
| "tensor_model_parallel_size", 1 |
| ) |
| if cfg.model.get("pipeline_model_parallel_size", 1) > 0: |
| gpt_cfg.pipeline_model_parallel_size = cfg.model.get( |
| "pipeline_model_parallel_size", 1 |
| ) |
| gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get( |
| "pipeline_model_parallel_split_rank", 0 |
| ) |
| gpt_cfg.use_loss_mask = cfg.model.use_loss_mask |
|
|
| sft_cls = GPTSFTModel |
| gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}" |
|
|
| if cfg.model.get("use_flash_attention", None) is not None: |
| gpt_cfg.use_flash_attention = cfg.model.use_flash_attention |
|
|
| if cfg.model.get("seq_len_interpolation_factor", None) is not None: |
| gpt_cfg.seq_len_interpolation_factor = ( |
| cfg.model.seq_len_interpolation_factor |
| ) |
|
|
| |
| |
| if add_cfg_to_tree: |
| OmegaConf.resolve(gpt_cfg) |
| gpt_cfg.cfg = gpt_cfg |
|
|
| return gpt_cfg |
|
|
|
|
| @hydra_runner(config_path="configs", config_name="sft") |
| def main(cfg): |
| if is_global_rank_zero(): |
| logging.info("\n\n************** Experiment configuration ***********") |
| logging.info(f"\n{OmegaConf.to_yaml(cfg)}") |
|
|
| if cfg.use_mpi: |
| global_rank = int(os.getenv("OMPI_COMM_WORLD_RANK", 0)) |
| local_rank = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", 0)) |
| world_size = int(os.getenv("OMPI_COMM_WORLD_SIZE", 1)) |
| os.environ["RANK"] = str(global_rank) |
| os.environ["LOCAL_RANK"] = str(local_rank) |
| os.environ["WORLD_SIZE"] = str(world_size) |
| if cfg.use_slurm: |
| os.environ["SLURM_PROCID"] = str(global_rank) |
| os.environ["SLURM_LOCALID"] = str(local_rank) |
| os.environ["SLURM_NTASKS"] = str(world_size) |
| logging.info( |
| f"global_rank: {global_rank}, local_rank: {local_rank}, world_size: {world_size}" |
| ) |
|
|
| trainer: Trainer = resolve_and_create_trainer(cfg, "sft") |
| log_dir = exp_manager(trainer, cfg.exp_manager) |
| logger = CustomLoggerWrapper(trainer.loggers) |
|
|
| |
| with open_dict(cfg): |
| cfg.model.precision = cfg.trainer.precision |
|
|
| ptl_model, updated_cfg = load_from_nemo( |
| GPTSFTModel, |
| model_cfg=cfg, |
| trainer=trainer, |
| strict=True, |
| modify_config_fn=_modify_config, |
| restore_path=cfg.model.restore_from_path, |
| return_updated_cfg=True, |
| ) |
| init_peft(ptl_model, updated_cfg) |
|
|
| latest_checkpoint: Optional[str] = get_latest_checkpoint(f"{log_dir}/checkpoints") |
| if latest_checkpoint is not None: |
| logging.info(f"Resuming from checkpoint: {latest_checkpoint}") |
| custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer) |
| consumed_samples: int = custom_trainer_state_dict["consumed_samples"] |
| else: |
| logging.info("No checkpoint found. Starting from scratch.") |
| custom_trainer_state_dict = None |
| consumed_samples = 0 |
|
|
| |
| if is_global_rank_zero(): |
| updated_config_path: str = f"{log_dir}/checkpoints/model_config.yaml" |
| os.makedirs(os.path.dirname(updated_config_path), exist_ok=True) |
| OmegaConf.save(updated_cfg, updated_config_path) |
|
|
| with open_dict(cfg): |
| |
| cfg.model.encoder_seq_length = ptl_model.cfg.encoder_seq_length |
|
|
| train_examples, dev_examples = load_datasets(cfg) |
|
|
| init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False)) |
|
|
| train_dataset = LLMJPSFTDataset( |
| loaded_examples=train_examples, |
| tokenizer=ptl_model.tokenizer, |
| use_loss_mask=cfg.model.use_loss_mask, |
| max_seq_length=cfg.model.max_seq_length, |
| ) |
| val_dataset = LLMJPSFTDataset( |
| loaded_examples=dev_examples, |
| tokenizer=ptl_model.tokenizer, |
| use_loss_mask=cfg.model.use_loss_mask, |
| max_seq_length=cfg.model.max_seq_length, |
| ) |
|
|
| train_dataloader = build_dataloader( |
| dataset=train_dataset, |
| consumed_samples=consumed_samples, |
| micro_batch_size=cfg.data.train_ds.micro_batch_size, |
| global_batch_size=cfg.data.train_ds.global_batch_size, |
| collate_fn=train_dataset.collate_fn, |
| seed=cfg.seed, |
| ) |
| val_dataloader = build_dataloader( |
| dataset=val_dataset, |
| consumed_samples=0, |
| micro_batch_size=cfg.data.validation_ds.micro_batch_size, |
| global_batch_size=cfg.data.validation_ds.global_batch_size, |
| collate_fn=val_dataset.collate_fn, |
| ) |
|
|
| init_using_ptl(trainer, ptl_model, train_dataloader, train_dataset) |
| optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) |
|
|
| ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) |
|
|
| logger.log_hyperparams(OmegaConf.to_container(cfg)) |
| timer = Timer(cfg.exp_manager.get("max_time_per_run")) |
|
|
| sft_trainer = SupervisedTrainer( |
| cfg=cfg.trainer.sft, |
| model=ptl_model, |
| optimizer=optimizer, |
| scheduler=scheduler, |
| train_dataloader=train_dataloader, |
| val_dataloader=val_dataloader, |
| test_dataloader=None, |
| logger=logger, |
| ckpt_callback=ckpt_callback, |
| run_timer=timer, |
| ) |
|
|
| if custom_trainer_state_dict is not None: |
| sft_trainer.load_state_dict(custom_trainer_state_dict) |
|
|
| sft_trainer.fit() |
|
|
| |
| for optimizer_state_file in glob.glob( |
| f"{log_dir}/checkpoints/step*/optimizer.state.*" |
| ): |
| try: |
| shutil.rmtree(optimizer_state_file) |
| logging.info(f"Deleted directory: {optimizer_state_file}") |
| except OSError as e: |
| logging.error(f"Error: {optimizer_state_file} : {e.strerror}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|