File size: 5,067 Bytes
2d7e335 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | #!/usr/bin/env python3
"""
AAM Diffusion LLM — Training Script
Main entry point for training the AAM Diffusion Model.
Usage:
# Train with default config (base model)
python scripts/train.py
# Train with specific model size
python scripts/train.py --model_size small
# Train with custom config
python scripts/train.py --config path/to/config.json
# Train with specific data
python scripts/train.py --train_data path/to/train.jsonl --val_data path/to/val.jsonl
Analogi: Seperti Jin Soun memulai latihan fisiknya —
ini adalah titik awal di mana "tubuh" AAM mulai dilatih.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from diffusion_llm.config.model_config import AamDiffusionConfig, get_default_config
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
from diffusion_llm.training.trainer import AamTrainer
from diffusion_llm.training.dataset import GraphNarrativeDataset
from diffusion_llm.data.data_pipeline import DataPipeline
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description="Train AAM Diffusion LLM",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Model configuration
parser.add_argument(
"--model_size", type=str, default="base",
choices=["tiny", "small", "base", "medium"],
help="Model size preset",
)
parser.add_argument(
"--config", type=str, default=None,
help="Path to custom config JSON (overrides --model_size)",
)
# Data
parser.add_argument(
"--train_data", type=str, default=None,
help="Path to training data (JSONL)",
)
parser.add_argument(
"--val_data", type=str, default=None,
help="Path to validation data (JSONL)",
)
parser.add_argument(
"--output_dir", type=str, default="./output",
help="Output directory for checkpoints and logs",
)
parser.add_argument(
"--force_regenerate", action="store_true",
help="Force regenerate synthetic data",
)
# Training overrides
parser.add_argument("--batch_size", type=int, default=None)
parser.add_argument("--learning_rate", type=float, default=None)
parser.add_argument("--max_steps", type=int, default=None)
parser.add_argument("--n_timesteps", type=int, default=None)
parser.add_argument("--seed", type=int, default=42)
return parser.parse_args()
def main() -> None:
"""Main training entry point."""
args = parse_args()
# Load or create config
if args.config:
config = AamDiffusionConfig.from_json(args.config)
logger.info("Loaded config from %s", args.config)
else:
config = get_default_config(args.model_size)
logger.info("Using %s model config", args.model_size)
# Apply CLI overrides
if args.output_dir:
config.output_dir = args.output_dir
if args.train_data:
config.training.train_data_path = args.train_data
if args.val_data:
config.training.val_data_path = args.val_data
if args.batch_size:
config.training.batch_size = args.batch_size
if args.learning_rate:
config.training.learning_rate = args.learning_rate
if args.max_steps:
config.training.max_steps = args.max_steps
if args.n_timesteps:
config.diffusion.n_timesteps = args.n_timesteps
config.seed = args.seed
# Print config summary
print(config.summary())
# Save config
config_path = Path(config.output_dir) / "config.json"
config.to_json(config_path)
logger.info("Config saved to %s", config_path)
# Step 1: Prepare data
pipeline = DataPipeline(config)
tokenizer, train_loader, val_loader = pipeline.prepare(
force_regenerate=args.force_regenerate,
)
# Step 2: Create model
model = AamDiffusionModel(config)
logger.info(
"Model created: %s parameters",
model._format_params(model.get_num_params()),
)
# Step 3: Create datasets (using pre-created loaders)
train_dataset = train_loader.dataset
val_dataset = val_loader.dataset if val_loader else None
# Step 4: Create trainer and train
trainer = AamTrainer(
config=config,
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
val_dataset=val_dataset,
)
# Override data loaders (already created by pipeline)
trainer.train_loader = train_loader
trainer.val_loader = val_loader
# Start training
trainer.train()
logger.info("Training complete! Output saved to %s", config.output_dir)
if __name__ == "__main__":
main()
|