| |
| |
| |
| |
| |
|
|
| import torch |
|
|
| from audiocraft.modules.rope import RotaryEmbedding |
| from audiocraft.modules.transformer import StreamingTransformer, set_efficient_attention_backend |
|
|
|
|
| def test_rope(): |
| set_efficient_attention_backend('xformers') |
| B, T, H, C = 8, 75, 16, 128 |
|
|
| rope = RotaryEmbedding(dim=C) |
| xq = torch.rand((B, T, H, C)) |
| xk = torch.rand((B, T, H, C)) |
| xq_out, xk_out = rope.rotate_qk(xq, xk, start=7) |
|
|
| assert list(xq_out.shape) == [B, T, H, C] |
| assert list(xk_out.shape) == [B, T, H, C] |
|
|
|
|
| def test_rope_io_dtypes(): |
| set_efficient_attention_backend('xformers') |
| B, T, H, C = 8, 75, 16, 128 |
|
|
| rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32) |
| rope_64 = RotaryEmbedding(dim=C, dtype=torch.float64) |
|
|
| |
| xq_16 = torch.rand((B, T, H, C)).to(torch.bfloat16) |
| xk_16 = torch.rand((B, T, H, C)).to(torch.bfloat16) |
| xq_out, xk_out = rope_32.rotate_qk(xq_16, xk_16) |
| assert xq_out.dtype == torch.bfloat16 |
| xq_out, xk_out = rope_64.rotate_qk(xq_16, xk_16) |
| assert xq_out.dtype == torch.bfloat16 |
|
|
| |
| xq_32 = torch.rand((B, T, H, C)).to(torch.float32) |
| xk_32 = torch.rand((B, T, H, C)).to(torch.float32) |
| xq_out, xk_out = rope_32.rotate_qk(xq_32, xk_32) |
| assert xq_out.dtype == torch.float32 |
| xq_out, xk_out = rope_64.rotate_qk(xq_32, xk_32) |
| assert xq_out.dtype == torch.float32 |
|
|
|
|
| def test_transformer_with_rope(): |
| set_efficient_attention_backend('xformers') |
| torch.manual_seed(1234) |
| for pos in ['rope', 'sin_rope']: |
| tr = StreamingTransformer( |
| 16, 4, 2, custom=True, dropout=0., layer_scale=0.1, |
| positional_embedding=pos) |
| tr.eval() |
| steps = 12 |
| x = torch.randn(3, steps, 16) |
|
|
| out = tr(x) |
| assert list(out.shape) == list(x.shape) |
|
|
|
|
| @torch.no_grad() |
| def test_rope_streaming(): |
| set_efficient_attention_backend('xformers') |
| torch.manual_seed(1234) |
| tr = StreamingTransformer( |
| 16, 4, 2, causal=True, dropout=0., |
| custom=True, positional_embedding='rope') |
| tr.eval() |
| steps = 12 |
| x = torch.randn(3, steps, 16) |
|
|
| ref = tr(x) |
|
|
| with tr.streaming(): |
| outs = [] |
| frame_sizes = [1] * steps |
|
|
| for frame_size in frame_sizes: |
| frame = x[:, :frame_size] |
| x = x[:, frame_size:] |
| outs.append(tr(frame)) |
|
|
| out = torch.cat(outs, dim=1) |
| assert list(out.shape) == [3, steps, 16] |
| delta = torch.norm(out - ref) / torch.norm(out) |
| assert delta < 1e-6, delta |
|
|
|
|
| @torch.no_grad() |
| def test_rope_streaming_past_context(): |
| set_efficient_attention_backend('xformers') |
| torch.manual_seed(1234) |
|
|
| for context in [None, 10]: |
| tr = StreamingTransformer( |
| 16, 4, 1 if context else 2, |
| causal=True, past_context=context, custom=True, |
| dropout=0., positional_embedding='rope') |
| tr.eval() |
|
|
| steps = 20 |
| x = torch.randn(3, steps, 16) |
| ref = tr(x) |
|
|
| with tr.streaming(): |
| outs = [] |
| frame_sizes = [1] * steps |
|
|
| for frame_size in frame_sizes: |
| frame = x[:, :frame_size] |
| x = x[:, frame_size:] |
| outs.append(tr(frame)) |
|
|
| out = torch.cat(outs, dim=1) |
| assert list(out.shape) == [3, steps, 16] |
| delta = torch.norm(out - ref) / torch.norm(out) |
| assert delta < 1e-6, delta |
|
|
|
|
| def test_rope_memory_efficient(): |
| set_efficient_attention_backend('xformers') |
| torch.manual_seed(1234) |
| tr = StreamingTransformer( |
| 16, 4, 2, custom=True, dropout=0., layer_scale=0.1, |
| positional_embedding='rope') |
| tr_mem_efficient = StreamingTransformer( |
| 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1, |
| positional_embedding='rope') |
| tr_mem_efficient.load_state_dict(tr.state_dict()) |
| tr.eval() |
| steps = 12 |
| x = torch.randn(3, steps, 16) |
|
|
| with torch.no_grad(): |
| y = tr(x) |
| y2 = tr_mem_efficient(x) |
| |
| assert torch.allclose(y, y2, atol=1e-7), (y - y2).norm() |
|
|
|
|
| def test_rope_with_xpos(): |
| set_efficient_attention_backend('xformers') |
| B, T, H, C = 8, 75, 16, 128 |
|
|
| rope = RotaryEmbedding(dim=C, xpos=True) |
| xq = torch.rand((B, T, H, C)) |
| xk = torch.rand((B, T, H, C)) |
| xq_out, xk_out = rope.rotate_qk(xq, xk, start=7) |
|
|
| assert list(xq_out.shape) == [B, T, H, C] |
| assert list(xk_out.shape) == [B, T, H, C] |
|
|
|
|
| def test_positional_scale(): |
| set_efficient_attention_backend('xformers') |
| B, T, H, C = 8, 75, 16, 128 |
|
|
| rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0) |
| xq = torch.rand((B, T, H, C)) |
| xk = torch.rand((B, T, H, C)) |
| xq_out, xk_out = rope.rotate_qk(xq, xk, start=7) |
|
|
| assert torch.allclose(xq, xq_out) |
| assert torch.allclose(xk, xk_out) |
|
|