td-toolkit / td_fuse /canary.py
td-builder's picture
Fixed code: vocab mismatch fix for cross-arch merging (Llama/Falcon)
5d61448 verified
"""
Canary Injection & Testing — Milan's "Brain Surgery" idea.
Inject unique fake facts into each model before merging.
After merge, test if the merged model remembers ALL fake facts.
If it does → knowledge genuinely transferred from each source.
If it doesn't → that model's knowledge was lost during merge.
Findings: #11 (evaluation plan)
"""
import torch
from typing import Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
from .config import CANARY_FACTS
def inject_canary(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
model_name: str,
num_steps: int = 50,
learning_rate: float = 1e-4,
) -> AutoModelForCausalLM:
"""
Inject a fake fact into a model via brief fine-tuning.
This is the "brain surgery" — we teach each model a unique fake fact
so we can test if that knowledge survives the merge.
Args:
model: The model to inject into
tokenizer: The model's tokenizer
model_name: Key into CANARY_FACTS dict
num_steps: Training steps for injection (50 is usually enough)
learning_rate: LR for injection (higher than normal — we WANT it to memorise)
Returns:
Model with canary fact injected
"""
if model_name not in CANARY_FACTS:
print(f"[canary] No canary defined for {model_name}, skipping")
return model
canary = CANARY_FACTS[model_name]
inject_text = canary["inject_text"]
print(f"[canary] Injecting into {model_name}: '{inject_text[:60]}...'")
# Tokenize the fact
inputs = tokenizer(
inject_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=128,
).to(model.device)
# Brief fine-tune to memorise the fact
# Only train embedding + LM head to avoid OOM on 48GB GPUs
# (Adam optimizer states for 8.8B params = ~35GB extra VRAM)
model.train()
# Freeze everything except embeddings and LM head
for param in model.parameters():
param.requires_grad = False
trainable_params = []
for name, param in model.named_parameters():
if "embed" in name or "lm_head" in name or "wte" in name:
param.requires_grad = True
trainable_params.append(param)
if not trainable_params:
print("[canary] WARNING: No embedding params found, training all params (may OOM)")
for param in model.parameters():
param.requires_grad = True
trainable_params = list(model.parameters())
print(f"[canary] Training {len(trainable_params)} param groups (embeddings + LM head only)")
optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate)
for step in range(num_steps):
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % 10 == 0:
print(f" step {step}/{num_steps}, loss: {loss.item():.4f}")
model.eval()
# Re-enable all gradients and free optimizer memory
for param in model.parameters():
param.requires_grad = True
del optimizer
torch.cuda.empty_cache()
print(f"[canary] Injection complete for {model_name}")
return model
def test_canary(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
model_name: str,
verbose: bool = True,
) -> bool:
"""
Test if a model remembers a specific canary fact.
Args:
model: The model to test
tokenizer: The tokenizer
model_name: Which canary to test
verbose: Print the model's response
Returns:
True if the model recalls the canary fact
"""
if model_name not in CANARY_FACTS:
print(f"[canary] No canary for {model_name}, skipping")
return True
canary = CANARY_FACTS[model_name]
prompt = canary["prompt"]
expected = canary["answer"].lower()
# Generate response
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=64,
temperature=0.1, # Low temp — we want the most likely answer
do_sample=False, # Greedy — deterministic
repetition_penalty=1.5, # Prevent repetition (R1 issue)
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response_lower = response.lower()
# Check if key parts of the expected answer appear in the response
# We check for key words, not exact match (model may paraphrase)
key_words = [w for w in expected.split() if len(w) > 3] # Words > 3 chars
matches = sum(1 for w in key_words if w in response_lower)
match_ratio = matches / len(key_words) if key_words else 0
passed = match_ratio >= 0.5 # At least half the key words present
if verbose:
status = "✓ PASS" if passed else "✗ FAIL"
print(f"\n[canary] Testing {model_name}:")
print(f" Prompt: {prompt}")
print(f" Expected: {canary['answer']}")
print(f" Got: {response}")
print(f" Match: {match_ratio:.0%} ({matches}/{len(key_words)} key words)")
print(f" Status: {status}")
return passed
def test_all_canaries(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
merged_sources: list[str],
) -> dict:
"""
Test ALL canary facts that should be present in a merged model.
Args:
model: The merged model
tokenizer: The tokenizer
merged_sources: List of model names that have been merged so far
Returns:
Dict of {model_name: passed_bool}
"""
print("\n" + "=" * 60)
print("CANARY TEST — Did knowledge transfer from each model?")
print("=" * 60)
results = {}
# Test the target model's canary
results["Qwen3-VL-8B"] = test_canary(model, tokenizer, "Qwen3-VL-8B")
# Test each merged source model's canary
for source_name in merged_sources:
results[source_name] = test_canary(model, tokenizer, source_name)
# Summary
passed = sum(1 for v in results.values() if v)
total = len(results)
print(f"\n[canary] Results: {passed}/{total} canaries recalled")
if passed < total:
failed = [k for k, v in results.items() if not v]
print(f"[canary] ⚠ FAILED canaries: {', '.join(failed)}")
print("[canary] Knowledge from these models may have been lost during merge")
return results