Commit
·
6bc83aa
1
Parent(s):
af54461
Corrected dtype setup
Browse files- modeling_hymba.py +2 -2
modeling_hymba.py
CHANGED
|
@@ -1709,9 +1709,9 @@ class HymbaBlock(nn.Module):
|
|
| 1709 |
## Attention Head
|
| 1710 |
if self.reuse_kv:
|
| 1711 |
assert kv_last_layer is not None
|
| 1712 |
-
attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, kv_last_layer=kv_last_layer, use_swa=use_swa, use_cache=use_cache, past_key_value=cache_params, target_dtype=self.in_proj.weight.
|
| 1713 |
else:
|
| 1714 |
-
attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, key_states=key_states, value_states=value_states, use_swa=use_swa, use_cache=use_cache, past_key_value=cache_params, target_dtype=self.in_proj.weight.
|
| 1715 |
|
| 1716 |
|
| 1717 |
if not self.pure_attn:
|
|
|
|
| 1709 |
## Attention Head
|
| 1710 |
if self.reuse_kv:
|
| 1711 |
assert kv_last_layer is not None
|
| 1712 |
+
attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, kv_last_layer=kv_last_layer, use_swa=use_swa, use_cache=use_cache, past_key_value=cache_params, target_dtype=self.in_proj.weight.dtype)
|
| 1713 |
else:
|
| 1714 |
+
attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, key_states=key_states, value_states=value_states, use_swa=use_swa, use_cache=use_cache, past_key_value=cache_params, target_dtype=self.in_proj.weight.dtype)
|
| 1715 |
|
| 1716 |
|
| 1717 |
if not self.pure_attn:
|