| |
| |
| |
| """FastMTP Head Training on HuggingFace A100. |
| |
| Self-contained: MTP Head architecture + training loop + Magpie dataset. |
| Uploads checkpoint to HF repo when done. |
| """ |
| import os, sys, json, time, random, math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from pathlib import Path |
| from torch.utils.data import Dataset, DataLoader |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from datasets import load_dataset |
| from huggingface_hub import HfApi |
|
|
| |
| |
| |
| MODEL_ID = "google/gemma-4-E2B-it" |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") |
| UPLOAD_REPO = "Cytrex/fastmtp-e2b-poc" |
| K = 3 |
| BETA = 0.6 |
| LR = 5e-5 |
| BATCH = 8 |
| EPOCHS = 5 |
| MAX_SEQ = 1024 |
| N_SAMPLES = 20000 |
| OUTPUT = "/tmp/mtp_checkpoint" |
|
|
| |
| |
| |
| class MTPHead(nn.Module): |
| def __init__(self, hidden_size, intermediate_size, num_attention_heads, num_key_value_heads, vocab_size): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.num_heads = num_attention_heads |
| self.num_kv_heads = num_key_value_heads |
| self.head_dim = hidden_size // num_attention_heads |
|
|
| self.embed_tokens = nn.Embedding(vocab_size, hidden_size) |
| self.fusion_proj = nn.Linear(hidden_size * 2, hidden_size, bias=False) |
| self.fusion_norm = nn.RMSNorm(hidden_size, eps=1e-6) |
|
|
| self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=False) |
| self.k_proj = nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=False) |
| self.attn_norm = nn.RMSNorm(hidden_size, eps=1e-6) |
|
|
| self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) |
| self.ffn_norm = nn.RMSNorm(hidden_size, eps=1e-6) |
|
|
| self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) |
|
|
| def forward(self, hidden_states, shifted_token_ids): |
| tok_embed = self.embed_tokens(shifted_token_ids) |
| fused = self.fusion_proj(torch.cat([hidden_states, tok_embed], dim=-1)) |
| fused = self.fusion_norm(fused) |
|
|
| |
| B, T, _ = fused.shape |
| normed = self.attn_norm(fused) |
| q = self.q_proj(normed).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(normed).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(normed).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| if self.num_kv_heads < self.num_heads: |
| n_rep = self.num_heads // self.num_kv_heads |
| k = k.repeat_interleave(n_rep, dim=1) |
| v = v.repeat_interleave(n_rep, dim=1) |
| attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
| attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, -1) |
| x = fused + self.o_proj(attn_out) |
|
|
| |
| normed = self.ffn_norm(x) |
| x = x + self.down_proj(F.silu(self.gate_proj(normed)) * self.up_proj(normed)) |
|
|
| return self.lm_head(x), x |
|
|
| def trainable_params(self): |
| return [p for p in self.parameters() if p.requires_grad] |
|
|
| |
| |
| |
| def compute_alphas(k=3, beta=0.6): |
| raw = [beta ** i for i in range(k)] |
| total = sum(raw) |
| return [w / total for w in raw] |
|
|
| def mtp_loss(draft_logits, target_ids, k=3, beta=0.6): |
| alphas = compute_alphas(k, beta) |
| loss = torch.tensor(0.0, device=draft_logits[0].device) |
| for i in range(k): |
| ce = F.cross_entropy( |
| draft_logits[i].reshape(-1, draft_logits[i].size(-1)), |
| target_ids[i].reshape(-1), |
| ignore_index=0, |
| reduction="mean", |
| ) |
| loss = loss + alphas[i] * ce |
| return loss |
|
|
| |
| |
| |
| def main(): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Device: {device}") |
| if device == "cuda": |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
|
|
| |
| print("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True) |
|
|
| print("Loading base model...") |
| base_model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, dtype=torch.bfloat16, device_map="auto", |
| token=HF_TOKEN, trust_remote_code=True, |
| ) |
| for p in base_model.parameters(): |
| p.requires_grad_(False) |
| base_model.eval() |
|
|
| |
| print("Creating MTP head...") |
| config = {"hidden_size": 1536, "intermediate_size": 6144, "num_attention_heads": 8, "num_key_value_heads": 2, "vocab_size": 262144} |
| mtp_head = MTPHead(**config) |
|
|
| |
| if hasattr(base_model, 'model') and hasattr(base_model.model, 'language_model'): |
| lm = base_model.model.language_model |
| embed_w = lm.model.embed_tokens.weight if hasattr(lm, 'model') else lm.embed_tokens.weight |
| lm_head_w = lm.lm_head.weight |
| elif hasattr(base_model, 'model'): |
| embed_w = base_model.model.embed_tokens.weight |
| lm_head_w = base_model.lm_head.weight |
| else: |
| raise RuntimeError("Cannot find embed/lm_head") |
|
|
| mtp_head.embed_tokens.weight = embed_w |
| mtp_head.lm_head.weight = lm_head_w |
| mtp_head.embed_tokens.weight.requires_grad = False |
| mtp_head.lm_head.weight.requires_grad = False |
|
|
| base_dtype = next(base_model.parameters()).dtype |
| mtp_head = mtp_head.to(device=device, dtype=base_dtype) |
| n_trainable = sum(p.numel() for p in mtp_head.trainable_params()) |
| print(f"MTP head: {n_trainable:,} trainable params") |
|
|
| |
| print("Loading Magpie-Pro-300K...") |
| ds = load_dataset("Magpie-Align/Magpie-Pro-300K-Filtered", split="train") |
|
|
| print(f"Tokenizing {N_SAMPLES} samples...") |
| samples = [] |
| indices = list(range(len(ds))) |
| random.seed(42) |
| random.shuffle(indices) |
| for idx in indices: |
| if len(samples) >= N_SAMPLES: |
| break |
| conv = ds[idx]["conversations"] |
| if len(conv) < 2: |
| continue |
| human = conv[0]["value"] if conv[0]["from"] == "human" else "" |
| gpt = conv[1]["value"] if conv[1]["from"] == "gpt" else "" |
| if not human or not gpt or len(gpt) < 50: |
| continue |
| text = "<start_of_turn>user\n" + human + "<end_of_turn>\n<start_of_turn>model\n" + gpt + "<end_of_turn>" |
| ids = tokenizer.encode(text, max_length=MAX_SEQ, truncation=True) |
| if len(ids) >= K + 4: |
| samples.append(torch.tensor(ids, dtype=torch.long)) |
| print(f"Tokenized: {len(samples)} valid samples") |
|
|
| def collate(batch): |
| mx = max(len(s) for s in batch) |
| padded = torch.zeros(len(batch), mx, dtype=torch.long) |
| for i, s in enumerate(batch): |
| padded[i, :len(s)] = s |
| return padded |
|
|
| loader = DataLoader(samples, batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2) |
| optimizer = torch.optim.AdamW(mtp_head.trainable_params(), lr=LR, betas=(0.9, 0.95), weight_decay=0.01) |
|
|
| total_steps = len(loader) * EPOCHS |
| print(f"\nTraining: {EPOCHS} epochs, {len(loader)} steps/epoch, {total_steps} total") |
|
|
| t0 = time.time() |
| best_loss = float("inf") |
| for epoch in range(EPOCHS): |
| epoch_loss = 0 |
| for step, batch in enumerate(loader): |
| input_ids = batch.to(device) |
| B, S = input_ids.shape |
| valid_len = S - K - 1 |
| if valid_len <= 0: |
| continue |
|
|
| |
| with torch.no_grad(): |
| outputs = base_model(input_ids=input_ids, output_hidden_states=True) |
| hidden = outputs.hidden_states[-1][:, :valid_len, :] |
|
|
| |
| targets = [] |
| for i in range(K): |
| shift = i + 2 |
| t = input_ids[:, shift:shift + valid_len] |
| if t.shape[1] < valid_len: |
| pad = torch.zeros(B, valid_len - t.shape[1], dtype=torch.long, device=device) |
| t = torch.cat([t, pad], dim=1) |
| targets.append(t) |
|
|
| |
| draft_logits = [] |
| h = hidden |
| for i in range(K): |
| shifted_ids = input_ids[:, i + 1:i + 1 + valid_len] |
| if shifted_ids.shape[1] < valid_len: |
| pad = torch.zeros(B, valid_len - shifted_ids.shape[1], dtype=torch.long, device=device) |
| shifted_ids = torch.cat([shifted_ids, pad], dim=1) |
| logits, h = mtp_head(h, shifted_ids) |
| draft_logits.append(logits) |
|
|
| loss = mtp_loss(draft_logits, targets, K, BETA) |
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(mtp_head.trainable_params(), 1.0) |
| optimizer.step() |
|
|
| epoch_loss += loss.item() |
| if (step + 1) % 50 == 0: |
| avg = epoch_loss / (step + 1) |
| elapsed = time.time() - t0 |
| steps_done = epoch * len(loader) + step + 1 |
| eta = (elapsed / steps_done) * (total_steps - steps_done) / 60 |
| print(f" E{epoch+1} S{step+1}/{len(loader)} | loss={loss.item():.4f} avg={avg:.4f} | {elapsed:.0f}s | ETA {eta:.0f}min") |
|
|
| avg_loss = epoch_loss / max(len(loader), 1) |
| elapsed = time.time() - t0 |
| print(f"Epoch {epoch+1}/{EPOCHS} | avg_loss={avg_loss:.4f} | {elapsed:.0f}s") |
|
|
| |
| os.makedirs(OUTPUT, exist_ok=True) |
| ckpt = { |
| "mtp_head_state_dict": {k: v.cpu() for k, v in mtp_head.state_dict().items() |
| if not k.startswith("embed_tokens") and not k.startswith("lm_head")}, |
| "epoch": epoch + 1, |
| "loss": avg_loss, |
| "k": K, |
| "beta": BETA, |
| "config": config, |
| } |
| torch.save(ckpt, f"{OUTPUT}/mtp_checkpoint_e{epoch+1}.pt") |
| if avg_loss < best_loss: |
| best_loss = avg_loss |
| torch.save(ckpt, f"{OUTPUT}/mtp_best.pt") |
| print(f" New best: {best_loss:.4f}") |
|
|
| elapsed = time.time() - t0 |
| print(f"\nDONE in {elapsed:.0f}s ({elapsed/60:.1f}min), best loss: {best_loss:.4f}") |
|
|
| |
| if HF_TOKEN: |
| print(f"\nUploading to {UPLOAD_REPO}...") |
| api = HfApi(token=HF_TOKEN) |
| try: |
| api.create_repo(UPLOAD_REPO, exist_ok=True) |
| except Exception as e: |
| print(f"Repo create: {e}") |
|
|
| |
| meta = { |
| "type": "fastmtp_head", |
| "base_model": MODEL_ID, |
| "k": K, "beta": BETA, |
| "epochs": EPOCHS, "samples": len(samples), |
| "best_loss": best_loss, |
| "architecture": "shared_weight_transformer_block", |
| "reference": "arXiv:2509.18362", |
| "trainable_params": n_trainable, |
| } |
| with open(f"{OUTPUT}/mtp_config.json", "w") as f: |
| json.dump(meta, f, indent=2) |
|
|
| api.upload_folder(folder_path=OUTPUT, repo_id=UPLOAD_REPO, commit_message=f"FastMTP E2B PoC — {EPOCHS} epochs, loss={best_loss:.4f}") |
| print(f"Uploaded to https://huggingface.co/{UPLOAD_REPO}") |
| else: |
| print("No HF_TOKEN — skipping upload") |
|
|
| if __name__ == "__main__": |
| main() |
|
|