| |
| """ |
| Fix 2 critical bugs and re-run SFT: |
| 1. SFT label masking: only train on assistant tokens |
| 2. Dead layer revival: reinit layers with collapsed FFN weights |
| """ |
| import sys, json, math, os, time, random, glob |
| from pathlib import Path |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
|
|
| sys.path.insert(0, "/teamspace/studios/this_studio/vini/scripts") |
| from train_pico import ViniPico, ViniTokenizer, get_lr, save_checkpoint, BitLinear |
|
|
| VINI = "/teamspace/studios/this_studio/vini" |
| TOK = f"{VINI}/models/vini-pico-tokenizer.json" |
| SFT_DATA = f"{VINI}/data/processed/sft_data.jsonl" |
| PRETRAIN_CKPT = f"{VINI}/checkpoints/pico/phase2/latest.pt" |
|
|
| |
| |
| |
|
|
| class SFTDatasetFixed(Dataset): |
| """SFT dataset that ONLY trains on assistant responses.""" |
| def __init__(self, path, tokenizer, max_seq_len=2048): |
| self.tokenizer = tokenizer |
| self.max_seq_len = max_seq_len |
| self.samples = [] |
| |
| self.role_tokens = { |
| "system": "<|system|>", |
| "user": "<|user|>", |
| "assistant": "<|assistant|>", |
| "tool": "<tool_result>", |
| } |
| self.end_token = "<|end|>" |
| |
| with open(path, "r", encoding="utf-8") as f: |
| for line in f: |
| try: |
| data = json.loads(line) |
| except json.JSONDecodeError: |
| continue |
| messages = data.get("messages", []) |
| if len(messages) >= 2: |
| self.samples.append(messages) |
| |
| print(f"Loaded {len(self.samples)} SFT samples (FIXED: assistant-only masking)") |
| |
| def __len__(self): |
| return len(self.samples) |
| |
| def __getitem__(self, idx): |
| messages = self.samples[idx] |
| |
| |
| all_ids = [] |
| assistant_mask = [] |
| |
| |
| bos_ids = self.tokenizer.encode("<|bos|>") |
| all_ids.extend(bos_ids) |
| assistant_mask.extend([False] * len(bos_ids)) |
| |
| for msg in messages: |
| role = msg.get("role", "user") |
| content = msg.get("content", "") |
| start_token = self.role_tokens.get(role, "<|user|>") |
| |
| if role == "tool": |
| segment = f"\n{start_token}\n{content}\n</tool_result>" |
| else: |
| segment = f"\n{start_token}\n{content}\n{self.end_token}" |
| |
| seg_ids = self.tokenizer.encode(segment) |
| all_ids.extend(seg_ids) |
| |
| |
| is_trainable = role in ("assistant",) |
| assistant_mask.extend([is_trainable] * len(seg_ids)) |
| |
| |
| eos_ids = self.tokenizer.encode("\n<|eos|>") |
| all_ids.extend(eos_ids) |
| assistant_mask.extend([True] * len(eos_ids)) |
| |
| |
| all_ids = all_ids[:self.max_seq_len] |
| assistant_mask = assistant_mask[:self.max_seq_len] |
| |
| input_ids = torch.tensor(all_ids, dtype=torch.long) |
| labels = input_ids.clone() |
| |
| |
| for i in range(len(labels)): |
| if not assistant_mask[i]: |
| labels[i] = -100 |
| |
| |
| pad_len = self.max_seq_len - len(all_ids) |
| if pad_len > 0: |
| input_ids = F.pad(input_ids, (0, pad_len), value=0) |
| labels = F.pad(labels, (0, pad_len), value=-100) |
| |
| return input_ids, labels |
|
|
|
|
| |
| |
| |
|
|
| def revive_dead_layers(model, threshold=0.001): |
| """Reinitialize FFN layers that have collapsed (beta near zero).""" |
| revived = 0 |
| for i, layer in enumerate(model.layers): |
| beta_w1 = layer.ffn.w1.weight.data.abs().mean().item() |
| beta_w3 = layer.ffn.w3.weight.data.abs().mean().item() |
| |
| if beta_w1 < threshold or beta_w3 < threshold: |
| print(f" REVIVING layer {i} FFN (w1 beta={beta_w1:.6f}, w3 beta={beta_w3:.6f})") |
| |
| nn.init.normal_(layer.ffn.w1.weight, mean=0.0, std=0.02) |
| nn.init.normal_(layer.ffn.w2.weight, mean=0.0, std=0.02) |
| nn.init.normal_(layer.ffn.w3.weight, mean=0.0, std=0.02) |
| revived += 1 |
| |
| if revived == 0: |
| print(" No dead layers found") |
| else: |
| print(f" Revived {revived} dead FFN layers") |
| return revived |
|
|
|
|
| |
| |
| |
|
|
| def train_sft_fixed(): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
| if device.type == "cuda": |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
| |
| config = { |
| "vocab_size": 32000, "dim": 384, "n_layers": 8, |
| "n_heads": 6, "n_kv_heads": 2, "hidden_dim": 1024, |
| "max_seq_len": 2048, "pad_token_id": 0, "eos_token_id": 1, "bos_token_id": 2, |
| } |
| |
| tokenizer = ViniTokenizer(TOK) |
| model = ViniPico(config).to(device) |
| |
| |
| print(f"\nLoading pretrain checkpoint: {PRETRAIN_CKPT}") |
| ckpt = torch.load(PRETRAIN_CKPT, map_location=device, weights_only=False) |
| model.load_state_dict(ckpt["model"]) |
| print(f"Loaded (step {ckpt.get('step', '?')})") |
| |
| |
| print("\nChecking for dead FFN layers:") |
| revive_dead_layers(model) |
| |
| |
| print("\nPost-revival FFN stats:") |
| for i in range(8): |
| beta = model.layers[i].ffn.w1.weight.data.abs().mean().item() |
| print(f" Layer {i}: beta={beta:.5f}") |
| |
| |
| dataset = SFTDatasetFixed(SFT_DATA, tokenizer, config["max_seq_len"]) |
| |
| |
| max_samples = 150000 |
| if len(dataset) > max_samples: |
| dataset.samples = dataset.samples[:max_samples] |
| print(f"Limited to {max_samples} samples") |
| |
| |
| sample_in, sample_lab = dataset[0] |
| n_total = sample_in.shape[0] |
| n_masked = (sample_lab == -100).sum().item() |
| n_train = (sample_lab != -100).sum().item() |
| print(f"\nMasking check (sample 0):") |
| print(f" Total tokens: {n_total}") |
| print(f" Masked (non-assistant): {n_masked} ({n_masked/n_total:.0%})") |
| print(f" Training (assistant only): {n_train} ({n_train/n_total:.0%})") |
| |
| batch_size = 16 |
| grad_accum = 4 |
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, |
| num_workers=2, pin_memory=True, drop_last=True) |
| |
| |
| max_lr = 2e-4 |
| min_lr = max_lr * 0.01 |
| epochs = 1 |
| total_steps = len(loader) * epochs |
| max_steps = total_steps |
| warmup_steps = min(500, total_steps // 10) |
| save_every = 3000 |
| log_every = 100 |
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, betas=(0.9, 0.95), |
| weight_decay=0.01) |
| |
| checkpoint_dir = Path(f"{VINI}/checkpoints/pico/phase2/sft_fixed") |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| |
| print(f"\n{'='*60}") |
| print(f"FIXED SFT TRAINING") |
| print(f"{'='*60}") |
| print(f" FIX 1: Assistant-only masking (was: train on ALL tokens)") |
| print(f" FIX 2: Dead FFN layers revived") |
| print(f" Samples: {len(dataset)}") |
| print(f" Steps: {total_steps}") |
| print(f" Batch: {batch_size} x {grad_accum} = {batch_size * grad_accum}") |
| print(f" LR: {max_lr}") |
| print(f" Save every: {save_every}") |
| print(f"{'='*60}\n") |
| |
| model.train() |
| step = 0 |
| t_start = time.time() |
| running_loss = [] |
| |
| for epoch in range(epochs): |
| for input_ids, labels in loader: |
| input_ids = input_ids.to(device) |
| labels = labels.to(device) |
| |
| lr = get_lr(step, total_steps, max_lr, min_lr, warmup_steps) |
| for pg in optimizer.param_groups: |
| pg["lr"] = lr |
| |
| with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): |
| logits, loss = model(input_ids, labels) |
| scaled = loss / grad_accum |
| |
| scaled.backward() |
| |
| if (step + 1) % grad_accum == 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| optimizer.zero_grad() |
| |
| running_loss.append(loss.item()) |
| step += 1 |
| |
| if step % log_every == 0: |
| avg = sum(running_loss[-log_every:]) / len(running_loss[-log_every:]) |
| elapsed = time.time() - t_start |
| print(f"step {step:>6d}/{total_steps} | loss {avg:.4f} | " |
| f"lr {lr:.2e} | {elapsed/60:.1f} min") |
| |
| if step % save_every == 0: |
| save_checkpoint(model, optimizer, config, step, checkpoint_dir, prefix="sft_fixed") |
| |
| if step >= max_steps: |
| break |
| |
| avg_loss = sum(running_loss) / len(running_loss) |
| print(f"\nEpoch {epoch+1}: avg_loss={avg_loss:.4f}") |
| |
| |
| save_checkpoint(model, optimizer, config, step, checkpoint_dir, is_final=True, prefix="sft_fixed") |
| |
| elapsed = time.time() - t_start |
| print(f"\n{'='*60}") |
| print(f"FIXED SFT COMPLETE") |
| print(f" Steps: {step}, Time: {elapsed/60:.1f} min") |
| print(f" Avg loss: {sum(running_loss)/len(running_loss):.4f}") |
| print(f" Final loss (last 100): {sum(running_loss[-100:])/len(running_loss[-100:]):.4f}") |
| print(f"{'='*60}") |
| |
| |
| print("\nPost-training FFN health check:") |
| for i in range(8): |
| beta = model.layers[i].ffn.w1.weight.data.abs().mean().item() |
| status = "DEAD" if beta < 0.001 else ("weak" if beta < 0.01 else "ok") |
| print(f" Layer {i}: beta={beta:.5f} [{status}]") |
|
|
| if __name__ == "__main__": |
| train_sft_fixed() |
|
|