khala / models /Megatron /tests /unit_tests /inference /test_flash_decode.py
multimodalart's picture
multimodalart HF Staff
Initial best-effort ZeroGPU port of Khala song generation
d1f1097 verified
import torch
from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb_with_cos_sin
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
class TestRotaryEmbeddingWithPrecomputedCosSin:
def setup_method(self):
self.batch_size = 3
self.seq_len = 4
self.d_rot = 6
self.rotary_embedding = RotaryEmbedding(kv_channels=4, rotary_percent=1.0)
def test_output_shapes_match(self):
# Create input tensors
t = torch.randn(self.seq_len, self.batch_size, 2, self.d_rot * 2, device="cuda")
rotary_pos_cos, rotary_pos_sin = self.rotary_embedding.get_cos_sin(self.seq_len)
# Test using Flash Decoding optimized kernel which requires precomputed cos & sin tensors
expected_shape = torch.Size(
[self.seq_len, self.batch_size, self.seq_len // 2, self.seq_len * self.batch_size]
)
output_flash_rotary = apply_rotary_pos_emb_with_cos_sin(
t, rotary_pos_cos, rotary_pos_sin, rotary_interleaved=True
)
assert (
output_flash_rotary.shape == expected_shape
), f"Outputs do not match: {output_flash_rotary.shape} != {expected_shape}"