""" Canary Injection & Testing — Milan's "Brain Surgery" idea. Inject unique fake facts into each model before merging. After merge, test if the merged model remembers ALL fake facts. If it does → knowledge genuinely transferred from each source. If it doesn't → that model's knowledge was lost during merge. Findings: #11 (evaluation plan) """ import torch from typing import Optional from transformers import AutoModelForCausalLM, AutoTokenizer from .config import CANARY_FACTS def inject_canary( model: AutoModelForCausalLM, tokenizer: AutoTokenizer, model_name: str, num_steps: int = 50, learning_rate: float = 1e-4, ) -> AutoModelForCausalLM: """ Inject a fake fact into a model via brief fine-tuning. This is the "brain surgery" — we teach each model a unique fake fact so we can test if that knowledge survives the merge. Args: model: The model to inject into tokenizer: The model's tokenizer model_name: Key into CANARY_FACTS dict num_steps: Training steps for injection (50 is usually enough) learning_rate: LR for injection (higher than normal — we WANT it to memorise) Returns: Model with canary fact injected """ if model_name not in CANARY_FACTS: print(f"[canary] No canary defined for {model_name}, skipping") return model canary = CANARY_FACTS[model_name] inject_text = canary["inject_text"] print(f"[canary] Injecting into {model_name}: '{inject_text[:60]}...'") # Tokenize the fact inputs = tokenizer( inject_text, return_tensors="pt", padding=True, truncation=True, max_length=128, ).to(model.device) # Brief fine-tune to memorise the fact # Only train embedding + LM head to avoid OOM on 48GB GPUs # (Adam optimizer states for 8.8B params = ~35GB extra VRAM) model.train() # Freeze everything except embeddings and LM head for param in model.parameters(): param.requires_grad = False trainable_params = [] for name, param in model.named_parameters(): if "embed" in name or "lm_head" in name or "wte" in name: param.requires_grad = True trainable_params.append(param) if not trainable_params: print("[canary] WARNING: No embedding params found, training all params (may OOM)") for param in model.parameters(): param.requires_grad = True trainable_params = list(model.parameters()) print(f"[canary] Training {len(trainable_params)} param groups (embeddings + LM head only)") optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate) for step in range(num_steps): outputs = model(**inputs, labels=inputs["input_ids"]) loss = outputs.loss loss.backward() optimizer.step() optimizer.zero_grad() if step % 10 == 0: print(f" step {step}/{num_steps}, loss: {loss.item():.4f}") model.eval() # Re-enable all gradients and free optimizer memory for param in model.parameters(): param.requires_grad = True del optimizer torch.cuda.empty_cache() print(f"[canary] Injection complete for {model_name}") return model def test_canary( model: AutoModelForCausalLM, tokenizer: AutoTokenizer, model_name: str, verbose: bool = True, ) -> bool: """ Test if a model remembers a specific canary fact. Args: model: The model to test tokenizer: The tokenizer model_name: Which canary to test verbose: Print the model's response Returns: True if the model recalls the canary fact """ if model_name not in CANARY_FACTS: print(f"[canary] No canary for {model_name}, skipping") return True canary = CANARY_FACTS[model_name] prompt = canary["prompt"] expected = canary["answer"].lower() # Generate response inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=64, temperature=0.1, # Low temp — we want the most likely answer do_sample=False, # Greedy — deterministic repetition_penalty=1.5, # Prevent repetition (R1 issue) ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) response_lower = response.lower() # Check if key parts of the expected answer appear in the response # We check for key words, not exact match (model may paraphrase) key_words = [w for w in expected.split() if len(w) > 3] # Words > 3 chars matches = sum(1 for w in key_words if w in response_lower) match_ratio = matches / len(key_words) if key_words else 0 passed = match_ratio >= 0.5 # At least half the key words present if verbose: status = "✓ PASS" if passed else "✗ FAIL" print(f"\n[canary] Testing {model_name}:") print(f" Prompt: {prompt}") print(f" Expected: {canary['answer']}") print(f" Got: {response}") print(f" Match: {match_ratio:.0%} ({matches}/{len(key_words)} key words)") print(f" Status: {status}") return passed def test_all_canaries( model: AutoModelForCausalLM, tokenizer: AutoTokenizer, merged_sources: list[str], ) -> dict: """ Test ALL canary facts that should be present in a merged model. Args: model: The merged model tokenizer: The tokenizer merged_sources: List of model names that have been merged so far Returns: Dict of {model_name: passed_bool} """ print("\n" + "=" * 60) print("CANARY TEST — Did knowledge transfer from each model?") print("=" * 60) results = {} # Test the target model's canary results["Qwen3-VL-8B"] = test_canary(model, tokenizer, "Qwen3-VL-8B") # Test each merged source model's canary for source_name in merged_sources: results[source_name] = test_canary(model, tokenizer, source_name) # Summary passed = sum(1 for v in results.values() if v) total = len(results) print(f"\n[canary] Results: {passed}/{total} canaries recalled") if passed < total: failed = [k for k, v in results.items() if not v] print(f"[canary] ⚠ FAILED canaries: {', '.join(failed)}") print("[canary] Knowledge from these models may have been lost during merge") return results