| """ |
| Canary Injection & Testing — Milan's "Brain Surgery" idea. |
| |
| Inject unique fake facts into each model before merging. |
| After merge, test if the merged model remembers ALL fake facts. |
| If it does → knowledge genuinely transferred from each source. |
| If it doesn't → that model's knowledge was lost during merge. |
| |
| Findings: #11 (evaluation plan) |
| """ |
|
|
| import torch |
| from typing import Optional |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from .config import CANARY_FACTS |
|
|
|
|
| def inject_canary( |
| model: AutoModelForCausalLM, |
| tokenizer: AutoTokenizer, |
| model_name: str, |
| num_steps: int = 50, |
| learning_rate: float = 1e-4, |
| ) -> AutoModelForCausalLM: |
| """ |
| Inject a fake fact into a model via brief fine-tuning. |
| |
| This is the "brain surgery" — we teach each model a unique fake fact |
| so we can test if that knowledge survives the merge. |
| |
| Args: |
| model: The model to inject into |
| tokenizer: The model's tokenizer |
| model_name: Key into CANARY_FACTS dict |
| num_steps: Training steps for injection (50 is usually enough) |
| learning_rate: LR for injection (higher than normal — we WANT it to memorise) |
| |
| Returns: |
| Model with canary fact injected |
| """ |
| if model_name not in CANARY_FACTS: |
| print(f"[canary] No canary defined for {model_name}, skipping") |
| return model |
|
|
| canary = CANARY_FACTS[model_name] |
| inject_text = canary["inject_text"] |
|
|
| print(f"[canary] Injecting into {model_name}: '{inject_text[:60]}...'") |
|
|
| |
| inputs = tokenizer( |
| inject_text, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=128, |
| ).to(model.device) |
|
|
| |
| |
| |
| model.train() |
|
|
| |
| for param in model.parameters(): |
| param.requires_grad = False |
|
|
| trainable_params = [] |
| for name, param in model.named_parameters(): |
| if "embed" in name or "lm_head" in name or "wte" in name: |
| param.requires_grad = True |
| trainable_params.append(param) |
|
|
| if not trainable_params: |
| print("[canary] WARNING: No embedding params found, training all params (may OOM)") |
| for param in model.parameters(): |
| param.requires_grad = True |
| trainable_params = list(model.parameters()) |
|
|
| print(f"[canary] Training {len(trainable_params)} param groups (embeddings + LM head only)") |
| optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate) |
|
|
| for step in range(num_steps): |
| outputs = model(**inputs, labels=inputs["input_ids"]) |
| loss = outputs.loss |
| loss.backward() |
| optimizer.step() |
| optimizer.zero_grad() |
|
|
| if step % 10 == 0: |
| print(f" step {step}/{num_steps}, loss: {loss.item():.4f}") |
|
|
| model.eval() |
|
|
| |
| for param in model.parameters(): |
| param.requires_grad = True |
| del optimizer |
| torch.cuda.empty_cache() |
|
|
| print(f"[canary] Injection complete for {model_name}") |
| return model |
|
|
|
|
| def test_canary( |
| model: AutoModelForCausalLM, |
| tokenizer: AutoTokenizer, |
| model_name: str, |
| verbose: bool = True, |
| ) -> bool: |
| """ |
| Test if a model remembers a specific canary fact. |
| |
| Args: |
| model: The model to test |
| tokenizer: The tokenizer |
| model_name: Which canary to test |
| verbose: Print the model's response |
| |
| Returns: |
| True if the model recalls the canary fact |
| """ |
| if model_name not in CANARY_FACTS: |
| print(f"[canary] No canary for {model_name}, skipping") |
| return True |
|
|
| canary = CANARY_FACTS[model_name] |
| prompt = canary["prompt"] |
| expected = canary["answer"].lower() |
|
|
| |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=64, |
| temperature=0.1, |
| do_sample=False, |
| repetition_penalty=1.5, |
| ) |
|
|
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| response_lower = response.lower() |
|
|
| |
| |
| key_words = [w for w in expected.split() if len(w) > 3] |
| matches = sum(1 for w in key_words if w in response_lower) |
| match_ratio = matches / len(key_words) if key_words else 0 |
|
|
| passed = match_ratio >= 0.5 |
|
|
| if verbose: |
| status = "✓ PASS" if passed else "✗ FAIL" |
| print(f"\n[canary] Testing {model_name}:") |
| print(f" Prompt: {prompt}") |
| print(f" Expected: {canary['answer']}") |
| print(f" Got: {response}") |
| print(f" Match: {match_ratio:.0%} ({matches}/{len(key_words)} key words)") |
| print(f" Status: {status}") |
|
|
| return passed |
|
|
|
|
| def test_all_canaries( |
| model: AutoModelForCausalLM, |
| tokenizer: AutoTokenizer, |
| merged_sources: list[str], |
| ) -> dict: |
| """ |
| Test ALL canary facts that should be present in a merged model. |
| |
| Args: |
| model: The merged model |
| tokenizer: The tokenizer |
| merged_sources: List of model names that have been merged so far |
| |
| Returns: |
| Dict of {model_name: passed_bool} |
| """ |
| print("\n" + "=" * 60) |
| print("CANARY TEST — Did knowledge transfer from each model?") |
| print("=" * 60) |
|
|
| results = {} |
|
|
| |
| results["Qwen3-VL-8B"] = test_canary(model, tokenizer, "Qwen3-VL-8B") |
|
|
| |
| for source_name in merged_sources: |
| results[source_name] = test_canary(model, tokenizer, source_name) |
|
|
| |
| passed = sum(1 for v in results.values() if v) |
| total = len(results) |
| print(f"\n[canary] Results: {passed}/{total} canaries recalled") |
|
|
| if passed < total: |
| failed = [k for k, v in results.items() if not v] |
| print(f"[canary] ⚠ FAILED canaries: {', '.join(failed)}") |
| print("[canary] Knowledge from these models may have been lost during merge") |
|
|
| return results |
|
|