adalbertojunior commited on
Commit
37dd16f
·
1 Parent(s): b7ed879

Upload roberta_layers.py

Browse files
Files changed (1) hide show
  1. roberta_layers.py +15 -24
roberta_layers.py CHANGED
@@ -134,7 +134,7 @@ class RobertaEmbeddings(nn.Module):
134
  return embeddings
135
 
136
 
137
- class RobertaSelfAttention(nn.Module):
138
  """Performs multi-headed self attention on a batch of unpadded sequences.
139
  If Triton is installed, this module uses Flash Attention to greatly improve throughput.
140
  The Flash Attention implementation used in Mosaic BERT supports arbitrary attention biases (which
@@ -158,18 +158,9 @@ class RobertaSelfAttention(nn.Module):
158
  self.all_head_size = self.num_attention_heads * self.attention_head_size
159
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
160
  self.p_dropout = config.attention_probs_dropout_prob
 
161
 
162
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
163
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
164
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
165
 
166
- # self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
167
-
168
- def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
169
- new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
170
- x = x.view(new_x_shape)
171
- return x.permute(0, 2, 1, 3)
172
-
173
  def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
174
  max_seqlen_in_batch: int, indices: torch.Tensor,
175
  attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
@@ -190,17 +181,17 @@ class RobertaSelfAttention(nn.Module):
190
  Returns:
191
  attention: (total_nnz, dim)
192
  """
193
- # qkv = self.Wqkv(hidden_states)
194
- # qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1,
195
- # max_seqlen_in_batch) # batch, max_seqlen_in_batch, thd
196
- # qkv = rearrange(qkv,
197
- # 'b s (t h d) -> b s t h d',
198
- # t=3,
199
- # h=self.num_attention_heads)
200
  # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
201
- q = self.transpose_for_scores(self.query(hidden_states))#qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
202
- k = self.transpose_for_scores(self.key(hidden_states))#qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
203
- v = self.transpose_for_scores(self.value(hidden_states))#qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d
204
 
205
  if self.p_dropout or xformers_available is False:
206
 
@@ -261,12 +252,12 @@ class RobertaSelfOutput(nn.Module):
261
  return hidden_states
262
 
263
 
264
- class RobertaAttention(nn.Module):
265
  """Chains attention, Dropout, and LayerNorm for Mosaic BERT."""
266
 
267
  def __init__(self, config):
268
  super().__init__()
269
- self.self = RobertaSelfAttention(config)
270
  self.output = RobertaSelfOutput(config)
271
 
272
  def forward(
@@ -349,7 +340,7 @@ class RobertaLayer(nn.Module):
349
 
350
  def __init__(self, config):
351
  super(RobertaLayer, self).__init__()
352
- self.attention = RobertaAttention(config)
353
  self.mlp = RobertaGatedLinearUnitMLP(config)
354
 
355
  def forward(
 
134
  return embeddings
135
 
136
 
137
+ class RobertaUnpadSelfAttention(nn.Module):
138
  """Performs multi-headed self attention on a batch of unpadded sequences.
139
  If Triton is installed, this module uses Flash Attention to greatly improve throughput.
140
  The Flash Attention implementation used in Mosaic BERT supports arbitrary attention biases (which
 
158
  self.all_head_size = self.num_attention_heads * self.attention_head_size
159
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
160
  self.p_dropout = config.attention_probs_dropout_prob
161
+ self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
162
 
 
 
 
163
 
 
 
 
 
 
 
 
164
  def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
165
  max_seqlen_in_batch: int, indices: torch.Tensor,
166
  attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
 
181
  Returns:
182
  attention: (total_nnz, dim)
183
  """
184
+ qkv = self.Wqkv(hidden_states)
185
+ qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1,
186
+ max_seqlen_in_batch) # batch, max_seqlen_in_batch, thd
187
+ qkv = rearrange(qkv,
188
+ 'b s (t h d) -> b s t h d',
189
+ t=3,
190
+ h=self.num_attention_heads)
191
  # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
192
+ q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
193
+ k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
194
+ v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d
195
 
196
  if self.p_dropout or xformers_available is False:
197
 
 
252
  return hidden_states
253
 
254
 
255
+ class RobertaUnpadAttention(nn.Module):
256
  """Chains attention, Dropout, and LayerNorm for Mosaic BERT."""
257
 
258
  def __init__(self, config):
259
  super().__init__()
260
+ self.self = RobertaUnpadSelfAttention(config)
261
  self.output = RobertaSelfOutput(config)
262
 
263
  def forward(
 
340
 
341
  def __init__(self, config):
342
  super(RobertaLayer, self).__init__()
343
+ self.attention = RobertaUnpadAttention(config)
344
  self.mlp = RobertaGatedLinearUnitMLP(config)
345
 
346
  def forward(