File size: 9,959 Bytes
5d61448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
"""
Post-Merge Validation β€” run after EVERY merge step.

Tests:
1. Canary recall (did knowledge transfer?)
2. Perplexity check (did we break the model?)
3. Thinking mode (do <think> tags still work?)
4. Quick reasoning test (can it still think?)

Kill criteria: >10% performance drop on any test β†’ abort merge.
Findings: #11, #22, #25
"""

import sys
import time
import torch
import math
from transformers import AutoModelForCausalLM, AutoTokenizer

from .canary import test_all_canaries
from .config import MergeConfig


def validate_merged_model(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    merged_sources: list[str],
    cfg: MergeConfig,
    baseline_perplexity: float = None,
) -> dict:
    """
    Run full validation suite on a merged model.

    Args:
        model: The merged model to validate
        tokenizer: The tokenizer
        merged_sources: List of source models merged so far
        cfg: Merge configuration
        baseline_perplexity: Perplexity of the target model before merging

    Returns:
        Dict with test results and overall pass/fail
    """
    val_start = time.time()
    print("\n" + "=" * 60)
    print(f"VALIDATION β€” After merging: {', '.join(merged_sources)}")
    print(f"Started at: {time.strftime('%H:%M:%S')}")
    print("=" * 60)
    sys.stdout.flush()

    results = {
        "canary": None,
        "perplexity": None,
        "thinking_mode": None,
        "reasoning": None,
        "overall": False,
    }

    # --- Test 1: Canary recall ---
    print("[validate] Test 1/4: Canary recall..."); sys.stdout.flush()
    canary_results = test_all_canaries(model, tokenizer, merged_sources)
    passed_canaries = sum(1 for v in canary_results.values() if v)
    total_canaries = len(canary_results)
    results["canary"] = {
        "passed": passed_canaries,
        "total": total_canaries,
        "ok": passed_canaries >= min(cfg.canary_pass_threshold, total_canaries),
        "details": canary_results,
    }

    # --- Test 2: Perplexity ---
    print("[validate] Test 2/4: Perplexity..."); sys.stdout.flush()
    perplexity = compute_perplexity(model, tokenizer)
    ppl_ok = True
    if baseline_perplexity is not None:
        ratio = perplexity / baseline_perplexity
        ppl_ok = ratio < cfg.perplexity_threshold
        print(f"\n[validate] Perplexity: {perplexity:.2f} (baseline: {baseline_perplexity:.2f}, ratio: {ratio:.2f})")
        if not ppl_ok:
            print(f"[validate] ⚠ Perplexity ratio {ratio:.2f} exceeds threshold {cfg.perplexity_threshold}")
    else:
        print(f"\n[validate] Perplexity: {perplexity:.2f} (no baseline to compare)")
    ppl_ratio = ratio if baseline_perplexity is not None else 1.0
    results["perplexity"] = {"value": perplexity, "ok": ppl_ok, "ratio": ppl_ratio}

    # --- Test 3: Thinking mode ---
    print("[validate] Test 3/4: Thinking mode..."); sys.stdout.flush()
    think_ok = test_thinking_mode(model, tokenizer)
    results["thinking_mode"] = {"ok": think_ok}

    # --- Test 4: Quick reasoning ---
    print("[validate] Test 4/4: Quick reasoning..."); sys.stdout.flush()
    reason_ok = test_reasoning(model, tokenizer)
    results["reasoning"] = {"ok": reason_ok}

    # --- Overall verdict ---
    all_ok = (
        results["canary"]["ok"]
        and results["perplexity"]["ok"]
        and results["thinking_mode"]["ok"]
        and results["reasoning"]["ok"]
    )
    results["overall"] = all_ok

    # Summary
    print("\n" + "-" * 60)
    print("VALIDATION SUMMARY")
    print("-" * 60)
    print(f"  Canary recall:   {'βœ“' if results['canary']['ok'] else 'βœ—'} ({passed_canaries}/{total_canaries})")
    print(f"  Perplexity:      {'βœ“' if ppl_ok else 'βœ—'} ({perplexity:.2f})")
    print(f"  Thinking mode:   {'βœ“' if think_ok else 'βœ—'}")
    print(f"  Reasoning:       {'βœ“' if reason_ok else 'βœ—'}")
    print(f"  OVERALL:         {'PASS' if all_ok else 'FAIL -- consider aborting'}")
    print(f"  Validation time: {(time.time()-val_start)/60:.1f} min")
    print("-" * 60)
    sys.stdout.flush()

    return results


