File size: 2,885 Bytes
e101805
 
 
 
 
 
 
 
 
 
 
 
 
 
2a1ba39
e101805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from typing import Callable, List, Optional

import torch
from torch import Tensor, nn

from .attention import SelfAttention
from .ffn_layers import SwiGLUFFN

class SelfAttentionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        ffn_ratio: float = 4.0,
        qkv_bias: bool = False,
        proj_bias: bool = True,
        ffn_bias: bool = True,
        ffn_drop: float = 0.0,
        attn_drop: float = 0.0,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
        attn_class: Callable[..., nn.Module] = SelfAttention,
        ffn_layer: Callable[..., nn.Module] = SwiGLUFFN,
        dual_attention: bool = False,
        device=None,
    ) -> None:
        super().__init__()
        self._dual_attention = bool(dual_attention)
        self.norm1 = norm_layer(dim)
        self.attn = attn_class(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            proj_bias=proj_bias,
            proj_drop=attn_drop,
            device=device,
        )
        self.nope_attn = (
            attn_class(
                dim,
                num_heads=num_heads,
                qkv_bias=qkv_bias,
                proj_bias=proj_bias,
                proj_drop=attn_drop,
                device=device,
            )
            if self._dual_attention
            else None
        )

        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * ffn_ratio)
        self.mlp = ffn_layer(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            drop=ffn_drop,
            bias=ffn_bias,
            device=device,
        )

    def compute_attention_output(
        self,
        x: Tensor,
        rope: tuple[Tensor, Tensor] | None = None,
        attn_mask: Tensor | None = None,
        use_nope_attn: bool = False,
    ) -> Tensor:
        """Compute self-attention output without adding the residual."""
        if use_nope_attn:
            if self.nope_attn is None:
                raise RuntimeError("nope_attn is not initialized; instantiate with dual_attention=True")
            attn_module = self.nope_attn
        else:
            attn_module = self.attn
        return attn_module(self.norm1(x), attn_mask=attn_mask, rope=rope)

    def forward_with_attention_output(self, x: Tensor, attn_output: Tensor) -> Tensor:
        """Apply residual connection + FFN given a precomputed attention output."""
        x_attn = x + attn_output
        x_ffn = x_attn + self.mlp(self.norm2(x_attn))
        return x_ffn

    def forward(self, x: Tensor, rope: tuple[Tensor, Tensor] | None = None, attn_mask: Tensor | None = None) -> Tensor:
        """Forward for batched tensor inputs only."""
        attn_output = self.compute_attention_output(x, rope=rope, attn_mask=attn_mask)
        return self.forward_with_attention_output(x, attn_output)