Fix mask dimensions: unsqueeze 3D->4D for OpenNMT-py 3.x attention
Browse files
molscribe/transformer/decoder.py
CHANGED
|
@@ -172,6 +172,13 @@ class TransformerDecoderLayerBase(nn.Module):
|
|
| 172 |
layer_cache["memory_keys"] = cache_dict.get("keys")
|
| 173 |
layer_cache["memory_values"] = cache_dict.get("values")
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
def _forward_self_attn(self, inputs_norm, dec_mask, layer_cache, step):
|
| 176 |
if isinstance(self.self_attn, MultiHeadedAttention):
|
| 177 |
# OpenNMT-py 3.x: layer_cache and attn_type are instance attributes
|
|
@@ -181,7 +188,7 @@ class TransformerDecoderLayerBase(nn.Module):
|
|
| 181 |
inputs_norm,
|
| 182 |
inputs_norm,
|
| 183 |
inputs_norm,
|
| 184 |
-
mask=dec_mask,
|
| 185 |
)
|
| 186 |
self._from_onmt3_cache(self.self_attn.layer_cache, layer_cache, "self")
|
| 187 |
return result
|
|
@@ -309,7 +316,7 @@ class TransformerDecoderLayer(TransformerDecoderLayerBase):
|
|
| 309 |
memory_bank,
|
| 310 |
memory_bank,
|
| 311 |
query_norm,
|
| 312 |
-
mask=src_pad_mask,
|
| 313 |
)
|
| 314 |
self._from_onmt3_cache(self.context_attn.layer_cache, layer_cache, "context")
|
| 315 |
output = self.feed_forward(self.drop(mid) + query)
|
|
|
|
| 172 |
layer_cache["memory_keys"] = cache_dict.get("keys")
|
| 173 |
layer_cache["memory_values"] = cache_dict.get("values")
|
| 174 |
|
| 175 |
+
@staticmethod
|
| 176 |
+
def _expand_mask(mask):
|
| 177 |
+
"""Expand 3D mask (B, 1, L) to 4D (B, 1, 1, L) for OpenNMT-py 3.x."""
|
| 178 |
+
if mask is not None and mask.dim() == 3:
|
| 179 |
+
return mask.unsqueeze(1)
|
| 180 |
+
return mask
|
| 181 |
+
|
| 182 |
def _forward_self_attn(self, inputs_norm, dec_mask, layer_cache, step):
|
| 183 |
if isinstance(self.self_attn, MultiHeadedAttention):
|
| 184 |
# OpenNMT-py 3.x: layer_cache and attn_type are instance attributes
|
|
|
|
| 188 |
inputs_norm,
|
| 189 |
inputs_norm,
|
| 190 |
inputs_norm,
|
| 191 |
+
mask=self._expand_mask(dec_mask),
|
| 192 |
)
|
| 193 |
self._from_onmt3_cache(self.self_attn.layer_cache, layer_cache, "self")
|
| 194 |
return result
|
|
|
|
| 316 |
memory_bank,
|
| 317 |
memory_bank,
|
| 318 |
query_norm,
|
| 319 |
+
mask=self._expand_mask(src_pad_mask),
|
| 320 |
)
|
| 321 |
self._from_onmt3_cache(self.context_attn.layer_cache, layer_cache, "context")
|
| 322 |
output = self.feed_forward(self.drop(mid) + query)
|