L3ul commited on
Commit
18b6dbd
·
verified ·
1 Parent(s): 1d66007

Fix mask dimensions: unsqueeze 3D->4D for OpenNMT-py 3.x attention

Browse files
Files changed (1) hide show
  1. molscribe/transformer/decoder.py +9 -2
molscribe/transformer/decoder.py CHANGED
@@ -172,6 +172,13 @@ class TransformerDecoderLayerBase(nn.Module):
172
  layer_cache["memory_keys"] = cache_dict.get("keys")
173
  layer_cache["memory_values"] = cache_dict.get("values")
174
 
 
 
 
 
 
 
 
175
  def _forward_self_attn(self, inputs_norm, dec_mask, layer_cache, step):
176
  if isinstance(self.self_attn, MultiHeadedAttention):
177
  # OpenNMT-py 3.x: layer_cache and attn_type are instance attributes
@@ -181,7 +188,7 @@ class TransformerDecoderLayerBase(nn.Module):
181
  inputs_norm,
182
  inputs_norm,
183
  inputs_norm,
184
- mask=dec_mask,
185
  )
186
  self._from_onmt3_cache(self.self_attn.layer_cache, layer_cache, "self")
187
  return result
@@ -309,7 +316,7 @@ class TransformerDecoderLayer(TransformerDecoderLayerBase):
309
  memory_bank,
310
  memory_bank,
311
  query_norm,
312
- mask=src_pad_mask,
313
  )
314
  self._from_onmt3_cache(self.context_attn.layer_cache, layer_cache, "context")
315
  output = self.feed_forward(self.drop(mid) + query)
 
172
  layer_cache["memory_keys"] = cache_dict.get("keys")
173
  layer_cache["memory_values"] = cache_dict.get("values")
174
 
175
+ @staticmethod
176
+ def _expand_mask(mask):
177
+ """Expand 3D mask (B, 1, L) to 4D (B, 1, 1, L) for OpenNMT-py 3.x."""
178
+ if mask is not None and mask.dim() == 3:
179
+ return mask.unsqueeze(1)
180
+ return mask
181
+
182
  def _forward_self_attn(self, inputs_norm, dec_mask, layer_cache, step):
183
  if isinstance(self.self_attn, MultiHeadedAttention):
184
  # OpenNMT-py 3.x: layer_cache and attn_type are instance attributes
 
188
  inputs_norm,
189
  inputs_norm,
190
  inputs_norm,
191
+ mask=self._expand_mask(dec_mask),
192
  )
193
  self._from_onmt3_cache(self.self_attn.layer_cache, layer_cache, "self")
194
  return result
 
316
  memory_bank,
317
  memory_bank,
318
  query_norm,
319
+ mask=self._expand_mask(src_pad_mask),
320
  )
321
  self._from_onmt3_cache(self.context_attn.layer_cache, layer_cache, "context")
322
  output = self.feed_forward(self.drop(mid) + query)