import os, sys, json, yaml, math, time import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, IterableDataset from torch.optim import AdamW from torch.optim.lr_scheduler import LambdaLR sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from scripts.model_tiny import TinyModel, restore_from_v2 def _format_item(item, tokenizer): if "text" in item: return item["text"] if "messages" in item: return tokenizer.apply_chat_template(item["messages"], tokenize=False) system = item.get("system", "") inp = item.get("input", "") instruction = item.get("instruction", "") output = item.get("output", "") user_msg = instruction + ("\n" + inp if inp else "") return f"<|system|>\n{system}\n<|user|>\n{user_msg}\n<|assistant|>\n{output}" def _encode(text, tokenizer, max_seq_len): enc = tokenizer.encode(text) if len(enc) > max_seq_len: enc = enc[:max_seq_len] return torch.tensor(enc, dtype=torch.long) class StreamingSFTDataset(IterableDataset): def __init__(self, hf_repo, tokenizer, max_seq_len=2048, split="train", hf_name=None): self.hf_repo = hf_repo self.hf_name = hf_name self.split = split self.tokenizer = tokenizer self.max_seq_len = max_seq_len def __iter__(self): from datasets import load_dataset ds = load_dataset(self.hf_repo, name=self.hf_name, split=self.split, streaming=True) for item in ds: yield _encode(_format_item(item, self.tokenizer), self.tokenizer, self.max_seq_len) class ListDataset(Dataset): def __init__(self, samples): self.samples = samples def __len__(self): return len(self.samples) def __getitem__(self, idx): return self.samples[idx] def make_eval_dataset(hf_repo, tokenizer, max_seq_len, num_eval=500, hf_name=None): from datasets import load_dataset ds = load_dataset(hf_repo, name=hf_name, split="train", streaming=True) samples = [] for item in ds: if len(samples) >= num_eval: break samples.append(_encode(_format_item(item, tokenizer), tokenizer, max_seq_len)) return ListDataset(samples) def collate_fn(batch): max_len = max(len(x) for x in batch) padded = torch.full((len(batch), max_len), fill_value=0, dtype=torch.long) labels = torch.full((len(batch), max_len), fill_value=-100, dtype=torch.long) for i, seq in enumerate(batch): l = len(seq) padded[i, :l] = seq labels[i, :l] = seq return padded, labels @torch.no_grad() def compute_metrics(logits, labels, ignore_index=-100): B, T, V = logits.shape shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy(shift_logits.view(-1, V), shift_labels.view(-1), ignore_index=ignore_index, reduction='mean') ppl = torch.exp(loss.double()) logits_stable = shift_logits - shift_logits.logsumexp(dim=-1, keepdim=True) probs = torch.exp(logits_stable) entropy = -(probs * logits_stable).sum(-1).mean() preds = shift_logits.argmax(dim=-1) mask = shift_labels != ignore_index acc = (preds == shift_labels)[mask].float().mean() if mask.any() else torch.tensor(0.0) return loss.item(), ppl.item(), entropy.item(), acc.item() @torch.no_grad() def evaluate(model, loader, device): model.eval() total_loss = 0 total_ppl = 0 total_entropy = 0 total_acc = 0 total_tokens = 0 n = 0 for x, y in loader: x, y = x.to(device), y.to(device) logits, loss = model(x, labels=y) loss_val, ppl_val, ent_val, acc_val = compute_metrics(logits, y) batch_tokens = (y != -100).sum().item() total_loss += loss_val * x.size(0) total_ppl += ppl_val * x.size(0) total_entropy += ent_val * x.size(0) total_acc += acc_val * x.size(0) total_tokens += batch_tokens n += x.size(0) return { "loss": total_loss / max(n, 1), "ppl": total_ppl / max(n, 1), "entropy": total_entropy / max(n, 1), "acc": total_acc / max(n, 1), "tokens": total_tokens, } def upload_to_hf(local_dir, repo_id, token, files): from huggingface_hub import HfApi api = HfApi() for f in files: path = os.path.join(local_dir, f) if os.path.exists(path): api.upload_file(path_or_fileobj=path, path_in_repo=f, repo_id=repo_id, token=token) print(f" Uploaded {f}") def train(): config_path = os.path.join(os.path.dirname(__file__), "..", "config", "train_tiny.yaml") with open(config_path) as f: cfg = yaml.safe_load(f) t_cfg = cfg["training"] d_cfg = cfg["data"] device = "cuda" if torch.cuda.is_available() and not t_cfg.get("use_cpu", False) else "cpu" print(f"Device: {device}") # Kaggle secrets for HF_TOKEN if not os.environ.get("HF_TOKEN"): try: from kaggle_secrets import UserSecretsClient secret = UserSecretsClient().get_secret("HF_TOKEN") if secret: os.environ["HF_TOKEN"] = secret except Exception: pass # Use actual vocab size from tokenizer from tokenizers import Tokenizer as Tk from transformers import PreTrainedTokenizerFast tok_obj = Tk.from_file(os.path.join(os.path.dirname(__file__), "..", "tokenizer", "tokenizer.json")) tok = PreTrainedTokenizerFast(tokenizer_object=tok_obj) tok.add_special_tokens({"pad_token": "", "bos_token": "", "eos_token": "", "unk_token": ""}) tok.chat_template = "{% for message in messages %}{% if message['role'] == 'system' %}<|system|>\n{{ message['content'] }}\n{% elif message['role'] == 'user' %}<|user|>\n{{ message['content'] }}\n{% elif message['role'] == 'assistant' %}<|assistant|>\n{{ message['content'] }}\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>\n{% endif %}" vocab_size = tok.vocab_size max_seq = d_cfg.get("max_seq_length", 2048) # ── QLoRA / CFT config ────────────────────────────────────────────────── qlora_cfg = cfg.get("qlora", {}) cft_cfg = cfg.get("cft", {}) qlora_enabled = qlora_cfg.get("enabled", False) cft_enabled = cft_cfg.get("enabled", False) # ── Create model ──────────────────────────────────────────────────────── if cft_enabled: # CFT: load a prior checkpoint and apply QLoRA resume_ckpt = cft_cfg.get("resume_checkpoint", "") if resume_ckpt and not os.path.exists(resume_ckpt): # Download if missing print(f"[CFT] Downloading {resume_ckpt} from HF ...") from huggingface_hub import hf_hub_download ckpt = hf_hub_download("samcheng0/lumia-tiny", resume_ckpt, repo_type="model") import shutil shutil.copy(ckpt, resume_ckpt) print(f"[CFT] Downloaded to {resume_ckpt}") model = TinyModel(vocab_size=vocab_size, hidden=128, code_dim=96, num_layers=6, num_heads=8, num_kv_heads=4, max_seq_len=max_seq, tie_weights=True) model.reset_weights() if resume_ckpt and os.path.exists(resume_ckpt): raw = torch.load(resume_ckpt, map_location="cpu", weights_only=True) sd = raw if "model" not in raw else raw["model"] missing, unexpected = model.load_state_dict(sd, strict=False) if missing: print(f" Missing keys: {len(missing)} (e.g. {missing[:3]})") if unexpected: print(f" Unexpected keys: {len(unexpected)} (e.g. {unexpected[:3]})") print(f"[CFT] Loaded checkpoint from {resume_ckpt}") if cft_cfg.get("reset_embeddings", False): nn.init.normal_(model.token_embed.weight, std=0.02) print(f"[CFT] Reset token embeddings (incompatible vocab)") if model.lm_head.weight is not model.token_embed.weight: nn.init.normal_(model.lm_head.weight, std=0.02) else: # Standard: download V2 checkpoint and restore if not os.path.exists("checkpoint.pt"): print("[Restore] Downloading checkpoint.pt from samcheng0/lumia-tiny ...") from huggingface_hub import hf_hub_download ckpt = hf_hub_download("samcheng0/lumia-tiny", "checkpoint.pt", repo_type="model") import shutil shutil.copy(ckpt, "checkpoint.pt") print(f"[Restore] Downloaded to checkpoint.pt") model = restore_from_v2("checkpoint.pt") model = model.to(device) # ── Apply QLoRA ───────────────────────────────────────────────────────── if qlora_enabled: from scripts.model_tiny import apply_qlora model = apply_qlora(model, r=qlora_cfg.get("r", 8), alpha=qlora_cfg.get("alpha", 16), dropout=qlora_cfg.get("dropout", 0.0)) # Unfreeze embeddings if they were reset (new vocab needs training) if cft_cfg.get("reset_embeddings", False): for name, param in model.named_parameters(): if "token_embed" in name or "lm_head" in name: param.requires_grad = True # ── torch.compile ────────────────────────────────────────────────────── if t_cfg.get("compile", False): try: model = torch.compile(model, dynamic=True) print("[Compile] torch.compile enabled (dynamic=True)") except Exception as e: print(f"[Compile] Failed: {e} — continuing without compile") hf_repo = d_cfg.get("hf_repo") is_streaming = bool(hf_repo) if is_streaming: train_ds = StreamingSFTDataset(hf_repo, tok, max_seq, split=d_cfg.get("hf_split", "train"), hf_name=d_cfg.get("hf_repo_name")) eval_ds = make_eval_dataset(hf_repo, tok, max_seq, num_eval=d_cfg.get("hf_num_eval", 500), hf_name=d_cfg.get("hf_repo_name")) print(f"Train: streaming from {hf_repo} Eval: {len(eval_ds.samples)} samples") else: path = d_cfg["train_file"] with open(path) as f: raw = [json.loads(line) for line in f] split = int(len(raw) * (1 - d_cfg.get("eval_split_ratio", 0.1))) train_raw = raw[:split] eval_raw = raw[split:] train_ds = ListDataset([_encode(_format_item(x, tok), tok, max_seq) for x in train_raw]) eval_ds = ListDataset([_encode(_format_item(x, tok), tok, max_seq) for x in eval_raw]) print(f"Train: {len(train_ds.samples)} Eval: {len(eval_ds.samples)}") bs = t_cfg.get("per_device_train_batch_size", 8) eval_bs = t_cfg.get("per_device_eval_batch_size", 8) ga_steps = t_cfg.get("gradient_accumulation_steps", 4) lr = t_cfg.get("learning_rate", 3e-4) epochs = t_cfg.get("num_train_epochs", 1) max_grad_norm = t_cfg.get("max_grad_norm", 1.0) log_steps = t_cfg.get("logging_steps", 5) save_steps = t_cfg.get("save_steps", 200) output_dir = t_cfg.get("output_dir", "outputs/tiny-sft") train_loader = DataLoader(train_ds, batch_size=bs, shuffle=not is_streaming, collate_fn=collate_fn, num_workers=0) eval_loader = DataLoader(eval_ds, batch_size=eval_bs, shuffle=False, collate_fn=collate_fn, num_workers=0) # HF upload config hf_repo_id = t_cfg.get("hf_repo_id", "") hf_token = os.environ.get("HF_TOKEN", "") n_params = sum(p.numel() for p in model.parameters()) step = 0 opt_step = 0 start_epoch = 1 best_loss = float('inf') best_acc = 0.0 cumul_tokens = 0 global_step_offset = 0 ckpt_path = os.path.join(output_dir, "checkpoint.pt") # ── Optimiser ────────────────────────────────────────────────────────── decay_params = [] no_decay_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if "norm" in name or "ln_" in name or "lora_" in name: no_decay_params.append(param) else: decay_params.append(param) if not decay_params and not no_decay_params: print("[WARNING] No trainable parameters found!") optimizer = AdamW([ {"params": decay_params, "weight_decay": t_cfg.get("weight_decay", 0.1)}, {"params": no_decay_params, "weight_decay": 0.0}, ], lr=lr, betas=(0.9, 0.95)) warmup_ratio = t_cfg.get("warmup_ratio", 0.1) max_steps_cfg = t_cfg.get("max_steps", -1) if is_streaming and max_steps_cfg > 0: total_model_steps = max_steps_cfg elif not is_streaming: total_model_steps = len(train_ds) // bs * epochs else: total_model_steps = 50000 total_opt_steps = total_model_steps // ga_steps warmup_opt_steps = int(total_opt_steps * warmup_ratio) def _lr_lambda(current_step): if current_step < warmup_opt_steps: return float(current_step) / max(1, warmup_opt_steps) progress = float(current_step - warmup_opt_steps) / max(1, total_opt_steps - warmup_opt_steps) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) scheduler = LambdaLR(optimizer, lr_lambda=_lr_lambda) if os.path.exists(ckpt_path): ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) model.load_state_dict(ckpt["model"]) optimizer.load_state_dict(ckpt["optimizer"]) scheduler.load_state_dict(ckpt["scheduler"]) start_epoch = ckpt["epoch"] + 1 global_step_offset = ckpt["step"] opt_step = ckpt.get("opt_step", 0) best_loss = ckpt.get("best_loss", float('inf')) best_acc = ckpt.get("best_acc", 0.0) cumul_tokens = ckpt.get("cumul_tokens", 0) print(f"Resumed: epoch={ckpt['epoch']}, step={ckpt['step']}, best_loss={best_loss:.4f}, best_acc={best_acc:.3f}, total_tok={cumul_tokens}") os.makedirs(output_dir, exist_ok=True) t_start = time.time() t_epoch = time.time() eff_bs = bs * ga_steps print() print("=" * 72) print(f" PCT-V3 Training") print(f" Model: {n_params:,} params | Device: {device.upper()}") print(f" Vocab: {vocab_size} | Depth: 6 | Hidden: 128 | Code: 96 | Heads: 8/4 | RPW+GPP+VCR") print(f" Epochs: {epochs} | Batch: {bs} | GA: {ga_steps} | Eff BS: {eff_bs} | LR: {lr}") print(f" Warmup: {warmup_opt_steps} opt steps | Total: {total_opt_steps} opt steps") print("=" * 72) print(f" {'Time':>8} {'Ep':>2} {'Step':>8} {'Loss':>9} {'PPL':>9} {'Entropy':>9} {'Acc':>6} {'GradNorm':>9} {'LR':>10} {'tok/s':>7} {'smp/s':>7}") print(f" {'-'*8:>8} {'-'*2:>2} {'-'*8:>8} {'-'*9:>9} {'-'*9:>9} {'-'*9:>9} {'-'*6:>6} {'-'*9:>9} {'-'*10:>10} {'-'*7:>7} {'-'*7:>7}") print() for epoch in range(start_epoch, epochs + 1): model.train() epoch_loss = 0 epoch_ppl = 0 epoch_entropy = 0 epoch_acc = 0 epoch_tokens = 0 epoch_steps = 0 optimizer.zero_grad() for x, y in train_loader: step += 1 epoch_steps += 1 x, y = x.to(device), y.to(device) n_tokens = y.numel() logits, loss = model(x, labels=y) loss_val = loss.item() loss = loss / ga_steps loss.backward() if step % ga_steps == 0: grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm).item() optimizer.step() scheduler.step() optimizer.zero_grad() opt_step += 1 else: grad_norm = 0.0 if max_steps_cfg > 0 and step >= max_steps_cfg: break epoch_loss += loss_val epoch_tokens += n_tokens cumul_tokens += n_tokens if step % log_steps == 0: elapsed = time.time() - t_start tok_per_sec = epoch_tokens / max(time.time() - t_epoch, 1) smp_per_sec = epoch_steps * bs / max(time.time() - t_epoch, 1) cur_lr = optimizer.param_groups[0]['lr'] ts = time.strftime("%H:%M:%S") pct = min(100, step / max_steps_cfg * 100) if max_steps_cfg > 0 else 0 with torch.no_grad(): met = compute_metrics(logits, y) loss_s, ppl_s, ent_s, acc_s = met print(f" {ts:>8} {epoch:>2} {step:>8} {loss_s:>9.4f} {ppl_s:>9.2f} {ent_s:>9.4f} {acc_s:>6.3f} {grad_norm:>9.2e} {cur_lr:>10.2e} {tok_per_sec:>7.0f} {smp_per_sec:>7.1f} cum_tok={cumul_tokens}") sys.stdout.flush() if step % save_steps == 0: global_step = global_step_offset + step eval_metrics = evaluate(model, eval_loader, device) elapsed = time.time() - t_start pct = min(100, step / max_steps_cfg * 100) if max_steps_cfg > 0 else 0 print(f" {'':>8} {'':>2} {'':>8} {'-'*9} {'-'*9} {'-'*9} {'-'*6} {'-'*9} {'-'*10} {'-'*7} {'-'*7}") print(f" -- Eval: loss={eval_metrics['loss']:.4f} ppl={eval_metrics['ppl']:.2f} ent={eval_metrics['entropy']:.4f} acc={eval_metrics['acc']:.3f} tokens={eval_metrics['tokens']} | best={best_loss:.4f} | step={global_step} [{pct:.1f}%] [{elapsed:.0f}s]") sys.stdout.flush() torch.save({ "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "step": global_step, "opt_step": opt_step, "epoch": epoch, "best_loss": best_loss, "best_acc": best_acc, "cumul_tokens": cumul_tokens, "eval_metrics": eval_metrics, }, ckpt_path) print(f" + Checkpoint saved (step={global_step})") sys.stdout.flush() if eval_metrics['loss'] < best_loss: best_loss = eval_metrics['loss'] best_acc = eval_metrics['acc'] torch.save(model.state_dict(), os.path.join(output_dir, "best.pt")) print(f" + New best! loss={best_loss:.4f} ppl={eval_metrics['ppl']:.2f} acc={best_acc:.3f}") if hf_repo_id and hf_token: upload_to_hf(output_dir, hf_repo_id, hf_token, ["best.pt"]) print(f" + Uploaded best.pt to {hf_repo_id}") elif eval_metrics['acc'] > best_acc: best_acc = eval_metrics['acc'] sys.stdout.flush() if max_steps_cfg > 0 and step >= max_steps_cfg: break t_epoch_end = time.time() avg_loss = epoch_loss / max(epoch_steps, 1) eval_metrics = evaluate(model, eval_loader, device) epoch_time = t_epoch_end - t_epoch t_epoch = t_epoch_end print(f" {'':>8} {'':>2} {'':>8} {'-'*9} {'-'*9} {'-'*9} {'-'*6} {'-'*9} {'-'*10} {'-'*7} {'-'*7}") print(f" -- Epoch {epoch}/{epochs} done | loss={avg_loss:.4f} | eval={eval_metrics['loss']:.4f} | ppl={eval_metrics['ppl']:.2f} | ent={eval_metrics['entropy']:.4f} | acc={eval_metrics['acc']:.3f} | tok={epoch_tokens} | cum_tok={cumul_tokens} | time={epoch_time:.0f}s --") sys.stdout.flush() torch.save(model.state_dict(), os.path.join(output_dir, f"epoch_{epoch}.pt")) final = os.path.join(output_dir, "final") os.makedirs(final, exist_ok=True) torch.save(model.state_dict(), os.path.join(final, "model.pt")) total_time = time.time() - t_start print() print("=" * 72) print(f" Training complete!") print(f" Total time: {total_time:.0f}s ({total_time/3600:.1f}h)") print(f" Total tokens: {cumul_tokens:,}") print(f" Total steps: {step}") print(f" Best loss: {best_loss:.4f}") print(f" Best acc: {best_acc:.3f}") print(f" Avg tok/s: {cumul_tokens/max(total_time, 1):.0f}") print(f" Saved: {final}/model.pt") print("=" * 72) print() if hf_repo_id and hf_token: print(f"Uploading to HF: {hf_repo_id} ...") upload_files = ["best.pt", "final/model.pt"] for ep in range(start_epoch, epochs + 1): ep_path = os.path.join(output_dir, f"epoch_{ep}.pt") if os.path.exists(ep_path): upload_files.append(f"epoch_{ep}.pt") upload_to_hf(output_dir, hf_repo_id, hf_token, upload_files) print(f"Uploaded to https://huggingface.co/{hf_repo_id}") return final if __name__ == "__main__": train()