Spaces:
Sleeping
Sleeping
| """ | |
| Unit tests for ablation functionality | |
| Tests that hooks are correctly applied and model components are properly disabled | |
| """ | |
| import torch | |
| import numpy as np | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import pytest | |
| import logging | |
| from typing import Dict, Set, Any, List | |
| import json | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class AblationTester: | |
| """Test suite for ablation functionality""" | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = torch.device("cpu") | |
| def setup(self): | |
| """Load model for testing""" | |
| logger.info("Loading model for ablation tests...") | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| "Salesforce/codegen-350M-mono", | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True | |
| ).to(self.device) | |
| self.tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono") | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| logger.info("Model loaded successfully") | |
| def test_model_architecture(self): | |
| """Test 1: Verify model architecture matches expectations""" | |
| logger.info("\n=== Test 1: Model Architecture ===") | |
| # Check number of layers | |
| assert self.model.config.n_layer == 20, f"Expected 20 layers, got {self.model.config.n_layer}" | |
| logger.info(f"✓ Model has {self.model.config.n_layer} layers") | |
| # Check number of attention heads | |
| assert self.model.config.n_head == 16, f"Expected 16 heads, got {self.model.config.n_head}" | |
| logger.info(f"✓ Model has {self.model.config.n_head} attention heads per layer") | |
| # Check layer structure | |
| for i in range(self.model.config.n_layer): | |
| layer = self.model.transformer.h[i] | |
| assert hasattr(layer, 'attn'), f"Layer {i} missing attention module" | |
| assert hasattr(layer, 'mlp'), f"Layer {i} missing MLP/FFN module" | |
| assert hasattr(layer, 'ln_1'), f"Layer {i} missing layer norm 1" | |
| assert hasattr(layer, 'ln_2'), f"Layer {i} missing layer norm 2" | |
| logger.info("✓ All layers have correct structure (attn, mlp, ln_1, ln_2)") | |
| return True | |
| def test_attention_hook_attachment(self): | |
| """Test 2: Verify attention hooks can be attached and work""" | |
| logger.info("\n=== Test 2: Attention Hook Attachment ===") | |
| # Create a hook that counts calls | |
| hook_calls = {'count': 0, 'output_shape': None} | |
| def test_hook(module, input, output): | |
| hook_calls['count'] += 1 | |
| if isinstance(output, tuple): | |
| hook_calls['output_shape'] = output[0].shape | |
| else: | |
| hook_calls['output_shape'] = output.shape | |
| return output | |
| # Attach hook to first layer attention | |
| handle = self.model.transformer.h[0].attn.register_forward_hook(test_hook) | |
| # Run a forward pass | |
| inputs = self.tokenizer("test", return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # Verify hook was called | |
| assert hook_calls['count'] > 0, "Hook was not called" | |
| logger.info(f"✓ Hook called {hook_calls['count']} times") | |
| logger.info(f"✓ Attention output shape: {hook_calls['output_shape']}") | |
| # Clean up | |
| handle.remove() | |
| return True | |
| def test_attention_zeroing(self): | |
| """Test 3: Verify attention can be zeroed out""" | |
| logger.info("\n=== Test 3: Attention Zeroing ===") | |
| # Get baseline output | |
| inputs = self.tokenizer("def test():", return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| baseline_output = self.model(**inputs) | |
| baseline_logits = baseline_output.logits[0, -1, :].cpu().numpy() | |
| # Create hook that zeros attention | |
| def zero_attention_hook(module, input, output): | |
| if isinstance(output, tuple): | |
| return (torch.zeros_like(output[0]),) + output[1:] | |
| return torch.zeros_like(output) | |
| # Apply hook to all attention layers | |
| handles = [] | |
| for i in range(self.model.config.n_layer): | |
| handle = self.model.transformer.h[i].attn.register_forward_hook(zero_attention_hook) | |
| handles.append(handle) | |
| # Get ablated output | |
| with torch.no_grad(): | |
| ablated_output = self.model(**inputs) | |
| ablated_logits = ablated_output.logits[0, -1, :].cpu().numpy() | |
| # Clean up hooks | |
| for handle in handles: | |
| handle.remove() | |
| # Verify outputs are different | |
| difference = np.mean(np.abs(baseline_logits - ablated_logits)) | |
| assert difference > 0.1, f"Outputs too similar (diff={difference}), ablation may not be working" | |
| logger.info(f"✓ Ablated output differs from baseline (mean diff: {difference:.4f})") | |
| # Check that ablated output has lower confidence (higher entropy) | |
| baseline_probs = torch.softmax(torch.tensor(baseline_logits), dim=0) | |
| ablated_probs = torch.softmax(torch.tensor(ablated_logits), dim=0) | |
| baseline_entropy = -torch.sum(baseline_probs * torch.log(baseline_probs + 1e-10)) | |
| ablated_entropy = -torch.sum(ablated_probs * torch.log(ablated_probs + 1e-10)) | |
| logger.info(f" Baseline entropy: {baseline_entropy:.4f}") | |
| logger.info(f" Ablated entropy: {ablated_entropy:.4f}") | |
| return True | |
| def test_ffn_ablation(self): | |
| """Test 4: Verify FFN can be disabled""" | |
| logger.info("\n=== Test 4: FFN Ablation ===") | |
| # Get baseline | |
| inputs = self.tokenizer("def test():", return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| baseline_output = self.model(**inputs) | |
| baseline_logits = baseline_output.logits[0, -1, :].cpu().numpy() | |
| # Hook to disable FFN | |
| def zero_ffn_hook(module, input, output): | |
| return torch.zeros_like(output) | |
| # Apply to all FFN layers | |
| handles = [] | |
| for i in range(self.model.config.n_layer): | |
| handle = self.model.transformer.h[i].mlp.register_forward_hook(zero_ffn_hook) | |
| handles.append(handle) | |
| # Get ablated output | |
| with torch.no_grad(): | |
| ablated_output = self.model(**inputs) | |
| ablated_logits = ablated_output.logits[0, -1, :].cpu().numpy() | |
| # Clean up | |
| for handle in handles: | |
| handle.remove() | |
| # Verify difference | |
| difference = np.mean(np.abs(baseline_logits - ablated_logits)) | |
| assert difference > 0.1, f"FFN ablation not working (diff={difference})" | |
| logger.info(f"✓ FFN ablation changes output (mean diff: {difference:.4f})") | |
| return True | |
| def test_partial_attention_ablation(self): | |
| """Test 5: Verify partial attention head disabling""" | |
| logger.info("\n=== Test 5: Partial Attention Ablation ===") | |
| # Get baseline | |
| inputs = self.tokenizer("def test():", return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| baseline_output = self.model(**inputs) | |
| baseline_logits = baseline_output.logits[0, -1, :].cpu().numpy() | |
| # Hook to scale attention (simulating partial disable) | |
| def scale_attention_hook(module, input, output): | |
| scale = 0.5 # Disable half the heads (simplified) | |
| if isinstance(output, tuple): | |
| return (output[0] * scale,) + output[1:] | |
| return output * scale | |
| # Apply to layer 0 | |
| handle = self.model.transformer.h[0].attn.register_forward_hook(scale_attention_hook) | |
| # Get partially ablated output | |
| with torch.no_grad(): | |
| ablated_output = self.model(**inputs) | |
| ablated_logits = ablated_output.logits[0, -1, :].cpu().numpy() | |
| # Clean up | |
| handle.remove() | |
| # Verify outputs are different but not as different as full ablation | |
| difference = np.mean(np.abs(baseline_logits - ablated_logits)) | |
| assert 0.01 < difference < 0.5, f"Partial ablation unexpected difference: {difference}" | |
| logger.info(f"✓ Partial ablation works (mean diff: {difference:.4f})") | |
| return True | |
| def test_data_format_conversion(self): | |
| """Test 6: Verify frontend data format is correctly parsed""" | |
| logger.info("\n=== Test 6: Data Format Conversion ===") | |
| # Simulate frontend data (JSON with string keys) | |
| frontend_data = { | |
| "layers": [0, 1, 2], | |
| "attention_heads": { | |
| "0": [0, 1, 2, 3], | |
| "1": [4, 5, 6, 7], | |
| "2": list(range(16)) # All heads | |
| }, | |
| "ffn_layers": [3, 4], | |
| "embeddings": False, | |
| "layer_norm": [] | |
| } | |
| # Parse as backend would | |
| disabled_layers = set(frontend_data.get('layers', [])) | |
| disabled_attention_raw = frontend_data.get('attention_heads', {}) | |
| disabled_attention = {int(k) if isinstance(k, str) else k: v | |
| for k, v in disabled_attention_raw.items()} | |
| disabled_ffn = set(frontend_data.get('ffn_layers', [])) | |
| # Verify parsing | |
| assert disabled_layers == {0, 1, 2}, f"Layers parsed incorrectly: {disabled_layers}" | |
| assert 0 in disabled_attention, "String key '0' not converted to int 0" | |
| assert disabled_attention[0] == [0, 1, 2, 3], f"Attention heads parsed incorrectly" | |
| assert len(disabled_attention[2]) == 16, "Full layer disable not parsed" | |
| assert disabled_ffn == {3, 4}, f"FFN layers parsed incorrectly: {disabled_ffn}" | |
| logger.info("✓ Frontend data format correctly parsed") | |
| logger.info(f" Disabled layers: {disabled_layers}") | |
| logger.info(f" Disabled attention heads: {list(disabled_attention.keys())}") | |
| logger.info(f" Disabled FFN: {disabled_ffn}") | |
| return True | |
| def test_generation_with_ablation(self): | |
| """Test 7: Full generation test with various ablations""" | |
| logger.info("\n=== Test 7: Generation with Ablation ===") | |
| prompt = "def fibonacci(n):" | |
| # Test configurations | |
| configs = [ | |
| {"name": "No ablation", "components": {}}, | |
| {"name": "All attention", "components": { | |
| "attention_heads": {str(i): list(range(16)) for i in range(20)} | |
| }}, | |
| {"name": "All FFN", "components": { | |
| "ffn_layers": list(range(20)) | |
| }}, | |
| {"name": "Layers 0-9", "components": { | |
| "layers": list(range(10)) | |
| }} | |
| ] | |
| results = [] | |
| for config in configs: | |
| logger.info(f"\n Testing: {config['name']}") | |
| # Apply ablation | |
| disabled_components = config['components'] | |
| # Parse components | |
| disabled_layers = set(disabled_components.get('layers', [])) | |
| disabled_attention_raw = disabled_components.get('attention_heads', {}) | |
| disabled_attention = {int(k) if isinstance(k, str) else k: v | |
| for k, v in disabled_attention_raw.items()} | |
| disabled_ffn = set(disabled_components.get('ffn_layers', [])) | |
| # Apply hooks | |
| handles = [] | |
| for layer_idx in range(self.model.config.n_layer): | |
| if layer_idx in disabled_layers: | |
| def layer_hook(module, input, output): | |
| if isinstance(output, tuple): | |
| return (input[0],) + output[1:] | |
| return input[0] | |
| handle = self.model.transformer.h[layer_idx].register_forward_hook(layer_hook) | |
| handles.append(handle) | |
| else: | |
| if layer_idx in disabled_attention: | |
| heads = disabled_attention[layer_idx] | |
| if len(heads) == 16: | |
| def attention_hook(module, input, output): | |
| if isinstance(output, tuple): | |
| return (torch.zeros_like(output[0]),) + output[1:] | |
| return torch.zeros_like(output) | |
| handle = self.model.transformer.h[layer_idx].attn.register_forward_hook(attention_hook) | |
| handles.append(handle) | |
| if layer_idx in disabled_ffn: | |
| def ffn_hook(module, input, output): | |
| return torch.zeros_like(output) | |
| handle = self.model.transformer.h[layer_idx].mlp.register_forward_hook(ffn_hook) | |
| handles.append(handle) | |
| # Generate | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| output_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=20, | |
| temperature=0.7, | |
| do_sample=True | |
| ) | |
| generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| # Clean up hooks | |
| for handle in handles: | |
| handle.remove() | |
| results.append({ | |
| "config": config['name'], | |
| "output": generated_text | |
| }) | |
| logger.info(f" Output: {generated_text[:50]}...") | |
| # Verify all outputs are different (except baseline) | |
| outputs = [r['output'] for r in results] | |
| unique_outputs = len(set(outputs)) | |
| logger.info(f"\n✓ Generated {unique_outputs} unique outputs from {len(configs)} configs") | |
| for result in results: | |
| logger.info(f" {result['config']}: {result['output'][:80]}...") | |
| return True | |
| def run_all_tests(self): | |
| """Run all ablation tests""" | |
| logger.info("=" * 60) | |
| logger.info("ABLATION FUNCTIONALITY TEST SUITE") | |
| logger.info("=" * 60) | |
| self.setup() | |
| tests = [ | |
| self.test_model_architecture, | |
| self.test_attention_hook_attachment, | |
| self.test_attention_zeroing, | |
| self.test_ffn_ablation, | |
| self.test_partial_attention_ablation, | |
| self.test_data_format_conversion, | |
| self.test_generation_with_ablation | |
| ] | |
| passed = 0 | |
| failed = 0 | |
| for test in tests: | |
| try: | |
| if test(): | |
| passed += 1 | |
| logger.info(f" ✅ {test.__name__} PASSED\n") | |
| except Exception as e: | |
| failed += 1 | |
| logger.error(f" ❌ {test.__name__} FAILED: {e}\n") | |
| logger.info("=" * 60) | |
| logger.info(f"TEST RESULTS: {passed} passed, {failed} failed") | |
| logger.info("=" * 60) | |
| return failed == 0 | |
| if __name__ == "__main__": | |
| tester = AblationTester() | |
| success = tester.run_all_tests() | |
| exit(0 if success else 1) |