File size: 1,629 Bytes
9639af0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from huggingface_hub import hf_hub_download
from LaughLM.config.loader import load_config
from LaughLM.training.trainer import Trainer
from LaughLM.data.memmap_loader import MemmapDataset
import jax
jax.config.update("jax_default_matmul_precision", "high")
def main():
# ------------------------------------------------------------
# Download dataset shard
# ------------------------------------------------------------
path = hf_hub_download(
repo_id="LaughTaleAI/fineweb-edu-gpt2-tokenized",
filename="train_00000.bin",
repo_type="dataset",
)
# ------------------------------------------------------------
# Load configuration
# ------------------------------------------------------------
config = load_config("configs/gpu_test.yaml")
# ------------------------------------------------------------
# Dataset
#
# IMPORTANT:
# The dataset should produce MICRO batches.
# Gradient accumulation is handled inside the Trainer.
# ------------------------------------------------------------
dataset = MemmapDataset(
paths=path,
seq_len=config.runtime.seq_len,
batch_size=config.runtime.micro_batch_per_device,
)
# ------------------------------------------------------------
# Trainer
# ------------------------------------------------------------
trainer = Trainer(config)
# ------------------------------------------------------------
# Train
# ------------------------------------------------------------
trainer.train(dataset)
if __name__ == "__main__":
main()
|