igorktech commited on
Commit
5677281
·
1 Parent(s): 44d00ce

Update modelling_hier.py

Browse files
Files changed (1) hide show
  1. modelling_hier.py +2 -2
modelling_hier.py CHANGED
@@ -338,7 +338,7 @@ class HierBert(Module):
338
  # Shared encoders or Segment-wise encoders
339
  # print("SWE")
340
  enc_inp, att_w = layer(enc_inp,
341
- src_key_padding_mask=src_key_padding_mask,
342
  src_mask=enc_mask_utt.repeat(self.config.num_attention_heads, 1, 1))
343
  else:
344
  # Positional Embedding for Context Encoder if few connected CSE use it before
@@ -346,7 +346,7 @@ class HierBert(Module):
346
  # Context encoder or Cross-segment encoders
347
  # print("CSE")
348
  enc_inp, att_w = layer(enc_inp,
349
- src_key_padding_mask=src_key_padding_mask,
350
  src_mask=enc_mask_ct.repeat(self.config.num_attention_heads, 1, 1))
351
  if output_attentions:
352
  all_self_attentions = all_self_attentions + (att_w,)
 
338
  # Shared encoders or Segment-wise encoders
339
  # print("SWE")
340
  enc_inp, att_w = layer(enc_inp,
341
+ # src_key_padding_mask=src_key_padding_mask,
342
  src_mask=enc_mask_utt.repeat(self.config.num_attention_heads, 1, 1))
343
  else:
344
  # Positional Embedding for Context Encoder if few connected CSE use it before
 
346
  # Context encoder or Cross-segment encoders
347
  # print("CSE")
348
  enc_inp, att_w = layer(enc_inp,
349
+ # src_key_padding_mask=src_key_padding_mask,
350
  src_mask=enc_mask_ct.repeat(self.config.num_attention_heads, 1, 1))
351
  if output_attentions:
352
  all_self_attentions = all_self_attentions + (att_w,)