Refactor Logits Naming
#15
by
codecho
- opened
modeling_moonshot_kimia.py
CHANGED
|
@@ -902,15 +902,15 @@ class MoonshotKimiaForCausalLM(Qwen2PreTrainedModel):
|
|
| 902 |
else:
|
| 903 |
hidden_states, mimo_hidden_states = outputs[0], outputs[1]
|
| 904 |
|
| 905 |
-
|
| 906 |
-
|
| 907 |
|
| 908 |
if not return_dict:
|
| 909 |
-
output = (
|
| 910 |
return output
|
| 911 |
return CausalLMOutputWithPast(
|
| 912 |
loss=None,
|
| 913 |
-
logits=(
|
| 914 |
past_key_values=outputs.past_key_values,
|
| 915 |
hidden_states=outputs.hidden_states,
|
| 916 |
attentions=outputs.attentions,
|
|
|
|
| 902 |
else:
|
| 903 |
hidden_states, mimo_hidden_states = outputs[0], outputs[1]
|
| 904 |
|
| 905 |
+
text_logits = self.lm_head(hidden_states)
|
| 906 |
+
audio_logits = self.mimo_output(mimo_hidden_states)
|
| 907 |
|
| 908 |
if not return_dict:
|
| 909 |
+
output = (audio_logits, text_logits) + outputs[2:]
|
| 910 |
return output
|
| 911 |
return CausalLMOutputWithPast(
|
| 912 |
loss=None,
|
| 913 |
+
logits=(audio_logits, text_logits),
|
| 914 |
past_key_values=outputs.past_key_values,
|
| 915 |
hidden_states=outputs.hidden_states,
|
| 916 |
attentions=outputs.attentions,
|