import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" os.environ["CUDA_VISIBLE_DEVICES"] = "0" print("[*] Loading libraries...") import torch import math from datasets import load_dataset from tokenizers import ByteLevelBPETokenizer from transformers import ( LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerFast, Trainer, TrainingArguments, ) from torch.utils.data import Dataset from tqdm import tqdm print("[*] Loading tokenizer...") fast_tokenizer = ByteLevelBPETokenizer( "./custom_llama_tokenizer-vocab.json", "./custom_llama_tokenizer-merges.txt" ) tokenizer = PreTrainedTokenizerFast( tokenizer_object=fast_tokenizer, bos_token="", eos_token="", unk_token="", pad_token="", ) class ChunkedDataset(Dataset): def __init__(self, fast_tokenizer, target_tokens=400_000_000, seq_len=256): self.seq_len = seq_len self.chunks = [] dataset = load_dataset( "HuggingFaceFW/fineweb-edu", "sample-10BT", split="train", streaming=True ) buffer = [] collected = 0 pbar = tqdm(total=target_tokens, desc="[*] Gathering tokens", unit="tok") for example in dataset: ids = fast_tokenizer.encode(example["text"]).ids buffer.extend(ids) while len(buffer) >= seq_len: chunk = buffer[:seq_len] buffer = buffer[seq_len:] self.chunks.append(chunk) collected += seq_len pbar.update(seq_len) if collected >= target_tokens: pbar.close() print(f"[+] Collected {collected:,} tokens → {len(self.chunks):,} chunks.") return pbar.close() print(f"[+] Collected {collected:,} tokens → {len(self.chunks):,} chunks.") def __len__(self): return len(self.chunks) def __getitem__(self, idx): ids = torch.tensor(self.chunks[idx], dtype=torch.long) return {"input_ids": ids, "labels": ids.clone()} def collate_fn(batch): input_ids = torch.stack([b["input_ids"] for b in batch]) labels = torch.stack([b["labels"] for b in batch]) return {"input_ids": input_ids, "labels": labels} print("[*] Gathering 400 million tokens by streaming dataset...") dataset = ChunkedDataset(fast_tokenizer, target_tokens=400_000_000, seq_len=256) print("[*] Setting up model...") config = LlamaConfig( vocab_size=500, hidden_size=96, intermediate_size=192, num_hidden_layers=4, num_attention_heads=4, max_position_embeddings=256, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, ) model = LlamaForCausalLM(config) print(f"[*] Model parameters: {model.num_parameters():,}") print("[*] Defining training arguments...") training_args = TrainingArguments( output_dir="./llama-sub-1m", num_train_epochs=3, per_device_train_batch_size=64, gradient_accumulation_steps=2, save_steps=500, save_total_limit=2, logging_steps=100, learning_rate=5e-4, weight_decay=0.01, warmup_steps=500, fp16=torch.cuda.is_available(), push_to_hub=False, report_to="none", dataloader_num_workers=2, dataloader_pin_memory=True, ) trainer = Trainer( model=model, args=training_args, train_dataset=dataset, data_collator=collate_fn, ) print("[*] Starting training...") trainer.train() trainer.save_model("./llama-sub-1m-final") tokenizer.save_pretrained("./llama-sub-1m-final") print("[*] Training finished.")