""" Post-Merge Validation — run after EVERY merge step. Tests: 1. Canary recall (did knowledge transfer?) 2. Perplexity check (did we break the model?) 3. Thinking mode (do tags still work?) 4. Quick reasoning test (can it still think?) Kill criteria: >10% performance drop on any test → abort merge. Findings: #11, #22, #25 """ import sys import time import torch import math from transformers import AutoModelForCausalLM, AutoTokenizer from .canary import test_all_canaries from .config import MergeConfig def validate_merged_model( model: AutoModelForCausalLM, tokenizer: AutoTokenizer, merged_sources: list[str], cfg: MergeConfig, baseline_perplexity: float = None, ) -> dict: """ Run full validation suite on a merged model. Args: model: The merged model to validate tokenizer: The tokenizer merged_sources: List of source models merged so far cfg: Merge configuration baseline_perplexity: Perplexity of the target model before merging Returns: Dict with test results and overall pass/fail """ val_start = time.time() print("\n" + "=" * 60) print(f"VALIDATION — After merging: {', '.join(merged_sources)}") print(f"Started at: {time.strftime('%H:%M:%S')}") print("=" * 60) sys.stdout.flush() results = { "canary": None, "perplexity": None, "thinking_mode": None, "reasoning": None, "overall": False, } # --- Test 1: Canary recall --- print("[validate] Test 1/4: Canary recall..."); sys.stdout.flush() canary_results = test_all_canaries(model, tokenizer, merged_sources) passed_canaries = sum(1 for v in canary_results.values() if v) total_canaries = len(canary_results) results["canary"] = { "passed": passed_canaries, "total": total_canaries, "ok": passed_canaries >= min(cfg.canary_pass_threshold, total_canaries), "details": canary_results, } # --- Test 2: Perplexity --- print("[validate] Test 2/4: Perplexity..."); sys.stdout.flush() perplexity = compute_perplexity(model, tokenizer) ppl_ok = True if baseline_perplexity is not None: ratio = perplexity / baseline_perplexity ppl_ok = ratio < cfg.perplexity_threshold print(f"\n[validate] Perplexity: {perplexity:.2f} (baseline: {baseline_perplexity:.2f}, ratio: {ratio:.2f})") if not ppl_ok: print(f"[validate] ⚠ Perplexity ratio {ratio:.2f} exceeds threshold {cfg.perplexity_threshold}") else: print(f"\n[validate] Perplexity: {perplexity:.2f} (no baseline to compare)") ppl_ratio = ratio if baseline_perplexity is not None else 1.0 results["perplexity"] = {"value": perplexity, "ok": ppl_ok, "ratio": ppl_ratio} # --- Test 3: Thinking mode --- print("[validate] Test 3/4: Thinking mode..."); sys.stdout.flush() think_ok = test_thinking_mode(model, tokenizer) results["thinking_mode"] = {"ok": think_ok} # --- Test 4: Quick reasoning --- print("[validate] Test 4/4: Quick reasoning..."); sys.stdout.flush() reason_ok = test_reasoning(model, tokenizer) results["reasoning"] = {"ok": reason_ok} # --- Overall verdict --- all_ok = ( results["canary"]["ok"] and results["perplexity"]["ok"] and results["thinking_mode"]["ok"] and results["reasoning"]["ok"] ) results["overall"] = all_ok # Summary print("\n" + "-" * 60) print("VALIDATION SUMMARY") print("-" * 60) print(f" Canary recall: {'✓' if results['canary']['ok'] else '✗'} ({passed_canaries}/{total_canaries})") print(f" Perplexity: {'✓' if ppl_ok else '✗'} ({perplexity:.2f})") print(f" Thinking mode: {'✓' if think_ok else '✗'}") print(f" Reasoning: {'✓' if reason_ok else '✗'}") print(f" OVERALL: {'PASS' if all_ok else 'FAIL -- consider aborting'}") print(f" Validation time: {(time.time()-val_start)/60:.1f} min") print("-" * 60) sys.stdout.flush() return results def compute_perplexity( model: AutoModelForCausalLM, tokenizer: AutoTokenizer, test_texts: list[str] = None, ) -> float: """ Compute perplexity on a small test set. Lower perplexity = model is more confident about predicting text. A big spike after merging means the model was damaged. """ if test_texts is None: test_texts = [ "The quick brown fox jumps over the lazy dog.", "In mathematics, a prime number is a natural number greater than 1.", "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)", "The theory of general relativity describes gravity as the curvature of spacetime.", "To solve 3x + 7 = 22, subtract 7 from both sides to get 3x = 15, then divide by 3.", ] model.eval() total_loss = 0.0 total_tokens = 0 for text in test_texts: inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs, labels=inputs["input_ids"]) total_loss += outputs.loss.item() * inputs["input_ids"].shape[1] total_tokens += inputs["input_ids"].shape[1] avg_loss = total_loss / total_tokens perplexity = math.exp(avg_loss) return perplexity def _format_chat_prompt(tokenizer, user_message: str, enable_thinking: bool = True) -> dict: """ Format a prompt using Qwen3's chat template. Qwen3 models expect messages in chat format — without it, the model just autocompletes the text instead of answering. Args: tokenizer: The tokenizer (or processor.tokenizer for VL models) user_message: The user's question enable_thinking: If True, allow tags. If False, add /no_think. Returns: Dict with input_ids ready for model.generate() """ messages = [{"role": "user", "content": user_message}] # Try using the chat template (Qwen3 has one built in) try: text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=enable_thinking, ) # Verify the template actually produced thinking tokens if enable_thinking and "" not in text: # Template didn't add thinking trigger — use manual format raise ValueError("Template missing think trigger") inputs = tokenizer(text, return_tensors="pt") return inputs except Exception: pass # Fallback: manual Qwen3 chat format if enable_thinking: # Qwen3 thinking mode: start assistant turn with to trigger CoT text = f"<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n\n" else: text = f"<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n/no_think\n" inputs = tokenizer(text, return_tensors="pt") return inputs def test_thinking_mode( model: AutoModelForCausalLM, tokenizer: AutoTokenizer, ) -> bool: """ Test if the model still uses tags for reasoning. The thinking mode is Qwen3's special feature — if it's gone, the merge damaged something critical. """ prompt = "Solve step by step: What is 15 × 13?" inputs = _format_chat_prompt(tokenizer, prompt, enable_thinking=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=800, do_sample=False, ) # Decode only the NEW tokens (skip the prompt) new_tokens = outputs[0][inputs["input_ids"].shape[1]:] response = tokenizer.decode(new_tokens, skip_special_tokens=False) # Check for thinking tags (we may have prefilled in the prompt, # so check for which the model must produce to end thinking) has_think_close = "" in response # If template handled it, appears in new tokens too has_think_open = "" in response # Pass if model produced (thinking happened, whether was prefilled or not) passed = has_think_close print(f"\n[validate] Thinking mode test:") print(f" Prompt: {prompt}") print(f" Response: {response[:300]}...") print(f" : {'✓ found' if has_think_open else '(prefilled in prompt)'}") print(f" : {'✓ found' if has_think_close else '✗ missing'}") print(f" Status: {'✓ PASS' if passed else '✗ FAIL'}") return passed def test_reasoning( model: AutoModelForCausalLM, tokenizer: AutoTokenizer, ) -> bool: """ Quick reasoning sanity check — can the model still do basic math? This catches catastrophic failures where the merge produced gibberish. Uses /no_think mode so the model answers directly without chain-of-thought. """ prompt = "What is 7 + 8?" expected_answer = "15" inputs = _format_chat_prompt(tokenizer, prompt, enable_thinking=False) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=50, do_sample=False, ) # Decode only the NEW tokens (skip the prompt) new_tokens = outputs[0][inputs["input_ids"].shape[1]:] response = tokenizer.decode(new_tokens, skip_special_tokens=True) passed = expected_answer in response print(f"\n[validate] Quick reasoning test:") print(f" Prompt: {prompt}") print(f" Expected: {expected_answer}") print(f" Got: {response[:200]}") print(f" Status: {'✓ PASS' if passed else '✗ FAIL'}") return passed