%%writefile train_model.py
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
print("[*] Loading libraries...")
import torch
import math
import numpy as np
from datasets import load_dataset
from tokenizers import ByteLevelBPETokenizer
from transformers import (
LlamaConfig,
LlamaForCausalLM,
PreTrainedTokenizerFast,
Trainer,
TrainingArguments,
)
from torch.utils.data import Dataset
from tqdm import tqdm
print("[*] Loading tokenizer...")
fast_tokenizer = ByteLevelBPETokenizer(
"./custom_llama_tokenizer-vocab.json",
"./custom_llama_tokenizer-merges.txt"
)
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=fast_tokenizer,
bos_token="",
eos_token="",
unk_token="",
pad_token="",
)
TOKEN_BIN = "/kaggle/working/tokens.bin"
TARGET_TOKENS = 1_000_000_000
SEQ_LEN = 256
BATCH_TEXTS = 1000
FLUSH_EVERY = 1_000_000
def build_token_bin(fast_tokenizer, path=TOKEN_BIN, target_tokens=TARGET_TOKENS):
if os.path.exists(path) and os.path.getsize(path) >= target_tokens * 2:
print(f"[=] Reusing existing token file: {path}")
return
print(f"[*] Streaming + tokenizing {target_tokens:,} tokens → {path}")
mm = np.memmap(path, dtype=np.uint16, mode="w+", shape=(target_tokens,))
dataset = load_dataset(
"HuggingFaceFW/fineweb-edu", "sample-10BT",
split="train", streaming=True
)
written = 0
buf = []
texts = []
pbar = tqdm(total=target_tokens, desc="[*] Gathering tokens", unit="tok")
def flush_buf():
nonlocal written, buf
if not buf:
return False
n = min(len(buf), target_tokens - written)
mm[written:written + n] = np.asarray(buf[:n], dtype=np.uint16)
written += n
pbar.update(n)
del buf[:n]
return written >= target_tokens
for example in dataset:
texts.append(example["text"])
if len(texts) >= BATCH_TEXTS:
encs = fast_tokenizer.encode_batch(texts)
texts.clear()
for e in encs:
buf.extend(e.ids)
if len(buf) >= FLUSH_EVERY:
if flush_buf():
break
if written < target_tokens and texts:
encs = fast_tokenizer.encode_batch(texts)
for e in encs:
buf.extend(e.ids)
if written < target_tokens:
flush_buf()
pbar.close()
mm.flush()
del mm
print(f"[+] Wrote {written:,} tokens to {path} "
f"({os.path.getsize(path)/1e6:.1f} MB)")
class MemmapDataset(Dataset):
def __init__(self, path, total_tokens, seq_len=SEQ_LEN):
self.path = path
self.seq_len = seq_len
self.n_chunks = total_tokens // seq_len
self._data = None # lazy open (Multiprocessing-safe)
@property
def data(self):
if self._data is None:
self._data = np.memmap(
self.path, dtype=np.uint16, mode="r",
shape=(self.n_chunks * self.seq_len,)
)
return self._data
def __len__(self):
return self.n_chunks
def __getitem__(self, idx):
s = idx * self.seq_len
arr = np.asarray(self.data[s:s + self.seq_len], dtype=np.int64)
ids = torch.from_numpy(arr)
return {"input_ids": ids, "labels": ids.clone()}
def collate_fn(batch):
input_ids = torch.stack([b["input_ids"] for b in batch])
labels = torch.stack([b["labels"] for b in batch])
return {"input_ids": input_ids, "labels": labels}
print(f"[*] Preparing {TARGET_TOKENS:,} tokens (streaming, memmap-backed)...")
build_token_bin(fast_tokenizer, TOKEN_BIN, TARGET_TOKENS)
dataset = MemmapDataset(TOKEN_BIN, TARGET_TOKENS, seq_len=SEQ_LEN)
print(f"[+] Dataset ready: {len(dataset):,} chunks of {SEQ_LEN} tokens")
print("[*] Setting up model...")
config = LlamaConfig(
vocab_size=500,
hidden_size=96,
intermediate_size=192,
num_hidden_layers=4,
num_attention_heads=4,
max_position_embeddings=256,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
model = LlamaForCausalLM(config)
print(f"[*] Model parameters: {model.num_parameters():,}")
print("[*] Defining training arguments...")
training_args = TrainingArguments(
output_dir="./quark-v2",
num_train_epochs=3,
per_device_train_batch_size=256,
gradient_accumulation_steps=1,
save_steps=500,
save_total_limit=2,
logging_steps=100,
weight_decay=0.01,
fp16=torch.cuda.is_available(),
push_to_hub=False,
report_to="none",
dataloader_num_workers=2,
dataloader_pin_memory=True,
learning_rate=6e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=collate_fn,
)
print("[*] Starting training...")
trainer.train()
trainer.save_model("./quark-v2-final")
tokenizer.save_pretrained("./quark-v2-final")
print("[*] Training finished.")