Commit ·
37dd16f
1
Parent(s): b7ed879
Upload roberta_layers.py
Browse files- roberta_layers.py +15 -24
roberta_layers.py
CHANGED
|
@@ -134,7 +134,7 @@ class RobertaEmbeddings(nn.Module):
|
|
| 134 |
return embeddings
|
| 135 |
|
| 136 |
|
| 137 |
-
class
|
| 138 |
"""Performs multi-headed self attention on a batch of unpadded sequences.
|
| 139 |
If Triton is installed, this module uses Flash Attention to greatly improve throughput.
|
| 140 |
The Flash Attention implementation used in Mosaic BERT supports arbitrary attention biases (which
|
|
@@ -158,18 +158,9 @@ class RobertaSelfAttention(nn.Module):
|
|
| 158 |
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 159 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 160 |
self.p_dropout = config.attention_probs_dropout_prob
|
|
|
|
| 161 |
|
| 162 |
-
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 163 |
-
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 164 |
-
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 165 |
|
| 166 |
-
# self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
|
| 167 |
-
|
| 168 |
-
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 169 |
-
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 170 |
-
x = x.view(new_x_shape)
|
| 171 |
-
return x.permute(0, 2, 1, 3)
|
| 172 |
-
|
| 173 |
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
|
| 174 |
max_seqlen_in_batch: int, indices: torch.Tensor,
|
| 175 |
attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
|
|
@@ -190,17 +181,17 @@ class RobertaSelfAttention(nn.Module):
|
|
| 190 |
Returns:
|
| 191 |
attention: (total_nnz, dim)
|
| 192 |
"""
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
# if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
|
| 201 |
-
q =
|
| 202 |
-
k =
|
| 203 |
-
v =
|
| 204 |
|
| 205 |
if self.p_dropout or xformers_available is False:
|
| 206 |
|
|
@@ -261,12 +252,12 @@ class RobertaSelfOutput(nn.Module):
|
|
| 261 |
return hidden_states
|
| 262 |
|
| 263 |
|
| 264 |
-
class
|
| 265 |
"""Chains attention, Dropout, and LayerNorm for Mosaic BERT."""
|
| 266 |
|
| 267 |
def __init__(self, config):
|
| 268 |
super().__init__()
|
| 269 |
-
self.self =
|
| 270 |
self.output = RobertaSelfOutput(config)
|
| 271 |
|
| 272 |
def forward(
|
|
@@ -349,7 +340,7 @@ class RobertaLayer(nn.Module):
|
|
| 349 |
|
| 350 |
def __init__(self, config):
|
| 351 |
super(RobertaLayer, self).__init__()
|
| 352 |
-
self.attention =
|
| 353 |
self.mlp = RobertaGatedLinearUnitMLP(config)
|
| 354 |
|
| 355 |
def forward(
|
|
|
|
| 134 |
return embeddings
|
| 135 |
|
| 136 |
|
| 137 |
+
class RobertaUnpadSelfAttention(nn.Module):
|
| 138 |
"""Performs multi-headed self attention on a batch of unpadded sequences.
|
| 139 |
If Triton is installed, this module uses Flash Attention to greatly improve throughput.
|
| 140 |
The Flash Attention implementation used in Mosaic BERT supports arbitrary attention biases (which
|
|
|
|
| 158 |
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 159 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 160 |
self.p_dropout = config.attention_probs_dropout_prob
|
| 161 |
+
self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
|
| 162 |
|
|
|
|
|
|
|
|
|
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
|
| 165 |
max_seqlen_in_batch: int, indices: torch.Tensor,
|
| 166 |
attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 181 |
Returns:
|
| 182 |
attention: (total_nnz, dim)
|
| 183 |
"""
|
| 184 |
+
qkv = self.Wqkv(hidden_states)
|
| 185 |
+
qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1,
|
| 186 |
+
max_seqlen_in_batch) # batch, max_seqlen_in_batch, thd
|
| 187 |
+
qkv = rearrange(qkv,
|
| 188 |
+
'b s (t h d) -> b s t h d',
|
| 189 |
+
t=3,
|
| 190 |
+
h=self.num_attention_heads)
|
| 191 |
# if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
|
| 192 |
+
q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
|
| 193 |
+
k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
|
| 194 |
+
v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d
|
| 195 |
|
| 196 |
if self.p_dropout or xformers_available is False:
|
| 197 |
|
|
|
|
| 252 |
return hidden_states
|
| 253 |
|
| 254 |
|
| 255 |
+
class RobertaUnpadAttention(nn.Module):
|
| 256 |
"""Chains attention, Dropout, and LayerNorm for Mosaic BERT."""
|
| 257 |
|
| 258 |
def __init__(self, config):
|
| 259 |
super().__init__()
|
| 260 |
+
self.self = RobertaUnpadSelfAttention(config)
|
| 261 |
self.output = RobertaSelfOutput(config)
|
| 262 |
|
| 263 |
def forward(
|
|
|
|
| 340 |
|
| 341 |
def __init__(self, config):
|
| 342 |
super(RobertaLayer, self).__init__()
|
| 343 |
+
self.attention = RobertaUnpadAttention(config)
|
| 344 |
self.mlp = RobertaGatedLinearUnitMLP(config)
|
| 345 |
|
| 346 |
def forward(
|