koichi12's picture
Add files using upload-large-folder tool
799d677 verified
import glob
import os
import shutil
from typing import Optional
import torch.multiprocessing as mp
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
from nemo.utils.get_rank import is_global_rank_zero
from nemo_aligner.algorithms.supervised import SupervisedTrainer
from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
add_custom_checkpoint_callback,
extract_optimizer_scheduler_from_ptl_model,
init_distributed,
init_peft,
init_using_ptl,
resolve_and_create_trainer,
retrieve_custom_trainer_state_dict,
)
from nemo_aligner.utils.utils import load_from_nemo
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from dataset import LLMJPSFTDataset, build_dataloader, load_datasets
OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True)
OmegaConf.register_new_resolver("int_div", lambda x, y: x // y, replace=True)
mp.set_start_method("spawn", force=True)
def get_latest_checkpoint(checkpoints_dir: str) -> Optional[str]:
if not os.path.exists(checkpoints_dir):
return None
checkpoint_dirs: list[str] = [
d for d in os.listdir(checkpoints_dir) if d.startswith("step=")
]
if not checkpoint_dirs:
return None
latest_checkpoint = max(checkpoint_dirs, key=lambda d: int(d.split("=")[1]))
return os.path.join(checkpoints_dir, latest_checkpoint)
def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
"""
This function modifies the original gpt pre-training config (gpt_cfg) with attributes from the finetuning config (cfg).
The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`.
"""
OmegaConf.set_struct(gpt_cfg, True)
OmegaConf.resolve(cfg)
with open_dict(gpt_cfg):
gpt_cfg.megatron_amp_O2 = cfg.model.get("megatron_amp_O2", False)
gpt_cfg.micro_batch_size = cfg.mbs
gpt_cfg.global_batch_size = cfg.gbs
gpt_cfg.data = cfg.data
gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False)
gpt_cfg.activations_checkpoint_granularity = cfg.model.get(
"activations_checkpoint_granularity", None
)
gpt_cfg.activations_checkpoint_num_layers = cfg.model.get(
"activations_checkpoint_num_layers", None
)
gpt_cfg.activations_checkpoint_method = cfg.model.get(
"activations_checkpoint_method", None
)
gpt_cfg.activations_checkpoint_layers_per_pipeline = cfg.model.get(
"activations_checkpoint_layers_per_pipeline", None
)
gpt_cfg.peft = cfg.model.peft
gpt_cfg.optim = cfg.model.optim
gpt_cfg.precision = cfg.trainer.precision
gpt_cfg.restore_from_path = cfg.model.restore_from_path
gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint
gpt_cfg.save_nemo_on_validation_end = cfg.model.save_nemo_on_validation_end
gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view
gpt_cfg.hidden_dropout = cfg.model.get("hidden_dropout", 0.0)
gpt_cfg.attention_dropout = cfg.model.get("attention_dropout", 0.0)
gpt_cfg.ffn_dropout = cfg.model.ffn_dropout
gpt_cfg.use_flash_attention = cfg.model.get("use_flash_attention", False)
# if TP/PP size is -1, use default TP/PP size as original model
if cfg.model.get("tensor_model_parallel_size", 1) > 0:
gpt_cfg.tensor_model_parallel_size = cfg.model.get(
"tensor_model_parallel_size", 1
)
if cfg.model.get("pipeline_model_parallel_size", 1) > 0:
gpt_cfg.pipeline_model_parallel_size = cfg.model.get(
"pipeline_model_parallel_size", 1
)
gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get(
"pipeline_model_parallel_split_rank", 0
)
gpt_cfg.use_loss_mask = cfg.model.use_loss_mask
sft_cls = GPTSFTModel
gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}"
if cfg.model.get("use_flash_attention", None) is not None:
gpt_cfg.use_flash_attention = cfg.model.use_flash_attention
if cfg.model.get("seq_len_interpolation_factor", None) is not None:
gpt_cfg.seq_len_interpolation_factor = (
cfg.model.seq_len_interpolation_factor
)
# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
if add_cfg_to_tree:
OmegaConf.resolve(gpt_cfg)
gpt_cfg.cfg = gpt_cfg
return gpt_cfg
@hydra_runner(config_path="configs", config_name="sft")
def main(cfg):
if is_global_rank_zero():
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f"\n{OmegaConf.to_yaml(cfg)}")
if cfg.use_mpi:
global_rank = int(os.getenv("OMPI_COMM_WORLD_RANK", 0))
local_rank = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", 0))
world_size = int(os.getenv("OMPI_COMM_WORLD_SIZE", 1))
os.environ["RANK"] = str(global_rank)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
if cfg.use_slurm:
os.environ["SLURM_PROCID"] = str(global_rank)
os.environ["SLURM_LOCALID"] = str(local_rank)
os.environ["SLURM_NTASKS"] = str(world_size)
logging.info(
f"global_rank: {global_rank}, local_rank: {local_rank}, world_size: {world_size}"
)
trainer: Trainer = resolve_and_create_trainer(cfg, "sft")
log_dir = exp_manager(trainer, cfg.exp_manager)
logger = CustomLoggerWrapper(trainer.loggers)
# hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
with open_dict(cfg):
cfg.model.precision = cfg.trainer.precision
ptl_model, updated_cfg = load_from_nemo(
GPTSFTModel,
model_cfg=cfg,
trainer=trainer,
strict=True,
modify_config_fn=_modify_config,
restore_path=cfg.model.restore_from_path,
return_updated_cfg=True,
)
init_peft(ptl_model, updated_cfg)
latest_checkpoint: Optional[str] = get_latest_checkpoint(f"{log_dir}/checkpoints")
if latest_checkpoint is not None:
logging.info(f"Resuming from checkpoint: {latest_checkpoint}")
custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer)
consumed_samples: int = custom_trainer_state_dict["consumed_samples"]
else:
logging.info("No checkpoint found. Starting from scratch.")
custom_trainer_state_dict = None
consumed_samples = 0
# save the updated config to the log directory
if is_global_rank_zero():
updated_config_path: str = f"{log_dir}/checkpoints/model_config.yaml"
os.makedirs(os.path.dirname(updated_config_path), exist_ok=True)
OmegaConf.save(updated_cfg, updated_config_path)
with open_dict(cfg):
# overwrite the model config with the config from the checkpoint
cfg.model.encoder_seq_length = ptl_model.cfg.encoder_seq_length
train_examples, dev_examples = load_datasets(cfg)
init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False))
train_dataset = LLMJPSFTDataset(
loaded_examples=train_examples,
tokenizer=ptl_model.tokenizer,
use_loss_mask=cfg.model.use_loss_mask,
max_seq_length=cfg.model.max_seq_length,
)
val_dataset = LLMJPSFTDataset(
loaded_examples=dev_examples,
tokenizer=ptl_model.tokenizer,
use_loss_mask=cfg.model.use_loss_mask,
max_seq_length=cfg.model.max_seq_length,
)
train_dataloader = build_dataloader(
dataset=train_dataset,
consumed_samples=consumed_samples,
micro_batch_size=cfg.data.train_ds.micro_batch_size,
global_batch_size=cfg.data.train_ds.global_batch_size,
collate_fn=train_dataset.collate_fn,
seed=cfg.seed,
)
val_dataloader = build_dataloader(
dataset=val_dataset,
consumed_samples=0,
micro_batch_size=cfg.data.validation_ds.micro_batch_size,
global_batch_size=cfg.data.validation_ds.global_batch_size,
collate_fn=val_dataset.collate_fn,
)
init_using_ptl(trainer, ptl_model, train_dataloader, train_dataset)
optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model)
ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model)
logger.log_hyperparams(OmegaConf.to_container(cfg))
timer = Timer(cfg.exp_manager.get("max_time_per_run"))
sft_trainer = SupervisedTrainer(
cfg=cfg.trainer.sft,
model=ptl_model,
optimizer=optimizer,
scheduler=scheduler,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
test_dataloader=None,
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
)
if custom_trainer_state_dict is not None:
sft_trainer.load_state_dict(custom_trainer_state_dict)
sft_trainer.fit()
# remove optimizer state files
for optimizer_state_file in glob.glob(
f"{log_dir}/checkpoints/step*/optimizer.state.*"
):
try:
shutil.rmtree(optimizer_state_file)
logging.info(f"Deleted directory: {optimizer_state_file}")
except OSError as e:
logging.error(f"Error: {optimizer_state_file} : {e.strerror}")
if __name__ == "__main__":
main()