|
|
|
|
|
""" |
|
|
Main training script - can be run directly without import issues. |
|
|
This script imports and runs the training function from the supernova package. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import sys |
|
|
import os |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
from supernova.train import train |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Train Supernova 25M model") |
|
|
parser.add_argument("--config", required=True, help="Path to model config JSON") |
|
|
parser.add_argument("--data", required=True, help="Path to data config YAML") |
|
|
parser.add_argument("--seq-len", type=int, default=1024, help="Sequence length") |
|
|
parser.add_argument("--batch-size", type=int, default=16, help="Batch size") |
|
|
parser.add_argument("--grad-accum", type=int, default=8, help="Gradient accumulation steps") |
|
|
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") |
|
|
parser.add_argument("--warmup-steps", type=int, default=2000, help="Warmup steps") |
|
|
parser.add_argument("--max-steps", type=int, default=100000, help="Maximum training steps") |
|
|
parser.add_argument("--save-every", type=int, default=10000, help="Save checkpoint every N steps") |
|
|
parser.add_argument("--out-dir", default="checkpoints", help="Output directory") |
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed") |
|
|
parser.add_argument("--validate-every", type=int, default=1000, help="Validate every N steps") |
|
|
parser.add_argument("--val-steps", type=int, default=100, help="Validation steps") |
|
|
parser.add_argument("--clip-grad-norm", type=float, default=1.0, help="Gradient clipping norm") |
|
|
parser.add_argument("--no-ema", action="store_true", help="Disable EMA") |
|
|
parser.add_argument("--ema-decay", type=float, default=0.9999, help="EMA decay rate") |
|
|
parser.add_argument("--resume-from", help="Resume from checkpoint") |
|
|
parser.add_argument("--no-tensorboard", action="store_true", help="Disable tensorboard") |
|
|
parser.add_argument("--ddp", action="store_true", help="Use distributed training") |
|
|
parser.add_argument("--local-rank", type=int, default=0, help="Local rank for DDP") |
|
|
parser.add_argument("--num-workers", type=int, default=4, help="DataLoader workers") |
|
|
parser.add_argument("--no-pin-memory", action="store_true", help="Disable pin memory") |
|
|
parser.add_argument("--compile-model", action="store_true", help="Use torch.compile") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
train( |
|
|
config_path=args.config, |
|
|
data_config_path=args.data, |
|
|
seq_len=args.seq_len, |
|
|
batch_size=args.batch_size, |
|
|
grad_accum=args.grad_accum, |
|
|
lr=args.lr, |
|
|
warmup_steps=args.warmup_steps, |
|
|
max_steps=args.max_steps, |
|
|
save_every=args.save_every, |
|
|
out_dir=args.out_dir, |
|
|
seed=args.seed, |
|
|
validate_every=args.validate_every, |
|
|
val_steps=args.val_steps, |
|
|
clip_grad_norm=args.clip_grad_norm, |
|
|
use_ema=not args.no_ema, |
|
|
ema_decay=args.ema_decay, |
|
|
resume_from=args.resume_from, |
|
|
use_tensorboard=not args.no_tensorboard, |
|
|
ddp=args.ddp, |
|
|
local_rank=args.local_rank, |
|
|
num_workers=args.num_workers, |
|
|
pin_memory=not args.no_pin_memory, |
|
|
compile_model=args.compile_model, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |