File size: 3,452 Bytes
30ecce6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python3
"""
Main training script - can be run directly without import issues.
This script imports and runs the training function from the supernova package.
"""

import argparse
import sys
import os

# Add the current directory to Python path to ensure supernova package can be imported
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from supernova.train import train

def main():
    parser = argparse.ArgumentParser(description="Train Supernova 25M model")
    parser.add_argument("--config", required=True, help="Path to model config JSON")
    parser.add_argument("--data", required=True, help="Path to data config YAML")
    parser.add_argument("--seq-len", type=int, default=1024, help="Sequence length")
    parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
    parser.add_argument("--grad-accum", type=int, default=8, help="Gradient accumulation steps")
    parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
    parser.add_argument("--warmup-steps", type=int, default=2000, help="Warmup steps")
    parser.add_argument("--max-steps", type=int, default=100000, help="Maximum training steps")
    parser.add_argument("--save-every", type=int, default=10000, help="Save checkpoint every N steps")
    parser.add_argument("--out-dir", default="checkpoints", help="Output directory")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--validate-every", type=int, default=1000, help="Validate every N steps")
    parser.add_argument("--val-steps", type=int, default=100, help="Validation steps")
    parser.add_argument("--clip-grad-norm", type=float, default=1.0, help="Gradient clipping norm")
    parser.add_argument("--no-ema", action="store_true", help="Disable EMA")
    parser.add_argument("--ema-decay", type=float, default=0.9999, help="EMA decay rate")
    parser.add_argument("--resume-from", help="Resume from checkpoint")
    parser.add_argument("--no-tensorboard", action="store_true", help="Disable tensorboard")
    parser.add_argument("--ddp", action="store_true", help="Use distributed training")
    parser.add_argument("--local-rank", type=int, default=0, help="Local rank for DDP")
    parser.add_argument("--num-workers", type=int, default=4, help="DataLoader workers")
    parser.add_argument("--no-pin-memory", action="store_true", help="Disable pin memory")
    parser.add_argument("--compile-model", action="store_true", help="Use torch.compile")
    
    args = parser.parse_args()
    
    # Call the training function
    train(
        config_path=args.config,
        data_config_path=args.data,
        seq_len=args.seq_len,
        batch_size=args.batch_size,
        grad_accum=args.grad_accum,
        lr=args.lr,
        warmup_steps=args.warmup_steps,
        max_steps=args.max_steps,
        save_every=args.save_every,
        out_dir=args.out_dir,
        seed=args.seed,
        validate_every=args.validate_every,
        val_steps=args.val_steps,
        clip_grad_norm=args.clip_grad_norm,
        use_ema=not args.no_ema,
        ema_decay=args.ema_decay,
        resume_from=args.resume_from,
        use_tensorboard=not args.no_tensorboard,
        ddp=args.ddp,
        local_rank=args.local_rank,
        num_workers=args.num_workers,
        pin_memory=not args.no_pin_memory,
        compile_model=args.compile_model,
    )

if __name__ == "__main__":
    main()