fastmtp-training / train.py
Cytrex's picture
FastMTP E2B training script
043a6bd verified
# /// script
# dependencies = ["torch", "transformers>=5.0", "datasets", "huggingface_hub", "accelerate", "rich"]
# ///
"""FastMTP Head Training on HuggingFace A100.
Self-contained: MTP Head architecture + training loop + Magpie dataset.
Uploads checkpoint to HF repo when done.
"""
import os, sys, json, time, random, math
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from huggingface_hub import HfApi
# ============================================================
# Config
# ============================================================
MODEL_ID = "google/gemma-4-E2B-it"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
UPLOAD_REPO = "Cytrex/fastmtp-e2b-poc" # personal account (org has no create rights)
K = 3
BETA = 0.6
LR = 5e-5
BATCH = 8
EPOCHS = 5
MAX_SEQ = 1024
N_SAMPLES = 20000
OUTPUT = "/tmp/mtp_checkpoint"
# ============================================================
# MTP Head (inline — no external dependencies)
# ============================================================
class MTPHead(nn.Module):
def __init__(self, hidden_size, intermediate_size, num_attention_heads, num_key_value_heads, vocab_size):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.num_kv_heads = num_key_value_heads
self.head_dim = hidden_size // num_attention_heads
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
self.fusion_proj = nn.Linear(hidden_size * 2, hidden_size, bias=False)
self.fusion_norm = nn.RMSNorm(hidden_size, eps=1e-6)
self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=False)
self.attn_norm = nn.RMSNorm(hidden_size, eps=1e-6)
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.ffn_norm = nn.RMSNorm(hidden_size, eps=1e-6)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(self, hidden_states, shifted_token_ids):
tok_embed = self.embed_tokens(shifted_token_ids)
fused = self.fusion_proj(torch.cat([hidden_states, tok_embed], dim=-1))
fused = self.fusion_norm(fused)
# Pre-norm attention
B, T, _ = fused.shape
normed = self.attn_norm(fused)
q = self.q_proj(normed).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(normed).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(normed).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
if self.num_kv_heads < self.num_heads:
n_rep = self.num_heads // self.num_kv_heads
k = k.repeat_interleave(n_rep, dim=1)
v = v.repeat_interleave(n_rep, dim=1)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, -1)
x = fused + self.o_proj(attn_out)
# Pre-norm FFN
normed = self.ffn_norm(x)
x = x + self.down_proj(F.silu(self.gate_proj(normed)) * self.up_proj(normed))
return self.lm_head(x), x
def trainable_params(self):
return [p for p in self.parameters() if p.requires_grad]
# ============================================================
# Loss
# ============================================================
def compute_alphas(k=3, beta=0.6):
raw = [beta ** i for i in range(k)]
total = sum(raw)
return [w / total for w in raw]
def mtp_loss(draft_logits, target_ids, k=3, beta=0.6):
alphas = compute_alphas(k, beta)
loss = torch.tensor(0.0, device=draft_logits[0].device)
for i in range(k):
ce = F.cross_entropy(
draft_logits[i].reshape(-1, draft_logits[i].size(-1)),
target_ids[i].reshape(-1),
ignore_index=0,
reduction="mean",
)
loss = loss + alphas[i] * ce
return loss
# ============================================================
# Main
# ============================================================
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
print(f"GPU: {torch.cuda.get_device_name(0)}")
# Load tokenizer + model
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, dtype=torch.bfloat16, device_map="auto",
token=HF_TOKEN, trust_remote_code=True,
)
for p in base_model.parameters():
p.requires_grad_(False)
base_model.eval()
# Create MTP head, tie embeddings
print("Creating MTP head...")
config = {"hidden_size": 1536, "intermediate_size": 6144, "num_attention_heads": 8, "num_key_value_heads": 2, "vocab_size": 262144}
mtp_head = MTPHead(**config)
# Tie embed + lm_head from base model
if hasattr(base_model, 'model') and hasattr(base_model.model, 'language_model'):
lm = base_model.model.language_model
embed_w = lm.model.embed_tokens.weight if hasattr(lm, 'model') else lm.embed_tokens.weight
lm_head_w = lm.lm_head.weight
elif hasattr(base_model, 'model'):
embed_w = base_model.model.embed_tokens.weight
lm_head_w = base_model.lm_head.weight
else:
raise RuntimeError("Cannot find embed/lm_head")
mtp_head.embed_tokens.weight = embed_w
mtp_head.lm_head.weight = lm_head_w
mtp_head.embed_tokens.weight.requires_grad = False
mtp_head.lm_head.weight.requires_grad = False
base_dtype = next(base_model.parameters()).dtype
mtp_head = mtp_head.to(device=device, dtype=base_dtype)
n_trainable = sum(p.numel() for p in mtp_head.trainable_params())
print(f"MTP head: {n_trainable:,} trainable params")
# Load Magpie dataset
print("Loading Magpie-Pro-300K...")
ds = load_dataset("Magpie-Align/Magpie-Pro-300K-Filtered", split="train")
print(f"Tokenizing {N_SAMPLES} samples...")
samples = []
indices = list(range(len(ds)))
random.seed(42)
random.shuffle(indices)
for idx in indices:
if len(samples) >= N_SAMPLES:
break
conv = ds[idx]["conversations"]
if len(conv) < 2:
continue
human = conv[0]["value"] if conv[0]["from"] == "human" else ""
gpt = conv[1]["value"] if conv[1]["from"] == "gpt" else ""
if not human or not gpt or len(gpt) < 50:
continue
text = "<start_of_turn>user\n" + human + "<end_of_turn>\n<start_of_turn>model\n" + gpt + "<end_of_turn>"
ids = tokenizer.encode(text, max_length=MAX_SEQ, truncation=True)
if len(ids) >= K + 4:
samples.append(torch.tensor(ids, dtype=torch.long))
print(f"Tokenized: {len(samples)} valid samples")
def collate(batch):
mx = max(len(s) for s in batch)
padded = torch.zeros(len(batch), mx, dtype=torch.long)
for i, s in enumerate(batch):
padded[i, :len(s)] = s
return padded
loader = DataLoader(samples, batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)
optimizer = torch.optim.AdamW(mtp_head.trainable_params(), lr=LR, betas=(0.9, 0.95), weight_decay=0.01)
total_steps = len(loader) * EPOCHS
print(f"\nTraining: {EPOCHS} epochs, {len(loader)} steps/epoch, {total_steps} total")
t0 = time.time()
best_loss = float("inf")
for epoch in range(EPOCHS):
epoch_loss = 0
for step, batch in enumerate(loader):
input_ids = batch.to(device)
B, S = input_ids.shape
valid_len = S - K - 1
if valid_len <= 0:
continue
# Extract hidden states from base model
with torch.no_grad():
outputs = base_model(input_ids=input_ids, output_hidden_states=True)
hidden = outputs.hidden_states[-1][:, :valid_len, :]
# Prepare shifted targets
targets = []
for i in range(K):
shift = i + 2
t = input_ids[:, shift:shift + valid_len]
if t.shape[1] < valid_len:
pad = torch.zeros(B, valid_len - t.shape[1], dtype=torch.long, device=device)
t = torch.cat([t, pad], dim=1)
targets.append(t)
# Forward MTP head recursively
draft_logits = []
h = hidden
for i in range(K):
shifted_ids = input_ids[:, i + 1:i + 1 + valid_len]
if shifted_ids.shape[1] < valid_len:
pad = torch.zeros(B, valid_len - shifted_ids.shape[1], dtype=torch.long, device=device)
shifted_ids = torch.cat([shifted_ids, pad], dim=1)
logits, h = mtp_head(h, shifted_ids)
draft_logits.append(logits)
loss = mtp_loss(draft_logits, targets, K, BETA)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(mtp_head.trainable_params(), 1.0)
optimizer.step()
epoch_loss += loss.item()
if (step + 1) % 50 == 0:
avg = epoch_loss / (step + 1)
elapsed = time.time() - t0
steps_done = epoch * len(loader) + step + 1
eta = (elapsed / steps_done) * (total_steps - steps_done) / 60
print(f" E{epoch+1} S{step+1}/{len(loader)} | loss={loss.item():.4f} avg={avg:.4f} | {elapsed:.0f}s | ETA {eta:.0f}min")
avg_loss = epoch_loss / max(len(loader), 1)
elapsed = time.time() - t0
print(f"Epoch {epoch+1}/{EPOCHS} | avg_loss={avg_loss:.4f} | {elapsed:.0f}s")
# Save checkpoint
os.makedirs(OUTPUT, exist_ok=True)
ckpt = {
"mtp_head_state_dict": {k: v.cpu() for k, v in mtp_head.state_dict().items()
if not k.startswith("embed_tokens") and not k.startswith("lm_head")},
"epoch": epoch + 1,
"loss": avg_loss,
"k": K,
"beta": BETA,
"config": config,
}
torch.save(ckpt, f"{OUTPUT}/mtp_checkpoint_e{epoch+1}.pt")
if avg_loss < best_loss:
best_loss = avg_loss
torch.save(ckpt, f"{OUTPUT}/mtp_best.pt")
print(f" New best: {best_loss:.4f}")
elapsed = time.time() - t0
print(f"\nDONE in {elapsed:.0f}s ({elapsed/60:.1f}min), best loss: {best_loss:.4f}")
# Upload to HF
if HF_TOKEN:
print(f"\nUploading to {UPLOAD_REPO}...")
api = HfApi(token=HF_TOKEN)
try:
api.create_repo(UPLOAD_REPO, exist_ok=True)
except Exception as e:
print(f"Repo create: {e}")
# Save metadata
meta = {
"type": "fastmtp_head",
"base_model": MODEL_ID,
"k": K, "beta": BETA,
"epochs": EPOCHS, "samples": len(samples),
"best_loss": best_loss,
"architecture": "shared_weight_transformer_block",
"reference": "arXiv:2509.18362",
"trainable_params": n_trainable,
}
with open(f"{OUTPUT}/mtp_config.json", "w") as f:
json.dump(meta, f, indent=2)
api.upload_folder(folder_path=OUTPUT, repo_id=UPLOAD_REPO, commit_message=f"FastMTP E2B PoC — {EPOCHS} epochs, loss={best_loss:.4f}")
print(f"Uploaded to https://huggingface.co/{UPLOAD_REPO}")
else:
print("No HF_TOKEN — skipping upload")
if __name__ == "__main__":
main()