| """ |
| A collection of attention layers. |
| """ |
|
|
| import torch |
|
|
|
|
| class Attention(torch.nn.Module): |
| """ |
| Basic but flexible attention module. |
| """ |
|
|
| def __init__( |
| self, |
| hidden_dim, |
| num_heads, |
| bias, |
| use_rope, |
| context_window, |
| is_causal, |
| group_size, |
| ): |
| super().__init__() |
| assert hidden_dim % num_heads == 0, "Hidden dim must be divisible by num heads" |
|
|
| |
| self.c_attn = torch.nn.Linear( |
| hidden_dim, hidden_dim + 2 * hidden_dim // group_size, bias=bias |
| ) |
|
|
| |
| self.c_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=bias) |
|
|
| |
| self.attn_dropout = torch.nn.Dropout() |
|
|
| self.num_heads = num_heads |
| self.group_size = group_size |
| self.is_causal = is_causal |
|
|
| |
| self.use_rope = use_rope |
| if self.use_rope: |
| assert context_window % 2 == 0 |
| self.freqs_cis = compute_freqs_cis( |
| seq_len=context_window, head_dim=hidden_dim // num_heads |
| ) |
|
|
| def forward(self, x, attention_mask=None): |
| """ |
| Forward pass |
| """ |
| B, S, H = x.size() |
| num_head_groups = self.num_heads // self.group_size |
| hidden_dim_per_group_member = H // self.group_size |
|
|
| |
| |
| q, k, v = self.c_attn(x).split([H, hidden_dim_per_group_member, hidden_dim_per_group_member], dim=-1) |
| Sq = q.size(1) |
| Skv = k.size(1) |
|
|
| q = q.view(B, S, self.num_heads, H // self.num_heads) |
| k = k.view(B, S, num_head_groups, H // self.num_heads) |
| v = v.view(B, Skv, num_head_groups, H // self.num_heads) |
|
|
| if self.use_rope: |
| q, k = apply_rotary_emb( |
| q, k, freqs_cis=self.freqs_cis[:S].to(x.device)) |
|
|
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| |
| k = k.repeat_interleave(self.group_size, dim=1) |
| v = v.repeat_interleave(self.group_size, dim=1) |
|
|
| |
| attention_mask = attention_mask.unsqueeze(1) if attention_mask is not None else None |
|
|
| |
| |
| y = torch.nn.functional.scaled_dot_product_attention( |
| query=q, |
| key=k, |
| value=v, |
| attn_mask=attention_mask, |
| dropout_p=self.attn_dropout.p if self.training else 0, |
| is_causal=False if attention_mask is not None else self.is_causal, |
| ) |
|
|
| y = ( |
| y.transpose(1, 2).contiguous().view(B, Sq, H) |
| ) |
|
|
| |
| y = self.attn_dropout(self.c_proj(y)) |
|
|
| return y |
|
|
|
|
| def _reshape_for_broadcast(freqs_cis, x): |
| ndim = x.ndim |
| assert 0 <= 1 < ndim |
| assert freqs_cis.shape == (x.shape[1], x.shape[-1]) |
| shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] |
| return freqs_cis.view(*shape) |
|
|
|
|
| def apply_rotary_emb(xq, xk, freqs_cis): |
| """ |
| Apply the rotary embedding to the query and key |
| """ |
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
| freqs_cis = _reshape_for_broadcast(freqs_cis, xq_) |
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) |
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) |
| return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
| def compute_freqs_cis(seq_len, head_dim): |
| """Computes complex frequences used for rotary positional encodings""" |
| freqs = 1.0 / ( |
| 10_000 ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim) |
| ) |
| t = torch.arange(seq_len * 2, device=freqs.device, dtype=torch.float32) |
| freqs = torch.outer(t, freqs) |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| return freqs_cis |
|
|
|
|
| ATTENTION_DICT = { |
| "generic": lambda hidden_dim, context_window, use_rope, attn_cfg: Attention( |
| hidden_dim=hidden_dim, |
| num_heads=attn_cfg["num_heads"], |
| bias=attn_cfg["bias"], |
| use_rope=use_rope, |
| context_window=context_window, |
| is_causal=attn_cfg["is_causal"], |
| group_size=attn_cfg["group_size"], |
| ) |
| } |
|
|
|
|
| def build_attention(hidden_dim, context_window, use_rope, attn_cfg): |
| """ |
| Build an attention layer |
| |
| Args: |
| hidden_dim: hidden dimension |
| context_window: context window |
| use_rope: whether to use rope |
| attn_cfg: attention config |
| """ |
| return ATTENTION_DICT[attn_cfg["attn_type"]]( |
| hidden_dim=hidden_dim, |
| context_window=context_window, |
| use_rope=use_rope, |
| attn_cfg=attn_cfg, |
| ) |
|
|