LaughLM / scripts /train_gpu_test.py
dignity045's picture
Duplicate from Dhiraj45/LaughLM
9639af0
from huggingface_hub import hf_hub_download
from LaughLM.config.loader import load_config
from LaughLM.training.trainer import Trainer
from LaughLM.data.memmap_loader import MemmapDataset
import jax
jax.config.update("jax_default_matmul_precision", "high")
def main():
# ------------------------------------------------------------
# Download dataset shard
# ------------------------------------------------------------
path = hf_hub_download(
repo_id="LaughTaleAI/fineweb-edu-gpt2-tokenized",
filename="train_00000.bin",
repo_type="dataset",
)
# ------------------------------------------------------------
# Load configuration
# ------------------------------------------------------------
config = load_config("configs/gpu_test.yaml")
# ------------------------------------------------------------
# Dataset
#
# IMPORTANT:
# The dataset should produce MICRO batches.
# Gradient accumulation is handled inside the Trainer.
# ------------------------------------------------------------
dataset = MemmapDataset(
paths=path,
seq_len=config.runtime.seq_len,
batch_size=config.runtime.micro_batch_per_device,
)
# ------------------------------------------------------------
# Trainer
# ------------------------------------------------------------
trainer = Trainer(config)
# ------------------------------------------------------------
# Train
# ------------------------------------------------------------
trainer.train(dataset)
if __name__ == "__main__":
main()