Update convert_voxtral_hf_to_mistral.py
Browse files
convert_voxtral_hf_to_mistral.py
CHANGED
|
@@ -101,11 +101,11 @@ def convert_state_dict(hf_state_dict, config):
|
|
| 101 |
if "language_model" in hf_key:
|
| 102 |
if hf_key.endswith("q_proj.weight"):
|
| 103 |
tensor = permute_for_mistral_rope(tensor, num_attention_heads, query_dim, hidden_size)
|
| 104 |
-
elif hf_key.endswith("q_proj.weight_scale") and tensor.size(0)
|
| 105 |
tensor = permute_for_mistral_rope(tensor, num_attention_heads, query_dim, 1)
|
| 106 |
elif hf_key.endswith("k_proj.weight"):
|
| 107 |
tensor = permute_for_mistral_rope(tensor, num_key_value_heads, key_value_dim, hidden_size)
|
| 108 |
-
elif hf_key.endswith("k_proj.weight_scale") and tensor.size(0)
|
| 109 |
tensor = permute_for_mistral_rope(tensor, num_key_value_heads, key_value_dim, 1)
|
| 110 |
|
| 111 |
mistral_dict[mistral_key] = tensor
|
|
|
|
| 101 |
if "language_model" in hf_key:
|
| 102 |
if hf_key.endswith("q_proj.weight"):
|
| 103 |
tensor = permute_for_mistral_rope(tensor, num_attention_heads, query_dim, hidden_size)
|
| 104 |
+
elif hf_key.endswith("q_proj.weight_scale") and tensor.size(0) > 1:
|
| 105 |
tensor = permute_for_mistral_rope(tensor, num_attention_heads, query_dim, 1)
|
| 106 |
elif hf_key.endswith("k_proj.weight"):
|
| 107 |
tensor = permute_for_mistral_rope(tensor, num_key_value_heads, key_value_dim, hidden_size)
|
| 108 |
+
elif hf_key.endswith("k_proj.weight_scale") and tensor.size(0) > 1:
|
| 109 |
tensor = permute_for_mistral_rope(tensor, num_key_value_heads, key_value_dim, 1)
|
| 110 |
|
| 111 |
mistral_dict[mistral_key] = tensor
|