File size: 1,629 Bytes
9639af0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

from huggingface_hub import hf_hub_download

from LaughLM.config.loader import load_config
from LaughLM.training.trainer import Trainer
from LaughLM.data.memmap_loader import MemmapDataset

import jax
jax.config.update("jax_default_matmul_precision", "high")

def main():

    # ------------------------------------------------------------
    # Download dataset shard
    # ------------------------------------------------------------
    path = hf_hub_download(
        repo_id="LaughTaleAI/fineweb-edu-gpt2-tokenized",
        filename="train_00000.bin",
        repo_type="dataset",
    )

    # ------------------------------------------------------------
    # Load configuration
    # ------------------------------------------------------------
    config = load_config("configs/gpu_test.yaml")

    # ------------------------------------------------------------
    # Dataset
    #
    # IMPORTANT:
    # The dataset should produce MICRO batches.
    # Gradient accumulation is handled inside the Trainer.
    # ------------------------------------------------------------
    dataset = MemmapDataset(
        paths=path,
        seq_len=config.runtime.seq_len,
        batch_size=config.runtime.micro_batch_per_device,
    )

    # ------------------------------------------------------------
    # Trainer
    # ------------------------------------------------------------
    trainer = Trainer(config)

    # ------------------------------------------------------------
    # Train
    # ------------------------------------------------------------
    trainer.train(dataset)


if __name__ == "__main__":
    main()