Abner0803 commited on
Commit
113dff4
·
verified ·
1 Parent(s): 011fa3c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +81 -0
README.md CHANGED
@@ -152,6 +152,87 @@ class BaseTransformerComp(nn.Module):
152
  return mask
153
  ```
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  ### RPB Components
156
 
157
  ```python
 
152
  return mask
153
  ```
154
 
155
+ ### Transformer Encoder Layer with RPB
156
+
157
+ ```python
158
+ class TransformerEncoderLayerWithRPB(nn.Module):
159
+ def __init__(
160
+ self,
161
+ d_model: int,
162
+ nhead: int,
163
+ dim_feedforward: int,
164
+ dropout: float,
165
+ rbp,
166
+ ):
167
+ super().__init__()
168
+ self.d_model = d_model
169
+ self.nhead = nhead
170
+ self.rbp = rbp
171
+
172
+ # QKV projections
173
+ self.qkv_proj = nn.Linear(d_model, 3 * d_model)
174
+ self.out_proj = nn.Linear(d_model, d_model)
175
+
176
+ # FFN layers
177
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
178
+ self.dropout = nn.Dropout(dropout)
179
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
180
+
181
+ # Normalization and dropout
182
+ self.norm1 = nn.LayerNorm(d_model)
183
+ self.norm2 = nn.LayerNorm(d_model)
184
+ self.dropout1 = nn.Dropout(dropout)
185
+ self.dropout2 = nn.Dropout(dropout)
186
+ self.activation = F.relu
187
+
188
+ def forward(
189
+ self,
190
+ src: torch.Tensor,
191
+ src_mask: Optional[torch.Tensor] = None,
192
+ src_key_padding_mask: Optional[torch.Tensor] = None,
193
+ is_causal: bool = False,
194
+ ) -> torch.Tensor:
195
+ seq_len, batch_size, d_model = src.shape
196
+ head_dim = d_model // self.nhead
197
+ qkv = self.qkv_proj(src)
198
+ q, k, v = qkv.chunk(3, dim=-1)
199
+ q = q.reshape(seq_len, batch_size, self.nhead, head_dim).permute(1, 2, 0, 3)
200
+ k = k.reshape(seq_len, batch_size, self.nhead, head_dim).permute(1, 2, 0, 3)
201
+ v = v.reshape(seq_len, batch_size, self.nhead, head_dim).permute(1, 2, 0, 3)
202
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)
203
+
204
+ # Add RBP after QK^T
205
+ rbp_bias = self.rbp(
206
+ seq_len, seq_len, device=src.device
207
+ ) # [nhead, seq_len, seq_len]
208
+ attn_weights = attn_weights + rbp_bias.unsqueeze(
209
+ 0
210
+ ) # [batch, nhead, seq_len, seq_len]
211
+
212
+ if src_mask is not None:
213
+ attn_weights = attn_weights + src_mask.unsqueeze(0).unsqueeze(0)
214
+
215
+ if src_key_padding_mask is not None:
216
+ attn_weights = attn_weights.masked_fill(
217
+ src_key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
218
+ )
219
+
220
+ attn_weights = F.softmax(attn_weights, dim=-1)
221
+ attn_weights = self.dropout1(attn_weights)
222
+ attn_output = torch.matmul(attn_weights, v) # [batch, nhead, seq_len, head_dim]
223
+ attn_output = attn_output.permute(2, 0, 1, 3).reshape(
224
+ seq_len, batch_size, d_model
225
+ )
226
+ attn_output = self.out_proj(attn_output)
227
+ src2 = src + self.dropout1(attn_output)
228
+ src2 = self.norm1(src2)
229
+ ffn_output = self.linear2(self.dropout(self.activation(self.linear1(src2))))
230
+ src3 = src2 + self.dropout2(ffn_output)
231
+ src3 = self.norm2(src3)
232
+
233
+ return src3
234
+ ```
235
+
236
  ### RPB Components
237
 
238
  ```python