| # -*- coding: utf-8 -*- | |
| # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang | |
| from typing import Optional | |
| import torch | |
| def naive_parallel_rebased( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| scale: Optional[float] = None, | |
| use_norm: bool = True, | |
| ) -> torch.Tensor: | |
| if scale is None: | |
| scale = q.shape[-1] ** -0.5 | |
| q = q * scale | |
| attn = q @ k.transpose(-2, -1) | |
| attn = attn ** 2 | |
| attn.masked_fill_(~torch.tril(torch.ones(q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) | |
| o = attn @ v | |
| if use_norm: | |
| z = attn.sum(-1) | |
| return o / (z[..., None] + 1e-6) | |
| else: | |
| return o | |