|
|
|
|
|
""" |
|
|
OmniCoreX Training CLI Script |
|
|
|
|
|
Launches the training process for OmniCoreX with command-line options |
|
|
to configure parameters. Supports distributed and mixed precision training. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
from utils import set_seed, load_config_file, setup_logging |
|
|
from model import OmniCoreXModel |
|
|
from data_loader import create_omncorex_dataloader |
|
|
from trainer import Trainer |
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Train OmniCoreX Model") |
|
|
parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file") |
|
|
parser.add_argument("--epochs", type=int, help="Number of epochs to train") |
|
|
parser.add_argument("--batch_size", type=int, help="Batch size override") |
|
|
parser.add_argument("--seed", type=int, help="Random seed override") |
|
|
parser.add_argument("--log_file", type=str, help="Log file path") |
|
|
parser.add_argument("--no_mixed_precision", action="store_true", help="Disable mixed precision training") |
|
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
config = load_config_file(args.config) |
|
|
if args.seed: |
|
|
set_seed(args.seed) |
|
|
else: |
|
|
set_seed(config.get("seed", 42)) |
|
|
|
|
|
logger = setup_logging(log_file=args.log_file) |
|
|
logger.info("Starting OmniCoreX training") |
|
|
|
|
|
data_cfg = config["data"] |
|
|
training_cfg = config.get("training", {}) |
|
|
|
|
|
batch_size = args.batch_size or data_cfg.get("batch_size", 16) |
|
|
mixed_precision = not args.no_mixed_precision and training_cfg.get("mixed_precision", True) |
|
|
|
|
|
train_loader = create_omncorex_dataloader( |
|
|
metadata_path=data_cfg["metadata_path"], |
|
|
modalities=data_cfg["modalities"], |
|
|
tokenizer=None, |
|
|
batch_size=batch_size, |
|
|
shuffle=data_cfg.get("shuffle", True), |
|
|
num_workers=data_cfg.get("num_workers", 4), |
|
|
augmentation=data_cfg.get("augmentation", False) |
|
|
) |
|
|
|
|
|
model_cfg = config["model"] |
|
|
model = OmniCoreXModel( |
|
|
stream_configs=model_cfg["streams"], |
|
|
embed_dim=model_cfg.get("architecture", {}).get("embed_dim", 768), |
|
|
num_layers=model_cfg.get("architecture", {}).get("num_layers", 24), |
|
|
num_heads=model_cfg.get("architecture", {}).get("num_heads", 12), |
|
|
dropout=model_cfg.get("architecture", {}).get("dropout", 0.1) |
|
|
) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
train_loader=train_loader, |
|
|
valid_loader=None, |
|
|
save_dir=training_cfg.get("save_dir", "./checkpoints"), |
|
|
lr=training_cfg.get("learning_rate", 5e-5), |
|
|
weight_decay=training_cfg.get("weight_decay", 0.01), |
|
|
max_grad_norm=training_cfg.get("max_grad_norm", 1.0), |
|
|
accumulation_steps=training_cfg.get("accumulation_steps", 1), |
|
|
total_steps=training_cfg.get("total_steps", 100000), |
|
|
warmup_steps=training_cfg.get("warmup_steps", 1000), |
|
|
device=None, |
|
|
mixed_precision=mixed_precision |
|
|
) |
|
|
|
|
|
epochs = args.epochs or training_cfg.get("epochs", 1) |
|
|
trainer.fit(epochs=epochs) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|