| |
|
| | """
|
| | Main training entry point for Vortex models.
|
| | """
|
| |
|
| | import argparse
|
| | import sys
|
| | from pathlib import Path
|
| |
|
| | import torch
|
| |
|
| | from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| | from configs.vortex_13b_config import VORTEX_13B_CONFIG
|
| | from configs.training_config import TRAINING_CONFIG, TRAINING_CONFIG_7B_CUDA, TRAINING_CONFIG_13B_CUDA, TRAINING_CONFIG_MPS
|
| |
|
| | from models.vortex_model import VortexModel
|
| | from tokenizer.vortex_tokenizer import VortexScienceTokenizer
|
| | from training.trainer import VortexTrainer, VortexDataset
|
| |
|
| |
|
| | def parse_args():
|
| | parser = argparse.ArgumentParser(description="Train Vortex scientific language model")
|
| | parser.add_argument("--model_size", type=str, choices=["7b", "13b"], default="7b",
|
| | help="Model size to train")
|
| | parser.add_argument("--device", type=str, default="cuda",
|
| | choices=["cuda", "mps", "cpu"],
|
| | help="Device to train on")
|
| | parser.add_argument("--use_mps", action="store_true",
|
| | help="Use MPS backend (Apple Silicon)")
|
| | parser.add_argument("--data_dir", type=str, default="./data/processed",
|
| | help="Directory with processed data shards")
|
| | parser.add_argument("--tokenizer_path", type=str, default=None,
|
| | help="Path to pretrained tokenizer")
|
| | parser.add_argument("--resume_from_checkpoint", type=str, default=None,
|
| | help="Resume training from checkpoint")
|
| | parser.add_argument("--output_dir", type=str, default="./checkpoints",
|
| | help="Output directory for checkpoints")
|
| | parser.add_argument("--max_steps", type=int, default=None,
|
| | help="Override max training steps")
|
| | parser.add_argument("--micro_batch_size", type=int, default=None,
|
| | help="Override micro batch size")
|
| | parser.add_argument("--quantization", type=str, choices=[None, "int8", "int4"], default=None,
|
| | help="Quantization for 13B on 8GB")
|
| | return parser.parse_args()
|
| |
|
| |
|
| | def main():
|
| | args = parse_args()
|
| |
|
| |
|
| | if args.model_size == "7b":
|
| | model_config = VORTEX_7B_CONFIG.copy()
|
| | train_config = TRAINING_CONFIG_7B_CUDA.copy()
|
| | else:
|
| | model_config = VORTEX_13B_CONFIG.copy()
|
| | train_config = TRAINING_CONFIG_13B_CUDA.copy()
|
| |
|
| |
|
| | if args.use_mps or args.device == "mps":
|
| | train_config = TRAINING_CONFIG_MPS.copy()
|
| | train_config["use_mps"] = True
|
| |
|
| |
|
| | if args.max_steps:
|
| | train_config["max_steps"] = args.max_steps
|
| | if args.micro_batch_size:
|
| | train_config["micro_batch_size"] = args.micro_batch_size
|
| | if args.quantization:
|
| | train_config["quantization"] = args.quantization
|
| |
|
| |
|
| | device = torch.device(args.device)
|
| | train_config["device"] = args.device
|
| |
|
| | print(f"Training Vortex-{args.model_size.upper()}")
|
| | print(f"Device: {device}")
|
| | print(f"Max steps: {train_config['max_steps']}")
|
| | print(f"Micro batch size: {train_config['micro_batch_size']}")
|
| |
|
| |
|
| | print("Loading tokenizer...")
|
| | tokenizer = VortexScienceTokenizer(
|
| | model_config,
|
| | tokenizer_path=args.tokenizer_path,
|
| | )
|
| | print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
|
| |
|
| |
|
| | print("Creating model...")
|
| | model = VortexModel(model_config)
|
| | print(f"Model parameters: {model.get_num_params():,}")
|
| |
|
| |
|
| | mem = model.estimate_memory_usage(
|
| | train_config["micro_batch_size"],
|
| | model_config["max_seq_len"],
|
| | )
|
| | print("Memory estimate:")
|
| | for k, v in mem.items():
|
| | print(f" {k}: {v:.2f} GB")
|
| |
|
| |
|
| | print("Loading dataset...")
|
| | data_dir = Path(args.data_dir)
|
| | shard_files = sorted(list(data_dir.glob("train_*.parquet")))
|
| | if not shard_files:
|
| | print(f"No training shards found in {data_dir}")
|
| | print("Please run data pipeline first.")
|
| | sys.exit(1)
|
| |
|
| | train_dataset = VortexDataset(
|
| | shard_files,
|
| | tokenizer,
|
| | max_seq_len=model_config["max_seq_len"],
|
| | )
|
| | print(f"Training dataset size: {len(train_dataset)} samples")
|
| |
|
| |
|
| | eval_shard_files = shard_files[:1]
|
| | eval_dataset = VortexDataset(
|
| | eval_shard_files,
|
| | tokenizer,
|
| | max_seq_len=model_config["max_seq_len"],
|
| | )
|
| |
|
| |
|
| | trainer = VortexTrainer(
|
| | model=model,
|
| | tokenizer=tokenizer,
|
| | train_dataset=train_dataset,
|
| | config=train_config,
|
| | eval_dataset=eval_dataset,
|
| | )
|
| |
|
| |
|
| | if args.resume_from_checkpoint:
|
| | trainer.load_checkpoint(args.resume_from_checkpoint)
|
| |
|
| |
|
| | trainer.train()
|
| |
|
| | print("Training complete!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|