from typing import Optional import torch from torch import Tensor, nn from transformers import LlamaForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_glmasr import GlmasrConfig from .modeling_audio import WhisperSpecialEncoder class AudioMLPAdapter(nn.Module): def __init__(self, config: GlmasrConfig): super().__init__() whisper_config = config.whisper_config self.merge_factor = config.merge_factor self.whisper = WhisperSpecialEncoder( whisper_config, use_rope=config.use_rope, ) self.whisper.layer_norm = nn.Identity() self.layer_norm = nn.LayerNorm(whisper_config.hidden_size) act = { "gelu": nn.GELU(), "relu": nn.ReLU(), "selu": nn.SELU(), }[config.mlp_adapter_act] hidden = whisper_config.hidden_size * self.merge_factor output_dim = config.lm_config.hidden_size self.adapting = nn.Sequential( nn.Linear(hidden, output_dim * 2), act, nn.Linear(output_dim * 2, output_dim), ) self.audio_bos_eos_token = nn.Embedding(2, output_dim) def forward(self, audios: Tensor) -> tuple[Tensor, Tensor, Tensor]: bsz = audios.size(0) encoded = self.whisper(audios)[0] encoded = self.layer_norm(encoded) encoded = encoded.reshape(bsz, -1, encoded.size(-1) * self.merge_factor) adapted = self.adapting(encoded) boa = self.audio_bos_eos_token.weight[0][None, :] eoa = self.audio_bos_eos_token.weight[1][None, :] return adapted, boa, eoa class GlmasrModel(LlamaForCausalLM): config_class = GlmasrConfig def __init__(self, config: GlmasrConfig): super().__init__(config.lm_config) self.audio_encoder = AudioMLPAdapter(config) self.all_config = config def forward( self, input_ids: Optional[torch.LongTensor] = None, audios: Optional[Tensor] = None, audio_offsets: Optional[list[list[int]]] = None, audio_length: Optional[list[list[int]]] = None, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, past_key_values: Optional[tuple] = None, use_cache: Optional[bool] = None, **kwargs, ) -> CausalLMOutputWithPast: tokens = input_ids vocab_size = self.config.vocab_size tokens = torch.clamp(tokens, 0, vocab_size - 1) language_embs = self.model.embed_tokens(tokens) have_audio = audios is not None and ( kwargs.get("past_key_values") is None or len(kwargs["past_key_values"]) == 0 ) if have_audio: if audio_length is None: raise ValueError("audio_length is required when audio_offsets are provided") audio_embs, boa, eoa = self.audio_encoder(audios) index = 0 for batch, (offsets, lengths) in enumerate(zip(audio_offsets, audio_length)): for offset, length in zip(offsets, lengths): language_embs[batch, offset : offset + length] = audio_embs[index, :length] language_embs[batch, offset - 1] = boa language_embs[batch, offset + length] = eoa index += 1 kwargs.pop("inputs_embeds", None) kwargs.pop("is_first_forward", None) outputs = self.model( inputs_embeds=language_embs, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, **kwargs, ) logits = self.lm_head(outputs[0]) return CausalLMOutputWithPast( loss=None, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def _update_model_kwargs_for_generation(self, *args, **kwargs): model_kwargs = super()._update_model_kwargs_for_generation(*args, **kwargs) model_kwargs["is_first_forward"] = False position_ids = model_kwargs.get("position_ids") if position_ids is not None: next_pos = position_ids[..., -1:].clone() + 1 model_kwargs["position_ids"] = torch.cat([position_ids, next_pos], dim=-1) return model_kwargs def prepare_inputs_for_generation( self, *args, past_key_values: Optional[tuple] = None, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, use_cache: Optional[bool] = None, is_first_forward: bool = True, **kwargs, ): prepared = super().prepare_inputs_for_generation( *args, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, use_cache=use_cache, is_first_forward=is_first_forward, **kwargs, ) for key, value in kwargs.items(): if key not in prepared and key.startswith("audio"): prepared[key] = value if is_first_forward and past_key_values is not None and len(past_key_values) > 0: cached_len = past_key_values[0][0].shape[2] prepared["input_ids"] = prepared["input_ids"][:, cached_len:] if "position_ids" in prepared: prepared["position_ids"] = prepared["position_ids"][:, cached_len:] if not is_first_forward: prepared["audios"] = None return prepared __all__ = ["GlmasrModel"]