from __future__ import annotations from typing import Optional, List, Tuple, Union, Dict import torch from torch import Tensor, nn from transformers import Qwen2ForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_arkasr import ArkasrConfig from .modeling_audio import WhisperSpecialEncoder class AudioMLPAdapter(nn.Module): def __init__(self, config: ArkasrConfig): super().__init__() whisper_config = config.whisper_config self.merge_factor = int(config.merge_factor) # 音频编码器 self.whisper = WhisperSpecialEncoder( whisper_config, use_rope=getattr(config, "use_rope", False), ) # 禁用 Whisper 自带 LayerNorm self.whisper.layer_norm = nn.Identity() self.layer_norm = nn.LayerNorm(whisper_config.hidden_size) act_fn_map = { "gelu": nn.GELU(), "relu": nn.ReLU(), "selu": nn.SELU(), } act = act_fn_map.get(getattr(config, "mlp_adapter_act", "gelu"), nn.GELU()) input_dim = whisper_config.hidden_size * self.merge_factor output_dim = config.hidden_size self.adapting = nn.Sequential( nn.Linear(input_dim, output_dim * 2), act, nn.Linear(output_dim * 2, output_dim), ) def forward(self, audios: Tensor) -> Tensor: """ Args: audios: (B, mel, T) 或 (B, raw_len) —— 由 WhisperSpecialEncoder 决定 Returns: adapted_features: (B, Seq_Audio, LLM_Hidden_Dim) """ bsz = audios.size(0) encoded = self.whisper(audios)[0] # (B, T, D) encoded = self.layer_norm(encoded) seq_len = encoded.size(1) if seq_len % self.merge_factor != 0: target_len = (seq_len // self.merge_factor) * self.merge_factor if target_len <= 0: # 极短音频兜底:pad 到 merge_factor target_len = self.merge_factor if seq_len < target_len: pad_len = target_len - seq_len pad = encoded.new_zeros((bsz, pad_len, encoded.size(-1))) encoded = torch.cat([encoded, pad], dim=1) else: encoded = encoded[:, :target_len, :] encoded = encoded.reshape(bsz, -1, encoded.size(-1) * self.merge_factor) adapted = self.adapting(encoded) # (B, T/k, hidden) return adapted class ArkasrForConditionalGeneration(Qwen2ForCausalLM): config_class = ArkasrConfig _no_split_modules = ["WhisperSpecialEncoder"] def __init__(self, config: ArkasrConfig): super().__init__(config) self.audio_encoder = AudioMLPAdapter(config) self.audio_token_id = getattr(config, "audio_token_id", None) if self.audio_token_id is None: raise ValueError("`audio_token_id` must be defined in config.") @staticmethod def _cache_seq_len(past_key_values) -> int: if past_key_values is None: return 0 if hasattr(past_key_values, "get_seq_length"): try: return int(past_key_values.get_seq_length()) except Exception: return 0 try: return int(past_key_values[0][0].shape[-2]) except Exception: return 0 def _inject_audio_embeddings_batch_encode_then_loop_scatter( self, input_ids: torch.LongTensor, # (B, S) inputs_embeds: torch.FloatTensor, # (B, S, H) audios: Tensor, # (B, ...) ) -> torch.FloatTensor: """ 先对「有 audio token 的样本」做一次 batch 音频编码, 然后 for-loop 把每个样本的 audio features 按 audio_token 位置写回 inputs_embeds。 好处: - encoder 只跑一次(快) - 写回按样本做,不会跨样本错位(稳) - 碰到某行没有 audio_token:直接跳过(TTS 行无影响) 约束: - 每条样本的 audio_token 数量 n_i 需要和 audio_encoder 输出的 Sa 对齐。 如果不对齐:这里采用截断/补零对齐到 n_i(不报错)。 """ B, S = input_ids.shape H = inputs_embeds.size(-1) device = inputs_embeds.device dtype = inputs_embeds.dtype # 找到哪些样本需要注入 mask = (input_ids == self.audio_token_id) # (B, S) per_counts = mask.sum(dim=1) # (B,) need_idx = (per_counts > 0).nonzero(as_tuple=False).squeeze(1) # (K,) if need_idx.numel() == 0: return inputs_embeds # 只编码需要注入的那部分音频(K, ...) audios_sub = audios.index_select(0, need_idx) feats_sub = self.audio_encoder(audios_sub) # (K, Sa, H) # 写回:逐样本替换(写回操作本身几乎不耗时) feats_sub = feats_sub.to(device=device, dtype=dtype) Sa = feats_sub.size(1) # 逐个样本注入 for k in range(need_idx.numel()): i = int(need_idx[k].item()) n_i = int(per_counts[i].item()) if n_i <= 0: continue feat_i = feats_sub[k] # (Sa, H) # 对齐到该样本的 audio token 数 n_i if Sa < n_i: pad = feat_i.new_zeros((n_i - Sa, H)) feat_i_use = torch.cat([feat_i, pad], dim=0) elif Sa > n_i: feat_i_use = feat_i[:n_i] else: feat_i_use = feat_i pos_i = mask[i].nonzero(as_tuple=False).squeeze(1) # (n_i,) # 写回 embeddings inputs_embeds[i, pos_i, :] = feat_i_use return inputs_embeds def forward( self, input_ids: Optional[torch.LongTensor] = None, audios: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, logits_to_keep: int | torch.Tensor = 0, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: if inputs_embeds is None: if input_ids is None: raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.") inputs_embeds = self.model.embed_tokens(input_ids) # 只在首步(past_len==0)注入,避免 generation 后续重复 encode past_len = self._cache_seq_len(past_key_values) if audios is not None and input_ids is not None and past_len == 0: inputs_embeds = self._inject_audio_embeddings_batch_encode_then_loop_scatter( input_ids=input_ids, inputs_embeds=inputs_embeds, audios=audios, ) outputs = self.model( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) hidden_states = outputs[0] # logits(避免重复算 lm_head) if isinstance(logits_to_keep, int) and logits_to_keep > 0: hidden_for_logits = hidden_states[:, -logits_to_keep:, :] elif isinstance(logits_to_keep, torch.Tensor): hidden_for_logits = hidden_states[:, logits_to_keep, :] else: hidden_for_logits = hidden_states logits = self.lm_head(hidden_for_logits) loss = None if labels is not None: loss = self.loss_function( logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs, ) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs, ): past_len = self._cache_seq_len(past_key_values) if past_len > 0: input_ids = input_ids[:, -1:] model_inputs = { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, # audios 透传;forward 内 past_len==0 才注入,所以后续 step 不会重复编码 "audios": kwargs.get("audios", None), } if inputs_embeds is not None and past_key_values is None: model_inputs["inputs_embeds"] = inputs_embeds del model_inputs["input_ids"] return model_inputs __all__ = ["ArkasrForConditionalGeneration", "AudioMLPAdapter"]