td-toolkit / td_fuse /validate.py
td-builder's picture
Fixed code: vocab mismatch fix for cross-arch merging (Llama/Falcon)
5d61448 verified
"""
Post-Merge Validation β€” run after EVERY merge step.
Tests:
1. Canary recall (did knowledge transfer?)
2. Perplexity check (did we break the model?)
3. Thinking mode (do <think> tags still work?)
4. Quick reasoning test (can it still think?)
Kill criteria: >10% performance drop on any test β†’ abort merge.
Findings: #11, #22, #25
"""
import sys
import time
import torch
import math
from transformers import AutoModelForCausalLM, AutoTokenizer
from .canary import test_all_canaries
from .config import MergeConfig
def validate_merged_model(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
merged_sources: list[str],
cfg: MergeConfig,
baseline_perplexity: float = None,
) -> dict:
"""
Run full validation suite on a merged model.
Args:
model: The merged model to validate
tokenizer: The tokenizer
merged_sources: List of source models merged so far
cfg: Merge configuration
baseline_perplexity: Perplexity of the target model before merging
Returns:
Dict with test results and overall pass/fail
"""
val_start = time.time()
print("\n" + "=" * 60)
print(f"VALIDATION β€” After merging: {', '.join(merged_sources)}")
print(f"Started at: {time.strftime('%H:%M:%S')}")
print("=" * 60)
sys.stdout.flush()
results = {
"canary": None,
"perplexity": None,
"thinking_mode": None,
"reasoning": None,
"overall": False,
}
# --- Test 1: Canary recall ---
print("[validate] Test 1/4: Canary recall..."); sys.stdout.flush()
canary_results = test_all_canaries(model, tokenizer, merged_sources)
passed_canaries = sum(1 for v in canary_results.values() if v)
total_canaries = len(canary_results)
results["canary"] = {
"passed": passed_canaries,
"total": total_canaries,
"ok": passed_canaries >= min(cfg.canary_pass_threshold, total_canaries),
"details": canary_results,
}
# --- Test 2: Perplexity ---
print("[validate] Test 2/4: Perplexity..."); sys.stdout.flush()
perplexity = compute_perplexity(model, tokenizer)
ppl_ok = True
if baseline_perplexity is not None:
ratio = perplexity / baseline_perplexity
ppl_ok = ratio < cfg.perplexity_threshold
print(f"\n[validate] Perplexity: {perplexity:.2f} (baseline: {baseline_perplexity:.2f}, ratio: {ratio:.2f})")
if not ppl_ok:
print(f"[validate] ⚠ Perplexity ratio {ratio:.2f} exceeds threshold {cfg.perplexity_threshold}")
else:
print(f"\n[validate] Perplexity: {perplexity:.2f} (no baseline to compare)")
ppl_ratio = ratio if baseline_perplexity is not None else 1.0
results["perplexity"] = {"value": perplexity, "ok": ppl_ok, "ratio": ppl_ratio}
# --- Test 3: Thinking mode ---
print("[validate] Test 3/4: Thinking mode..."); sys.stdout.flush()
think_ok = test_thinking_mode(model, tokenizer)
results["thinking_mode"] = {"ok": think_ok}
# --- Test 4: Quick reasoning ---
print("[validate] Test 4/4: Quick reasoning..."); sys.stdout.flush()
reason_ok = test_reasoning(model, tokenizer)
results["reasoning"] = {"ok": reason_ok}
# --- Overall verdict ---
all_ok = (
results["canary"]["ok"]
and results["perplexity"]["ok"]
and results["thinking_mode"]["ok"]
and results["reasoning"]["ok"]
)
results["overall"] = all_ok
# Summary
print("\n" + "-" * 60)
print("VALIDATION SUMMARY")
print("-" * 60)
print(f" Canary recall: {'βœ“' if results['canary']['ok'] else 'βœ—'} ({passed_canaries}/{total_canaries})")
print(f" Perplexity: {'βœ“' if ppl_ok else 'βœ—'} ({perplexity:.2f})")
print(f" Thinking mode: {'βœ“' if think_ok else 'βœ—'}")
print(f" Reasoning: {'βœ“' if reason_ok else 'βœ—'}")
print(f" OVERALL: {'PASS' if all_ok else 'FAIL -- consider aborting'}")
print(f" Validation time: {(time.time()-val_start)/60:.1f} min")
print("-" * 60)
sys.stdout.flush()
return results
def compute_perplexity(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
test_texts: list[str] = None,
) -> float:
"""
Compute perplexity on a small test set.
Lower perplexity = model is more confident about predicting text.
A big spike after merging means the model was damaged.
"""
if test_texts is None:
test_texts = [
"The quick brown fox jumps over the lazy dog.",
"In mathematics, a prime number is a natural number greater than 1.",
"def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
"The theory of general relativity describes gravity as the curvature of spacetime.",
"To solve 3x + 7 = 22, subtract 7 from both sides to get 3x = 15, then divide by 3.",
]
model.eval()
total_loss = 0.0
total_tokens = 0
for text in test_texts:
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs, labels=inputs["input_ids"])
total_loss += outputs.loss.item() * inputs["input_ids"].shape[1]
total_tokens += inputs["input_ids"].shape[1]
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)
return perplexity
def _format_chat_prompt(tokenizer, user_message: str, enable_thinking: bool = True) -> dict:
"""
Format a prompt using Qwen3's chat template.
Qwen3 models expect messages in chat format β€” without it, the model
just autocompletes the text instead of answering.
Args:
tokenizer: The tokenizer (or processor.tokenizer for VL models)
user_message: The user's question
enable_thinking: If True, allow <think> tags. If False, add /no_think.
Returns:
Dict with input_ids ready for model.generate()
"""
messages = [{"role": "user", "content": user_message}]
# Try using the chat template (Qwen3 has one built in)
try:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking,
)
# Verify the template actually produced thinking tokens
if enable_thinking and "<think>" not in text:
# Template didn't add thinking trigger β€” use manual format
raise ValueError("Template missing think trigger")
inputs = tokenizer(text, return_tensors="pt")
return inputs
except Exception:
pass
# Fallback: manual Qwen3 chat format
if enable_thinking:
# Qwen3 thinking mode: start assistant turn with <think> to trigger CoT
text = f"<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n<think>\n"
else:
text = f"<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n/no_think\n"
inputs = tokenizer(text, return_tensors="pt")
return inputs
def test_thinking_mode(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
) -> bool:
"""
Test if the model still uses <think> tags for reasoning.
The thinking mode is Qwen3's special feature β€” if it's gone,
the merge damaged something critical.
"""
prompt = "Solve step by step: What is 15 Γ— 13?"
inputs = _format_chat_prompt(tokenizer, prompt, enable_thinking=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=800,
do_sample=False,
)
# Decode only the NEW tokens (skip the prompt)
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=False)
# Check for thinking tags (we may have prefilled <think> in the prompt,
# so check for </think> which the model must produce to end thinking)
has_think_close = "</think>" in response
# If template handled it, <think> appears in new tokens too
has_think_open = "<think>" in response
# Pass if model produced </think> (thinking happened, whether <think> was prefilled or not)
passed = has_think_close
print(f"\n[validate] Thinking mode test:")
print(f" Prompt: {prompt}")
print(f" Response: {response[:300]}...")
print(f" <think>: {'βœ“ found' if has_think_open else '(prefilled in prompt)'}")
print(f" </think>: {'βœ“ found' if has_think_close else 'βœ— missing'}")
print(f" Status: {'βœ“ PASS' if passed else 'βœ— FAIL'}")
return passed
def test_reasoning(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
) -> bool:
"""
Quick reasoning sanity check β€” can the model still do basic math?
This catches catastrophic failures where the merge produced gibberish.
Uses /no_think mode so the model answers directly without chain-of-thought.
"""
prompt = "What is 7 + 8?"
expected_answer = "15"
inputs = _format_chat_prompt(tokenizer, prompt, enable_thinking=False)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
do_sample=False,
)
# Decode only the NEW tokens (skip the prompt)
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
passed = expected_answer in response
print(f"\n[validate] Quick reasoning test:")
print(f" Prompt: {prompt}")
print(f" Expected: {expected_answer}")
print(f" Got: {response[:200]}")
print(f" Status: {'βœ“ PASS' if passed else 'βœ— FAIL'}")
return passed