| |
| """ |
| AAM Diffusion LLM — Training Script |
| |
| Main entry point for training the AAM Diffusion Model. |
| |
| Usage: |
| # Train with default config (base model) |
| python scripts/train.py |
| |
| # Train with specific model size |
| python scripts/train.py --model_size small |
| |
| # Train with custom config |
| python scripts/train.py --config path/to/config.json |
| |
| # Train with specific data |
| python scripts/train.py --train_data path/to/train.jsonl --val_data path/to/val.jsonl |
| |
| Analogi: Seperti Jin Soun memulai latihan fisiknya — |
| ini adalah titik awal di mana "tubuh" AAM mulai dilatih. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import logging |
| import sys |
| from pathlib import Path |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from diffusion_llm.config.model_config import AamDiffusionConfig, get_default_config |
| from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel |
| from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer |
| from diffusion_llm.training.trainer import AamTrainer |
| from diffusion_llm.training.dataset import GraphNarrativeDataset |
| from diffusion_llm.data.data_pipeline import DataPipeline |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| """Parse command-line arguments.""" |
| parser = argparse.ArgumentParser( |
| description="Train AAM Diffusion LLM", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
|
|
| |
| parser.add_argument( |
| "--model_size", type=str, default="base", |
| choices=["tiny", "small", "base", "medium"], |
| help="Model size preset", |
| ) |
| parser.add_argument( |
| "--config", type=str, default=None, |
| help="Path to custom config JSON (overrides --model_size)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--train_data", type=str, default=None, |
| help="Path to training data (JSONL)", |
| ) |
| parser.add_argument( |
| "--val_data", type=str, default=None, |
| help="Path to validation data (JSONL)", |
| ) |
| parser.add_argument( |
| "--output_dir", type=str, default="./output", |
| help="Output directory for checkpoints and logs", |
| ) |
| parser.add_argument( |
| "--force_regenerate", action="store_true", |
| help="Force regenerate synthetic data", |
| ) |
|
|
| |
| parser.add_argument("--batch_size", type=int, default=None) |
| parser.add_argument("--learning_rate", type=float, default=None) |
| parser.add_argument("--max_steps", type=int, default=None) |
| parser.add_argument("--n_timesteps", type=int, default=None) |
| parser.add_argument("--seed", type=int, default=42) |
|
|
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| """Main training entry point.""" |
| args = parse_args() |
|
|
| |
| if args.config: |
| config = AamDiffusionConfig.from_json(args.config) |
| logger.info("Loaded config from %s", args.config) |
| else: |
| config = get_default_config(args.model_size) |
| logger.info("Using %s model config", args.model_size) |
|
|
| |
| if args.output_dir: |
| config.output_dir = args.output_dir |
| if args.train_data: |
| config.training.train_data_path = args.train_data |
| if args.val_data: |
| config.training.val_data_path = args.val_data |
| if args.batch_size: |
| config.training.batch_size = args.batch_size |
| if args.learning_rate: |
| config.training.learning_rate = args.learning_rate |
| if args.max_steps: |
| config.training.max_steps = args.max_steps |
| if args.n_timesteps: |
| config.diffusion.n_timesteps = args.n_timesteps |
| config.seed = args.seed |
|
|
| |
| print(config.summary()) |
|
|
| |
| config_path = Path(config.output_dir) / "config.json" |
| config.to_json(config_path) |
| logger.info("Config saved to %s", config_path) |
|
|
| |
| pipeline = DataPipeline(config) |
| tokenizer, train_loader, val_loader = pipeline.prepare( |
| force_regenerate=args.force_regenerate, |
| ) |
|
|
| |
| model = AamDiffusionModel(config) |
| logger.info( |
| "Model created: %s parameters", |
| model._format_params(model.get_num_params()), |
| ) |
|
|
| |
| train_dataset = train_loader.dataset |
| val_dataset = val_loader.dataset if val_loader else None |
|
|
| |
| trainer = AamTrainer( |
| config=config, |
| model=model, |
| tokenizer=tokenizer, |
| train_dataset=train_dataset, |
| val_dataset=val_dataset, |
| ) |
|
|
| |
| trainer.train_loader = train_loader |
| trainer.val_loader = val_loader |
|
|
| |
| trainer.train() |
|
|
| logger.info("Training complete! Output saved to %s", config.output_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|