Refactor Logits Naming

#15
by codecho - opened
Files changed (1) hide show
  1. modeling_moonshot_kimia.py +4 -4
modeling_moonshot_kimia.py CHANGED
@@ -902,15 +902,15 @@ class MoonshotKimiaForCausalLM(Qwen2PreTrainedModel):
902
  else:
903
  hidden_states, mimo_hidden_states = outputs[0], outputs[1]
904
 
905
- audio_logits = self.lm_head(hidden_states)
906
- text_logits = self.mimo_output(mimo_hidden_states)
907
 
908
  if not return_dict:
909
- output = (text_logits, audio_logits) + outputs[2:]
910
  return output
911
  return CausalLMOutputWithPast(
912
  loss=None,
913
- logits=(text_logits, audio_logits),
914
  past_key_values=outputs.past_key_values,
915
  hidden_states=outputs.hidden_states,
916
  attentions=outputs.attentions,
 
902
  else:
903
  hidden_states, mimo_hidden_states = outputs[0], outputs[1]
904
 
905
+ text_logits = self.lm_head(hidden_states)
906
+ audio_logits = self.mimo_output(mimo_hidden_states)
907
 
908
  if not return_dict:
909
+ output = (audio_logits, text_logits) + outputs[2:]
910
  return output
911
  return CausalLMOutputWithPast(
912
  loss=None,
913
+ logits=(audio_logits, text_logits),
914
  past_key_values=outputs.past_key_values,
915
  hidden_states=outputs.hidden_states,
916
  attentions=outputs.attentions,