| import os |
| os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
| print("[*] Loading libraries...") |
| import torch |
| import math |
| from datasets import load_dataset |
| from tokenizers import ByteLevelBPETokenizer |
| from transformers import ( |
| LlamaConfig, |
| LlamaForCausalLM, |
| PreTrainedTokenizerFast, |
| Trainer, |
| TrainingArguments, |
| ) |
| from torch.utils.data import Dataset |
| from tqdm import tqdm |
|
|
| print("[*] Loading tokenizer...") |
| fast_tokenizer = ByteLevelBPETokenizer( |
| "./custom_llama_tokenizer-vocab.json", |
| "./custom_llama_tokenizer-merges.txt" |
| ) |
| tokenizer = PreTrainedTokenizerFast( |
| tokenizer_object=fast_tokenizer, |
| bos_token="<s>", |
| eos_token="</s>", |
| unk_token="<unk>", |
| pad_token="<pad>", |
| ) |
|
|
| class ChunkedDataset(Dataset): |
| def __init__(self, fast_tokenizer, target_tokens=400_000_000, seq_len=256): |
| self.seq_len = seq_len |
| self.chunks = [] |
|
|
| dataset = load_dataset( |
| "HuggingFaceFW/fineweb-edu", "sample-10BT", |
| split="train", streaming=True |
| ) |
|
|
| buffer = [] |
| collected = 0 |
| pbar = tqdm(total=target_tokens, desc="[*] Gathering tokens", unit="tok") |
|
|
| for example in dataset: |
| ids = fast_tokenizer.encode(example["text"]).ids |
| buffer.extend(ids) |
|
|
| while len(buffer) >= seq_len: |
| chunk = buffer[:seq_len] |
| buffer = buffer[seq_len:] |
| self.chunks.append(chunk) |
| collected += seq_len |
| pbar.update(seq_len) |
|
|
| if collected >= target_tokens: |
| pbar.close() |
| print(f"[+] Collected {collected:,} tokens → {len(self.chunks):,} chunks.") |
| return |
|
|
| pbar.close() |
| print(f"[+] Collected {collected:,} tokens → {len(self.chunks):,} chunks.") |
|
|
| def __len__(self): |
| return len(self.chunks) |
|
|
| def __getitem__(self, idx): |
| ids = torch.tensor(self.chunks[idx], dtype=torch.long) |
| return {"input_ids": ids, "labels": ids.clone()} |
|
|
|
|
| def collate_fn(batch): |
| input_ids = torch.stack([b["input_ids"] for b in batch]) |
| labels = torch.stack([b["labels"] for b in batch]) |
| return {"input_ids": input_ids, "labels": labels} |
|
|
|
|
| print("[*] Gathering 400 million tokens by streaming dataset...") |
| dataset = ChunkedDataset(fast_tokenizer, target_tokens=400_000_000, seq_len=256) |
|
|
| print("[*] Setting up model...") |
| config = LlamaConfig( |
| vocab_size=500, |
| hidden_size=96, |
| intermediate_size=192, |
| num_hidden_layers=4, |
| num_attention_heads=4, |
| max_position_embeddings=256, |
| pad_token_id=tokenizer.pad_token_id, |
| bos_token_id=tokenizer.bos_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
| model = LlamaForCausalLM(config) |
| print(f"[*] Model parameters: {model.num_parameters():,}") |
|
|
| print("[*] Defining training arguments...") |
| training_args = TrainingArguments( |
| output_dir="./llama-sub-1m", |
| num_train_epochs=3, |
| per_device_train_batch_size=64, |
| gradient_accumulation_steps=2, |
| save_steps=500, |
| save_total_limit=2, |
| logging_steps=100, |
| learning_rate=5e-4, |
| weight_decay=0.01, |
| warmup_steps=500, |
| fp16=torch.cuda.is_available(), |
| push_to_hub=False, |
| report_to="none", |
| dataloader_num_workers=2, |
| dataloader_pin_memory=True, |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=dataset, |
| data_collator=collate_fn, |
| ) |
|
|
| print("[*] Starting training...") |
| trainer.train() |
| trainer.save_model("./llama-sub-1m-final") |
| tokenizer.save_pretrained("./llama-sub-1m-final") |
| print("[*] Training finished.") |