taohu's picture
Upload folder using huggingface_hub
0839907 verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
The training entry script for the FastGen project. Works for both DDP and FSDP training.
"""
import argparse
import warnings
from fastgen.configs.config import BaseConfig
from fastgen.utils import instantiate
from fastgen.trainer import Trainer
import fastgen.utils.logging_utils as logger
from fastgen.utils.distributed import synchronize, clean_up
from fastgen.utils.scripts import parse_args, setup
warnings.filterwarnings(
"ignore", "Grad strides do not match bucket view strides"
) # False warning printed by PyTorch 2.6.
def main(config: BaseConfig):
# initiate the model
config.model_class.config = config.model
model = instantiate(config.model_class)
config.model_class.config = None
synchronize()
# initiate the trainer
logger.info("Initializing trainer...")
fastgen_trainer = Trainer(config)
logger.success("Trainer initialized successfully")
synchronize()
# Start training
fastgen_trainer.run(model)
synchronize()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Training")
args = parse_args(parser)
config = setup(args)
main(config)
clean_up()
logger.info("Training finished.")