File size: 2,943 Bytes
fed1832
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from vllm import LLM, SamplingParams

# Initialize model
model = LLM(
    model="Qwen/Qwen2.5-0.5B", 
    tensor_parallel_size=1, 
    enforce_eager=True,
    trust_remote_code=True
)

# Storage for layer 2 hidden states
layer2_original = None
layer2_patched = None

def capture_layer2_hook(module, input, output):
    """Hook to capture layer 2 hidden states"""
    global layer2_original, layer2_patched
    hidden_state = output[0] if isinstance(output, tuple) else output
    
    # Store based on whether we're in patched mode
    if hasattr(module, '_is_patched'):
        layer2_patched = hidden_state.detach().clone().cpu()
    else:
        layer2_original = hidden_state.detach().clone().cpu()

def make_qwen_hook():
    def qwen_forward(self, x):
        # Get activations - fix for VLLM's flattened format
        gate_up, _ = self.gate_up_proj(x)
        intermediate_size = gate_up.size(-1) // 2
        gate = gate_up[..., :intermediate_size]
        up = gate_up[..., intermediate_size:]
        gate_activation = torch.nn.functional.silu(gate)
        
        # Complete forward pass
        x, _ = self.down_proj(gate_activation * up)
        return x
    
    return qwen_forward

def main():
    sentence = "hello world"
    
    # Get layer 2
    layer2 = model.llm_engine.model_executor.driver_worker.model_runner.model.model.layers[2]
    
    # Add hook to layer 2
    hook = layer2.register_forward_hook(capture_layer2_hook)
    
    print("=== Getting original hidden states ===")
    # Run without patch
    sampling_params = SamplingParams(temperature=0, max_tokens=1)
    model.generate([sentence], sampling_params)
    
    print("=== Applying patch and getting patched hidden states ===")
    # Apply patch to layer 2 MLP
    original_forward = layer2.mlp.forward
    layer2.mlp.forward = make_qwen_hook().__get__(layer2.mlp, layer2.mlp.__class__)
    layer2._is_patched = True  # Flag to distinguish in hook
    
    # Run with patch
    model.generate([sentence], sampling_params)
    
    # Remove patch
    layer2.mlp.forward = original_forward
    delattr(layer2, '_is_patched')
    hook.remove()
    
    print("=== Comparison ===")
    print(f"Original shape: {layer2_original.shape}")
    print(f"Patched shape: {layer2_patched.shape}")
    
    # Compare
    if torch.allclose(layer2_original, layer2_patched, rtol=1e-4, atol=1e-6):
        print("✅ PATCH IS CORRECT: Hidden states match!")
    else:
        max_diff = torch.max(torch.abs(layer2_original - layer2_patched)).item()
        mean_diff = torch.mean(torch.abs(layer2_original - layer2_patched)).item()
        print(f"❌ PATCH IS INCORRECT: Max diff = {max_diff:.6f}, Mean diff = {mean_diff:.6f}")
    
    print(f"\nOriginal hidden states (first 10 values):\n{layer2_original.flatten()[:10]}")
    print(f"\nPatched hidden states (first 10 values):\n{layer2_patched.flatten()[:10]}")

if __name__ == "__main__":
    main()