%%writefile train_model.py import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" os.environ["CUDA_VISIBLE_DEVICES"] = "0" print("[*] Loading libraries...") import torch import math import numpy as np from datasets import load_dataset from tokenizers import ByteLevelBPETokenizer from transformers import ( LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerFast, Trainer, TrainingArguments, ) from torch.utils.data import Dataset from tqdm import tqdm print("[*] Loading tokenizer...") fast_tokenizer = ByteLevelBPETokenizer( "./custom_llama_tokenizer-vocab.json", "./custom_llama_tokenizer-merges.txt" ) tokenizer = PreTrainedTokenizerFast( tokenizer_object=fast_tokenizer, bos_token="", eos_token="", unk_token="", pad_token="", ) TOKEN_BIN = "/kaggle/working/tokens.bin" TARGET_TOKENS = 1_000_000_000 SEQ_LEN = 256 BATCH_TEXTS = 1000 FLUSH_EVERY = 1_000_000 def build_token_bin(fast_tokenizer, path=TOKEN_BIN, target_tokens=TARGET_TOKENS): if os.path.exists(path) and os.path.getsize(path) >= target_tokens * 2: print(f"[=] Reusing existing token file: {path}") return print(f"[*] Streaming + tokenizing {target_tokens:,} tokens → {path}") mm = np.memmap(path, dtype=np.uint16, mode="w+", shape=(target_tokens,)) dataset = load_dataset( "HuggingFaceFW/fineweb-edu", "sample-10BT", split="train", streaming=True ) written = 0 buf = [] texts = [] pbar = tqdm(total=target_tokens, desc="[*] Gathering tokens", unit="tok") def flush_buf(): nonlocal written, buf if not buf: return False n = min(len(buf), target_tokens - written) mm[written:written + n] = np.asarray(buf[:n], dtype=np.uint16) written += n pbar.update(n) del buf[:n] return written >= target_tokens for example in dataset: texts.append(example["text"]) if len(texts) >= BATCH_TEXTS: encs = fast_tokenizer.encode_batch(texts) texts.clear() for e in encs: buf.extend(e.ids) if len(buf) >= FLUSH_EVERY: if flush_buf(): break if written < target_tokens and texts: encs = fast_tokenizer.encode_batch(texts) for e in encs: buf.extend(e.ids) if written < target_tokens: flush_buf() pbar.close() mm.flush() del mm print(f"[+] Wrote {written:,} tokens to {path} " f"({os.path.getsize(path)/1e6:.1f} MB)") class MemmapDataset(Dataset): def __init__(self, path, total_tokens, seq_len=SEQ_LEN): self.path = path self.seq_len = seq_len self.n_chunks = total_tokens // seq_len self._data = None # lazy open (Multiprocessing-safe) @property def data(self): if self._data is None: self._data = np.memmap( self.path, dtype=np.uint16, mode="r", shape=(self.n_chunks * self.seq_len,) ) return self._data def __len__(self): return self.n_chunks def __getitem__(self, idx): s = idx * self.seq_len arr = np.asarray(self.data[s:s + self.seq_len], dtype=np.int64) ids = torch.from_numpy(arr) return {"input_ids": ids, "labels": ids.clone()} def collate_fn(batch): input_ids = torch.stack([b["input_ids"] for b in batch]) labels = torch.stack([b["labels"] for b in batch]) return {"input_ids": input_ids, "labels": labels} print(f"[*] Preparing {TARGET_TOKENS:,} tokens (streaming, memmap-backed)...") build_token_bin(fast_tokenizer, TOKEN_BIN, TARGET_TOKENS) dataset = MemmapDataset(TOKEN_BIN, TARGET_TOKENS, seq_len=SEQ_LEN) print(f"[+] Dataset ready: {len(dataset):,} chunks of {SEQ_LEN} tokens") print("[*] Setting up model...") config = LlamaConfig( vocab_size=500, hidden_size=96, intermediate_size=192, num_hidden_layers=4, num_attention_heads=4, max_position_embeddings=256, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, ) model = LlamaForCausalLM(config) print(f"[*] Model parameters: {model.num_parameters():,}") print("[*] Defining training arguments...") training_args = TrainingArguments( output_dir="./quark-v2", num_train_epochs=3, per_device_train_batch_size=256, gradient_accumulation_steps=1, save_steps=500, save_total_limit=2, logging_steps=100, weight_decay=0.01, fp16=torch.cuda.is_available(), push_to_hub=False, report_to="none", dataloader_num_workers=2, dataloader_pin_memory=True, learning_rate=6e-4, lr_scheduler_type="cosine", warmup_ratio=0.05, ) trainer = Trainer( model=model, args=training_args, train_dataset=dataset, data_collator=collate_fn, ) print("[*] Starting training...") trainer.train() trainer.save_model("./quark-v2-final") tokenizer.save_pretrained("./quark-v2-final") print("[*] Training finished.")