|
|
import torch |
|
|
from vllm import LLM, SamplingParams |
|
|
|
|
|
|
|
|
model = LLM( |
|
|
model="Qwen/Qwen2.5-0.5B", |
|
|
tensor_parallel_size=1, |
|
|
enforce_eager=True, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
layer2_original = None |
|
|
layer2_patched = None |
|
|
|
|
|
def capture_layer2_hook(module, input, output): |
|
|
"""Hook to capture layer 2 hidden states""" |
|
|
global layer2_original, layer2_patched |
|
|
hidden_state = output[0] if isinstance(output, tuple) else output |
|
|
|
|
|
|
|
|
if hasattr(module, '_is_patched'): |
|
|
layer2_patched = hidden_state.detach().clone().cpu() |
|
|
else: |
|
|
layer2_original = hidden_state.detach().clone().cpu() |
|
|
|
|
|
def make_qwen_hook(): |
|
|
def qwen_forward(self, x): |
|
|
|
|
|
gate_up, _ = self.gate_up_proj(x) |
|
|
intermediate_size = gate_up.size(-1) // 2 |
|
|
gate = gate_up[..., :intermediate_size] |
|
|
up = gate_up[..., intermediate_size:] |
|
|
gate_activation = torch.nn.functional.silu(gate) |
|
|
|
|
|
|
|
|
x, _ = self.down_proj(gate_activation * up) |
|
|
return x |
|
|
|
|
|
return qwen_forward |
|
|
|
|
|
def main(): |
|
|
sentence = "hello world" |
|
|
|
|
|
|
|
|
layer2 = model.llm_engine.model_executor.driver_worker.model_runner.model.model.layers[2] |
|
|
|
|
|
|
|
|
hook = layer2.register_forward_hook(capture_layer2_hook) |
|
|
|
|
|
print("=== Getting original hidden states ===") |
|
|
|
|
|
sampling_params = SamplingParams(temperature=0, max_tokens=1) |
|
|
model.generate([sentence], sampling_params) |
|
|
|
|
|
print("=== Applying patch and getting patched hidden states ===") |
|
|
|
|
|
original_forward = layer2.mlp.forward |
|
|
layer2.mlp.forward = make_qwen_hook().__get__(layer2.mlp, layer2.mlp.__class__) |
|
|
layer2._is_patched = True |
|
|
|
|
|
|
|
|
model.generate([sentence], sampling_params) |
|
|
|
|
|
|
|
|
layer2.mlp.forward = original_forward |
|
|
delattr(layer2, '_is_patched') |
|
|
hook.remove() |
|
|
|
|
|
print("=== Comparison ===") |
|
|
print(f"Original shape: {layer2_original.shape}") |
|
|
print(f"Patched shape: {layer2_patched.shape}") |
|
|
|
|
|
|
|
|
if torch.allclose(layer2_original, layer2_patched, rtol=1e-4, atol=1e-6): |
|
|
print("✅ PATCH IS CORRECT: Hidden states match!") |
|
|
else: |
|
|
max_diff = torch.max(torch.abs(layer2_original - layer2_patched)).item() |
|
|
mean_diff = torch.mean(torch.abs(layer2_original - layer2_patched)).item() |
|
|
print(f"❌ PATCH IS INCORRECT: Max diff = {max_diff:.6f}, Mean diff = {mean_diff:.6f}") |
|
|
|
|
|
print(f"\nOriginal hidden states (first 10 values):\n{layer2_original.flatten()[:10]}") |
|
|
print(f"\nPatched hidden states (first 10 values):\n{layer2_patched.flatten()[:10]}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |