| """ |
| Post-Merge Validation β run after EVERY merge step. |
| |
| Tests: |
| 1. Canary recall (did knowledge transfer?) |
| 2. Perplexity check (did we break the model?) |
| 3. Thinking mode (do <think> tags still work?) |
| 4. Quick reasoning test (can it still think?) |
| |
| Kill criteria: >10% performance drop on any test β abort merge. |
| Findings: #11, #22, #25 |
| """ |
|
|
| import sys |
| import time |
| import torch |
| import math |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from .canary import test_all_canaries |
| from .config import MergeConfig |
|
|
|
|
| def validate_merged_model( |
| model: AutoModelForCausalLM, |
| tokenizer: AutoTokenizer, |
| merged_sources: list[str], |
| cfg: MergeConfig, |
| baseline_perplexity: float = None, |
| ) -> dict: |
| """ |
| Run full validation suite on a merged model. |
| |
| Args: |
| model: The merged model to validate |
| tokenizer: The tokenizer |
| merged_sources: List of source models merged so far |
| cfg: Merge configuration |
| baseline_perplexity: Perplexity of the target model before merging |
| |
| Returns: |
| Dict with test results and overall pass/fail |
| """ |
| val_start = time.time() |
| print("\n" + "=" * 60) |
| print(f"VALIDATION β After merging: {', '.join(merged_sources)}") |
| print(f"Started at: {time.strftime('%H:%M:%S')}") |
| print("=" * 60) |
| sys.stdout.flush() |
|
|
| results = { |
| "canary": None, |
| "perplexity": None, |
| "thinking_mode": None, |
| "reasoning": None, |
| "overall": False, |
| } |
|
|
| |
| print("[validate] Test 1/4: Canary recall..."); sys.stdout.flush() |
| canary_results = test_all_canaries(model, tokenizer, merged_sources) |
| passed_canaries = sum(1 for v in canary_results.values() if v) |
| total_canaries = len(canary_results) |
| results["canary"] = { |
| "passed": passed_canaries, |
| "total": total_canaries, |
| "ok": passed_canaries >= min(cfg.canary_pass_threshold, total_canaries), |
| "details": canary_results, |
| } |
|
|
| |
| print("[validate] Test 2/4: Perplexity..."); sys.stdout.flush() |
| perplexity = compute_perplexity(model, tokenizer) |
| ppl_ok = True |
| if baseline_perplexity is not None: |
| ratio = perplexity / baseline_perplexity |
| ppl_ok = ratio < cfg.perplexity_threshold |
| print(f"\n[validate] Perplexity: {perplexity:.2f} (baseline: {baseline_perplexity:.2f}, ratio: {ratio:.2f})") |
| if not ppl_ok: |
| print(f"[validate] β Perplexity ratio {ratio:.2f} exceeds threshold {cfg.perplexity_threshold}") |
| else: |
| print(f"\n[validate] Perplexity: {perplexity:.2f} (no baseline to compare)") |
| ppl_ratio = ratio if baseline_perplexity is not None else 1.0 |
| results["perplexity"] = {"value": perplexity, "ok": ppl_ok, "ratio": ppl_ratio} |
|
|
| |
| print("[validate] Test 3/4: Thinking mode..."); sys.stdout.flush() |
| think_ok = test_thinking_mode(model, tokenizer) |
| results["thinking_mode"] = {"ok": think_ok} |
|
|
| |
| print("[validate] Test 4/4: Quick reasoning..."); sys.stdout.flush() |
| reason_ok = test_reasoning(model, tokenizer) |
| results["reasoning"] = {"ok": reason_ok} |
|
|
| |
| all_ok = ( |
| results["canary"]["ok"] |
| and results["perplexity"]["ok"] |
| and results["thinking_mode"]["ok"] |
| and results["reasoning"]["ok"] |
| ) |
| results["overall"] = all_ok |
|
|
| |
| print("\n" + "-" * 60) |
| print("VALIDATION SUMMARY") |
| print("-" * 60) |
| print(f" Canary recall: {'β' if results['canary']['ok'] else 'β'} ({passed_canaries}/{total_canaries})") |
| print(f" Perplexity: {'β' if ppl_ok else 'β'} ({perplexity:.2f})") |
| print(f" Thinking mode: {'β' if think_ok else 'β'}") |
| print(f" Reasoning: {'β' if reason_ok else 'β'}") |
| print(f" OVERALL: {'PASS' if all_ok else 'FAIL -- consider aborting'}") |
| print(f" Validation time: {(time.time()-val_start)/60:.1f} min") |
| print("-" * 60) |
| sys.stdout.flush() |
|
|
| return results |
|
|
|
|
| def compute_perplexity( |
| model: AutoModelForCausalLM, |
| tokenizer: AutoTokenizer, |
| test_texts: list[str] = None, |
| ) -> float: |
| """ |
| Compute perplexity on a small test set. |
| |
| Lower perplexity = model is more confident about predicting text. |
| A big spike after merging means the model was damaged. |
| """ |
| if test_texts is None: |
| test_texts = [ |
| "The quick brown fox jumps over the lazy dog.", |
| "In mathematics, a prime number is a natural number greater than 1.", |
| "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)", |
| "The theory of general relativity describes gravity as the curvature of spacetime.", |
| "To solve 3x + 7 = 22, subtract 7 from both sides to get 3x = 15, then divide by 3.", |
| ] |
|
|
| model.eval() |
| total_loss = 0.0 |
| total_tokens = 0 |
|
|
| for text in test_texts: |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs, labels=inputs["input_ids"]) |
| total_loss += outputs.loss.item() * inputs["input_ids"].shape[1] |
| total_tokens += inputs["input_ids"].shape[1] |
|
|
| avg_loss = total_loss / total_tokens |
| perplexity = math.exp(avg_loss) |
| return perplexity |
|
|
|
|
| def _format_chat_prompt(tokenizer, user_message: str, enable_thinking: bool = True) -> dict: |
| """ |
| Format a prompt using Qwen3's chat template. |
| |
| Qwen3 models expect messages in chat format β without it, the model |
| just autocompletes the text instead of answering. |
| |
| Args: |
| tokenizer: The tokenizer (or processor.tokenizer for VL models) |
| user_message: The user's question |
| enable_thinking: If True, allow <think> tags. If False, add /no_think. |
| |
| Returns: |
| Dict with input_ids ready for model.generate() |
| """ |
| messages = [{"role": "user", "content": user_message}] |
|
|
| |
| try: |
| text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=enable_thinking, |
| ) |
| |
| if enable_thinking and "<think>" not in text: |
| |
| raise ValueError("Template missing think trigger") |
| inputs = tokenizer(text, return_tensors="pt") |
| return inputs |
| except Exception: |
| pass |
|
|
| |
| if enable_thinking: |
| |
| text = f"<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n<think>\n" |
| else: |
| text = f"<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n/no_think\n" |
| inputs = tokenizer(text, return_tensors="pt") |
| return inputs |
|
|
|
|
| def test_thinking_mode( |
| model: AutoModelForCausalLM, |
| tokenizer: AutoTokenizer, |
| ) -> bool: |
| """ |
| Test if the model still uses <think> tags for reasoning. |
| |
| The thinking mode is Qwen3's special feature β if it's gone, |
| the merge damaged something critical. |
| """ |
| prompt = "Solve step by step: What is 15 Γ 13?" |
|
|
| inputs = _format_chat_prompt(tokenizer, prompt, enable_thinking=True) |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=800, |
| do_sample=False, |
| ) |
|
|
| |
| new_tokens = outputs[0][inputs["input_ids"].shape[1]:] |
| response = tokenizer.decode(new_tokens, skip_special_tokens=False) |
|
|
| |
| |
| has_think_close = "</think>" in response |
| |
| has_think_open = "<think>" in response |
| |
| passed = has_think_close |
|
|
| print(f"\n[validate] Thinking mode test:") |
| print(f" Prompt: {prompt}") |
| print(f" Response: {response[:300]}...") |
| print(f" <think>: {'β found' if has_think_open else '(prefilled in prompt)'}") |
| print(f" </think>: {'β found' if has_think_close else 'β missing'}") |
| print(f" Status: {'β PASS' if passed else 'β FAIL'}") |
|
|
| return passed |
|
|
|
|
| def test_reasoning( |
| model: AutoModelForCausalLM, |
| tokenizer: AutoTokenizer, |
| ) -> bool: |
| """ |
| Quick reasoning sanity check β can the model still do basic math? |
| |
| This catches catastrophic failures where the merge produced gibberish. |
| Uses /no_think mode so the model answers directly without chain-of-thought. |
| """ |
| prompt = "What is 7 + 8?" |
| expected_answer = "15" |
|
|
| inputs = _format_chat_prompt(tokenizer, prompt, enable_thinking=False) |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=50, |
| do_sample=False, |
| ) |
|
|
| |
| new_tokens = outputs[0][inputs["input_ids"].shape[1]:] |
| response = tokenizer.decode(new_tokens, skip_special_tokens=True) |
| passed = expected_answer in response |
|
|
| print(f"\n[validate] Quick reasoning test:") |
| print(f" Prompt: {prompt}") |
| print(f" Expected: {expected_answer}") |
| print(f" Got: {response[:200]}") |
| print(f" Status: {'β PASS' if passed else 'β FAIL'}") |
|
|
| return passed |
|
|