WildnerveAI commited on
Commit
6e0ad94
·
verified ·
1 Parent(s): 7b8ab89

Upload model_Custm.py

Browse files
Files changed (1) hide show
  1. model_Custm.py +2 -1
model_Custm.py CHANGED
@@ -288,6 +288,7 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
288
  labels=None,
289
  src=None,
290
  tgt=None,
 
291
  src_key_padding_mask=None,
292
  tgt_key_padding_mask=None,
293
  memory_key_padding_mask=None,
@@ -319,7 +320,7 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
319
  key = query
320
  value = query
321
 
322
- # IMPORTANT: Initialize src_mask if it's None
323
  if src_mask is None and src is not None:
324
  # Create a default mask that allows all tokens to attend to all other tokens
325
  src_seq_len = src.size(1)
 
288
  labels=None,
289
  src=None,
290
  tgt=None,
291
+ src_mask: Optional[torch.Tensor] = None, # added
292
  src_key_padding_mask=None,
293
  tgt_key_padding_mask=None,
294
  memory_key_padding_mask=None,
 
320
  key = query
321
  value = query
322
 
323
+ # CRITICAL: Initialize src_mask if it's None
324
  if src_mask is None and src is not None:
325
  # Create a default mask that allows all tokens to attend to all other tokens
326
  src_seq_len = src.size(1)