File size: 1,899 Bytes
eca55dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from timm.layers import RotaryEmbedding


def verify_rope():
    # 2D Grid: H=4, W=4
    # Dim=64 (Head dim)
    # We want half dim for H, half for W? Or how does timm handle it?

    dim = 64
    rope = RotaryEmbedding(dim, feat_shape=[4, 4])

    # Input: [B, H, N, D] -> [1, 1, 16, 64]
    x = torch.randn(1, 1, 16, 64)

    # Forward
    x_rope = rope(x)

    print(f"Input shape: {x.shape}")
    print(f"Output shape: {x_rope.shape}")

    # Check if it varies along H and W
    # Reshape to [H, W, D]
    x_grid = x_rope.reshape(4, 4, 64)

    # Check difference between (0,0) and (0,1) -> W change
    diff_w = (x_grid[0, 0] - x_grid[0, 1]).abs().sum()
    print(f"Diff along W: {diff_w}")

    # Check difference between (0,0) and (1,0) -> H change
    diff_h = (x_grid[0, 0] - x_grid[1, 0]).abs().sum()
    print(f"Diff along H: {diff_h}")

    # If it's 1D RoPE on flattened sequence, diff_w and diff_h would both be non-zero but structure might be different.
    # If it's 2D, it should encode H and W separately.

    # Let's check if the embedding is indeed 2D.
    # Usually 2D RoPE splits D into D/2 for H and D/2 for W.
    # Let's see if the first half changes with H and second half with W?

    # Change in W (0,0) vs (0,1)
    # Should affect one half?
    diff_w_first_half = (x_grid[0, 0, :32] - x_grid[0, 1, :32]).abs().sum()
    diff_w_second_half = (x_grid[0, 0, 32:] - x_grid[0, 1, 32:]).abs().sum()

    print(f"Diff W (First Half): {diff_w_first_half}")
    print(f"Diff W (Second Half): {diff_w_second_half}")

    # Change in H (0,0) vs (1,0)
    diff_h_first_half = (x_grid[0, 0, :32] - x_grid[1, 0, :32]).abs().sum()
    diff_h_second_half = (x_grid[0, 0, 32:] - x_grid[1, 0, 32:]).abs().sum()

    print(f"Diff H (First Half): {diff_h_first_half}")
    print(f"Diff H (Second Half): {diff_h_second_half}")


if __name__ == "__main__":
    verify_rope()