Vortex-7b-V1 / train.py
Zandy-Wandy's picture
Upload Vortex model
bf64b03 verified
#!/usr/bin/env python3
"""
Main training entry point for Vortex models.
"""
import argparse
import sys
from pathlib import Path
import torch
from configs.vortex_7b_config import VORTEX_7B_CONFIG
from configs.vortex_13b_config import VORTEX_13B_CONFIG
from configs.training_config import TRAINING_CONFIG, TRAINING_CONFIG_7B_CUDA, TRAINING_CONFIG_13B_CUDA, TRAINING_CONFIG_MPS
from models.vortex_model import VortexModel
from tokenizer.vortex_tokenizer import VortexScienceTokenizer
from training.trainer import VortexTrainer, VortexDataset
def parse_args():
parser = argparse.ArgumentParser(description="Train Vortex scientific language model")
parser.add_argument("--model_size", type=str, choices=["7b", "13b"], default="7b",
help="Model size to train")
parser.add_argument("--device", type=str, default="cuda",
choices=["cuda", "mps", "cpu"],
help="Device to train on")
parser.add_argument("--use_mps", action="store_true",
help="Use MPS backend (Apple Silicon)")
parser.add_argument("--data_dir", type=str, default="./data/processed",
help="Directory with processed data shards")
parser.add_argument("--tokenizer_path", type=str, default=None,
help="Path to pretrained tokenizer")
parser.add_argument("--resume_from_checkpoint", type=str, default=None,
help="Resume training from checkpoint")
parser.add_argument("--output_dir", type=str, default="./checkpoints",
help="Output directory for checkpoints")
parser.add_argument("--max_steps", type=int, default=None,
help="Override max training steps")
parser.add_argument("--micro_batch_size", type=int, default=None,
help="Override micro batch size")
parser.add_argument("--quantization", type=str, choices=[None, "int8", "int4"], default=None,
help="Quantization for 13B on 8GB")
return parser.parse_args()
def main():
args = parse_args()
# Load configs
if args.model_size == "7b":
model_config = VORTEX_7B_CONFIG.copy()
train_config = TRAINING_CONFIG_7B_CUDA.copy()
else:
model_config = VORTEX_13B_CONFIG.copy()
train_config = TRAINING_CONFIG_13B_CUDA.copy()
# Override with MPS config if needed
if args.use_mps or args.device == "mps":
train_config = TRAINING_CONFIG_MPS.copy()
train_config["use_mps"] = True
# Apply overrides
if args.max_steps:
train_config["max_steps"] = args.max_steps
if args.micro_batch_size:
train_config["micro_batch_size"] = args.micro_batch_size
if args.quantization:
train_config["quantization"] = args.quantization
# Set device
device = torch.device(args.device)
train_config["device"] = args.device
print(f"Training Vortex-{args.model_size.upper()}")
print(f"Device: {device}")
print(f"Max steps: {train_config['max_steps']}")
print(f"Micro batch size: {train_config['micro_batch_size']}")
# Create tokenizer
print("Loading tokenizer...")
tokenizer = VortexScienceTokenizer(
model_config,
tokenizer_path=args.tokenizer_path,
)
print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
# Create model
print("Creating model...")
model = VortexModel(model_config)
print(f"Model parameters: {model.get_num_params():,}")
# Estimate memory
mem = model.estimate_memory_usage(
train_config["micro_batch_size"],
model_config["max_seq_len"],
)
print("Memory estimate:")
for k, v in mem.items():
print(f" {k}: {v:.2f} GB")
# Load dataset
print("Loading dataset...")
data_dir = Path(args.data_dir)
shard_files = sorted(list(data_dir.glob("train_*.parquet")))
if not shard_files:
print(f"No training shards found in {data_dir}")
print("Please run data pipeline first.")
sys.exit(1)
train_dataset = VortexDataset(
shard_files,
tokenizer,
max_seq_len=model_config["max_seq_len"],
)
print(f"Training dataset size: {len(train_dataset)} samples")
# Create eval dataset (use first few shards)
eval_shard_files = shard_files[:1] # Use first shard for eval
eval_dataset = VortexDataset(
eval_shard_files,
tokenizer,
max_seq_len=model_config["max_seq_len"],
)
# Create trainer
trainer = VortexTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
config=train_config,
eval_dataset=eval_dataset,
)
# Resume from checkpoint if specified
if args.resume_from_checkpoint:
trainer.load_checkpoint(args.resume_from_checkpoint)
# Train
trainer.train()
print("Training complete!")
if __name__ == "__main__":
main()