LLM-GO / src /llm_go /scripts /train.py
LesterCerioli's picture
Building first version from Golang LLM
a58ece3
Raw
History Blame Contribute Delete
2.19 kB
"""CLI: run the full training loop."""
import click
import tensorflow as tf
from llm_go.config import ModelConfig, TrainingConfig
from llm_go.data.dataset import GoDataset
from llm_go.training.trainer import Trainer
@click.command()
@click.option("--model-size", default="medium", type=click.Choice(["small","medium","large","xl"]))
@click.option("--data-dir", default="data/processed", show_default=True)
@click.option("--ckpt-dir", default="checkpoints", show_default=True)
@click.option("--log-dir", default="logs", show_default=True)
@click.option("--batch-size", default=32, show_default=True)
@click.option("--max-steps", default=100_000, show_default=True)
@click.option("--lr", default=3e-4, show_default=True)
@click.option("--warmup-steps", default=2000, show_default=True)
@click.option("--grad-accum", default=4, show_default=True)
@click.option("--precision", default="bfloat16", type=click.Choice(["float32","float16","bfloat16"]))
@click.option("--gpus", default=-1, help="-1 = all GPUs")
def main(model_size, data_dir, ckpt_dir, log_dir, batch_size, max_steps,
lr, warmup_steps, grad_accum, precision, gpus):
"""Train GoLLM from scratch."""
mc = {"small": ModelConfig.small, "medium": ModelConfig.medium,
"large": ModelConfig.large, "xl": ModelConfig.xl}[model_size]()
tc = TrainingConfig(
learning_rate=lr,
warmup_steps=warmup_steps,
max_steps=max_steps,
batch_size=batch_size,
gradient_accumulation_steps=grad_accum,
mixed_precision=precision,
checkpoint_dir=ckpt_dir,
log_dir=log_dir,
)
ds = GoDataset(data_dir, seq_len=mc.max_seq_len, batch_size=batch_size)
if gpus == 1:
strategy = tf.distribute.OneDeviceStrategy("/gpu:0")
elif gpus == -1:
strategy = tf.distribute.MirroredStrategy()
else:
devices = [f"/gpu:{i}" for i in range(gpus)]
strategy = tf.distribute.MirroredStrategy(devices=devices)
trainer = Trainer(mc, tc, ds.train(), ds.val(), strategy=strategy)
trainer.train()