File size: 5,108 Bytes
bf64b03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#!/usr/bin/env python3
"""

Main training entry point for Vortex models.

"""

import argparse
import sys
from pathlib import Path

import torch

from configs.vortex_7b_config import VORTEX_7B_CONFIG
from configs.vortex_13b_config import VORTEX_13B_CONFIG
from configs.training_config import TRAINING_CONFIG, TRAINING_CONFIG_7B_CUDA, TRAINING_CONFIG_13B_CUDA, TRAINING_CONFIG_MPS

from models.vortex_model import VortexModel
from tokenizer.vortex_tokenizer import VortexScienceTokenizer
from training.trainer import VortexTrainer, VortexDataset


def parse_args():
    parser = argparse.ArgumentParser(description="Train Vortex scientific language model")
    parser.add_argument("--model_size", type=str, choices=["7b", "13b"], default="7b",
                        help="Model size to train")
    parser.add_argument("--device", type=str, default="cuda",
                        choices=["cuda", "mps", "cpu"],
                        help="Device to train on")
    parser.add_argument("--use_mps", action="store_true",
                        help="Use MPS backend (Apple Silicon)")
    parser.add_argument("--data_dir", type=str, default="./data/processed",
                        help="Directory with processed data shards")
    parser.add_argument("--tokenizer_path", type=str, default=None,
                        help="Path to pretrained tokenizer")
    parser.add_argument("--resume_from_checkpoint", type=str, default=None,
                        help="Resume training from checkpoint")
    parser.add_argument("--output_dir", type=str, default="./checkpoints",
                        help="Output directory for checkpoints")
    parser.add_argument("--max_steps", type=int, default=None,
                        help="Override max training steps")
    parser.add_argument("--micro_batch_size", type=int, default=None,
                        help="Override micro batch size")
    parser.add_argument("--quantization", type=str, choices=[None, "int8", "int4"], default=None,
                        help="Quantization for 13B on 8GB")
    return parser.parse_args()


def main():
    args = parse_args()

    # Load configs
    if args.model_size == "7b":
        model_config = VORTEX_7B_CONFIG.copy()
        train_config = TRAINING_CONFIG_7B_CUDA.copy()
    else:
        model_config = VORTEX_13B_CONFIG.copy()
        train_config = TRAINING_CONFIG_13B_CUDA.copy()

    # Override with MPS config if needed
    if args.use_mps or args.device == "mps":
        train_config = TRAINING_CONFIG_MPS.copy()
        train_config["use_mps"] = True

    # Apply overrides
    if args.max_steps:
        train_config["max_steps"] = args.max_steps
    if args.micro_batch_size:
        train_config["micro_batch_size"] = args.micro_batch_size
    if args.quantization:
        train_config["quantization"] = args.quantization

    # Set device
    device = torch.device(args.device)
    train_config["device"] = args.device

    print(f"Training Vortex-{args.model_size.upper()}")
    print(f"Device: {device}")
    print(f"Max steps: {train_config['max_steps']}")
    print(f"Micro batch size: {train_config['micro_batch_size']}")

    # Create tokenizer
    print("Loading tokenizer...")
    tokenizer = VortexScienceTokenizer(
        model_config,
        tokenizer_path=args.tokenizer_path,
    )
    print(f"Tokenizer vocab size: {tokenizer.vocab_size}")

    # Create model
    print("Creating model...")
    model = VortexModel(model_config)
    print(f"Model parameters: {model.get_num_params():,}")

    # Estimate memory
    mem = model.estimate_memory_usage(
        train_config["micro_batch_size"],
        model_config["max_seq_len"],
    )
    print("Memory estimate:")
    for k, v in mem.items():
        print(f"  {k}: {v:.2f} GB")

    # Load dataset
    print("Loading dataset...")
    data_dir = Path(args.data_dir)
    shard_files = sorted(list(data_dir.glob("train_*.parquet")))
    if not shard_files:
        print(f"No training shards found in {data_dir}")
        print("Please run data pipeline first.")
        sys.exit(1)

    train_dataset = VortexDataset(
        shard_files,
        tokenizer,
        max_seq_len=model_config["max_seq_len"],
    )
    print(f"Training dataset size: {len(train_dataset)} samples")

    # Create eval dataset (use first few shards)
    eval_shard_files = shard_files[:1]  # Use first shard for eval
    eval_dataset = VortexDataset(
        eval_shard_files,
        tokenizer,
        max_seq_len=model_config["max_seq_len"],
    )

    # Create trainer
    trainer = VortexTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        config=train_config,
        eval_dataset=eval_dataset,
    )

    # Resume from checkpoint if specified
    if args.resume_from_checkpoint:
        trainer.load_checkpoint(args.resume_from_checkpoint)

    # Train
    trainer.train()

    print("Training complete!")


if __name__ == "__main__":
    main()