OmniCoreX / scripts /train.py
Kosasih's picture
Create scripts/train.py
e3c7370 verified
#!/usr/bin/env python3
"""
OmniCoreX Training CLI Script
Launches the training process for OmniCoreX with command-line options
to configure parameters. Supports distributed and mixed precision training.
"""
import argparse
import os
from utils import set_seed, load_config_file, setup_logging
from model import OmniCoreXModel
from data_loader import create_omncorex_dataloader
from trainer import Trainer
def parse_args():
parser = argparse.ArgumentParser(description="Train OmniCoreX Model")
parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file")
parser.add_argument("--epochs", type=int, help="Number of epochs to train")
parser.add_argument("--batch_size", type=int, help="Batch size override")
parser.add_argument("--seed", type=int, help="Random seed override")
parser.add_argument("--log_file", type=str, help="Log file path")
parser.add_argument("--no_mixed_precision", action="store_true", help="Disable mixed precision training")
return parser.parse_args()
def main():
args = parse_args()
config = load_config_file(args.config)
if args.seed:
set_seed(args.seed)
else:
set_seed(config.get("seed", 42))
logger = setup_logging(log_file=args.log_file)
logger.info("Starting OmniCoreX training")
data_cfg = config["data"]
training_cfg = config.get("training", {})
batch_size = args.batch_size or data_cfg.get("batch_size", 16)
mixed_precision = not args.no_mixed_precision and training_cfg.get("mixed_precision", True)
train_loader = create_omncorex_dataloader(
metadata_path=data_cfg["metadata_path"],
modalities=data_cfg["modalities"],
tokenizer=None, # Assume custom tokenizer later
batch_size=batch_size,
shuffle=data_cfg.get("shuffle", True),
num_workers=data_cfg.get("num_workers", 4),
augmentation=data_cfg.get("augmentation", False)
)
model_cfg = config["model"]
model = OmniCoreXModel(
stream_configs=model_cfg["streams"],
embed_dim=model_cfg.get("architecture", {}).get("embed_dim", 768),
num_layers=model_cfg.get("architecture", {}).get("num_layers", 24),
num_heads=model_cfg.get("architecture", {}).get("num_heads", 12),
dropout=model_cfg.get("architecture", {}).get("dropout", 0.1)
)
trainer = Trainer(
model=model,
train_loader=train_loader,
valid_loader=None,
save_dir=training_cfg.get("save_dir", "./checkpoints"),
lr=training_cfg.get("learning_rate", 5e-5),
weight_decay=training_cfg.get("weight_decay", 0.01),
max_grad_norm=training_cfg.get("max_grad_norm", 1.0),
accumulation_steps=training_cfg.get("accumulation_steps", 1),
total_steps=training_cfg.get("total_steps", 100000),
warmup_steps=training_cfg.get("warmup_steps", 1000),
device=None,
mixed_precision=mixed_precision
)
epochs = args.epochs or training_cfg.get("epochs", 1)
trainer.fit(epochs=epochs)
if __name__ == "__main__":
main()