kokolamba commited on
Commit
d527508
·
1 Parent(s): cf51f7c

Add patch to fix when attention mask is None in task_heads.py

Browse files
Files changed (1) hide show
  1. task_heads.py +14 -0
task_heads.py CHANGED
@@ -154,6 +154,20 @@ class SharedSpaceDecoderForCausalLM(SharedSpaceDecoderPreTrainedModel):
154
  - loss: Cross-entropy loss if labels provided, else None
155
  - hidden_states: Final layer hidden states [batch_size, seq_len, hidden_size]
156
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  # Run the base decoder model
159
  # This applies all the transformer layers with causal attention
 
154
  - loss: Cross-entropy loss if labels provided, else None
155
  - hidden_states: Final layer hidden states [batch_size, seq_len, hidden_size]
156
  """
157
+
158
+ # Adding a patch for when attention_mask is None
159
+ # ---------------------------
160
+ # >>> PATCH: ensure a mask if none is provided
161
+ # ---------------------------
162
+ if attention_mask is None and input_ids is not None:
163
+ # Create an all-ones mask (no padding) so SDPA mask prep won’t crash
164
+ # dtype long/bool are both accepted by HF mask utils; long is common.
165
+ attention_mask = torch.ones(
166
+ (input_ids.size(0), input_ids.size(1)),
167
+ dtype=torch.long,
168
+ device=input_ids.device,
169
+ )
170
+ # ---------------------------
171
 
172
  # Run the base decoder model
173
  # This applies all the transformer layers with causal attention