Maxtimer97 commited on
Commit
6bc83aa
·
1 Parent(s): af54461

Corrected dtype setup

Browse files
Files changed (1) hide show
  1. modeling_hymba.py +2 -2
modeling_hymba.py CHANGED
@@ -1709,9 +1709,9 @@ class HymbaBlock(nn.Module):
1709
  ## Attention Head
1710
  if self.reuse_kv:
1711
  assert kv_last_layer is not None
1712
- attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, kv_last_layer=kv_last_layer, use_swa=use_swa, use_cache=use_cache, past_key_value=cache_params, target_dtype=self.in_proj.weight.type)
1713
  else:
1714
- attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, key_states=key_states, value_states=value_states, use_swa=use_swa, use_cache=use_cache, past_key_value=cache_params, target_dtype=self.in_proj.weight.type)
1715
 
1716
 
1717
  if not self.pure_attn:
 
1709
  ## Attention Head
1710
  if self.reuse_kv:
1711
  assert kv_last_layer is not None
1712
+ attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, kv_last_layer=kv_last_layer, use_swa=use_swa, use_cache=use_cache, past_key_value=cache_params, target_dtype=self.in_proj.weight.dtype)
1713
  else:
1714
+ attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, key_states=key_states, value_states=value_states, use_swa=use_swa, use_cache=use_cache, past_key_value=cache_params, target_dtype=self.in_proj.weight.dtype)
1715
 
1716
 
1717
  if not self.pure_attn: