| |
|
|
| import torch |
|
|
| from fla.ops.rebased.parallel import parallel_rebased |
|
|
|
|
| def naive_parallel_rebased(q, k, v, use_scale=True, use_norm=True): |
| if use_scale: |
| q = q * (q.shape[-1] ** -0.5) |
| attn = q @ k.transpose(-2, -1) |
| attn = (attn ** 2) |
| attn.masked_fill_(~torch.tril(torch.ones( |
| q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) |
| o = attn @ v |
| if use_norm: |
| z = attn.sum(-1) |
| return o / (z[..., None] + 1e-6) |
| else: |
| return o |
|
|
|
|
| if __name__ == "__main__": |
| B = 4 |
| H = 4 |
| L = 128 |
| |
| dtype = torch.float32 |
| q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) |
| k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) |
| v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True) |
|
|
| do = torch.randn_like(v).cuda() |
| ref = naive_parallel_rebased(q, k, v, True, True) |
| ref.backward(do, retain_graph=True) |
| ref_dq, q.grad = q.grad.clone(), None |
| ref_dk, k.grad = k.grad.clone(), None |
| ref_dv, v.grad = v.grad.clone(), None |
|
|
| tri = parallel_rebased(q, k, v, 1e-6, True, True) |
| tri.backward(do, retain_graph=True) |
| tri_dq, q.grad = q.grad.clone(), None |
| tri_dk, k.grad = k.grad.clone(), None |
| tri_dv, v.grad = v.grad.clone(), None |
| print((ref-tri).abs().max()) |
| print((ref_dq-tri_dq).abs().max()) |
| print((ref_dk-tri_dk).abs().max()) |
| print((ref_dv-tri_dv).abs().max()) |
|
|