#!/usr/bin/env python3 """ Main training entry point for Vortex models. """ import argparse import sys from pathlib import Path import torch from configs.vortex_7b_config import VORTEX_7B_CONFIG from configs.vortex_13b_config import VORTEX_13B_CONFIG from configs.training_config import TRAINING_CONFIG, TRAINING_CONFIG_7B_CUDA, TRAINING_CONFIG_13B_CUDA, TRAINING_CONFIG_MPS from models.vortex_model import VortexModel from tokenizer.vortex_tokenizer import VortexScienceTokenizer from training.trainer import VortexTrainer, VortexDataset def parse_args(): parser = argparse.ArgumentParser(description="Train Vortex scientific language model") parser.add_argument("--model_size", type=str, choices=["7b", "13b"], default="7b", help="Model size to train") parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "mps", "cpu"], help="Device to train on") parser.add_argument("--use_mps", action="store_true", help="Use MPS backend (Apple Silicon)") parser.add_argument("--data_dir", type=str, default="./data/processed", help="Directory with processed data shards") parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to pretrained tokenizer") parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Resume training from checkpoint") parser.add_argument("--output_dir", type=str, default="./checkpoints", help="Output directory for checkpoints") parser.add_argument("--max_steps", type=int, default=None, help="Override max training steps") parser.add_argument("--micro_batch_size", type=int, default=None, help="Override micro batch size") parser.add_argument("--quantization", type=str, choices=[None, "int8", "int4"], default=None, help="Quantization for 13B on 8GB") return parser.parse_args() def main(): args = parse_args() # Load configs if args.model_size == "7b": model_config = VORTEX_7B_CONFIG.copy() train_config = TRAINING_CONFIG_7B_CUDA.copy() else: model_config = VORTEX_13B_CONFIG.copy() train_config = TRAINING_CONFIG_13B_CUDA.copy() # Override with MPS config if needed if args.use_mps or args.device == "mps": train_config = TRAINING_CONFIG_MPS.copy() train_config["use_mps"] = True # Apply overrides if args.max_steps: train_config["max_steps"] = args.max_steps if args.micro_batch_size: train_config["micro_batch_size"] = args.micro_batch_size if args.quantization: train_config["quantization"] = args.quantization # Set device device = torch.device(args.device) train_config["device"] = args.device print(f"Training Vortex-{args.model_size.upper()}") print(f"Device: {device}") print(f"Max steps: {train_config['max_steps']}") print(f"Micro batch size: {train_config['micro_batch_size']}") # Create tokenizer print("Loading tokenizer...") tokenizer = VortexScienceTokenizer( model_config, tokenizer_path=args.tokenizer_path, ) print(f"Tokenizer vocab size: {tokenizer.vocab_size}") # Create model print("Creating model...") model = VortexModel(model_config) print(f"Model parameters: {model.get_num_params():,}") # Estimate memory mem = model.estimate_memory_usage( train_config["micro_batch_size"], model_config["max_seq_len"], ) print("Memory estimate:") for k, v in mem.items(): print(f" {k}: {v:.2f} GB") # Load dataset print("Loading dataset...") data_dir = Path(args.data_dir) shard_files = sorted(list(data_dir.glob("train_*.parquet"))) if not shard_files: print(f"No training shards found in {data_dir}") print("Please run data pipeline first.") sys.exit(1) train_dataset = VortexDataset( shard_files, tokenizer, max_seq_len=model_config["max_seq_len"], ) print(f"Training dataset size: {len(train_dataset)} samples") # Create eval dataset (use first few shards) eval_shard_files = shard_files[:1] # Use first shard for eval eval_dataset = VortexDataset( eval_shard_files, tokenizer, max_seq_len=model_config["max_seq_len"], ) # Create trainer trainer = VortexTrainer( model=model, tokenizer=tokenizer, train_dataset=train_dataset, config=train_config, eval_dataset=eval_dataset, ) # Resume from checkpoint if specified if args.resume_from_checkpoint: trainer.load_checkpoint(args.resume_from_checkpoint) # Train trainer.train() print("Training complete!") if __name__ == "__main__": main()