Add patch to fix when attention mask is None in task_heads.py
Browse files- task_heads.py +14 -0
task_heads.py
CHANGED
|
@@ -154,6 +154,20 @@ class SharedSpaceDecoderForCausalLM(SharedSpaceDecoderPreTrainedModel):
|
|
| 154 |
- loss: Cross-entropy loss if labels provided, else None
|
| 155 |
- hidden_states: Final layer hidden states [batch_size, seq_len, hidden_size]
|
| 156 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
# Run the base decoder model
|
| 159 |
# This applies all the transformer layers with causal attention
|
|
|
|
| 154 |
- loss: Cross-entropy loss if labels provided, else None
|
| 155 |
- hidden_states: Final layer hidden states [batch_size, seq_len, hidden_size]
|
| 156 |
"""
|
| 157 |
+
|
| 158 |
+
# Adding a patch for when attention_mask is None
|
| 159 |
+
# ---------------------------
|
| 160 |
+
# >>> PATCH: ensure a mask if none is provided
|
| 161 |
+
# ---------------------------
|
| 162 |
+
if attention_mask is None and input_ids is not None:
|
| 163 |
+
# Create an all-ones mask (no padding) so SDPA mask prep won’t crash
|
| 164 |
+
# dtype long/bool are both accepted by HF mask utils; long is common.
|
| 165 |
+
attention_mask = torch.ones(
|
| 166 |
+
(input_ids.size(0), input_ids.size(1)),
|
| 167 |
+
dtype=torch.long,
|
| 168 |
+
device=input_ids.device,
|
| 169 |
+
)
|
| 170 |
+
# ---------------------------
|
| 171 |
|
| 172 |
# Run the base decoder model
|
| 173 |
# This applies all the transformer layers with causal attention
|