| | import glob |
| | import os |
| | import shutil |
| | from typing import Optional |
| |
|
| | import torch.multiprocessing as mp |
| | from nemo.core.config import hydra_runner |
| | from nemo.utils import logging |
| | from nemo.utils.exp_manager import exp_manager |
| | from nemo.utils.get_rank import is_global_rank_zero |
| | from nemo_aligner.algorithms.supervised import SupervisedTrainer |
| | from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel |
| | from nemo_aligner.utils.distributed import Timer |
| | from nemo_aligner.utils.train_script_utils import ( |
| | CustomLoggerWrapper, |
| | add_custom_checkpoint_callback, |
| | extract_optimizer_scheduler_from_ptl_model, |
| | init_distributed, |
| | init_peft, |
| | init_using_ptl, |
| | resolve_and_create_trainer, |
| | retrieve_custom_trainer_state_dict, |
| | ) |
| | from nemo_aligner.utils.utils import load_from_nemo |
| | from omegaconf.omegaconf import OmegaConf, open_dict |
| | from pytorch_lightning import Trainer |
| |
|
| | from dataset import LLMJPSFTDataset, build_dataloader, load_datasets |
| |
|
| | OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) |
| | OmegaConf.register_new_resolver("int_div", lambda x, y: x // y, replace=True) |
| |
|
| | mp.set_start_method("spawn", force=True) |
| |
|
| |
|
| | def get_latest_checkpoint(checkpoints_dir: str) -> Optional[str]: |
| | if not os.path.exists(checkpoints_dir): |
| | return None |
| | checkpoint_dirs: list[str] = [ |
| | d for d in os.listdir(checkpoints_dir) if d.startswith("step=") |
| | ] |
| | if not checkpoint_dirs: |
| | return None |
| | latest_checkpoint = max(checkpoint_dirs, key=lambda d: int(d.split("=")[1])) |
| | return os.path.join(checkpoints_dir, latest_checkpoint) |
| |
|
| |
|
| | def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): |
| | """ |
| | This function modifies the original gpt pre-training config (gpt_cfg) with attributes from the finetuning config (cfg). |
| | The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`. |
| | """ |
| | OmegaConf.set_struct(gpt_cfg, True) |
| | OmegaConf.resolve(cfg) |
| | with open_dict(gpt_cfg): |
| | gpt_cfg.megatron_amp_O2 = cfg.model.get("megatron_amp_O2", False) |
| | gpt_cfg.micro_batch_size = cfg.mbs |
| | gpt_cfg.global_batch_size = cfg.gbs |
| | gpt_cfg.data = cfg.data |
| | gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False) |
| | gpt_cfg.activations_checkpoint_granularity = cfg.model.get( |
| | "activations_checkpoint_granularity", None |
| | ) |
| | gpt_cfg.activations_checkpoint_num_layers = cfg.model.get( |
| | "activations_checkpoint_num_layers", None |
| | ) |
| | gpt_cfg.activations_checkpoint_method = cfg.model.get( |
| | "activations_checkpoint_method", None |
| | ) |
| | gpt_cfg.activations_checkpoint_layers_per_pipeline = cfg.model.get( |
| | "activations_checkpoint_layers_per_pipeline", None |
| | ) |
| | gpt_cfg.peft = cfg.model.peft |
| | gpt_cfg.optim = cfg.model.optim |
| | gpt_cfg.precision = cfg.trainer.precision |
| | gpt_cfg.restore_from_path = cfg.model.restore_from_path |
| | gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint |
| | gpt_cfg.save_nemo_on_validation_end = cfg.model.save_nemo_on_validation_end |
| | gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view |
| | gpt_cfg.hidden_dropout = cfg.model.get("hidden_dropout", 0.0) |
| | gpt_cfg.attention_dropout = cfg.model.get("attention_dropout", 0.0) |
| | gpt_cfg.ffn_dropout = cfg.model.ffn_dropout |
| | gpt_cfg.use_flash_attention = cfg.model.get("use_flash_attention", False) |
| | |
| | if cfg.model.get("tensor_model_parallel_size", 1) > 0: |
| | gpt_cfg.tensor_model_parallel_size = cfg.model.get( |
| | "tensor_model_parallel_size", 1 |
| | ) |
| | if cfg.model.get("pipeline_model_parallel_size", 1) > 0: |
| | gpt_cfg.pipeline_model_parallel_size = cfg.model.get( |
| | "pipeline_model_parallel_size", 1 |
| | ) |
| | gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get( |
| | "pipeline_model_parallel_split_rank", 0 |
| | ) |
| | gpt_cfg.use_loss_mask = cfg.model.use_loss_mask |
| |
|
| | sft_cls = GPTSFTModel |
| | gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}" |
| |
|
| | if cfg.model.get("use_flash_attention", None) is not None: |
| | gpt_cfg.use_flash_attention = cfg.model.use_flash_attention |
| |
|
| | if cfg.model.get("seq_len_interpolation_factor", None) is not None: |
| | gpt_cfg.seq_len_interpolation_factor = ( |
| | cfg.model.seq_len_interpolation_factor |
| | ) |
| |
|
| | |
| | |
| | if add_cfg_to_tree: |
| | OmegaConf.resolve(gpt_cfg) |
| | gpt_cfg.cfg = gpt_cfg |
| |
|
| | return gpt_cfg |
| |
|
| |
|
| | @hydra_runner(config_path="configs", config_name="sft") |
| | def main(cfg): |
| | if is_global_rank_zero(): |
| | logging.info("\n\n************** Experiment configuration ***********") |
| | logging.info(f"\n{OmegaConf.to_yaml(cfg)}") |
| |
|
| | if cfg.use_mpi: |
| | global_rank = int(os.getenv("OMPI_COMM_WORLD_RANK", 0)) |
| | local_rank = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", 0)) |
| | world_size = int(os.getenv("OMPI_COMM_WORLD_SIZE", 1)) |
| | os.environ["RANK"] = str(global_rank) |
| | os.environ["LOCAL_RANK"] = str(local_rank) |
| | os.environ["WORLD_SIZE"] = str(world_size) |
| | if cfg.use_slurm: |
| | os.environ["SLURM_PROCID"] = str(global_rank) |
| | os.environ["SLURM_LOCALID"] = str(local_rank) |
| | os.environ["SLURM_NTASKS"] = str(world_size) |
| | logging.info( |
| | f"global_rank: {global_rank}, local_rank: {local_rank}, world_size: {world_size}" |
| | ) |
| |
|
| | trainer: Trainer = resolve_and_create_trainer(cfg, "sft") |
| | log_dir = exp_manager(trainer, cfg.exp_manager) |
| | logger = CustomLoggerWrapper(trainer.loggers) |
| |
|
| | |
| | with open_dict(cfg): |
| | cfg.model.precision = cfg.trainer.precision |
| |
|
| | ptl_model, updated_cfg = load_from_nemo( |
| | GPTSFTModel, |
| | model_cfg=cfg, |
| | trainer=trainer, |
| | strict=True, |
| | modify_config_fn=_modify_config, |
| | restore_path=cfg.model.restore_from_path, |
| | return_updated_cfg=True, |
| | ) |
| | init_peft(ptl_model, updated_cfg) |
| |
|
| | latest_checkpoint: Optional[str] = get_latest_checkpoint(f"{log_dir}/checkpoints") |
| | if latest_checkpoint is not None: |
| | logging.info(f"Resuming from checkpoint: {latest_checkpoint}") |
| | custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer) |
| | consumed_samples: int = custom_trainer_state_dict["consumed_samples"] |
| | else: |
| | logging.info("No checkpoint found. Starting from scratch.") |
| | custom_trainer_state_dict = None |
| | consumed_samples = 0 |
| |
|
| | |
| | if is_global_rank_zero(): |
| | updated_config_path: str = f"{log_dir}/checkpoints/model_config.yaml" |
| | os.makedirs(os.path.dirname(updated_config_path), exist_ok=True) |
| | OmegaConf.save(updated_cfg, updated_config_path) |
| |
|
| | with open_dict(cfg): |
| | |
| | cfg.model.encoder_seq_length = ptl_model.cfg.encoder_seq_length |
| |
|
| | train_examples, dev_examples = load_datasets(cfg) |
| |
|
| | init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False)) |
| |
|
| | train_dataset = LLMJPSFTDataset( |
| | loaded_examples=train_examples, |
| | tokenizer=ptl_model.tokenizer, |
| | use_loss_mask=cfg.model.use_loss_mask, |
| | max_seq_length=cfg.model.max_seq_length, |
| | ) |
| | val_dataset = LLMJPSFTDataset( |
| | loaded_examples=dev_examples, |
| | tokenizer=ptl_model.tokenizer, |
| | use_loss_mask=cfg.model.use_loss_mask, |
| | max_seq_length=cfg.model.max_seq_length, |
| | ) |
| |
|
| | train_dataloader = build_dataloader( |
| | dataset=train_dataset, |
| | consumed_samples=consumed_samples, |
| | micro_batch_size=cfg.data.train_ds.micro_batch_size, |
| | global_batch_size=cfg.data.train_ds.global_batch_size, |
| | collate_fn=train_dataset.collate_fn, |
| | seed=cfg.seed, |
| | ) |
| | val_dataloader = build_dataloader( |
| | dataset=val_dataset, |
| | consumed_samples=0, |
| | micro_batch_size=cfg.data.validation_ds.micro_batch_size, |
| | global_batch_size=cfg.data.validation_ds.global_batch_size, |
| | collate_fn=val_dataset.collate_fn, |
| | ) |
| |
|
| | init_using_ptl(trainer, ptl_model, train_dataloader, train_dataset) |
| | optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) |
| |
|
| | ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) |
| |
|
| | logger.log_hyperparams(OmegaConf.to_container(cfg)) |
| | timer = Timer(cfg.exp_manager.get("max_time_per_run")) |
| |
|
| | sft_trainer = SupervisedTrainer( |
| | cfg=cfg.trainer.sft, |
| | model=ptl_model, |
| | optimizer=optimizer, |
| | scheduler=scheduler, |
| | train_dataloader=train_dataloader, |
| | val_dataloader=val_dataloader, |
| | test_dataloader=None, |
| | logger=logger, |
| | ckpt_callback=ckpt_callback, |
| | run_timer=timer, |
| | ) |
| |
|
| | if custom_trainer_state_dict is not None: |
| | sft_trainer.load_state_dict(custom_trainer_state_dict) |
| |
|
| | sft_trainer.fit() |
| |
|
| | |
| | for optimizer_state_file in glob.glob( |
| | f"{log_dir}/checkpoints/step*/optimizer.state.*" |
| | ): |
| | try: |
| | shutil.rmtree(optimizer_state_file) |
| | logging.info(f"Deleted directory: {optimizer_state_file}") |
| | except OSError as e: |
| | logging.error(f"Error: {optimizer_state_file} : {e.strerror}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|