Spaces:
Running on Zero
Running on Zero
| import torch | |
| from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb_with_cos_sin | |
| from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding | |
| class TestRotaryEmbeddingWithPrecomputedCosSin: | |
| def setup_method(self): | |
| self.batch_size = 3 | |
| self.seq_len = 4 | |
| self.d_rot = 6 | |
| self.rotary_embedding = RotaryEmbedding(kv_channels=4, rotary_percent=1.0) | |
| def test_output_shapes_match(self): | |
| # Create input tensors | |
| t = torch.randn(self.seq_len, self.batch_size, 2, self.d_rot * 2, device="cuda") | |
| rotary_pos_cos, rotary_pos_sin = self.rotary_embedding.get_cos_sin(self.seq_len) | |
| # Test using Flash Decoding optimized kernel which requires precomputed cos & sin tensors | |
| expected_shape = torch.Size( | |
| [self.seq_len, self.batch_size, self.seq_len // 2, self.seq_len * self.batch_size] | |
| ) | |
| output_flash_rotary = apply_rotary_pos_emb_with_cos_sin( | |
| t, rotary_pos_cos, rotary_pos_sin, rotary_interleaved=True | |
| ) | |
| assert ( | |
| output_flash_rotary.shape == expected_shape | |
| ), f"Outputs do not match: {output_flash_rotary.shape} != {expected_shape}" | |