Update modeling_nort5.py
Browse files- modeling_nort5.py +9 -8
modeling_nort5.py
CHANGED
|
@@ -62,7 +62,7 @@ class Decoder(nn.Module):
|
|
| 62 |
self_relative_embedding = self.self_relative_embedding()
|
| 63 |
cross_relative_embedding = self.cross_relative_embedding()
|
| 64 |
|
| 65 |
-
if past_key_values is
|
| 66 |
autoreg_mask = torch.triu(
|
| 67 |
torch.full((x.size(0), x.size(0)), True, device=x.device),
|
| 68 |
diagonal=1
|
|
@@ -259,12 +259,12 @@ class Attention(nn.Module):
|
|
| 259 |
|
| 260 |
if past_key_value is not None:
|
| 261 |
if not self.is_cross_attention:
|
| 262 |
-
key = torch.cat([past_key_value[0], key], dim=1)
|
| 263 |
-
value = torch.cat([past_key_value[1], value], dim=1)
|
| 264 |
key_len = key.size(1)
|
| 265 |
elif past_key_value[0].size(1) == kv.size(0):
|
| 266 |
-
key = past_key_value[0]
|
| 267 |
-
value = past_key_value[1]
|
| 268 |
|
| 269 |
if self.position_indices.size(0) < max(query_len, key_len):
|
| 270 |
position_indices = torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(1) \
|
|
@@ -306,7 +306,10 @@ class Attention(nn.Module):
|
|
| 306 |
context = self.post_layer_norm(context)
|
| 307 |
context = self.dropout(context)
|
| 308 |
|
| 309 |
-
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
|
| 312 |
class WordEmbedding(nn.Module):
|
|
@@ -662,9 +665,7 @@ class NorT5ForConditionalGeneration(NorT5Model):
|
|
| 662 |
reordered_layer_past_states = ()
|
| 663 |
for layer_past_state in layer_past_states:
|
| 664 |
# need to set correct `past` for each of the four key / value states
|
| 665 |
-
layer_past_state = layer_past_state.unflatten(0, (-1, self.config.num_attention_heads))
|
| 666 |
layer_past_state = layer_past_state.index_select(0, beam_idx.to(layer_past_state.device))
|
| 667 |
-
layer_past_state = layer_past_state.flatten(0, 1)
|
| 668 |
reordered_layer_past_states = reordered_layer_past_states + (layer_past_state,)
|
| 669 |
|
| 670 |
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
|
|
|
|
| 62 |
self_relative_embedding = self.self_relative_embedding()
|
| 63 |
cross_relative_embedding = self.cross_relative_embedding()
|
| 64 |
|
| 65 |
+
if past_key_values is None:
|
| 66 |
autoreg_mask = torch.triu(
|
| 67 |
torch.full((x.size(0), x.size(0)), True, device=x.device),
|
| 68 |
diagonal=1
|
|
|
|
| 259 |
|
| 260 |
if past_key_value is not None:
|
| 261 |
if not self.is_cross_attention:
|
| 262 |
+
key = torch.cat([past_key_value[0].flatten(0, 1), key], dim=1)
|
| 263 |
+
value = torch.cat([past_key_value[1].flatten(0, 1), value], dim=1)
|
| 264 |
key_len = key.size(1)
|
| 265 |
elif past_key_value[0].size(1) == kv.size(0):
|
| 266 |
+
key = past_key_value[0].flatten(0, 1)
|
| 267 |
+
value = past_key_value[1].flatten(0, 1)
|
| 268 |
|
| 269 |
if self.position_indices.size(0) < max(query_len, key_len):
|
| 270 |
position_indices = torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(1) \
|
|
|
|
| 306 |
context = self.post_layer_norm(context)
|
| 307 |
context = self.dropout(context)
|
| 308 |
|
| 309 |
+
key = key.detach().unflatten(0, (-1, self.num_heads))
|
| 310 |
+
value = value.detach().unflatten(0, (-1, self.num_heads))
|
| 311 |
+
|
| 312 |
+
return context, attention_probs.detach(), (key, value)
|
| 313 |
|
| 314 |
|
| 315 |
class WordEmbedding(nn.Module):
|
|
|
|
| 665 |
reordered_layer_past_states = ()
|
| 666 |
for layer_past_state in layer_past_states:
|
| 667 |
# need to set correct `past` for each of the four key / value states
|
|
|
|
| 668 |
layer_past_state = layer_past_state.index_select(0, beam_idx.to(layer_past_state.device))
|
|
|
|
| 669 |
reordered_layer_past_states = reordered_layer_past_states + (layer_past_state,)
|
| 670 |
|
| 671 |
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
|