from huggingface_hub import hf_hub_download from LaughLM.config.loader import load_config from LaughLM.training.trainer import Trainer from LaughLM.data.memmap_loader import MemmapDataset import jax jax.config.update("jax_default_matmul_precision", "high") def main(): # ------------------------------------------------------------ # Download dataset shard # ------------------------------------------------------------ path = hf_hub_download( repo_id="LaughTaleAI/fineweb-edu-gpt2-tokenized", filename="train_00000.bin", repo_type="dataset", ) # ------------------------------------------------------------ # Load configuration # ------------------------------------------------------------ config = load_config("configs/gpu_test.yaml") # ------------------------------------------------------------ # Dataset # # IMPORTANT: # The dataset should produce MICRO batches. # Gradient accumulation is handled inside the Trainer. # ------------------------------------------------------------ dataset = MemmapDataset( paths=path, seq_len=config.runtime.seq_len, batch_size=config.runtime.micro_batch_per_device, ) # ------------------------------------------------------------ # Trainer # ------------------------------------------------------------ trainer = Trainer(config) # ------------------------------------------------------------ # Train # ------------------------------------------------------------ trainer.train(dataset) if __name__ == "__main__": main()