Upload model_Custm.py
Browse files- model_Custm.py +2 -1
model_Custm.py
CHANGED
|
@@ -288,6 +288,7 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
|
|
| 288 |
labels=None,
|
| 289 |
src=None,
|
| 290 |
tgt=None,
|
|
|
|
| 291 |
src_key_padding_mask=None,
|
| 292 |
tgt_key_padding_mask=None,
|
| 293 |
memory_key_padding_mask=None,
|
|
@@ -319,7 +320,7 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
|
|
| 319 |
key = query
|
| 320 |
value = query
|
| 321 |
|
| 322 |
-
#
|
| 323 |
if src_mask is None and src is not None:
|
| 324 |
# Create a default mask that allows all tokens to attend to all other tokens
|
| 325 |
src_seq_len = src.size(1)
|
|
|
|
| 288 |
labels=None,
|
| 289 |
src=None,
|
| 290 |
tgt=None,
|
| 291 |
+
src_mask: Optional[torch.Tensor] = None, # added
|
| 292 |
src_key_padding_mask=None,
|
| 293 |
tgt_key_padding_mask=None,
|
| 294 |
memory_key_padding_mask=None,
|
|
|
|
| 320 |
key = query
|
| 321 |
value = query
|
| 322 |
|
| 323 |
+
# CRITICAL: Initialize src_mask if it's None
|
| 324 |
if src_mask is None and src is not None:
|
| 325 |
# Create a default mask that allows all tokens to attend to all other tokens
|
| 326 |
src_seq_len = src.size(1)
|