| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| """ | |
| The training entry script for the FastGen project. Works for both DDP and FSDP training. | |
| """ | |
| import argparse | |
| import warnings | |
| from fastgen.configs.config import BaseConfig | |
| from fastgen.utils import instantiate | |
| from fastgen.trainer import Trainer | |
| import fastgen.utils.logging_utils as logger | |
| from fastgen.utils.distributed import synchronize, clean_up | |
| from fastgen.utils.scripts import parse_args, setup | |
| warnings.filterwarnings( | |
| "ignore", "Grad strides do not match bucket view strides" | |
| ) # False warning printed by PyTorch 2.6. | |
| def main(config: BaseConfig): | |
| # initiate the model | |
| config.model_class.config = config.model | |
| model = instantiate(config.model_class) | |
| config.model_class.config = None | |
| synchronize() | |
| # initiate the trainer | |
| logger.info("Initializing trainer...") | |
| fastgen_trainer = Trainer(config) | |
| logger.success("Trainer initialized successfully") | |
| synchronize() | |
| # Start training | |
| fastgen_trainer.run(model) | |
| synchronize() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Training") | |
| args = parse_args(parser) | |
| config = setup(args) | |
| main(config) | |
| clean_up() | |
| logger.info("Training finished.") | |