import torch from vllm import LLM, SamplingParams # Initialize model model = LLM( model="Qwen/Qwen2.5-0.5B", tensor_parallel_size=1, enforce_eager=True, trust_remote_code=True ) # Storage for layer 2 hidden states layer2_original = None layer2_patched = None def capture_layer2_hook(module, input, output): """Hook to capture layer 2 hidden states""" global layer2_original, layer2_patched hidden_state = output[0] if isinstance(output, tuple) else output # Store based on whether we're in patched mode if hasattr(module, '_is_patched'): layer2_patched = hidden_state.detach().clone().cpu() else: layer2_original = hidden_state.detach().clone().cpu() def make_qwen_hook(): def qwen_forward(self, x): # Get activations - fix for VLLM's flattened format gate_up, _ = self.gate_up_proj(x) intermediate_size = gate_up.size(-1) // 2 gate = gate_up[..., :intermediate_size] up = gate_up[..., intermediate_size:] gate_activation = torch.nn.functional.silu(gate) # Complete forward pass x, _ = self.down_proj(gate_activation * up) return x return qwen_forward def main(): sentence = "hello world" # Get layer 2 layer2 = model.llm_engine.model_executor.driver_worker.model_runner.model.model.layers[2] # Add hook to layer 2 hook = layer2.register_forward_hook(capture_layer2_hook) print("=== Getting original hidden states ===") # Run without patch sampling_params = SamplingParams(temperature=0, max_tokens=1) model.generate([sentence], sampling_params) print("=== Applying patch and getting patched hidden states ===") # Apply patch to layer 2 MLP original_forward = layer2.mlp.forward layer2.mlp.forward = make_qwen_hook().__get__(layer2.mlp, layer2.mlp.__class__) layer2._is_patched = True # Flag to distinguish in hook # Run with patch model.generate([sentence], sampling_params) # Remove patch layer2.mlp.forward = original_forward delattr(layer2, '_is_patched') hook.remove() print("=== Comparison ===") print(f"Original shape: {layer2_original.shape}") print(f"Patched shape: {layer2_patched.shape}") # Compare if torch.allclose(layer2_original, layer2_patched, rtol=1e-4, atol=1e-6): print("✅ PATCH IS CORRECT: Hidden states match!") else: max_diff = torch.max(torch.abs(layer2_original - layer2_patched)).item() mean_diff = torch.mean(torch.abs(layer2_original - layer2_patched)).item() print(f"❌ PATCH IS INCORRECT: Max diff = {max_diff:.6f}, Mean diff = {mean_diff:.6f}") print(f"\nOriginal hidden states (first 10 values):\n{layer2_original.flatten()[:10]}") print(f"\nPatched hidden states (first 10 values):\n{layer2_patched.flatten()[:10]}") if __name__ == "__main__": main()