| import torch |
| from timm.layers import RotaryEmbedding |
|
|
|
|
| def verify_rope(): |
| |
| |
| |
|
|
| dim = 64 |
| rope = RotaryEmbedding(dim, feat_shape=[4, 4]) |
|
|
| |
| x = torch.randn(1, 1, 16, 64) |
|
|
| |
| x_rope = rope(x) |
|
|
| print(f"Input shape: {x.shape}") |
| print(f"Output shape: {x_rope.shape}") |
|
|
| |
| |
| x_grid = x_rope.reshape(4, 4, 64) |
|
|
| |
| diff_w = (x_grid[0, 0] - x_grid[0, 1]).abs().sum() |
| print(f"Diff along W: {diff_w}") |
|
|
| |
| diff_h = (x_grid[0, 0] - x_grid[1, 0]).abs().sum() |
| print(f"Diff along H: {diff_h}") |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| diff_w_first_half = (x_grid[0, 0, :32] - x_grid[0, 1, :32]).abs().sum() |
| diff_w_second_half = (x_grid[0, 0, 32:] - x_grid[0, 1, 32:]).abs().sum() |
|
|
| print(f"Diff W (First Half): {diff_w_first_half}") |
| print(f"Diff W (Second Half): {diff_w_second_half}") |
|
|
| |
| diff_h_first_half = (x_grid[0, 0, :32] - x_grid[1, 0, :32]).abs().sum() |
| diff_h_second_half = (x_grid[0, 0, 32:] - x_grid[1, 0, 32:]).abs().sum() |
|
|
| print(f"Diff H (First Half): {diff_h_first_half}") |
| print(f"Diff H (Second Half): {diff_h_second_half}") |
|
|
|
|
| if __name__ == "__main__": |
| verify_rope() |
|
|