|
|
import glob |
|
|
import os |
|
|
import shutil |
|
|
from typing import Optional |
|
|
|
|
|
import torch.multiprocessing as mp |
|
|
from nemo.core.config import hydra_runner |
|
|
from nemo.utils import logging |
|
|
from nemo.utils.exp_manager import exp_manager |
|
|
from nemo.utils.get_rank import is_global_rank_zero |
|
|
from nemo_aligner.algorithms.supervised import SupervisedTrainer |
|
|
from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel |
|
|
from nemo_aligner.utils.distributed import Timer |
|
|
from nemo_aligner.utils.train_script_utils import ( |
|
|
CustomLoggerWrapper, |
|
|
add_custom_checkpoint_callback, |
|
|
extract_optimizer_scheduler_from_ptl_model, |
|
|
init_distributed, |
|
|
init_peft, |
|
|
init_using_ptl, |
|
|
resolve_and_create_trainer, |
|
|
retrieve_custom_trainer_state_dict, |
|
|
) |
|
|
from nemo_aligner.utils.utils import load_from_nemo |
|
|
from omegaconf.omegaconf import OmegaConf, open_dict |
|
|
from pytorch_lightning import Trainer |
|
|
|
|
|
from dataset import LLMJPSFTDataset, build_dataloader, load_datasets |
|
|
|
|
|
OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) |
|
|
OmegaConf.register_new_resolver("int_div", lambda x, y: x // y, replace=True) |
|
|
|
|
|
mp.set_start_method("spawn", force=True) |
|
|
|
|
|
|
|
|
def get_latest_checkpoint(checkpoints_dir: str) -> Optional[str]: |
|
|
if not os.path.exists(checkpoints_dir): |
|
|
return None |
|
|
checkpoint_dirs: list[str] = [ |
|
|
d for d in os.listdir(checkpoints_dir) if d.startswith("step=") |
|
|
] |
|
|
if not checkpoint_dirs: |
|
|
return None |
|
|
latest_checkpoint = max(checkpoint_dirs, key=lambda d: int(d.split("=")[1])) |
|
|
return os.path.join(checkpoints_dir, latest_checkpoint) |
|
|
|
|
|
|
|
|
def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): |
|
|
""" |
|
|
This function modifies the original gpt pre-training config (gpt_cfg) with attributes from the finetuning config (cfg). |
|
|
The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`. |
|
|
""" |
|
|
OmegaConf.set_struct(gpt_cfg, True) |
|
|
OmegaConf.resolve(cfg) |
|
|
with open_dict(gpt_cfg): |
|
|
gpt_cfg.megatron_amp_O2 = cfg.model.get("megatron_amp_O2", False) |
|
|
gpt_cfg.micro_batch_size = cfg.mbs |
|
|
gpt_cfg.global_batch_size = cfg.gbs |
|
|
gpt_cfg.data = cfg.data |
|
|
gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False) |
|
|
gpt_cfg.activations_checkpoint_granularity = cfg.model.get( |
|
|
"activations_checkpoint_granularity", None |
|
|
) |
|
|
gpt_cfg.activations_checkpoint_num_layers = cfg.model.get( |
|
|
"activations_checkpoint_num_layers", None |
|
|
) |
|
|
gpt_cfg.activations_checkpoint_method = cfg.model.get( |
|
|
"activations_checkpoint_method", None |
|
|
) |
|
|
gpt_cfg.activations_checkpoint_layers_per_pipeline = cfg.model.get( |
|
|
"activations_checkpoint_layers_per_pipeline", None |
|
|
) |
|
|
gpt_cfg.peft = cfg.model.peft |
|
|
gpt_cfg.optim = cfg.model.optim |
|
|
gpt_cfg.precision = cfg.trainer.precision |
|
|
gpt_cfg.restore_from_path = cfg.model.restore_from_path |
|
|
gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint |
|
|
gpt_cfg.save_nemo_on_validation_end = cfg.model.save_nemo_on_validation_end |
|
|
gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view |
|
|
gpt_cfg.hidden_dropout = cfg.model.get("hidden_dropout", 0.0) |
|
|
gpt_cfg.attention_dropout = cfg.model.get("attention_dropout", 0.0) |
|
|
gpt_cfg.ffn_dropout = cfg.model.ffn_dropout |
|
|
gpt_cfg.use_flash_attention = cfg.model.get("use_flash_attention", False) |
|
|
|
|
|
if cfg.model.get("tensor_model_parallel_size", 1) > 0: |
|
|
gpt_cfg.tensor_model_parallel_size = cfg.model.get( |
|
|
"tensor_model_parallel_size", 1 |
|
|
) |
|
|
if cfg.model.get("pipeline_model_parallel_size", 1) > 0: |
|
|
gpt_cfg.pipeline_model_parallel_size = cfg.model.get( |
|
|
"pipeline_model_parallel_size", 1 |
|
|
) |
|
|
gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get( |
|
|
"pipeline_model_parallel_split_rank", 0 |
|
|
) |
|
|
gpt_cfg.use_loss_mask = cfg.model.use_loss_mask |
|
|
|
|
|
sft_cls = GPTSFTModel |
|
|
gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}" |
|
|
|
|
|
if cfg.model.get("use_flash_attention", None) is not None: |
|
|
gpt_cfg.use_flash_attention = cfg.model.use_flash_attention |
|
|
|
|
|
if cfg.model.get("seq_len_interpolation_factor", None) is not None: |
|
|
gpt_cfg.seq_len_interpolation_factor = ( |
|
|
cfg.model.seq_len_interpolation_factor |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if add_cfg_to_tree: |
|
|
OmegaConf.resolve(gpt_cfg) |
|
|
gpt_cfg.cfg = gpt_cfg |
|
|
|
|
|
return gpt_cfg |
|
|
|
|
|
|
|
|
@hydra_runner(config_path="configs", config_name="sft") |
|
|
def main(cfg): |
|
|
if is_global_rank_zero(): |
|
|
logging.info("\n\n************** Experiment configuration ***********") |
|
|
logging.info(f"\n{OmegaConf.to_yaml(cfg)}") |
|
|
|
|
|
if cfg.use_mpi: |
|
|
global_rank = int(os.getenv("OMPI_COMM_WORLD_RANK", 0)) |
|
|
local_rank = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", 0)) |
|
|
world_size = int(os.getenv("OMPI_COMM_WORLD_SIZE", 1)) |
|
|
os.environ["RANK"] = str(global_rank) |
|
|
os.environ["LOCAL_RANK"] = str(local_rank) |
|
|
os.environ["WORLD_SIZE"] = str(world_size) |
|
|
if cfg.use_slurm: |
|
|
os.environ["SLURM_PROCID"] = str(global_rank) |
|
|
os.environ["SLURM_LOCALID"] = str(local_rank) |
|
|
os.environ["SLURM_NTASKS"] = str(world_size) |
|
|
logging.info( |
|
|
f"global_rank: {global_rank}, local_rank: {local_rank}, world_size: {world_size}" |
|
|
) |
|
|
|
|
|
trainer: Trainer = resolve_and_create_trainer(cfg, "sft") |
|
|
log_dir = exp_manager(trainer, cfg.exp_manager) |
|
|
logger = CustomLoggerWrapper(trainer.loggers) |
|
|
|
|
|
|
|
|
with open_dict(cfg): |
|
|
cfg.model.precision = cfg.trainer.precision |
|
|
|
|
|
ptl_model, updated_cfg = load_from_nemo( |
|
|
GPTSFTModel, |
|
|
model_cfg=cfg, |
|
|
trainer=trainer, |
|
|
strict=True, |
|
|
modify_config_fn=_modify_config, |
|
|
restore_path=cfg.model.restore_from_path, |
|
|
return_updated_cfg=True, |
|
|
) |
|
|
init_peft(ptl_model, updated_cfg) |
|
|
|
|
|
latest_checkpoint: Optional[str] = get_latest_checkpoint(f"{log_dir}/checkpoints") |
|
|
if latest_checkpoint is not None: |
|
|
logging.info(f"Resuming from checkpoint: {latest_checkpoint}") |
|
|
custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer) |
|
|
consumed_samples: int = custom_trainer_state_dict["consumed_samples"] |
|
|
else: |
|
|
logging.info("No checkpoint found. Starting from scratch.") |
|
|
custom_trainer_state_dict = None |
|
|
consumed_samples = 0 |
|
|
|
|
|
|
|
|
if is_global_rank_zero(): |
|
|
updated_config_path: str = f"{log_dir}/checkpoints/model_config.yaml" |
|
|
os.makedirs(os.path.dirname(updated_config_path), exist_ok=True) |
|
|
OmegaConf.save(updated_cfg, updated_config_path) |
|
|
|
|
|
with open_dict(cfg): |
|
|
|
|
|
cfg.model.encoder_seq_length = ptl_model.cfg.encoder_seq_length |
|
|
|
|
|
train_examples, dev_examples = load_datasets(cfg) |
|
|
|
|
|
init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False)) |
|
|
|
|
|
train_dataset = LLMJPSFTDataset( |
|
|
loaded_examples=train_examples, |
|
|
tokenizer=ptl_model.tokenizer, |
|
|
use_loss_mask=cfg.model.use_loss_mask, |
|
|
max_seq_length=cfg.model.max_seq_length, |
|
|
) |
|
|
val_dataset = LLMJPSFTDataset( |
|
|
loaded_examples=dev_examples, |
|
|
tokenizer=ptl_model.tokenizer, |
|
|
use_loss_mask=cfg.model.use_loss_mask, |
|
|
max_seq_length=cfg.model.max_seq_length, |
|
|
) |
|
|
|
|
|
train_dataloader = build_dataloader( |
|
|
dataset=train_dataset, |
|
|
consumed_samples=consumed_samples, |
|
|
micro_batch_size=cfg.data.train_ds.micro_batch_size, |
|
|
global_batch_size=cfg.data.train_ds.global_batch_size, |
|
|
collate_fn=train_dataset.collate_fn, |
|
|
seed=cfg.seed, |
|
|
) |
|
|
val_dataloader = build_dataloader( |
|
|
dataset=val_dataset, |
|
|
consumed_samples=0, |
|
|
micro_batch_size=cfg.data.validation_ds.micro_batch_size, |
|
|
global_batch_size=cfg.data.validation_ds.global_batch_size, |
|
|
collate_fn=val_dataset.collate_fn, |
|
|
) |
|
|
|
|
|
init_using_ptl(trainer, ptl_model, train_dataloader, train_dataset) |
|
|
optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) |
|
|
|
|
|
ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) |
|
|
|
|
|
logger.log_hyperparams(OmegaConf.to_container(cfg)) |
|
|
timer = Timer(cfg.exp_manager.get("max_time_per_run")) |
|
|
|
|
|
sft_trainer = SupervisedTrainer( |
|
|
cfg=cfg.trainer.sft, |
|
|
model=ptl_model, |
|
|
optimizer=optimizer, |
|
|
scheduler=scheduler, |
|
|
train_dataloader=train_dataloader, |
|
|
val_dataloader=val_dataloader, |
|
|
test_dataloader=None, |
|
|
logger=logger, |
|
|
ckpt_callback=ckpt_callback, |
|
|
run_timer=timer, |
|
|
) |
|
|
|
|
|
if custom_trainer_state_dict is not None: |
|
|
sft_trainer.load_state_dict(custom_trainer_state_dict) |
|
|
|
|
|
sft_trainer.fit() |
|
|
|
|
|
|
|
|
for optimizer_state_file in glob.glob( |
|
|
f"{log_dir}/checkpoints/step*/optimizer.state.*" |
|
|
): |
|
|
try: |
|
|
shutil.rmtree(optimizer_state_file) |
|
|
logging.info(f"Deleted directory: {optimizer_state_file}") |
|
|
except OSError as e: |
|
|
logging.error(f"Error: {optimizer_state_file} : {e.strerror}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|