| from torch import nn |
| from icecream import ic |
| from einops import rearrange |
|
|
| class ScaleDotProductAttention(nn.Module): |
| |
| def __init__(self, layer_number, causal=False, softmax_scale=None, attention_dropout=0.0): |
| super().__init__() |
| self.layer_number = layer_number |
| self.causal = causal |
| self.softmax_scale = softmax_scale |
| self.dropout_p = attention_dropout |
|
|
| |
|
|
| def forward(self, q, k, v, attn_mask=None, order='sbhd'): |
| """Implements the multihead softmax attention. |
| Arguments |
| --------- |
| q, k, v: The tensor containing the query, key, and value. (B, S, H, D) |
| """ |
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| if order == 'sbhd': |
| q, k, v = [rearrange(x, 's b h d -> b h s d').contiguous() |
| for x in (q, k, v)] |
| elif order == 'bhsd': |
| pass |
|
|
| if attn_mask is not None: |
| attn_mask = (~attn_mask.clone().bool()).contiguous() |
| else: |
| attn_mask = None |
| |
| if self.training: |
| |
| if self.causal: |
| assert q.shape[-2] == k.shape[-2] |
| is_causal = self.causal |
| dropout_p = self.dropout_p |
| else: |
| |
| |
| if self.causal: |
| is_causal = q.shape[-2] == k.shape[-2] |
| else: |
| is_causal = self.causal |
| dropout_p = 0.0 |
|
|
| |
| o = F.scaled_dot_product_attention(q, k, v, |
| attn_mask=attn_mask, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| scale=self.softmax_scale |
| ) |
| |
| o = rearrange(o, 'B Head L D -> L B (Head D)').contiguous() |
| return o |