File size: 5,067 Bytes
2d7e335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#!/usr/bin/env python3
"""
AAM Diffusion LLM — Training Script

Main entry point for training the AAM Diffusion Model.

Usage:
    # Train with default config (base model)
    python scripts/train.py

    # Train with specific model size
    python scripts/train.py --model_size small

    # Train with custom config
    python scripts/train.py --config path/to/config.json

    # Train with specific data
    python scripts/train.py --train_data path/to/train.jsonl --val_data path/to/val.jsonl

Analogi: Seperti Jin Soun memulai latihan fisiknya —
ini adalah titik awal di mana "tubuh" AAM mulai dilatih.
"""

from __future__ import annotations

import argparse
import logging
import sys
from pathlib import Path

# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))

from diffusion_llm.config.model_config import AamDiffusionConfig, get_default_config
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
from diffusion_llm.training.trainer import AamTrainer
from diffusion_llm.training.dataset import GraphNarrativeDataset
from diffusion_llm.data.data_pipeline import DataPipeline

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)


def parse_args() -> argparse.Namespace:
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(
        description="Train AAM Diffusion LLM",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    # Model configuration
    parser.add_argument(
        "--model_size", type=str, default="base",
        choices=["tiny", "small", "base", "medium"],
        help="Model size preset",
    )
    parser.add_argument(
        "--config", type=str, default=None,
        help="Path to custom config JSON (overrides --model_size)",
    )

    # Data
    parser.add_argument(
        "--train_data", type=str, default=None,
        help="Path to training data (JSONL)",
    )
    parser.add_argument(
        "--val_data", type=str, default=None,
        help="Path to validation data (JSONL)",
    )
    parser.add_argument(
        "--output_dir", type=str, default="./output",
        help="Output directory for checkpoints and logs",
    )
    parser.add_argument(
        "--force_regenerate", action="store_true",
        help="Force regenerate synthetic data",
    )

    # Training overrides
    parser.add_argument("--batch_size", type=int, default=None)
    parser.add_argument("--learning_rate", type=float, default=None)
    parser.add_argument("--max_steps", type=int, default=None)
    parser.add_argument("--n_timesteps", type=int, default=None)
    parser.add_argument("--seed", type=int, default=42)

    return parser.parse_args()


def main() -> None:
    """Main training entry point."""
    args = parse_args()

    # Load or create config
    if args.config:
        config = AamDiffusionConfig.from_json(args.config)
        logger.info("Loaded config from %s", args.config)
    else:
        config = get_default_config(args.model_size)
        logger.info("Using %s model config", args.model_size)

    # Apply CLI overrides
    if args.output_dir:
        config.output_dir = args.output_dir
    if args.train_data:
        config.training.train_data_path = args.train_data
    if args.val_data:
        config.training.val_data_path = args.val_data
    if args.batch_size:
        config.training.batch_size = args.batch_size
    if args.learning_rate:
        config.training.learning_rate = args.learning_rate
    if args.max_steps:
        config.training.max_steps = args.max_steps
    if args.n_timesteps:
        config.diffusion.n_timesteps = args.n_timesteps
    config.seed = args.seed

    # Print config summary
    print(config.summary())

    # Save config
    config_path = Path(config.output_dir) / "config.json"
    config.to_json(config_path)
    logger.info("Config saved to %s", config_path)

    # Step 1: Prepare data
    pipeline = DataPipeline(config)
    tokenizer, train_loader, val_loader = pipeline.prepare(
        force_regenerate=args.force_regenerate,
    )

    # Step 2: Create model
    model = AamDiffusionModel(config)
    logger.info(
        "Model created: %s parameters",
        model._format_params(model.get_num_params()),
    )

    # Step 3: Create datasets (using pre-created loaders)
    train_dataset = train_loader.dataset
    val_dataset = val_loader.dataset if val_loader else None

    # Step 4: Create trainer and train
    trainer = AamTrainer(
        config=config,
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
    )

    # Override data loaders (already created by pipeline)
    trainer.train_loader = train_loader
    trainer.val_loader = val_loader

    # Start training
    trainer.train()

    logger.info("Training complete! Output saved to %s", config.output_dir)


if __name__ == "__main__":
    main()