File size: 880 Bytes
17758b3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | """
train — LLM pretraining package.
Public API:
TrainConfig : Dataclass of training hyper-parameters.
Trainer : Core training loop with gradient accumulation, AMP, and logging.
Utility functions (re-exported from train.utils):
get_cosine_schedule_with_warmup
save_checkpoint
load_checkpoint
get_grad_norm
setup_ddp
cleanup_ddp
is_main_process
"""
from train.trainer import TrainConfig, Trainer
from train.utils import (
cleanup_ddp,
get_cosine_schedule_with_warmup,
get_grad_norm,
is_main_process,
load_checkpoint,
save_checkpoint,
setup_ddp,
)
__all__ = [
# Core classes
"TrainConfig",
"Trainer",
# Utility functions
"get_cosine_schedule_with_warmup",
"save_checkpoint",
"load_checkpoint",
"get_grad_norm",
"setup_ddp",
"cleanup_ddp",
"is_main_process",
]
|