Update modelling_hier.py
Browse files- modelling_hier.py +2 -2
modelling_hier.py
CHANGED
|
@@ -338,7 +338,7 @@ class HierBert(Module):
|
|
| 338 |
# Shared encoders or Segment-wise encoders
|
| 339 |
# print("SWE")
|
| 340 |
enc_inp, att_w = layer(enc_inp,
|
| 341 |
-
src_key_padding_mask=src_key_padding_mask,
|
| 342 |
src_mask=enc_mask_utt.repeat(self.config.num_attention_heads, 1, 1))
|
| 343 |
else:
|
| 344 |
# Positional Embedding for Context Encoder if few connected CSE use it before
|
|
@@ -346,7 +346,7 @@ class HierBert(Module):
|
|
| 346 |
# Context encoder or Cross-segment encoders
|
| 347 |
# print("CSE")
|
| 348 |
enc_inp, att_w = layer(enc_inp,
|
| 349 |
-
src_key_padding_mask=src_key_padding_mask,
|
| 350 |
src_mask=enc_mask_ct.repeat(self.config.num_attention_heads, 1, 1))
|
| 351 |
if output_attentions:
|
| 352 |
all_self_attentions = all_self_attentions + (att_w,)
|
|
|
|
| 338 |
# Shared encoders or Segment-wise encoders
|
| 339 |
# print("SWE")
|
| 340 |
enc_inp, att_w = layer(enc_inp,
|
| 341 |
+
# src_key_padding_mask=src_key_padding_mask,
|
| 342 |
src_mask=enc_mask_utt.repeat(self.config.num_attention_heads, 1, 1))
|
| 343 |
else:
|
| 344 |
# Positional Embedding for Context Encoder if few connected CSE use it before
|
|
|
|
| 346 |
# Context encoder or Cross-segment encoders
|
| 347 |
# print("CSE")
|
| 348 |
enc_inp, att_w = layer(enc_inp,
|
| 349 |
+
# src_key_padding_mask=src_key_padding_mask,
|
| 350 |
src_mask=enc_mask_ct.repeat(self.config.num_attention_heads, 1, 1))
|
| 351 |
if output_attentions:
|
| 352 |
all_self_attentions = all_self_attentions + (att_w,)
|