Update AsteriskForCausalLM.py
Browse files- AsteriskForCausalLM.py +4 -2
AsteriskForCausalLM.py
CHANGED
|
@@ -188,12 +188,14 @@ class HybridASPPAttentionLayer(LlamaDecoderLayer):
|
|
| 188 |
|
| 189 |
# Attention branch - use parent's self_attn
|
| 190 |
attn_outputs = self.self_attn(
|
| 191 |
-
hidden_states,
|
| 192 |
-
position_embeddings,
|
| 193 |
attention_mask=attention_mask,
|
|
|
|
| 194 |
past_key_values=past_key_values,
|
| 195 |
cache_position=cache_position,
|
|
|
|
| 196 |
)
|
|
|
|
| 197 |
attn_output = attn_outputs[0]
|
| 198 |
|
| 199 |
# Gated fusion
|
|
|
|
| 188 |
|
| 189 |
# Attention branch - use parent's self_attn
|
| 190 |
attn_outputs = self.self_attn(
|
| 191 |
+
hidden_states=hidden_states,
|
|
|
|
| 192 |
attention_mask=attention_mask,
|
| 193 |
+
position_ids=position_ids,
|
| 194 |
past_key_values=past_key_values,
|
| 195 |
cache_position=cache_position,
|
| 196 |
+
position_embeddings=position_embeddings,
|
| 197 |
)
|
| 198 |
+
|
| 199 |
attn_output = attn_outputs[0]
|
| 200 |
|
| 201 |
# Gated fusion
|