File size: 5,108 Bytes
bf64b03 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | #!/usr/bin/env python3
"""
Main training entry point for Vortex models.
"""
import argparse
import sys
from pathlib import Path
import torch
from configs.vortex_7b_config import VORTEX_7B_CONFIG
from configs.vortex_13b_config import VORTEX_13B_CONFIG
from configs.training_config import TRAINING_CONFIG, TRAINING_CONFIG_7B_CUDA, TRAINING_CONFIG_13B_CUDA, TRAINING_CONFIG_MPS
from models.vortex_model import VortexModel
from tokenizer.vortex_tokenizer import VortexScienceTokenizer
from training.trainer import VortexTrainer, VortexDataset
def parse_args():
parser = argparse.ArgumentParser(description="Train Vortex scientific language model")
parser.add_argument("--model_size", type=str, choices=["7b", "13b"], default="7b",
help="Model size to train")
parser.add_argument("--device", type=str, default="cuda",
choices=["cuda", "mps", "cpu"],
help="Device to train on")
parser.add_argument("--use_mps", action="store_true",
help="Use MPS backend (Apple Silicon)")
parser.add_argument("--data_dir", type=str, default="./data/processed",
help="Directory with processed data shards")
parser.add_argument("--tokenizer_path", type=str, default=None,
help="Path to pretrained tokenizer")
parser.add_argument("--resume_from_checkpoint", type=str, default=None,
help="Resume training from checkpoint")
parser.add_argument("--output_dir", type=str, default="./checkpoints",
help="Output directory for checkpoints")
parser.add_argument("--max_steps", type=int, default=None,
help="Override max training steps")
parser.add_argument("--micro_batch_size", type=int, default=None,
help="Override micro batch size")
parser.add_argument("--quantization", type=str, choices=[None, "int8", "int4"], default=None,
help="Quantization for 13B on 8GB")
return parser.parse_args()
def main():
args = parse_args()
# Load configs
if args.model_size == "7b":
model_config = VORTEX_7B_CONFIG.copy()
train_config = TRAINING_CONFIG_7B_CUDA.copy()
else:
model_config = VORTEX_13B_CONFIG.copy()
train_config = TRAINING_CONFIG_13B_CUDA.copy()
# Override with MPS config if needed
if args.use_mps or args.device == "mps":
train_config = TRAINING_CONFIG_MPS.copy()
train_config["use_mps"] = True
# Apply overrides
if args.max_steps:
train_config["max_steps"] = args.max_steps
if args.micro_batch_size:
train_config["micro_batch_size"] = args.micro_batch_size
if args.quantization:
train_config["quantization"] = args.quantization
# Set device
device = torch.device(args.device)
train_config["device"] = args.device
print(f"Training Vortex-{args.model_size.upper()}")
print(f"Device: {device}")
print(f"Max steps: {train_config['max_steps']}")
print(f"Micro batch size: {train_config['micro_batch_size']}")
# Create tokenizer
print("Loading tokenizer...")
tokenizer = VortexScienceTokenizer(
model_config,
tokenizer_path=args.tokenizer_path,
)
print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
# Create model
print("Creating model...")
model = VortexModel(model_config)
print(f"Model parameters: {model.get_num_params():,}")
# Estimate memory
mem = model.estimate_memory_usage(
train_config["micro_batch_size"],
model_config["max_seq_len"],
)
print("Memory estimate:")
for k, v in mem.items():
print(f" {k}: {v:.2f} GB")
# Load dataset
print("Loading dataset...")
data_dir = Path(args.data_dir)
shard_files = sorted(list(data_dir.glob("train_*.parquet")))
if not shard_files:
print(f"No training shards found in {data_dir}")
print("Please run data pipeline first.")
sys.exit(1)
train_dataset = VortexDataset(
shard_files,
tokenizer,
max_seq_len=model_config["max_seq_len"],
)
print(f"Training dataset size: {len(train_dataset)} samples")
# Create eval dataset (use first few shards)
eval_shard_files = shard_files[:1] # Use first shard for eval
eval_dataset = VortexDataset(
eval_shard_files,
tokenizer,
max_seq_len=model_config["max_seq_len"],
)
# Create trainer
trainer = VortexTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
config=train_config,
eval_dataset=eval_dataset,
)
# Resume from checkpoint if specified
if args.resume_from_checkpoint:
trainer.load_checkpoint(args.resume_from_checkpoint)
# Train
trainer.train()
print("Training complete!")
if __name__ == "__main__":
main()
|