File size: 5,445 Bytes
e2bfccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""Gamma SSM Block with residual connections and normalization."""

import torch
import torch.nn as nn
from typing import Optional, Tuple

from .ssm_gamma import SSMGamma
from .normalization import LayerNorm


class GammaSingleBlock(nn.Module):
    """

    Single Gamma SSM Block with residual connection and layer normalization.

    

    Performs: y = Block(LayerNorm(x)) + x (if prenorm=True)

              or y = LayerNorm(Block(x) + x) (if prenorm=False)

    

    Args:

        d_model: Model dimension

        hidden_dim: Hidden dimension for the SSM state

        delta_t: Time discretization step (default: 0.1)

        kernel_length: Convolution kernel length for future use (default: 4)

        A_type: Type of A matrix initialization (default: "tridiagonal")

        prenorm: Use prenorm (LayerNorm -> Block) vs postnorm (Block -> LayerNorm) (default: True)

        residual_scale: Scaling factor for residual connection (default: 1.0)

        dropout: Dropout rate after block (default: 0.0)

    

    Shape:

        - Input: (batch, seq_len, d_model)

        - Output: (batch, seq_len, d_model)

    """
    
    def __init__(

        self,

        d_model: int,

        hidden_dim: int,

        delta_t: float = 0.1,

        kernel_length: int = 4,

        A_type: str = "tridiagonal",

        prenorm: bool = True,

        residual_scale: float = 1.0,

        dropout: float = 0.0,

    ):
        super().__init__()
        self.d_model = d_model
        self.prenorm = prenorm
        self.dropout_p = dropout
        self.residual_scale = residual_scale
        
        # Normalization
        self.norm = LayerNorm(d_model)
        
        # SSM block
        self.ssm = SSMGamma(
            state_dim=d_model,
            hidden_dim=hidden_dim,
            delta_t=delta_t,
            kernel_length=kernel_length,
            A_type=A_type,
        )
        
        # Dropout
        if dropout > 0:
            self.dropout = nn.Dropout(dropout)
        else:
            self.dropout = None
    
    def forward(

        self,

        x: torch.Tensor,

        state: Optional[torch.Tensor] = None,

        mask: Optional[torch.Tensor] = None,

        return_state: bool = True,

    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """

        Forward pass through block.

        

        Args:

            x: Input tensor (batch, seq_len, d_model)

            state: Optional initial hidden state (batch, hidden_dim)

            mask: Optional mask (batch, seq_len) for valid positions

        

        Returns:

            output: (batch, seq_len, d_model)

            final_state: Final hidden state from SSM (batch, hidden_dim)

        """
        if self.prenorm:
            # Apply norm before SSM
            x_norm = self.norm(x)
            ssm_out, final_state = self.ssm(x_norm, mask=mask, state=state)
        else:
            # Apply SSM first, then norm
            ssm_out, final_state = self.ssm(x, mask=mask, state=state)
            ssm_out = self.norm(ssm_out)
        
        # Apply dropout if present
        if self.dropout is not None:
            ssm_out = self.dropout(ssm_out)
        
        # Residual connection with optional scaling
        output = x * self.residual_scale + ssm_out
        
        # Apply final norm if postnorm
        if not self.prenorm:
            output = self.norm(output)

        if not return_state:
            final_state = None
        return output, final_state
    
    def step(self, u: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Single step inference through block (RNN style).

        

        Args:

            u: Input tensor (batch, d_model) - single timestep

            h: Hidden state (batch, hidden_dim)

        

        Returns:

            output: (batch, d_model) - block output

            h_new: (batch, hidden_dim) - new hidden state

        """
        if self.prenorm:
            # Apply norm before SSM
            u_norm = self.norm(u)
            ssm_out, h_new = self.ssm.step(u_norm, h)
        else:
            # Apply SSM first, then norm
            ssm_out, h_new = self.ssm.step(u, h)
            ssm_out = self.norm(ssm_out)
        
        # Apply dropout if present
        if self.dropout is not None:
            ssm_out = self.dropout(ssm_out)
        
        # Residual connection with optional scaling
        output = u * self.residual_scale + ssm_out
        
        return output, h_new
    
    def allocate_inference_cache(

        self,

        batch_size: int,

        seq_len: int,

        device: torch.device,

        dtype: torch.dtype,

    ):
        """Allocate cache for efficient inference."""
        return self.ssm.allocate_inference_cache(batch_size, seq_len, device, dtype)

    def allocate_deployment_cache(

        self,

        batch_size: int,

        seq_len: int,

        device: torch.device,

        dtype: torch.dtype,

    ):
        return self.allocate_inference_cache(batch_size, seq_len, device, dtype)

    def allocate_balanced_deployment_cache(

        self,

        batch_size: int,

        seq_len: int,

        device: torch.device,

        dtype: torch.dtype,

    ):
        return self.allocate_inference_cache(batch_size, seq_len, device, dtype)