# /// script # dependencies = ["torch", "transformers>=5.0", "datasets", "huggingface_hub", "accelerate", "rich"] # /// """FastMTP Head Training on HuggingFace A100. Self-contained: MTP Head architecture + training loop + Magpie dataset. Uploads checkpoint to HF repo when done. """ import os, sys, json, time, random, math import torch import torch.nn as nn import torch.nn.functional as F from pathlib import Path from torch.utils.data import Dataset, DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset from huggingface_hub import HfApi # ============================================================ # Config # ============================================================ MODEL_ID = "google/gemma-4-E2B-it" HF_TOKEN = os.environ.get("HF_TOKEN", "") UPLOAD_REPO = "Cytrex/fastmtp-e2b-poc" # personal account (org has no create rights) K = 3 BETA = 0.6 LR = 5e-5 BATCH = 8 EPOCHS = 5 MAX_SEQ = 1024 N_SAMPLES = 20000 OUTPUT = "/tmp/mtp_checkpoint" # ============================================================ # MTP Head (inline — no external dependencies) # ============================================================ class MTPHead(nn.Module): def __init__(self, hidden_size, intermediate_size, num_attention_heads, num_key_value_heads, vocab_size): super().__init__() self.hidden_size = hidden_size self.num_heads = num_attention_heads self.num_kv_heads = num_key_value_heads self.head_dim = hidden_size // num_attention_heads self.embed_tokens = nn.Embedding(vocab_size, hidden_size) self.fusion_proj = nn.Linear(hidden_size * 2, hidden_size, bias=False) self.fusion_norm = nn.RMSNorm(hidden_size, eps=1e-6) self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=False) self.attn_norm = nn.RMSNorm(hidden_size, eps=1e-6) self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) self.ffn_norm = nn.RMSNorm(hidden_size, eps=1e-6) self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) def forward(self, hidden_states, shifted_token_ids): tok_embed = self.embed_tokens(shifted_token_ids) fused = self.fusion_proj(torch.cat([hidden_states, tok_embed], dim=-1)) fused = self.fusion_norm(fused) # Pre-norm attention B, T, _ = fused.shape normed = self.attn_norm(fused) q = self.q_proj(normed).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(normed).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(normed).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) if self.num_kv_heads < self.num_heads: n_rep = self.num_heads // self.num_kv_heads k = k.repeat_interleave(n_rep, dim=1) v = v.repeat_interleave(n_rep, dim=1) attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, -1) x = fused + self.o_proj(attn_out) # Pre-norm FFN normed = self.ffn_norm(x) x = x + self.down_proj(F.silu(self.gate_proj(normed)) * self.up_proj(normed)) return self.lm_head(x), x def trainable_params(self): return [p for p in self.parameters() if p.requires_grad] # ============================================================ # Loss # ============================================================ def compute_alphas(k=3, beta=0.6): raw = [beta ** i for i in range(k)] total = sum(raw) return [w / total for w in raw] def mtp_loss(draft_logits, target_ids, k=3, beta=0.6): alphas = compute_alphas(k, beta) loss = torch.tensor(0.0, device=draft_logits[0].device) for i in range(k): ce = F.cross_entropy( draft_logits[i].reshape(-1, draft_logits[i].size(-1)), target_ids[i].reshape(-1), ignore_index=0, reduction="mean", ) loss = loss + alphas[i] * ce return loss # ============================================================ # Main # ============================================================ def main(): device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") if device == "cuda": print(f"GPU: {torch.cuda.get_device_name(0)}") # Load tokenizer + model print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True) print("Loading base model...") base_model = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype=torch.bfloat16, device_map="auto", token=HF_TOKEN, trust_remote_code=True, ) for p in base_model.parameters(): p.requires_grad_(False) base_model.eval() # Create MTP head, tie embeddings print("Creating MTP head...") config = {"hidden_size": 1536, "intermediate_size": 6144, "num_attention_heads": 8, "num_key_value_heads": 2, "vocab_size": 262144} mtp_head = MTPHead(**config) # Tie embed + lm_head from base model if hasattr(base_model, 'model') and hasattr(base_model.model, 'language_model'): lm = base_model.model.language_model embed_w = lm.model.embed_tokens.weight if hasattr(lm, 'model') else lm.embed_tokens.weight lm_head_w = lm.lm_head.weight elif hasattr(base_model, 'model'): embed_w = base_model.model.embed_tokens.weight lm_head_w = base_model.lm_head.weight else: raise RuntimeError("Cannot find embed/lm_head") mtp_head.embed_tokens.weight = embed_w mtp_head.lm_head.weight = lm_head_w mtp_head.embed_tokens.weight.requires_grad = False mtp_head.lm_head.weight.requires_grad = False base_dtype = next(base_model.parameters()).dtype mtp_head = mtp_head.to(device=device, dtype=base_dtype) n_trainable = sum(p.numel() for p in mtp_head.trainable_params()) print(f"MTP head: {n_trainable:,} trainable params") # Load Magpie dataset print("Loading Magpie-Pro-300K...") ds = load_dataset("Magpie-Align/Magpie-Pro-300K-Filtered", split="train") print(f"Tokenizing {N_SAMPLES} samples...") samples = [] indices = list(range(len(ds))) random.seed(42) random.shuffle(indices) for idx in indices: if len(samples) >= N_SAMPLES: break conv = ds[idx]["conversations"] if len(conv) < 2: continue human = conv[0]["value"] if conv[0]["from"] == "human" else "" gpt = conv[1]["value"] if conv[1]["from"] == "gpt" else "" if not human or not gpt or len(gpt) < 50: continue text = "user\n" + human + "\nmodel\n" + gpt + "" ids = tokenizer.encode(text, max_length=MAX_SEQ, truncation=True) if len(ids) >= K + 4: samples.append(torch.tensor(ids, dtype=torch.long)) print(f"Tokenized: {len(samples)} valid samples") def collate(batch): mx = max(len(s) for s in batch) padded = torch.zeros(len(batch), mx, dtype=torch.long) for i, s in enumerate(batch): padded[i, :len(s)] = s return padded loader = DataLoader(samples, batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2) optimizer = torch.optim.AdamW(mtp_head.trainable_params(), lr=LR, betas=(0.9, 0.95), weight_decay=0.01) total_steps = len(loader) * EPOCHS print(f"\nTraining: {EPOCHS} epochs, {len(loader)} steps/epoch, {total_steps} total") t0 = time.time() best_loss = float("inf") for epoch in range(EPOCHS): epoch_loss = 0 for step, batch in enumerate(loader): input_ids = batch.to(device) B, S = input_ids.shape valid_len = S - K - 1 if valid_len <= 0: continue # Extract hidden states from base model with torch.no_grad(): outputs = base_model(input_ids=input_ids, output_hidden_states=True) hidden = outputs.hidden_states[-1][:, :valid_len, :] # Prepare shifted targets targets = [] for i in range(K): shift = i + 2 t = input_ids[:, shift:shift + valid_len] if t.shape[1] < valid_len: pad = torch.zeros(B, valid_len - t.shape[1], dtype=torch.long, device=device) t = torch.cat([t, pad], dim=1) targets.append(t) # Forward MTP head recursively draft_logits = [] h = hidden for i in range(K): shifted_ids = input_ids[:, i + 1:i + 1 + valid_len] if shifted_ids.shape[1] < valid_len: pad = torch.zeros(B, valid_len - shifted_ids.shape[1], dtype=torch.long, device=device) shifted_ids = torch.cat([shifted_ids, pad], dim=1) logits, h = mtp_head(h, shifted_ids) draft_logits.append(logits) loss = mtp_loss(draft_logits, targets, K, BETA) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(mtp_head.trainable_params(), 1.0) optimizer.step() epoch_loss += loss.item() if (step + 1) % 50 == 0: avg = epoch_loss / (step + 1) elapsed = time.time() - t0 steps_done = epoch * len(loader) + step + 1 eta = (elapsed / steps_done) * (total_steps - steps_done) / 60 print(f" E{epoch+1} S{step+1}/{len(loader)} | loss={loss.item():.4f} avg={avg:.4f} | {elapsed:.0f}s | ETA {eta:.0f}min") avg_loss = epoch_loss / max(len(loader), 1) elapsed = time.time() - t0 print(f"Epoch {epoch+1}/{EPOCHS} | avg_loss={avg_loss:.4f} | {elapsed:.0f}s") # Save checkpoint os.makedirs(OUTPUT, exist_ok=True) ckpt = { "mtp_head_state_dict": {k: v.cpu() for k, v in mtp_head.state_dict().items() if not k.startswith("embed_tokens") and not k.startswith("lm_head")}, "epoch": epoch + 1, "loss": avg_loss, "k": K, "beta": BETA, "config": config, } torch.save(ckpt, f"{OUTPUT}/mtp_checkpoint_e{epoch+1}.pt") if avg_loss < best_loss: best_loss = avg_loss torch.save(ckpt, f"{OUTPUT}/mtp_best.pt") print(f" New best: {best_loss:.4f}") elapsed = time.time() - t0 print(f"\nDONE in {elapsed:.0f}s ({elapsed/60:.1f}min), best loss: {best_loss:.4f}") # Upload to HF if HF_TOKEN: print(f"\nUploading to {UPLOAD_REPO}...") api = HfApi(token=HF_TOKEN) try: api.create_repo(UPLOAD_REPO, exist_ok=True) except Exception as e: print(f"Repo create: {e}") # Save metadata meta = { "type": "fastmtp_head", "base_model": MODEL_ID, "k": K, "beta": BETA, "epochs": EPOCHS, "samples": len(samples), "best_loss": best_loss, "architecture": "shared_weight_transformer_block", "reference": "arXiv:2509.18362", "trainable_params": n_trainable, } with open(f"{OUTPUT}/mtp_config.json", "w") as f: json.dump(meta, f, indent=2) api.upload_folder(folder_path=OUTPUT, repo_id=UPLOAD_REPO, commit_message=f"FastMTP E2B PoC — {EPOCHS} epochs, loss={best_loss:.4f}") print(f"Uploaded to https://huggingface.co/{UPLOAD_REPO}") else: print("No HF_TOKEN — skipping upload") if __name__ == "__main__": main()