| """Minimal AGILLM-4 sublinear attention V2 snippets. |
| |
| These are the core blocks now folded into `nB300_agillm4_vram_dblock.py`. |
| """ |
|
|
| |
| anchors = self._sublinear_anchor_positions(k_len, device) |
|
|
| |
| if anchors.numel() and self.sublinear_pooled_landmarks and self.sublinear_stride > 1: |
| ends = anchors + 1 |
| starts = (ends - self.sublinear_stride).clamp_min(0) |
| zero_k = k.new_zeros(k.size(0), k.size(1), 1, k.size(3)) |
| zero_v = v.new_zeros(v.size(0), v.size(1), 1, v.size(3)) |
| prefix_k = torch.cat([zero_k, k.cumsum(dim=2)], dim=2) |
| prefix_v = torch.cat([zero_v, v.cumsum(dim=2)], dim=2) |
| denom = (ends - starts).to(dtype=k.dtype).view(1, 1, -1, 1).clamp_min(1) |
| anchor_k = (prefix_k[:, :, ends, :] - prefix_k[:, :, starts, :]) / denom |
| anchor_v = (prefix_v[:, :, ends, :] - prefix_v[:, :, starts, :]) / denom |
|
|
| |
| anchor_idx = anchors.view(1, -1).expand(cur, -1) |
| local_lo = (q_pos - self.sublinear_window).clamp_min(0).view(-1, 1) |
| local_hi = (q_pos + self.sublinear_window).clamp_max(max(0, k_len - 1)).view(-1, 1) |
| anchor_valid = (anchor_idx < local_lo) | (anchor_idx > local_hi) |
|
|
| |
| dist = (q_pos.view(1, 1, cur, 1) - idx.view(1, 1, cur, -1)).abs().to(torch.float32) |
| scores = scores + (-slopes * dist).to(scores.dtype) |
|
|