File size: 880 Bytes
17758b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
"""
train — LLM pretraining package.

Public API:
    TrainConfig   : Dataclass of training hyper-parameters.
    Trainer       : Core training loop with gradient accumulation, AMP, and logging.

Utility functions (re-exported from train.utils):
    get_cosine_schedule_with_warmup
    save_checkpoint
    load_checkpoint
    get_grad_norm
    setup_ddp
    cleanup_ddp
    is_main_process
"""

from train.trainer import TrainConfig, Trainer
from train.utils import (
    cleanup_ddp,
    get_cosine_schedule_with_warmup,
    get_grad_norm,
    is_main_process,
    load_checkpoint,
    save_checkpoint,
    setup_ddp,
)

__all__ = [
    # Core classes
    "TrainConfig",
    "Trainer",
    # Utility functions
    "get_cosine_schedule_with_warmup",
    "save_checkpoint",
    "load_checkpoint",
    "get_grad_norm",
    "setup_ddp",
    "cleanup_ddp",
    "is_main_process",
]