| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| from argparse import ArgumentParser |
|
|
| import torch |
| from huggingface_hub import snapshot_download |
| from lightning.pytorch.loggers import WandbLogger |
| from megatron.core.optimizer import OptimizerConfig |
| from nemo import lightning as nl |
| from nemo.collections import llm |
| from nemo.lightning.pytorch.callbacks import ModelCheckpoint, PreemptionCallback |
| from nemo.lightning.pytorch.strategies.utils import RestoreConfig |
|
|
| from cosmos1.models.autoregressive.nemo.cosmos import CosmosConfig4B, CosmosConfig12B, CosmosModel |
|
|
|
|
| def main(args): |
| if "4B" in args.model_path: |
| config = CosmosConfig4B() |
| elif "12B" in args.model_path: |
| config = CosmosConfig12B() |
| else: |
| raise NotImplementedError |
|
|
| if args.model_path in ["nvidia/Cosmos-1.0-Autoregressive-4B", "nvidia/Cosmos-1.0-Autoregressive-12B"]: |
| args.model_path = os.path.join(snapshot_download(args.model_path, allow_patterns=["nemo/*"]), "nemo") |
|
|
| model = CosmosModel(config) |
|
|
| data_module = llm.PreTrainingDataModule( |
| paths=[args.data_path], |
| seq_length=12800, |
| global_batch_size=args.global_batch_size, |
| micro_batch_size=args.micro_batch_size, |
| tokenizer=None, |
| split=args.split_string, |
| num_workers=1, |
| index_mapping_dir=args.index_mapping_dir, |
| ) |
|
|
| |
| |
| llm.api.train( |
| model=model, |
| data=data_module, |
| trainer=nl.Trainer( |
| devices=args.tensor_model_parallel_size, |
| num_nodes=1, |
| max_steps=args.max_steps, |
| accelerator="gpu", |
| strategy=nl.MegatronStrategy( |
| tensor_model_parallel_size=args.tensor_model_parallel_size, |
| pipeline_model_parallel_size=1, |
| context_parallel_size=1, |
| sequence_parallel=False, |
| pipeline_dtype=torch.bfloat16, |
| ), |
| plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), |
| num_sanity_val_steps=0, |
| limit_val_batches=0, |
| max_epochs=args.max_epochs, |
| log_every_n_steps=1, |
| callbacks=[ |
| ModelCheckpoint( |
| monitor="reduced_train_loss", |
| filename="{epoch}-{step}", |
| every_n_train_steps=args.save_every_n_steps, |
| save_top_k=2, |
| ), |
| PreemptionCallback(), |
| ], |
| ), |
| log=nl.NeMoLogger(wandb=(WandbLogger() if "WANDB_API_KEY" in os.environ else None), log_dir=args.log_dir), |
| optim=nl.MegatronOptimizerModule( |
| config=OptimizerConfig( |
| lr=args.lr, |
| bf16=True, |
| params_dtype=torch.bfloat16, |
| use_distributed_optimizer=False, |
| ) |
| ), |
| tokenizer=None, |
| resume=nl.AutoResume( |
| restore_config=RestoreConfig(path=args.model_path), |
| resume_if_exists=True, |
| resume_ignore_no_checkpoint=False, |
| resume_past_end=True, |
| ), |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| parser = ArgumentParser() |
| parser.add_argument("--data_path", required=True, type=str, help="The path to the .bin .idx files") |
| parser.add_argument( |
| "--model_path", default="nvidia/Cosmos-1.0-Autoregressive-4B", type=str, help="The path to the nemo model" |
| ) |
| parser.add_argument( |
| "--index_mapping_dir", default="./index_mapping", type=str, help="The directory to store mapped indices" |
| ) |
| parser.add_argument("--log_dir", default="./log_dir", type=str, help="The path to the logs") |
| parser.add_argument("--split_string", default="98,1,1", type=str, help="The train/test/validation split") |
| parser.add_argument("--tensor_model_parallel_size", default=2, type=int, help="Tensor model parallel size") |
| parser.add_argument("--max_steps", default=100, type=int, help="The max number of steps to run finetuning") |
| parser.add_argument("--save_every_n_steps", default=100, type=int, help="How often to save a checkpoint") |
| parser.add_argument("--global_batch_size", default=2, type=int, help="The global batch size") |
| parser.add_argument( |
| "--micro_batch_size", default=1, type=int, help="The micro batch size if using pipeline parallel" |
| ) |
| parser.add_argument("--lr", default=5e-5, type=float, help="The learning rate") |
| parser.add_argument("--max_epochs", default=10, type=int, help="Max number of epochs") |
|
|
| args = parser.parse_args() |
|
|
| main(args) |
|
|