Quark-0.5M / train_model.py
LH-Tech-AI's picture
Update train_model.py
107cbdc verified
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
print("[*] Loading libraries...")
import torch
import math
from datasets import load_dataset
from tokenizers import ByteLevelBPETokenizer
from transformers import (
LlamaConfig,
LlamaForCausalLM,
PreTrainedTokenizerFast,
Trainer,
TrainingArguments,
)
from torch.utils.data import Dataset
from tqdm import tqdm
print("[*] Loading tokenizer...")
fast_tokenizer = ByteLevelBPETokenizer(
"./custom_llama_tokenizer-vocab.json",
"./custom_llama_tokenizer-merges.txt"
)
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=fast_tokenizer,
bos_token="<s>",
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
)
class ChunkedDataset(Dataset):
def __init__(self, fast_tokenizer, target_tokens=400_000_000, seq_len=256):
self.seq_len = seq_len
self.chunks = []
dataset = load_dataset(
"HuggingFaceFW/fineweb-edu", "sample-10BT",
split="train", streaming=True
)
buffer = []
collected = 0
pbar = tqdm(total=target_tokens, desc="[*] Gathering tokens", unit="tok")
for example in dataset:
ids = fast_tokenizer.encode(example["text"]).ids
buffer.extend(ids)
while len(buffer) >= seq_len:
chunk = buffer[:seq_len]
buffer = buffer[seq_len:]
self.chunks.append(chunk)
collected += seq_len
pbar.update(seq_len)
if collected >= target_tokens:
pbar.close()
print(f"[+] Collected {collected:,} tokens → {len(self.chunks):,} chunks.")
return
pbar.close()
print(f"[+] Collected {collected:,} tokens → {len(self.chunks):,} chunks.")
def __len__(self):
return len(self.chunks)
def __getitem__(self, idx):
ids = torch.tensor(self.chunks[idx], dtype=torch.long)
return {"input_ids": ids, "labels": ids.clone()}
def collate_fn(batch):
input_ids = torch.stack([b["input_ids"] for b in batch])
labels = torch.stack([b["labels"] for b in batch])
return {"input_ids": input_ids, "labels": labels}
print("[*] Gathering 400 million tokens by streaming dataset...")
dataset = ChunkedDataset(fast_tokenizer, target_tokens=400_000_000, seq_len=256)
print("[*] Setting up model...")
config = LlamaConfig(
vocab_size=500,
hidden_size=96,
intermediate_size=192,
num_hidden_layers=4,
num_attention_heads=4,
max_position_embeddings=256,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
model = LlamaForCausalLM(config)
print(f"[*] Model parameters: {model.num_parameters():,}")
print("[*] Defining training arguments...")
training_args = TrainingArguments(
output_dir="./llama-sub-1m",
num_train_epochs=3,
per_device_train_batch_size=64,
gradient_accumulation_steps=2,
save_steps=500,
save_total_limit=2,
logging_steps=100,
learning_rate=5e-4,
weight_decay=0.01,
warmup_steps=500,
fp16=torch.cuda.is_available(),
push_to_hub=False,
report_to="none",
dataloader_num_workers=2,
dataloader_pin_memory=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=collate_fn,
)
print("[*] Starting training...")
trainer.train()
trainer.save_model("./llama-sub-1m-final")
tokenizer.save_pretrained("./llama-sub-1m-final")
print("[*] Training finished.")