| import sys |
| import os |
| import torch |
|
|
| sys.path.append(os.path.abspath(".")) |
| from src.models.components.rope import RotaryEmbedding2D |
|
|
|
|
| def verify_custom_rope(): |
| dim = 64 |
| rope = RotaryEmbedding2D(dim, max_res=(4, 4)) |
|
|
| |
| |
| |
| |
| q = torch.ones(1, 1, 16, 64) |
| k = torch.ones(1, 1, 16, 64) |
|
|
| |
| pos_ids = torch.arange(16).unsqueeze(0) |
| grid_size = (4, 4) |
|
|
| q_rot, k_rot = rope(q, k, pos_ids, grid_size) |
|
|
| print(f"Output shape: {q_rot.shape}") |
|
|
| |
| q_grid = q_rot.reshape(4, 4, 64) |
|
|
| |
| |
| |
|
|
| diff_w_first_half = (q_grid[0, 0, :32] - q_grid[0, 1, :32]).abs().sum() |
| diff_w_second_half = (q_grid[0, 0, 32:] - q_grid[0, 1, 32:]).abs().sum() |
|
|
| print(f"Diff W (First Half - H part): {diff_w_first_half}") |
| print(f"Diff W (Second Half - W part): {diff_w_second_half}") |
|
|
| |
| |
| |
|
|
| diff_h_first_half = (q_grid[0, 0, :32] - q_grid[1, 0, :32]).abs().sum() |
| diff_h_second_half = (q_grid[0, 0, 32:] - q_grid[1, 0, 32:]).abs().sum() |
|
|
| print(f"Diff H (First Half - H part): {diff_h_first_half}") |
| print(f"Diff H (Second Half - W part): {diff_h_second_half}") |
|
|
| |
| assert diff_w_first_half < 1e-5, "First half should not change with W" |
| assert diff_w_second_half > 1.0, "Second half should change with W" |
| assert diff_h_first_half > 1.0, "First half should change with H" |
| assert diff_h_second_half < 1e-5, "Second half should not change with H" |
|
|
| print("Verification Successful!") |
|
|
|
|
| if __name__ == "__main__": |
| verify_custom_rope() |
|
|