| | |
| | |
| | |
| | |
| | |
| |
|
| | from itertools import product |
| |
|
| | import pytest |
| | import torch |
| |
|
| | from audiocraft.modules.transformer import ( |
| | StreamingMultiheadAttention, StreamingTransformer, set_efficient_attention_backend) |
| |
|
| |
|
| | def test_transformer_causal_streaming(): |
| | torch.manual_seed(1234) |
| |
|
| | for context, custom in product([None, 10], [False, True]): |
| | |
| | |
| | tr = StreamingTransformer( |
| | 16, 4, 1 if context else 2, |
| | causal=True, past_context=context, custom=custom, |
| | dropout=0.) |
| | steps = 20 |
| | for k in [0, 10, 15, 19]: |
| | x = torch.randn(4, steps, 16, requires_grad=True) |
| | y = tr(x) |
| | y[:, k].abs().sum().backward() |
| | if k + 1 < steps: |
| | assert torch.allclose(x.grad[:, k + 1:], torch.tensor(0.)), x.grad[:, k + 1:].norm() |
| | assert not torch.allclose(x.grad[:, :k + 1], torch.tensor(0.)), x.grad[:, :k + 1].norm() |
| | if context is not None and k > context: |
| | limit = k - context - 1 |
| | assert torch.allclose(x.grad[:, :limit], |
| | torch.tensor(0.)), x.grad[:, :limit].norm() |
| |
|
| | |
| | x = torch.randn(4, steps, 16) |
| | y = tr(x) |
| | ys = [] |
| | with tr.streaming(): |
| | for k in range(steps): |
| | chunk = x[:, k:k + 1, :] |
| | ys.append(tr(chunk)) |
| | y_stream = torch.cat(ys, dim=1) |
| | delta = torch.norm(y_stream - y) / torch.norm(y) |
| | assert delta < 1e-6, delta |
| |
|
| |
|
| | def test_transformer_vs_pytorch(): |
| | torch.manual_seed(1234) |
| | |
| | |
| | for custom in [False, True]: |
| | tr = StreamingTransformer( |
| | 16, 4, 2, |
| | causal=False, custom=custom, dropout=0., positional_scale=0.) |
| | layer = torch.nn.TransformerEncoderLayer(16, 4, dropout=0., batch_first=True) |
| | tr_ref = torch.nn.TransformerEncoder(layer, 2) |
| | tr.load_state_dict(tr_ref.state_dict()) |
| |
|
| | x = torch.randn(4, 20, 16) |
| | y = tr(x) |
| | y2 = tr_ref(x) |
| | delta = torch.norm(y2 - y) / torch.norm(y) |
| | assert delta < 1e-6, delta |
| |
|
| |
|
| | def test_streaming_api(): |
| | tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0.) |
| | tr.eval() |
| | steps = 12 |
| | x = torch.randn(1, steps, 16) |
| |
|
| | with torch.no_grad(): |
| | with tr.streaming(): |
| | _ = tr(x[:, :1]) |
| | state = {k: v.clone() for k, v in tr.get_streaming_state().items()} |
| | y = tr(x[:, 1:2]) |
| | tr.set_streaming_state(state) |
| | y2 = tr(x[:, 1:2]) |
| | assert torch.allclose(y, y2), (y - y2).norm() |
| | assert tr.flush() is None |
| |
|
| |
|
| | def test_memory_efficient(): |
| | for backend in ['torch', 'xformers']: |
| | torch.manual_seed(1234) |
| | set_efficient_attention_backend(backend) |
| |
|
| | tr = StreamingTransformer( |
| | 16, 4, 2, custom=True, dropout=0., layer_scale=0.1) |
| | tr_mem_efficient = StreamingTransformer( |
| | 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1) |
| | tr_mem_efficient.load_state_dict(tr.state_dict()) |
| | tr.eval() |
| | steps = 12 |
| | x = torch.randn(3, steps, 16) |
| |
|
| | with torch.no_grad(): |
| | y = tr(x) |
| | y2 = tr_mem_efficient(x) |
| | assert torch.allclose(y, y2), ((y - y2).norm(), backend) |
| |
|
| |
|
| | def test_attention_as_float32(): |
| | torch.manual_seed(1234) |
| | cases = [ |
| | {'custom': True}, |
| | {'custom': False}, |
| | ] |
| | for case in cases: |
| | tr = StreamingTransformer(16, 4, 2, dropout=0., dtype=torch.bfloat16, **case) |
| | tr_float32 = StreamingTransformer( |
| | 16, 4, 2, dropout=0., attention_as_float32=True, dtype=torch.bfloat16, **case) |
| | if not case['custom']: |
| | |
| | |
| | for layer in tr_float32.layers: |
| | layer.self_attn.mha.to(torch.float32) |
| | tr_float32.load_state_dict(tr.state_dict()) |
| | steps = 12 |
| | x = torch.randn(3, steps, 16, dtype=torch.bfloat16) |
| |
|
| | with torch.no_grad(): |
| | y = tr(x) |
| | y2 = tr_float32(x) |
| | assert not torch.allclose(y, y2), (y - y2).norm() |
| |
|
| |
|
| | @torch.no_grad() |
| | def test_streaming_memory_efficient(): |
| | for backend in ['torch', 'xformers']: |
| | torch.manual_seed(1234) |
| | set_efficient_attention_backend(backend) |
| | tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True) |
| | tr_mem_efficient = StreamingTransformer( |
| | 16, 4, 2, dropout=0., memory_efficient=True, causal=True) |
| | tr.load_state_dict(tr_mem_efficient.state_dict()) |
| | tr.eval() |
| | tr_mem_efficient.eval() |
| | steps = 12 |
| | x = torch.randn(3, steps, 16) |
| |
|
| | ref = tr(x) |
| |
|
| | with tr_mem_efficient.streaming(): |
| | outs = [] |
| | |
| | frame_sizes = [1] * steps |
| |
|
| | for frame_size in frame_sizes: |
| | frame = x[:, :frame_size] |
| | x = x[:, frame_size:] |
| | outs.append(tr_mem_efficient(frame)) |
| |
|
| | out = torch.cat(outs, dim=1) |
| | delta = torch.norm(out - ref) / torch.norm(out) |
| | assert delta < 1e-6, delta |
| |
|
| |
|
| | def test_cross_attention(): |
| | torch.manual_seed(1234) |
| | for norm_first in [True, False]: |
| | m = StreamingTransformer( |
| | 16, 4, 2, cross_attention=False, norm_first=norm_first, dropout=0., custom=True) |
| | m_cross = StreamingTransformer( |
| | 16, 4, 2, cross_attention=True, norm_first=norm_first, dropout=0., custom=True) |
| | m_cross.load_state_dict(m.state_dict(), strict=False) |
| | x = torch.randn(2, 5, 16) |
| | cross_x = torch.randn(2, 3, 16) |
| | y_ref = m(x) |
| | y_cross_zero = m_cross(x, cross_attention_src=0 * cross_x) |
| | |
| | |
| | |
| | atol = 0. if norm_first else 1e-6 |
| | print((y_ref - y_cross_zero).norm() / y_ref.norm()) |
| | assert torch.allclose(y_ref, y_cross_zero, atol=atol) |
| |
|
| | |
| | y_cross = m_cross(x, cross_attention_src=cross_x) |
| | assert not torch.allclose(y_cross, y_cross_zero, atol=1e-2) |
| |
|
| | with pytest.raises(AssertionError): |
| | _ = m_cross(x) |
| | _ = m(x, cross_attention_src=cross_x) |
| |
|
| |
|
| | def test_cross_attention_compat(): |
| | torch.manual_seed(1234) |
| | num_heads = 2 |
| | dim = num_heads * 64 |
| | with pytest.raises(AssertionError): |
| | StreamingMultiheadAttention(dim, num_heads, causal=True, cross_attention=True) |
| |
|
| | cross_attn = StreamingMultiheadAttention( |
| | dim, num_heads, dropout=0, cross_attention=True, custom=True) |
| | ref_attn = torch.nn.MultiheadAttention(dim, num_heads, dropout=0, batch_first=True) |
| |
|
| | |
| | |
| | cross_attn.load_state_dict(ref_attn.state_dict()) |
| |
|
| | queries = torch.randn(3, 7, dim) |
| | keys = torch.randn(3, 9, dim) |
| | values = torch.randn(3, 9, dim) |
| |
|
| | y = cross_attn(queries, keys, values)[0] |
| | y_ref = ref_attn(queries, keys, values)[0] |
| | assert torch.allclose(y, y_ref, atol=1e-7), (y - y_ref).norm() / y_ref.norm() |
| |
|
| | |
| | with cross_attn.streaming(): |
| | ys = [] |
| | for step in range(queries.shape[1]): |
| | ys.append(cross_attn(queries[:, step: step + 1], keys, values)[0]) |
| | y_streaming = torch.cat(ys, dim=1) |
| | assert torch.allclose(y_streaming, y, atol=1e-7) |
| |
|
| |
|
| | def test_repeat_kv(): |
| | torch.manual_seed(1234) |
| | num_heads = 8 |
| | kv_repeat = 4 |
| | dim = num_heads * 64 |
| | with pytest.raises(AssertionError): |
| | mha = StreamingMultiheadAttention( |
| | dim, num_heads, causal=True, kv_repeat=kv_repeat, cross_attention=True) |
| | mha = StreamingMultiheadAttention( |
| | dim, num_heads, causal=True, kv_repeat=kv_repeat) |
| | mha = StreamingMultiheadAttention( |
| | dim, num_heads, causal=True, kv_repeat=kv_repeat, custom=True) |
| | x = torch.randn(4, 18, dim) |
| | y = mha(x, x, x)[0] |
| | assert x.shape == y.shape |
| |
|
| |
|
| | def test_qk_layer_norm(): |
| | torch.manual_seed(1234) |
| | tr = StreamingTransformer( |
| | 16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, bias_attn=False) |
| | steps = 12 |
| | x = torch.randn(3, steps, 16) |
| | y = tr(x) |
| |
|
| | tr = StreamingTransformer( |
| | 16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, cross_attention=True) |
| | z = torch.randn(3, 21, 16) |
| | y = tr(x, cross_attention_src=z) |
| | assert y.shape == x.shape |
| |
|