alexmarques commited on
Commit
aa0effb
·
verified ·
1 Parent(s): 1f6c083

Update convert_voxtral_hf_to_mistral.py

Browse files
Files changed (1) hide show
  1. convert_voxtral_hf_to_mistral.py +2 -2
convert_voxtral_hf_to_mistral.py CHANGED
@@ -101,11 +101,11 @@ def convert_state_dict(hf_state_dict, config):
101
  if "language_model" in hf_key:
102
  if hf_key.endswith("q_proj.weight"):
103
  tensor = permute_for_mistral_rope(tensor, num_attention_heads, query_dim, hidden_size)
104
- elif hf_key.endswith("q_proj.weight_scale") and tensor.size(0) == num_attention_heads:
105
  tensor = permute_for_mistral_rope(tensor, num_attention_heads, query_dim, 1)
106
  elif hf_key.endswith("k_proj.weight"):
107
  tensor = permute_for_mistral_rope(tensor, num_key_value_heads, key_value_dim, hidden_size)
108
- elif hf_key.endswith("k_proj.weight_scale") and tensor.size(0) == num_key_value_heads:
109
  tensor = permute_for_mistral_rope(tensor, num_key_value_heads, key_value_dim, 1)
110
 
111
  mistral_dict[mistral_key] = tensor
 
101
  if "language_model" in hf_key:
102
  if hf_key.endswith("q_proj.weight"):
103
  tensor = permute_for_mistral_rope(tensor, num_attention_heads, query_dim, hidden_size)
104
+ elif hf_key.endswith("q_proj.weight_scale") and tensor.size(0) > 1:
105
  tensor = permute_for_mistral_rope(tensor, num_attention_heads, query_dim, 1)
106
  elif hf_key.endswith("k_proj.weight"):
107
  tensor = permute_for_mistral_rope(tensor, num_key_value_heads, key_value_dim, hidden_size)
108
+ elif hf_key.endswith("k_proj.weight_scale") and tensor.size(0) > 1:
109
  tensor = permute_for_mistral_rope(tensor, num_key_value_heads, key_value_dim, 1)
110
 
111
  mistral_dict[mistral_key] = tensor