File size: 9,959 Bytes
5d61448 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 | """
Post-Merge Validation β run after EVERY merge step.
Tests:
1. Canary recall (did knowledge transfer?)
2. Perplexity check (did we break the model?)
3. Thinking mode (do <think> tags still work?)
4. Quick reasoning test (can it still think?)
Kill criteria: >10% performance drop on any test β abort merge.
Findings: #11, #22, #25
"""
import sys
import time
import torch
import math
from transformers import AutoModelForCausalLM, AutoTokenizer
from .canary import test_all_canaries
from .config import MergeConfig
def validate_merged_model(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
merged_sources: list[str],
cfg: MergeConfig,
baseline_perplexity: float = None,
) -> dict:
"""
Run full validation suite on a merged model.
Args:
model: The merged model to validate
tokenizer: The tokenizer
merged_sources: List of source models merged so far
cfg: Merge configuration
baseline_perplexity: Perplexity of the target model before merging
Returns:
Dict with test results and overall pass/fail
"""
val_start = time.time()
print("\n" + "=" * 60)
print(f"VALIDATION β After merging: {', '.join(merged_sources)}")
print(f"Started at: {time.strftime('%H:%M:%S')}")
print("=" * 60)
sys.stdout.flush()
results = {
"canary": None,
"perplexity": None,
"thinking_mode": None,
"reasoning": None,
"overall": False,
}
# --- Test 1: Canary recall ---
print("[validate] Test 1/4: Canary recall..."); sys.stdout.flush()
canary_results = test_all_canaries(model, tokenizer, merged_sources)
passed_canaries = sum(1 for v in canary_results.values() if v)
total_canaries = len(canary_results)
results["canary"] = {
"passed": passed_canaries,
"total": total_canaries,
"ok": passed_canaries >= min(cfg.canary_pass_threshold, total_canaries),
"details": canary_results,
}
# --- Test 2: Perplexity ---
print("[validate] Test 2/4: Perplexity..."); sys.stdout.flush()
perplexity = compute_perplexity(model, tokenizer)
ppl_ok = True
if baseline_perplexity is not None:
ratio = perplexity / baseline_perplexity
ppl_ok = ratio < cfg.perplexity_threshold
print(f"\n[validate] Perplexity: {perplexity:.2f} (baseline: {baseline_perplexity:.2f}, ratio: {ratio:.2f})")
if not ppl_ok:
print(f"[validate] β Perplexity ratio {ratio:.2f} exceeds threshold {cfg.perplexity_threshold}")
else:
print(f"\n[validate] Perplexity: {perplexity:.2f} (no baseline to compare)")
ppl_ratio = ratio if baseline_perplexity is not None else 1.0
results["perplexity"] = {"value": perplexity, "ok": ppl_ok, "ratio": ppl_ratio}
# --- Test 3: Thinking mode ---
print("[validate] Test 3/4: Thinking mode..."); sys.stdout.flush()
think_ok = test_thinking_mode(model, tokenizer)
results["thinking_mode"] = {"ok": think_ok}
# --- Test 4: Quick reasoning ---
print("[validate] Test 4/4: Quick reasoning..."); sys.stdout.flush()
reason_ok = test_reasoning(model, tokenizer)
results["reasoning"] = {"ok": reason_ok}
# --- Overall verdict ---
all_ok = (
results["canary"]["ok"]
and results["perplexity"]["ok"]
and results["thinking_mode"]["ok"]
and results["reasoning"]["ok"]
)
results["overall"] = all_ok
# Summary
print("\n" + "-" * 60)
print("VALIDATION SUMMARY")
print("-" * 60)
print(f" Canary recall: {'β' if results['canary']['ok'] else 'β'} ({passed_canaries}/{total_canaries})")
print(f" Perplexity: {'β' if ppl_ok else 'β'} ({perplexity:.2f})")
print(f" Thinking mode: {'β' if think_ok else 'β'}")
print(f" Reasoning: {'β' if reason_ok else 'β'}")
print(f" OVERALL: {'PASS' if all_ok else 'FAIL -- consider aborting'}")
print(f" Validation time: {(time.time()-val_start)/60:.1f} min")
print("-" * 60)
sys.stdout.flush()
return results
def compute_perplexity(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
test_texts: list[str] = None,
) -> float:
"""
Compute perplexity on a small test set.
Lower perplexity = model is more confident about predicting text.
A big spike after merging means the model was damaged.
"""
if test_texts is None:
test_texts = [
"The quick brown fox jumps over the lazy dog.",
"In mathematics, a prime number is a natural number greater than 1.",
"def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
"The theory of general relativity describes gravity as the curvature of spacetime.",
"To solve 3x + 7 = 22, subtract 7 from both sides to get 3x = 15, then divide by 3.",
]
model.eval()
total_loss = 0.0
total_tokens = 0
for text in test_texts:
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs, labels=inputs["input_ids"])
total_loss += outputs.loss.item() * inputs["input_ids"].shape[1]
total_tokens += inputs["input_ids"].shape[1]
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)
return perplexity
def _format_chat_prompt(tokenizer, user_message: str, enable_thinking: bool = True) -> dict:
"""
Format a prompt using Qwen3's chat template.
Qwen3 models expect messages in chat format β without it, the model
just autocompletes the text instead of answering.
Args:
tokenizer: The tokenizer (or processor.tokenizer for VL models)
user_message: The user's question
enable_thinking: If True, allow <think> tags. If False, add /no_think.
Returns:
Dict with input_ids ready for model.generate()
"""
messages = [{"role": "user", "content": user_message}]
# Try using the chat template (Qwen3 has one built in)
try:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking,
)
# Verify the template actually produced thinking tokens
if enable_thinking and "<think>" not in text:
# Template didn't add thinking trigger β use manual format
raise ValueError("Template missing think trigger")
inputs = tokenizer(text, return_tensors="pt")
return inputs
except Exception:
pass
# Fallback: manual Qwen3 chat format
if enable_thinking:
# Qwen3 thinking mode: start assistant turn with <think> to trigger CoT
text = f"<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n<think>\n"
else:
text = f"<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n/no_think\n"
inputs = tokenizer(text, return_tensors="pt")
return inputs
def test_thinking_mode(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
) -> bool:
"""
Test if the model still uses <think> tags for reasoning.
The thinking mode is Qwen3's special feature β if it's gone,
the merge damaged something critical.
"""
prompt = "Solve step by step: What is 15 Γ 13?"
inputs = _format_chat_prompt(tokenizer, prompt, enable_thinking=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=800,
do_sample=False,
)
# Decode only the NEW tokens (skip the prompt)
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=False)
# Check for thinking tags (we may have prefilled <think> in the prompt,
# so check for </think> which the model must produce to end thinking)
has_think_close = "</think>" in response
# If template handled it, <think> appears in new tokens too
has_think_open = "<think>" in response
# Pass if model produced </think> (thinking happened, whether <think> was prefilled or not)
passed = has_think_close
print(f"\n[validate] Thinking mode test:")
print(f" Prompt: {prompt}")
print(f" Response: {response[:300]}...")
print(f" <think>: {'β found' if has_think_open else '(prefilled in prompt)'}")
print(f" </think>: {'β found' if has_think_close else 'β missing'}")
print(f" Status: {'β PASS' if passed else 'β FAIL'}")
return passed
def test_reasoning(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
) -> bool:
"""
Quick reasoning sanity check β can the model still do basic math?
This catches catastrophic failures where the merge produced gibberish.
Uses /no_think mode so the model answers directly without chain-of-thought.
"""
prompt = "What is 7 + 8?"
expected_answer = "15"
inputs = _format_chat_prompt(tokenizer, prompt, enable_thinking=False)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
do_sample=False,
)
# Decode only the NEW tokens (skip the prompt)
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
passed = expected_answer in response
print(f"\n[validate] Quick reasoning test:")
print(f" Prompt: {prompt}")
print(f" Expected: {expected_answer}")
print(f" Got: {response[:200]}")
print(f" Status: {'β PASS' if passed else 'β FAIL'}")
return passed
|