adalbertojunior commited on
Commit
cb68421
·
1 Parent(s): 7adcc8f

Upload roberta_layers.py

Browse files
Files changed (1) hide show
  1. roberta_layers.py +58 -57
roberta_layers.py CHANGED
@@ -202,64 +202,65 @@ class RobertaSelfAttention(nn.Module):
202
  context_layer = xformers.memory_efficient_attention(
203
  query_layer, key_layer, value_layer, p=self.dropout_prob
204
  )
 
205
 
206
- use_cache = past_key_value is not None
207
- if self.is_decoder:
208
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
209
- # Further calls to cross_attention layer can then reuse all cross-attention
210
- # key/value_states (first "if" case)
211
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
212
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
213
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
214
- # if encoder bi-directional self-attention `past_key_value` is always `None`
215
- past_key_value = (key_layer, value_layer)
216
-
217
- # Take the dot product between "query" and "key" to get the raw attention scores.
218
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
219
-
220
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
221
- query_length, key_length = query_layer.shape[2], key_layer.shape[2]
222
- if use_cache:
223
- position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
224
- -1, 1
225
- )
226
- else:
227
- position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
228
- position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
229
- distance = position_ids_l - position_ids_r
230
-
231
- positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
232
- positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
233
-
234
- if self.position_embedding_type == "relative_key":
235
- relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
236
- attention_scores = attention_scores + relative_position_scores
237
- elif self.position_embedding_type == "relative_key_query":
238
- relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
239
- relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
240
- attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
241
-
242
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
243
- if attention_mask is not None:
244
- # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
245
- attention_scores = attention_scores + attention_mask
246
-
247
- # Normalize the attention scores to probabilities.
248
- attention_probs = nn.functional.softmax(attention_scores, dim=-1)
249
-
250
- # This is actually dropping out entire tokens to attend to, which might
251
- # seem a bit unusual, but is taken from the original Transformer paper.
252
- attention_probs = self.dropout(attention_probs)
253
-
254
- # Mask heads if we want to
255
- if head_mask is not None:
256
- attention_probs = attention_probs * head_mask
257
-
258
- context_layer = torch.matmul(attention_probs, value_layer)
259
-
260
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
261
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
262
- context_layer = context_layer.view(new_context_layer_shape)
263
 
264
  outputs = (context_layer, ) if output_attentions else (context_layer,)
265
 
 
202
  context_layer = xformers.memory_efficient_attention(
203
  query_layer, key_layer, value_layer, p=self.dropout_prob
204
  )
205
+ else:
206
 
207
+ use_cache = past_key_value is not None
208
+ if self.is_decoder:
209
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
210
+ # Further calls to cross_attention layer can then reuse all cross-attention
211
+ # key/value_states (first "if" case)
212
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
213
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
214
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
215
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
216
+ past_key_value = (key_layer, value_layer)
217
+
218
+ # Take the dot product between "query" and "key" to get the raw attention scores.
219
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
220
+
221
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
222
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
223
+ if use_cache:
224
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
225
+ -1, 1
226
+ )
227
+ else:
228
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
229
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
230
+ distance = position_ids_l - position_ids_r
231
+
232
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
233
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
234
+
235
+ if self.position_embedding_type == "relative_key":
236
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
237
+ attention_scores = attention_scores + relative_position_scores
238
+ elif self.position_embedding_type == "relative_key_query":
239
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
240
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
241
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
242
+
243
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
244
+ if attention_mask is not None:
245
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
246
+ attention_scores = attention_scores + attention_mask
247
+
248
+ # Normalize the attention scores to probabilities.
249
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
250
+
251
+ # This is actually dropping out entire tokens to attend to, which might
252
+ # seem a bit unusual, but is taken from the original Transformer paper.
253
+ attention_probs = self.dropout(attention_probs)
254
+
255
+ # Mask heads if we want to
256
+ if head_mask is not None:
257
+ attention_probs = attention_probs * head_mask
258
+
259
+ context_layer = torch.matmul(attention_probs, value_layer)
260
+
261
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
262
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
263
+ context_layer = context_layer.view(new_context_layer_shape)
264
 
265
  outputs = (context_layer, ) if output_attentions else (context_layer,)
266