Update README.md
Browse files
README.md
CHANGED
|
@@ -152,6 +152,87 @@ class BaseTransformerComp(nn.Module):
|
|
| 152 |
return mask
|
| 153 |
```
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
### RPB Components
|
| 156 |
|
| 157 |
```python
|
|
|
|
| 152 |
return mask
|
| 153 |
```
|
| 154 |
|
| 155 |
+
### Transformer Encoder Layer with RPB
|
| 156 |
+
|
| 157 |
+
```python
|
| 158 |
+
class TransformerEncoderLayerWithRPB(nn.Module):
|
| 159 |
+
def __init__(
|
| 160 |
+
self,
|
| 161 |
+
d_model: int,
|
| 162 |
+
nhead: int,
|
| 163 |
+
dim_feedforward: int,
|
| 164 |
+
dropout: float,
|
| 165 |
+
rbp,
|
| 166 |
+
):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.d_model = d_model
|
| 169 |
+
self.nhead = nhead
|
| 170 |
+
self.rbp = rbp
|
| 171 |
+
|
| 172 |
+
# QKV projections
|
| 173 |
+
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
|
| 174 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 175 |
+
|
| 176 |
+
# FFN layers
|
| 177 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 178 |
+
self.dropout = nn.Dropout(dropout)
|
| 179 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 180 |
+
|
| 181 |
+
# Normalization and dropout
|
| 182 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 183 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 184 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 185 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 186 |
+
self.activation = F.relu
|
| 187 |
+
|
| 188 |
+
def forward(
|
| 189 |
+
self,
|
| 190 |
+
src: torch.Tensor,
|
| 191 |
+
src_mask: Optional[torch.Tensor] = None,
|
| 192 |
+
src_key_padding_mask: Optional[torch.Tensor] = None,
|
| 193 |
+
is_causal: bool = False,
|
| 194 |
+
) -> torch.Tensor:
|
| 195 |
+
seq_len, batch_size, d_model = src.shape
|
| 196 |
+
head_dim = d_model // self.nhead
|
| 197 |
+
qkv = self.qkv_proj(src)
|
| 198 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 199 |
+
q = q.reshape(seq_len, batch_size, self.nhead, head_dim).permute(1, 2, 0, 3)
|
| 200 |
+
k = k.reshape(seq_len, batch_size, self.nhead, head_dim).permute(1, 2, 0, 3)
|
| 201 |
+
v = v.reshape(seq_len, batch_size, self.nhead, head_dim).permute(1, 2, 0, 3)
|
| 202 |
+
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)
|
| 203 |
+
|
| 204 |
+
# Add RBP after QK^T
|
| 205 |
+
rbp_bias = self.rbp(
|
| 206 |
+
seq_len, seq_len, device=src.device
|
| 207 |
+
) # [nhead, seq_len, seq_len]
|
| 208 |
+
attn_weights = attn_weights + rbp_bias.unsqueeze(
|
| 209 |
+
0
|
| 210 |
+
) # [batch, nhead, seq_len, seq_len]
|
| 211 |
+
|
| 212 |
+
if src_mask is not None:
|
| 213 |
+
attn_weights = attn_weights + src_mask.unsqueeze(0).unsqueeze(0)
|
| 214 |
+
|
| 215 |
+
if src_key_padding_mask is not None:
|
| 216 |
+
attn_weights = attn_weights.masked_fill(
|
| 217 |
+
src_key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
| 221 |
+
attn_weights = self.dropout1(attn_weights)
|
| 222 |
+
attn_output = torch.matmul(attn_weights, v) # [batch, nhead, seq_len, head_dim]
|
| 223 |
+
attn_output = attn_output.permute(2, 0, 1, 3).reshape(
|
| 224 |
+
seq_len, batch_size, d_model
|
| 225 |
+
)
|
| 226 |
+
attn_output = self.out_proj(attn_output)
|
| 227 |
+
src2 = src + self.dropout1(attn_output)
|
| 228 |
+
src2 = self.norm1(src2)
|
| 229 |
+
ffn_output = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
| 230 |
+
src3 = src2 + self.dropout2(ffn_output)
|
| 231 |
+
src3 = self.norm2(src3)
|
| 232 |
+
|
| 233 |
+
return src3
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
### RPB Components
|
| 237 |
|
| 238 |
```python
|