Cytrex commited on
Commit
b2a103b
·
verified ·
1 Parent(s): 043a6bd

Self-distill + train script v2

Browse files
Files changed (1) hide show
  1. selfdistill_train.py +348 -0
selfdistill_train.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastMTP: Self-Distill + Train in one job on HF A100.
2
+
3
+ 1. Load E4B base model
4
+ 2. Generate 5k responses (self-distillation)
5
+ 3. Train MTP head on those responses
6
+ 4. Upload checkpoint to HF
7
+ """
8
+ import os, sys, json, time, random
9
+ sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', buffering=1)
10
+ sys.stderr = os.fdopen(sys.stderr.fileno(), 'w', buffering=1)
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from pathlib import Path
16
+ from torch.utils.data import DataLoader
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer
18
+ from datasets import load_dataset
19
+ from huggingface_hub import HfApi
20
+
21
+ # ============================================================
22
+ # Config
23
+ # ============================================================
24
+ MODEL_ID = "InfinimindCreations/gemma-4-E4B-it-uncensored"
25
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
26
+ UPLOAD_REPO = "Cytrex/fastmtp-e4b-selfdistill"
27
+
28
+ # Self-distill config
29
+ N_DISTILL = 5000
30
+ GEN_MAX_TOKENS = 256
31
+ GEN_TEMPERATURE = 0.6
32
+ GEN_TOP_K = 20
33
+ GEN_TOP_P = 0.95
34
+
35
+ # Training config
36
+ K = 3
37
+ BETA = 0.6
38
+ LR = 5e-5
39
+ BATCH = 2
40
+ EPOCHS = 3
41
+ MAX_SEQ = 512
42
+
43
+ OUTPUT = "/tmp/mtp_checkpoint"
44
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
45
+
46
+ # ============================================================
47
+ # MTP Head
48
+ # ============================================================
49
+ class MTPHead(nn.Module):
50
+ def __init__(self, hidden_size, intermediate_size, num_attention_heads, num_key_value_heads, vocab_size):
51
+ super().__init__()
52
+ self.hidden_size = hidden_size
53
+ self.num_heads = num_attention_heads
54
+ self.num_kv_heads = num_key_value_heads
55
+ self.head_dim = hidden_size // num_attention_heads
56
+
57
+ self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
58
+ self.fusion_proj = nn.Linear(hidden_size * 2, hidden_size, bias=False)
59
+ self.fusion_norm = nn.RMSNorm(hidden_size, eps=1e-6)
60
+
61
+ self.q_proj = nn.Linear(hidden_size, self.num_heads * self.head_dim, bias=False)
62
+ self.k_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
63
+ self.v_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
64
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, hidden_size, bias=False)
65
+ self.attn_norm = nn.RMSNorm(hidden_size, eps=1e-6)
66
+
67
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
68
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
69
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
70
+ self.ffn_norm = nn.RMSNorm(hidden_size, eps=1e-6)
71
+
72
+ self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
73
+
74
+ def forward(self, hidden_states, shifted_token_ids):
75
+ tok_embed = self.embed_tokens(shifted_token_ids)
76
+ fused = self.fusion_proj(torch.cat([hidden_states, tok_embed], dim=-1))
77
+ fused = self.fusion_norm(fused)
78
+
79
+ B, T, _ = fused.shape
80
+ normed = self.attn_norm(fused)
81
+ q = self.q_proj(normed).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
82
+ k = self.k_proj(normed).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
83
+ v = self.v_proj(normed).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
84
+ if self.num_kv_heads < self.num_heads:
85
+ n_rep = self.num_heads // self.num_kv_heads
86
+ k = k.repeat_interleave(n_rep, dim=1)
87
+ v = v.repeat_interleave(n_rep, dim=1)
88
+ attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
89
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, -1)
90
+ x = fused + self.o_proj(attn_out)
91
+
92
+ normed = self.ffn_norm(x)
93
+ x = x + self.down_proj(F.silu(self.gate_proj(normed)) * self.up_proj(normed))
94
+
95
+ return self.lm_head(x), x
96
+
97
+ def trainable_params(self):
98
+ return [p for p in self.parameters() if p.requires_grad]
99
+
100
+ # ============================================================
101
+ # Loss
102
+ # ============================================================
103
+ def mtp_loss(draft_logits, target_ids, k=3, beta=0.6):
104
+ raw = [beta ** i for i in range(k)]
105
+ total = sum(raw)
106
+ alphas = [w / total for w in raw]
107
+ loss = torch.tensor(0.0, device=draft_logits[0].device)
108
+ for i in range(k):
109
+ ce = F.cross_entropy(
110
+ draft_logits[i].reshape(-1, draft_logits[i].size(-1)),
111
+ target_ids[i].reshape(-1),
112
+ ignore_index=0, reduction="mean",
113
+ )
114
+ loss = loss + alphas[i] * ce
115
+ return loss
116
+
117
+ # ============================================================
118
+ # Phase 1: Self-Distillation
119
+ # ============================================================
120
+ def generate_selfdistill(model, tokenizer, prompts, max_tokens=256):
121
+ """Generate responses from the model itself."""
122
+ print(f"\n=== PHASE 1: Self-Distill ({len(prompts)} prompts, max_tokens={max_tokens}) ===")
123
+ samples = []
124
+ t0 = time.time()
125
+
126
+ for i, prompt in enumerate(prompts):
127
+ input_text = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
128
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
129
+
130
+ with torch.no_grad():
131
+ output = model.generate(
132
+ input_ids,
133
+ max_new_tokens=max_tokens,
134
+ do_sample=True,
135
+ temperature=GEN_TEMPERATURE,
136
+ top_k=GEN_TOP_K,
137
+ top_p=GEN_TOP_P,
138
+ )
139
+
140
+ response_ids = output[0][input_ids.shape[1]:]
141
+ response = tokenizer.decode(response_ids, skip_special_tokens=True).strip()
142
+
143
+ if len(response) > 20:
144
+ full_text = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n{response}<end_of_turn>"
145
+ ids = tokenizer.encode(full_text, max_length=MAX_SEQ, truncation=True)
146
+ if len(ids) >= K + 4:
147
+ samples.append(torch.tensor(ids, dtype=torch.long))
148
+
149
+ if (i + 1) % 100 == 0:
150
+ elapsed = time.time() - t0
151
+ rate = (i + 1) / elapsed
152
+ eta = (len(prompts) - i - 1) / rate / 60
153
+ print(f" [{i+1}/{len(prompts)}] {len(samples)} valid | {rate:.1f} prompts/s | ETA {eta:.1f}min")
154
+
155
+ elapsed = time.time() - t0
156
+ print(f"Self-distill done: {len(samples)} valid samples in {elapsed:.0f}s ({elapsed/60:.1f}min)")
157
+ return samples
158
+
159
+ # ============================================================
160
+ # Phase 2: Training
161
+ # ============================================================
162
+ def train_mtp(model, mtp_head, samples):
163
+ print(f"\n=== PHASE 2: Training ({len(samples)} samples, {EPOCHS} epochs) ===")
164
+
165
+ def collate(batch):
166
+ mx = max(len(s) for s in batch)
167
+ padded = torch.zeros(len(batch), mx, dtype=torch.long)
168
+ for i, s in enumerate(batch):
169
+ padded[i, :len(s)] = s
170
+ return padded
171
+
172
+ loader = DataLoader(samples, batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=0)
173
+ optimizer = torch.optim.AdamW(mtp_head.trainable_params(), lr=LR, betas=(0.9, 0.95), weight_decay=0.01)
174
+
175
+ # Freeze base model
176
+ for p in model.parameters():
177
+ p.requires_grad_(False)
178
+ model.eval()
179
+
180
+ total_steps = len(loader) * EPOCHS
181
+ print(f"Steps: {len(loader)}/epoch, {total_steps} total")
182
+
183
+ t0 = time.time()
184
+ best_loss = float("inf")
185
+
186
+ for epoch in range(EPOCHS):
187
+ epoch_loss = 0
188
+ for step, batch in enumerate(loader):
189
+ input_ids = batch.to(DEVICE)
190
+ B, S = input_ids.shape
191
+ valid_len = S - K - 1
192
+ if valid_len <= 0:
193
+ continue
194
+
195
+ with torch.no_grad():
196
+ outputs = model(input_ids=input_ids, output_hidden_states=True)
197
+ hidden = outputs.hidden_states[-1][:, :valid_len, :]
198
+
199
+ targets = []
200
+ for i in range(K):
201
+ shift = i + 2
202
+ t = input_ids[:, shift:shift + valid_len]
203
+ if t.shape[1] < valid_len:
204
+ pad = torch.zeros(B, valid_len - t.shape[1], dtype=torch.long, device=DEVICE)
205
+ t = torch.cat([t, pad], dim=1)
206
+ targets.append(t)
207
+
208
+ draft_logits = []
209
+ h = hidden
210
+ for i in range(K):
211
+ shifted_ids = input_ids[:, i + 1:i + 1 + valid_len]
212
+ if shifted_ids.shape[1] < valid_len:
213
+ pad = torch.zeros(B, valid_len - shifted_ids.shape[1], dtype=torch.long, device=DEVICE)
214
+ shifted_ids = torch.cat([shifted_ids, pad], dim=1)
215
+ logits, h = mtp_head(h, shifted_ids)
216
+ draft_logits.append(logits)
217
+
218
+ loss = mtp_loss(draft_logits, targets, K, BETA)
219
+ optimizer.zero_grad()
220
+ loss.backward()
221
+ torch.nn.utils.clip_grad_norm_(mtp_head.trainable_params(), 1.0)
222
+ optimizer.step()
223
+
224
+ epoch_loss += loss.item()
225
+ if (step + 1) % 50 == 0:
226
+ avg = epoch_loss / (step + 1)
227
+ elapsed = time.time() - t0
228
+ steps_done = epoch * len(loader) + step + 1
229
+ eta = (elapsed / steps_done) * (total_steps - steps_done) / 60
230
+ print(f" E{epoch+1} S{step+1}/{len(loader)} | loss={loss.item():.4f} avg={avg:.4f} | {elapsed:.0f}s | ETA {eta:.0f}min")
231
+
232
+ avg_loss = epoch_loss / max(len(loader), 1)
233
+ print(f"Epoch {epoch+1}/{EPOCHS} | avg_loss={avg_loss:.4f} | {time.time()-t0:.0f}s")
234
+
235
+ os.makedirs(OUTPUT, exist_ok=True)
236
+ ckpt = {
237
+ "mtp_head_state_dict": {k: v.cpu() for k, v in mtp_head.state_dict().items()
238
+ if not k.startswith("embed_tokens") and not k.startswith("lm_head")},
239
+ "epoch": epoch + 1,
240
+ "loss": avg_loss,
241
+ "k": K, "beta": BETA,
242
+ "config": {"hidden_size": 2560, "intermediate_size": 10240, "num_attention_heads": 8, "num_key_value_heads": 2, "vocab_size": 262144},
243
+ }
244
+ torch.save(ckpt, f"{OUTPUT}/mtp_checkpoint_e{epoch+1}.pt")
245
+ if avg_loss < best_loss:
246
+ best_loss = avg_loss
247
+ torch.save(ckpt, f"{OUTPUT}/mtp_best.pt")
248
+ print(f" New best: {best_loss:.4f}")
249
+
250
+ return best_loss
251
+
252
+ # ============================================================
253
+ # Main
254
+ # ============================================================
255
+ def main():
256
+ print(f"Device: {DEVICE}")
257
+ if DEVICE == "cuda":
258
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
259
+
260
+ # Load model
261
+ print("Loading tokenizer...")
262
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
263
+
264
+ print("Loading base model...")
265
+ model = AutoModelForCausalLM.from_pretrained(
266
+ MODEL_ID, dtype=torch.bfloat16, device_map="auto",
267
+ token=HF_TOKEN, trust_remote_code=True,
268
+ )
269
+
270
+ # Load prompts from Magpie (only the prompts, not responses)
271
+ print("Loading prompts from Magpie-Pro-300K...")
272
+ ds = load_dataset("Magpie-Align/Magpie-Pro-300K-Filtered", split="train")
273
+ prompts = []
274
+ indices = list(range(len(ds)))
275
+ random.seed(42)
276
+ random.shuffle(indices)
277
+ for idx in indices:
278
+ if len(prompts) >= N_DISTILL:
279
+ break
280
+ conv = ds[idx]["conversations"]
281
+ if len(conv) >= 1 and conv[0]["from"] == "human" and len(conv[0]["value"]) > 10:
282
+ prompts.append(conv[0]["value"])
283
+ print(f"Loaded {len(prompts)} prompts")
284
+
285
+ # Phase 1: Self-distill
286
+ samples = generate_selfdistill(model, tokenizer, prompts, GEN_MAX_TOKENS)
287
+
288
+ if len(samples) < 100:
289
+ print(f"ERROR: Only {len(samples)} valid samples — not enough for training")
290
+ return
291
+
292
+ # Phase 2: Create MTP head and train
293
+ print("\nCreating MTP head...")
294
+ config = {"hidden_size": 2560, "intermediate_size": 10240, "num_attention_heads": 8, "num_key_value_heads": 2, "vocab_size": 262144}
295
+ mtp_head = MTPHead(**config)
296
+
297
+ # Tie embed + lm_head
298
+ if hasattr(model, 'model') and hasattr(model.model, 'language_model'):
299
+ embed_w = model.model.language_model.embed_tokens.weight
300
+ elif hasattr(model, 'model'):
301
+ embed_w = model.model.embed_tokens.weight
302
+ else:
303
+ raise RuntimeError("Cannot find embed_tokens")
304
+ lm_head_w = model.lm_head.weight
305
+
306
+ mtp_head.embed_tokens.weight = embed_w
307
+ mtp_head.lm_head.weight = lm_head_w
308
+ mtp_head.embed_tokens.weight.requires_grad = False
309
+ mtp_head.lm_head.weight.requires_grad = False
310
+
311
+ base_dtype = next(model.parameters()).dtype
312
+ mtp_head = mtp_head.to(device=DEVICE, dtype=base_dtype)
313
+ n_trainable = sum(p.numel() for p in mtp_head.trainable_params())
314
+ print(f"MTP head: {n_trainable:,} trainable params, dtype={base_dtype}")
315
+
316
+ best_loss = train_mtp(model, mtp_head, samples)
317
+
318
+ print(f"\n=== DONE === Best loss: {best_loss:.4f}")
319
+
320
+ # Upload
321
+ if HF_TOKEN:
322
+ print(f"\nUploading to {UPLOAD_REPO}...")
323
+ api = HfApi(token=HF_TOKEN)
324
+ try:
325
+ api.create_repo(UPLOAD_REPO, exist_ok=True)
326
+ except Exception as e:
327
+ print(f"Repo: {e}")
328
+
329
+ meta = {
330
+ "type": "fastmtp_head",
331
+ "base_model": MODEL_ID,
332
+ "method": "self-distillation",
333
+ "distill_samples": len(samples),
334
+ "k": K, "beta": BETA, "epochs": EPOCHS,
335
+ "best_loss": best_loss,
336
+ "trainable_params": n_trainable,
337
+ "reference": "arXiv:2509.18362",
338
+ }
339
+ with open(f"{OUTPUT}/mtp_config.json", "w") as f:
340
+ json.dump(meta, f, indent=2)
341
+
342
+ api.upload_folder(folder_path=OUTPUT, repo_id=UPLOAD_REPO,
343
+ commit_message=f"FastMTP E4B self-distill — {EPOCHS}ep, {len(samples)} samples, loss={best_loss:.4f}")
344
+ print(f"Uploaded: https://huggingface.co/{UPLOAD_REPO}")
345
+
346
+
347
+ if __name__ == "__main__":
348
+ main()