lsn-analysis / verify_patch.py
tvkain's picture
Upload folder using huggingface_hub
fed1832 verified
import torch
from vllm import LLM, SamplingParams
# Initialize model
model = LLM(
model="Qwen/Qwen2.5-0.5B",
tensor_parallel_size=1,
enforce_eager=True,
trust_remote_code=True
)
# Storage for layer 2 hidden states
layer2_original = None
layer2_patched = None
def capture_layer2_hook(module, input, output):
"""Hook to capture layer 2 hidden states"""
global layer2_original, layer2_patched
hidden_state = output[0] if isinstance(output, tuple) else output
# Store based on whether we're in patched mode
if hasattr(module, '_is_patched'):
layer2_patched = hidden_state.detach().clone().cpu()
else:
layer2_original = hidden_state.detach().clone().cpu()
def make_qwen_hook():
def qwen_forward(self, x):
# Get activations - fix for VLLM's flattened format
gate_up, _ = self.gate_up_proj(x)
intermediate_size = gate_up.size(-1) // 2
gate = gate_up[..., :intermediate_size]
up = gate_up[..., intermediate_size:]
gate_activation = torch.nn.functional.silu(gate)
# Complete forward pass
x, _ = self.down_proj(gate_activation * up)
return x
return qwen_forward
def main():
sentence = "hello world"
# Get layer 2
layer2 = model.llm_engine.model_executor.driver_worker.model_runner.model.model.layers[2]
# Add hook to layer 2
hook = layer2.register_forward_hook(capture_layer2_hook)
print("=== Getting original hidden states ===")
# Run without patch
sampling_params = SamplingParams(temperature=0, max_tokens=1)
model.generate([sentence], sampling_params)
print("=== Applying patch and getting patched hidden states ===")
# Apply patch to layer 2 MLP
original_forward = layer2.mlp.forward
layer2.mlp.forward = make_qwen_hook().__get__(layer2.mlp, layer2.mlp.__class__)
layer2._is_patched = True # Flag to distinguish in hook
# Run with patch
model.generate([sentence], sampling_params)
# Remove patch
layer2.mlp.forward = original_forward
delattr(layer2, '_is_patched')
hook.remove()
print("=== Comparison ===")
print(f"Original shape: {layer2_original.shape}")
print(f"Patched shape: {layer2_patched.shape}")
# Compare
if torch.allclose(layer2_original, layer2_patched, rtol=1e-4, atol=1e-6):
print("✅ PATCH IS CORRECT: Hidden states match!")
else:
max_diff = torch.max(torch.abs(layer2_original - layer2_patched)).item()
mean_diff = torch.mean(torch.abs(layer2_original - layer2_patched)).item()
print(f"❌ PATCH IS INCORRECT: Max diff = {max_diff:.6f}, Mean diff = {mean_diff:.6f}")
print(f"\nOriginal hidden states (first 10 values):\n{layer2_original.flatten()[:10]}")
print(f"\nPatched hidden states (first 10 values):\n{layer2_patched.flatten()[:10]}")
if __name__ == "__main__":
main()