from typing import Optional, List, Union, Tuple import torch import torch.nn as nn from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast from transformers.utils import logging from transformers.models.qwen3.modeling_qwen3 import Qwen3Model, Qwen3PreTrainedModel, Qwen3DecoderLayer from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeAudioEncoder from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import Qwen3OmniMoeAudioEncoderConfig from transformers.models.qwen3.configuration_qwen3 import Qwen3Config from transformers.utils.auto_docstring import auto_docstring from transformers.modeling_utils import PreTrainedModel from transformers.generation.utils import GenerationMixin class MusciConfig(Qwen3Config): model_type = "musci" is_composition = True # Make the architecture discoverable by Megatron-Bridge's AutoBridge # when loading configs from disk. architectures = ["MusciForCausalLM"] def __init__( self, audio_config=None, language_config=None, adapter_hidden_size=8192, ignore_index=-100, **kwargs ): num_hidden_layers = None if language_config is not None: if isinstance(language_config, dict): num_hidden_layers = language_config.get("num_hidden_layers", None) elif isinstance(language_config, Qwen3Config): num_hidden_layers = language_config.num_hidden_layers if num_hidden_layers is not None: kwargs.update({"num_hidden_layers": num_hidden_layers}) # Initialize parent Qwen3Config with kwargs to handle standard config params super().__init__(**kwargs) if isinstance(audio_config, dict): audio_config = Qwen3OmniMoeAudioEncoderConfig(**audio_config) if isinstance(audio_config, Qwen3OmniMoeAudioEncoderConfig): audio_config = audio_config elif audio_config is None: audio_config = Qwen3OmniMoeAudioEncoderConfig() if isinstance(language_config, dict): language_config = Qwen3Config(**language_config) elif isinstance(language_config, Qwen3Config): language_config = language_config elif language_config is None: language_config = Qwen3Config() self.audio_config = audio_config self.language_config = language_config self.adapter_hidden_size = adapter_hidden_size self.ignore_index = ignore_index self.dtype = language_config.dtype def to_dict(self): output = super().to_dict() if self.audio_config is not None: if hasattr(self.audio_config, "to_dict"): output["audio_config"] = self.audio_config.to_dict() else: output["audio_config"] = self.audio_config if self.language_config is not None: if hasattr(self.language_config, "to_dict"): output["language_config"] = self.language_config.to_dict() else: output["language_config"] = self.language_config return output class MusciGatedMLP(nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.gate_proj = nn.Linear(input_size, hidden_size, bias=False) self.up_proj = nn.Linear(input_size, hidden_size, bias=False) self.down_proj = nn.Linear(hidden_size, output_size, bias=False) self.act_fn = nn.SiLU() def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) @auto_docstring class MusciPreTrainedModel(PreTrainedModel): config: MusciConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen3DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _can_compile_fullgraph = False _supports_attention_backend = True _can_record_outputs = { "hidden_states": Qwen3DecoderLayer, } class MusciModel(MusciPreTrainedModel): config_class = MusciConfig def __init__(self, config: MusciConfig): super().__init__(config) self.audio_model = Qwen3OmniMoeAudioEncoder(config.audio_config) self.language_model = Qwen3Model(config.language_config) self.audio_adapter = MusciGatedMLP( input_size=config.audio_config.output_dim, hidden_size=config.adapter_hidden_size, output_size=config.language_config.hidden_size ) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def get_audio_features(self, input_features, feature_lens): audio_outputs = self.audio_model( input_features=input_features, feature_lens=feature_lens, ) return audio_outputs.last_hidden_state def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, audio_data: Optional[torch.FloatTensor] = None, audio_data_seqlens: Optional[torch.Tensor] = None, audio_input_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # 1. Get text embeddings if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) # 2. Process audio and merge embeddings if audio is present if audio_data is not None: # [B, Audio_Len, D] audio_embeds = self.get_audio_features(audio_data, audio_data_seqlens) audio_embeds = self.audio_adapter(audio_embeds) # audio_input_mask: [B, L] -> [B, L, 1] -> [B, L, D] # D elements will be replaced by audio embeddings mask_expanded = audio_input_mask.unsqueeze(-1).expand_as(inputs_embeds) inputs_embeds.masked_scatter_(mask_expanded, audio_embeds) # 3. Forward pass through language model return self.language_model( input_ids=None, # We pass inputs_embeds attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) class MusciForCausalLM(MusciPreTrainedModel, GenerationMixin): config_class = MusciConfig _tied_weights_keys = ["lm_head.weight"] _keys_to_ignore_on_save = ["lm_head.weight"] def __init__(self, config: MusciConfig): super().__init__(config) self.model = MusciModel(config) self.vocab_size = config.language_config.vocab_size self.lm_head = nn.Linear(config.language_config.hidden_size, self.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def tie_weights(self): super().tie_weights() # tie lm_head to input embeddings self.lm_head.weight = self.model.language_model.embed_tokens.weight def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, audio_data: Optional[torch.FloatTensor] = None, audio_data_seqlens: Optional[torch.Tensor] = None, audio_input_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, audio_data=audio_data, audio_data_seqlens=audio_data_seqlens, audio_input_mask=audio_input_mask, cache_position=cache_position, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.ignore_index) shift_logits = shift_logits.view(-1, self.config.language_config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs ): # decoding step (KV cache present) keeps only the last token and drops audio inputs; # prefill step pulls audio inputs from kwargs. position_ids = kwargs.get("position_ids", None) if cache_position is not None and cache_position[0] > 0: input_ids = input_ids[:, -1:] if position_ids is not None: position_ids = position_ids[:, -1:] audio_data = None audio_input_mask = None audio_data_seqlens = None else: audio_data = kwargs.get("audio_data", None) audio_input_mask = kwargs.get("audio_input_mask", None) audio_data_seqlens = kwargs.get("audio_data_seqlens", None) # prefer inputs_embeds at the first step when present if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update({ "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "position_ids": position_ids, "audio_data": audio_data, "audio_input_mask": audio_input_mask, "audio_data_seqlens": audio_data_seqlens, }) return model_inputs __all__ = [ "MusciConfig", "MusciModel", "MusciForCausalLM", ]