File size: 3,654 Bytes
8019d79 107cbdc 8019d79 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
print("[*] Loading libraries...")
import torch
import math
from datasets import load_dataset
from tokenizers import ByteLevelBPETokenizer
from transformers import (
LlamaConfig,
LlamaForCausalLM,
PreTrainedTokenizerFast,
Trainer,
TrainingArguments,
)
from torch.utils.data import Dataset
from tqdm import tqdm
print("[*] Loading tokenizer...")
fast_tokenizer = ByteLevelBPETokenizer(
"./custom_llama_tokenizer-vocab.json",
"./custom_llama_tokenizer-merges.txt"
)
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=fast_tokenizer,
bos_token="<s>",
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
)
class ChunkedDataset(Dataset):
def __init__(self, fast_tokenizer, target_tokens=400_000_000, seq_len=256):
self.seq_len = seq_len
self.chunks = []
dataset = load_dataset(
"HuggingFaceFW/fineweb-edu", "sample-10BT",
split="train", streaming=True
)
buffer = []
collected = 0
pbar = tqdm(total=target_tokens, desc="[*] Gathering tokens", unit="tok")
for example in dataset:
ids = fast_tokenizer.encode(example["text"]).ids
buffer.extend(ids)
while len(buffer) >= seq_len:
chunk = buffer[:seq_len]
buffer = buffer[seq_len:]
self.chunks.append(chunk)
collected += seq_len
pbar.update(seq_len)
if collected >= target_tokens:
pbar.close()
print(f"[+] Collected {collected:,} tokens → {len(self.chunks):,} chunks.")
return
pbar.close()
print(f"[+] Collected {collected:,} tokens → {len(self.chunks):,} chunks.")
def __len__(self):
return len(self.chunks)
def __getitem__(self, idx):
ids = torch.tensor(self.chunks[idx], dtype=torch.long)
return {"input_ids": ids, "labels": ids.clone()}
def collate_fn(batch):
input_ids = torch.stack([b["input_ids"] for b in batch])
labels = torch.stack([b["labels"] for b in batch])
return {"input_ids": input_ids, "labels": labels}
print("[*] Gathering 400 million tokens by streaming dataset...")
dataset = ChunkedDataset(fast_tokenizer, target_tokens=400_000_000, seq_len=256)
print("[*] Setting up model...")
config = LlamaConfig(
vocab_size=500,
hidden_size=96,
intermediate_size=192,
num_hidden_layers=4,
num_attention_heads=4,
max_position_embeddings=256,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
model = LlamaForCausalLM(config)
print(f"[*] Model parameters: {model.num_parameters():,}")
print("[*] Defining training arguments...")
training_args = TrainingArguments(
output_dir="./llama-sub-1m",
num_train_epochs=3,
per_device_train_batch_size=64,
gradient_accumulation_steps=2,
save_steps=500,
save_total_limit=2,
logging_steps=100,
learning_rate=5e-4,
weight_decay=0.01,
warmup_steps=500,
fp16=torch.cuda.is_available(),
push_to_hub=False,
report_to="none",
dataloader_num_workers=2,
dataloader_pin_memory=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=collate_fn,
)
print("[*] Starting training...")
trainer.train()
trainer.save_model("./llama-sub-1m-final")
tokenizer.save_pretrained("./llama-sub-1m-final")
print("[*] Training finished.") |