Commit ·
cb68421
1
Parent(s): 7adcc8f
Upload roberta_layers.py
Browse files- roberta_layers.py +58 -57
roberta_layers.py
CHANGED
|
@@ -202,64 +202,65 @@ class RobertaSelfAttention(nn.Module):
|
|
| 202 |
context_layer = xformers.memory_efficient_attention(
|
| 203 |
query_layer, key_layer, value_layer, p=self.dropout_prob
|
| 204 |
)
|
|
|
|
| 205 |
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
|
| 264 |
outputs = (context_layer, ) if output_attentions else (context_layer,)
|
| 265 |
|
|
|
|
| 202 |
context_layer = xformers.memory_efficient_attention(
|
| 203 |
query_layer, key_layer, value_layer, p=self.dropout_prob
|
| 204 |
)
|
| 205 |
+
else:
|
| 206 |
|
| 207 |
+
use_cache = past_key_value is not None
|
| 208 |
+
if self.is_decoder:
|
| 209 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 210 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 211 |
+
# key/value_states (first "if" case)
|
| 212 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 213 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 214 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 215 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 216 |
+
past_key_value = (key_layer, value_layer)
|
| 217 |
+
|
| 218 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 219 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 220 |
+
|
| 221 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 222 |
+
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
| 223 |
+
if use_cache:
|
| 224 |
+
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
| 225 |
+
-1, 1
|
| 226 |
+
)
|
| 227 |
+
else:
|
| 228 |
+
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 229 |
+
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 230 |
+
distance = position_ids_l - position_ids_r
|
| 231 |
+
|
| 232 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 233 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 234 |
+
|
| 235 |
+
if self.position_embedding_type == "relative_key":
|
| 236 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 237 |
+
attention_scores = attention_scores + relative_position_scores
|
| 238 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 239 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 240 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 241 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 242 |
+
|
| 243 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 244 |
+
if attention_mask is not None:
|
| 245 |
+
# Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
|
| 246 |
+
attention_scores = attention_scores + attention_mask
|
| 247 |
+
|
| 248 |
+
# Normalize the attention scores to probabilities.
|
| 249 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 250 |
+
|
| 251 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 252 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 253 |
+
attention_probs = self.dropout(attention_probs)
|
| 254 |
+
|
| 255 |
+
# Mask heads if we want to
|
| 256 |
+
if head_mask is not None:
|
| 257 |
+
attention_probs = attention_probs * head_mask
|
| 258 |
+
|
| 259 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 260 |
+
|
| 261 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 262 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 263 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
| 264 |
|
| 265 |
outputs = (context_layer, ) if output_attentions else (context_layer,)
|
| 266 |
|