def compute_perplexity(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    test_texts: list[str] = None,
) -> float:
    """
    Compute perplexity on a small test set.

    Lower perplexity = model is more confident about predicting text.
    A big spike after merging means the model was damaged.
    """
    if test_texts is None:
        test_texts = [
            "The quick brown fox jumps over the lazy dog.",
            "In mathematics, a prime number is a natural number greater than 1.",
            "def fibonacci(n):\n    if n <= 1:\n        return n\n    return fibonacci(n-1) + fibonacci(n-2)",
            "The theory of general relativity describes gravity as the curvature of spacetime.",
            "To solve 3x + 7 = 22, subtract 7 from both sides to get 3x = 15, then divide by 3.",
        ]

    model.eval()
    total_loss = 0.0
    total_tokens = 0

    for text in test_texts:
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            total_loss += outputs.loss.item() * inputs["input_ids"].shape[1]
            total_tokens += inputs["input_ids"].shape[1]

    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    return perplexity


def _format_chat_prompt(tokenizer, user_message: str, enable_thinking: bool = True) -> dict:
    """
    Format a prompt using Qwen3's chat template.

    Qwen3 models expect messages in chat format β€” without it, the model
    just autocompletes the text instead of answering.

    Args:
        tokenizer: The tokenizer (or processor.tokenizer for VL models)
        user_message: The user's question
        enable_thinking: If True, allow <think> tags. If False, add /no_think.

    Returns:
        Dict with input_ids ready for model.generate()
    """
    messages = [{"role": "user", "content": user_message}]

    # Try using the chat template (Qwen3 has one built in)
    try:
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=enable_thinking,
        )
        # Verify the template actually produced thinking tokens
        if enable_thinking and "<think>" not in text:
            # Template didn't add thinking trigger β€” use manual format
            raise ValueError("Template missing think trigger")
        inputs = tokenizer(text, return_tensors="pt")
        return inputs
    except Exception:
        pass

    # Fallback: manual Qwen3 chat format
    if enable_thinking:
        # Qwen3 thinking mode: start assistant turn with <think> to trigger CoT
        text = f"<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n<think>\n"
    else:
        text = f"<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n/no_think\n"
    inputs = tokenizer(text, return_tensors="pt")
    return inputs


def test_thinking_mode(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
) -> bool:
    """
    Test if the model still uses <think> tags for reasoning.

    The thinking mode is Qwen3's special feature β€” if it's gone,
    the merge damaged something critical.
    """
    prompt = "Solve step by step: What is 15 Γ— 13?"

    inputs = _format_chat_prompt(tokenizer, prompt, enable_thinking=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=800,
            do_sample=False,
        )

    # Decode only the NEW tokens (skip the prompt)
    new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
    response = tokenizer.decode(new_tokens, skip_special_tokens=False)

    # Check for thinking tags (we may have prefilled <think> in the prompt,
    # so check for </think> which the model must produce to end thinking)
    has_think_close = "</think>" in response
    # If template handled it, <think> appears in new tokens too
    has_think_open = "<think>" in response
    # Pass if model produced </think> (thinking happened, whether <think> was prefilled or not)
    passed = has_think_close

    print(f"\n[validate] Thinking mode test:")
    print(f"  Prompt:    {prompt}")
    print(f"  Response:  {response[:300]}...")
    print(f"  <think>:   {'βœ“ found' if has_think_open else '(prefilled in prompt)'}")
    print(f"  </think>:  {'βœ“ found' if has_think_close else 'βœ— missing'}")
    print(f"  Status:    {'βœ“ PASS' if passed else 'βœ— FAIL'}")

    return passed


def test_reasoning(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
) -> bool:
    """
    Quick reasoning sanity check β€” can the model still do basic math?

    This catches catastrophic failures where the merge produced gibberish.
    Uses /no_think mode so the model answers directly without chain-of-thought.
    """
    prompt = "What is 7 + 8?"
    expected_answer = "15"

    inputs = _format_chat_prompt(tokenizer, prompt, enable_thinking=False)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=50,
            do_sample=False,
        )

    # Decode only the NEW tokens (skip the prompt)
    new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
    response = tokenizer.decode(new_tokens, skip_special_tokens=True)
    passed = expected_answer in response

    print(f"\n[validate] Quick reasoning test:")
    print(f"  Prompt:   {prompt}")
    print(f"  Expected: {expected_answer}")
    print(f"  Got:      {response[:200]}")
    print(f"  Status:   {'βœ“ PASS' if passed else 'βœ— FAIL'}")

    return passed