vini-pico / scripts /fix_and_retrain.py
jayptl-rq's picture
Upload scripts/fix_and_retrain.py with huggingface_hub
174811b verified
#!/usr/bin/env python3
"""
Fix 2 critical bugs and re-run SFT:
1. SFT label masking: only train on assistant tokens
2. Dead layer revival: reinit layers with collapsed FFN weights
"""
import sys, json, math, os, time, random, glob
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
sys.path.insert(0, "/teamspace/studios/this_studio/vini/scripts")
from train_pico import ViniPico, ViniTokenizer, get_lr, save_checkpoint, BitLinear
VINI = "/teamspace/studios/this_studio/vini"
TOK = f"{VINI}/models/vini-pico-tokenizer.json"
SFT_DATA = f"{VINI}/data/processed/sft_data.jsonl"
PRETRAIN_CKPT = f"{VINI}/checkpoints/pico/phase2/latest.pt"
# ============================================================================
# FIX 1: Proper SFT Dataset with assistant-only masking
# ============================================================================
class SFTDatasetFixed(Dataset):
"""SFT dataset that ONLY trains on assistant responses."""
def __init__(self, path, tokenizer, max_seq_len=2048):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.samples = []
self.role_tokens = {
"system": "<|system|>",
"user": "<|user|>",
"assistant": "<|assistant|>",
"tool": "<tool_result>",
}
self.end_token = "<|end|>"
with open(path, "r", encoding="utf-8") as f:
for line in f:
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
messages = data.get("messages", [])
if len(messages) >= 2:
self.samples.append(messages)
print(f"Loaded {len(self.samples)} SFT samples (FIXED: assistant-only masking)")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
messages = self.samples[idx]
# Build each segment with offset tracking
all_ids = []
assistant_mask = [] # True for tokens we want to train on
# Start with BOS
bos_ids = self.tokenizer.encode("<|bos|>")
all_ids.extend(bos_ids)
assistant_mask.extend([False] * len(bos_ids))
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
start_token = self.role_tokens.get(role, "<|user|>")
if role == "tool":
segment = f"\n{start_token}\n{content}\n</tool_result>"
else:
segment = f"\n{start_token}\n{content}\n{self.end_token}"
seg_ids = self.tokenizer.encode(segment)
all_ids.extend(seg_ids)
# Only train on assistant (and tool result) tokens
is_trainable = role in ("assistant",)
assistant_mask.extend([is_trainable] * len(seg_ids))
# EOS
eos_ids = self.tokenizer.encode("\n<|eos|>")
all_ids.extend(eos_ids)
assistant_mask.extend([True] * len(eos_ids)) # Train on EOS after last assistant
# Truncate
all_ids = all_ids[:self.max_seq_len]
assistant_mask = assistant_mask[:self.max_seq_len]
input_ids = torch.tensor(all_ids, dtype=torch.long)
labels = input_ids.clone()
# MASK non-assistant tokens
for i in range(len(labels)):
if not assistant_mask[i]:
labels[i] = -100
# Pad
pad_len = self.max_seq_len - len(all_ids)
if pad_len > 0:
input_ids = F.pad(input_ids, (0, pad_len), value=0)
labels = F.pad(labels, (0, pad_len), value=-100)
return input_ids, labels
# ============================================================================
# FIX 2: Dead layer revival
# ============================================================================
def revive_dead_layers(model, threshold=0.001):
"""Reinitialize FFN layers that have collapsed (beta near zero)."""
revived = 0
for i, layer in enumerate(model.layers):
beta_w1 = layer.ffn.w1.weight.data.abs().mean().item()
beta_w3 = layer.ffn.w3.weight.data.abs().mean().item()
if beta_w1 < threshold or beta_w3 < threshold:
print(f" REVIVING layer {i} FFN (w1 beta={beta_w1:.6f}, w3 beta={beta_w3:.6f})")
# Reinitialize with small values (same as model init)
nn.init.normal_(layer.ffn.w1.weight, mean=0.0, std=0.02)
nn.init.normal_(layer.ffn.w2.weight, mean=0.0, std=0.02)
nn.init.normal_(layer.ffn.w3.weight, mean=0.0, std=0.02)
revived += 1
if revived == 0:
print(" No dead layers found")
else:
print(f" Revived {revived} dead FFN layers")
return revived
# ============================================================================
# FIXED SFT Training
# ============================================================================
def train_sft_fixed():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if device.type == "cuda":
print(f"GPU: {torch.cuda.get_device_name(0)}")
config = {
"vocab_size": 32000, "dim": 384, "n_layers": 8,
"n_heads": 6, "n_kv_heads": 2, "hidden_dim": 1024,
"max_seq_len": 2048, "pad_token_id": 0, "eos_token_id": 1, "bos_token_id": 2,
}
tokenizer = ViniTokenizer(TOK)
model = ViniPico(config).to(device)
# Load Phase 2 pretrain checkpoint
print(f"\nLoading pretrain checkpoint: {PRETRAIN_CKPT}")
ckpt = torch.load(PRETRAIN_CKPT, map_location=device, weights_only=False)
model.load_state_dict(ckpt["model"])
print(f"Loaded (step {ckpt.get('step', '?')})")
# FIX 2: Revive dead layers BEFORE training
print("\nChecking for dead FFN layers:")
revive_dead_layers(model)
# Verify fix
print("\nPost-revival FFN stats:")
for i in range(8):
beta = model.layers[i].ffn.w1.weight.data.abs().mean().item()
print(f" Layer {i}: beta={beta:.5f}")
# FIXED SFT dataset
dataset = SFTDatasetFixed(SFT_DATA, tokenizer, config["max_seq_len"])
# Limit to 150K samples (same as Day 2)
max_samples = 150000
if len(dataset) > max_samples:
dataset.samples = dataset.samples[:max_samples]
print(f"Limited to {max_samples} samples")
# Verify masking works
sample_in, sample_lab = dataset[0]
n_total = sample_in.shape[0]
n_masked = (sample_lab == -100).sum().item()
n_train = (sample_lab != -100).sum().item()
print(f"\nMasking check (sample 0):")
print(f" Total tokens: {n_total}")
print(f" Masked (non-assistant): {n_masked} ({n_masked/n_total:.0%})")
print(f" Training (assistant only): {n_train} ({n_train/n_total:.0%})")
batch_size = 16
grad_accum = 4
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True,
num_workers=2, pin_memory=True, drop_last=True)
# Training config
max_lr = 2e-4
min_lr = max_lr * 0.01
epochs = 1
total_steps = len(loader) * epochs
max_steps = total_steps # No cap this time
warmup_steps = min(500, total_steps // 10)
save_every = 3000
log_every = 100
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, betas=(0.9, 0.95),
weight_decay=0.01)
checkpoint_dir = Path(f"{VINI}/checkpoints/pico/phase2/sft_fixed")
checkpoint_dir.mkdir(parents=True, exist_ok=True)
print(f"\n{'='*60}")
print(f"FIXED SFT TRAINING")
print(f"{'='*60}")
print(f" FIX 1: Assistant-only masking (was: train on ALL tokens)")
print(f" FIX 2: Dead FFN layers revived")
print(f" Samples: {len(dataset)}")
print(f" Steps: {total_steps}")
print(f" Batch: {batch_size} x {grad_accum} = {batch_size * grad_accum}")
print(f" LR: {max_lr}")
print(f" Save every: {save_every}")
print(f"{'='*60}\n")
model.train()
step = 0
t_start = time.time()
running_loss = []
for epoch in range(epochs):
for input_ids, labels in loader:
input_ids = input_ids.to(device)
labels = labels.to(device)
lr = get_lr(step, total_steps, max_lr, min_lr, warmup_steps)
for pg in optimizer.param_groups:
pg["lr"] = lr
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
logits, loss = model(input_ids, labels)
scaled = loss / grad_accum
scaled.backward()
if (step + 1) % grad_accum == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
running_loss.append(loss.item())
step += 1
if step % log_every == 0:
avg = sum(running_loss[-log_every:]) / len(running_loss[-log_every:])
elapsed = time.time() - t_start
print(f"step {step:>6d}/{total_steps} | loss {avg:.4f} | "
f"lr {lr:.2e} | {elapsed/60:.1f} min")
if step % save_every == 0:
save_checkpoint(model, optimizer, config, step, checkpoint_dir, prefix="sft_fixed")
if step >= max_steps:
break
avg_loss = sum(running_loss) / len(running_loss)
print(f"\nEpoch {epoch+1}: avg_loss={avg_loss:.4f}")
# Final checkpoint
save_checkpoint(model, optimizer, config, step, checkpoint_dir, is_final=True, prefix="sft_fixed")
elapsed = time.time() - t_start
print(f"\n{'='*60}")
print(f"FIXED SFT COMPLETE")
print(f" Steps: {step}, Time: {elapsed/60:.1f} min")
print(f" Avg loss: {sum(running_loss)/len(running_loss):.4f}")
print(f" Final loss (last 100): {sum(running_loss[-100:])/len(running_loss[-100:]):.4f}")
print(f"{'='*60}")
# Post-training check: are layers still alive?
print("\nPost-training FFN health check:")
for i in range(8):
beta = model.layers[i].ffn.w1.weight.data.abs().mean().item()
status = "DEAD" if beta < 0.001 else ("weak" if beta < 0.01 else "ok")
print(f" Layer {i}: beta={beta:.5f} [{status}]")
if __name__ == "__main__":
train_sft_fixed()