File size: 1,899 Bytes
eca55dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | import torch
from timm.layers import RotaryEmbedding
def verify_rope():
# 2D Grid: H=4, W=4
# Dim=64 (Head dim)
# We want half dim for H, half for W? Or how does timm handle it?
dim = 64
rope = RotaryEmbedding(dim, feat_shape=[4, 4])
# Input: [B, H, N, D] -> [1, 1, 16, 64]
x = torch.randn(1, 1, 16, 64)
# Forward
x_rope = rope(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {x_rope.shape}")
# Check if it varies along H and W
# Reshape to [H, W, D]
x_grid = x_rope.reshape(4, 4, 64)
# Check difference between (0,0) and (0,1) -> W change
diff_w = (x_grid[0, 0] - x_grid[0, 1]).abs().sum()
print(f"Diff along W: {diff_w}")
# Check difference between (0,0) and (1,0) -> H change
diff_h = (x_grid[0, 0] - x_grid[1, 0]).abs().sum()
print(f"Diff along H: {diff_h}")
# If it's 1D RoPE on flattened sequence, diff_w and diff_h would both be non-zero but structure might be different.
# If it's 2D, it should encode H and W separately.
# Let's check if the embedding is indeed 2D.
# Usually 2D RoPE splits D into D/2 for H and D/2 for W.
# Let's see if the first half changes with H and second half with W?
# Change in W (0,0) vs (0,1)
# Should affect one half?
diff_w_first_half = (x_grid[0, 0, :32] - x_grid[0, 1, :32]).abs().sum()
diff_w_second_half = (x_grid[0, 0, 32:] - x_grid[0, 1, 32:]).abs().sum()
print(f"Diff W (First Half): {diff_w_first_half}")
print(f"Diff W (Second Half): {diff_w_second_half}")
# Change in H (0,0) vs (1,0)
diff_h_first_half = (x_grid[0, 0, :32] - x_grid[1, 0, :32]).abs().sum()
diff_h_second_half = (x_grid[0, 0, 32:] - x_grid[1, 0, 32:]).abs().sum()
print(f"Diff H (First Half): {diff_h_first_half}")
print(f"Diff H (Second Half): {diff_h_second_half}")
if __name__ == "__main__":
verify_rope()
|