ARK-ASR-0.6B / modeling_arkasr.py
bupalinyu's picture
Upload ARK-ASR-0.6B model card and support files
05f7466 verified
from __future__ import annotations
from typing import Optional, List, Tuple, Union, Dict
import torch
from torch import Tensor, nn
from transformers import Qwen2ForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_arkasr import ArkasrConfig
from .modeling_audio import WhisperSpecialEncoder
class AudioMLPAdapter(nn.Module):
def __init__(self, config: ArkasrConfig):
super().__init__()
whisper_config = config.whisper_config
self.merge_factor = int(config.merge_factor)
# 音频编码器
self.whisper = WhisperSpecialEncoder(
whisper_config,
use_rope=getattr(config, "use_rope", False),
)
# 禁用 Whisper 自带 LayerNorm
self.whisper.layer_norm = nn.Identity()
self.layer_norm = nn.LayerNorm(whisper_config.hidden_size)
act_fn_map = {
"gelu": nn.GELU(),
"relu": nn.ReLU(),
"selu": nn.SELU(),
}
act = act_fn_map.get(getattr(config, "mlp_adapter_act", "gelu"), nn.GELU())
input_dim = whisper_config.hidden_size * self.merge_factor
output_dim = config.hidden_size
self.adapting = nn.Sequential(
nn.Linear(input_dim, output_dim * 2),
act,
nn.Linear(output_dim * 2, output_dim),
)
def forward(self, audios: Tensor) -> Tensor:
"""
Args:
audios: (B, mel, T) 或 (B, raw_len) —— 由 WhisperSpecialEncoder 决定
Returns:
adapted_features: (B, Seq_Audio, LLM_Hidden_Dim)
"""
bsz = audios.size(0)
encoded = self.whisper(audios)[0] # (B, T, D)
encoded = self.layer_norm(encoded)
seq_len = encoded.size(1)
if seq_len % self.merge_factor != 0:
target_len = (seq_len // self.merge_factor) * self.merge_factor
if target_len <= 0:
# 极短音频兜底:pad 到 merge_factor
target_len = self.merge_factor
if seq_len < target_len:
pad_len = target_len - seq_len
pad = encoded.new_zeros((bsz, pad_len, encoded.size(-1)))
encoded = torch.cat([encoded, pad], dim=1)
else:
encoded = encoded[:, :target_len, :]
encoded = encoded.reshape(bsz, -1, encoded.size(-1) * self.merge_factor)
adapted = self.adapting(encoded) # (B, T/k, hidden)
return adapted
class ArkasrForConditionalGeneration(Qwen2ForCausalLM):
config_class = ArkasrConfig
_no_split_modules = ["WhisperSpecialEncoder"]
def __init__(self, config: ArkasrConfig):
super().__init__(config)
self.audio_encoder = AudioMLPAdapter(config)
self.audio_token_id = getattr(config, "audio_token_id", None)
if self.audio_token_id is None:
raise ValueError("`audio_token_id` must be defined in config.")
@staticmethod
def _cache_seq_len(past_key_values) -> int:
if past_key_values is None:
return 0
if hasattr(past_key_values, "get_seq_length"):
try:
return int(past_key_values.get_seq_length())
except Exception:
return 0
try:
return int(past_key_values[0][0].shape[-2])
except Exception:
return 0
def _inject_audio_embeddings_batch_encode_then_loop_scatter(
self,
input_ids: torch.LongTensor, # (B, S)
inputs_embeds: torch.FloatTensor, # (B, S, H)
audios: Tensor, # (B, ...)
) -> torch.FloatTensor:
"""
先对「有 audio token 的样本」做一次 batch 音频编码,
然后 for-loop 把每个样本的 audio features 按 audio_token 位置写回 inputs_embeds。
好处:
- encoder 只跑一次(快)
- 写回按样本做,不会跨样本错位(稳)
- 碰到某行没有 audio_token:直接跳过(TTS 行无影响)
约束:
- 每条样本的 audio_token 数量 n_i 需要和 audio_encoder 输出的 Sa 对齐。
如果不对齐:这里采用截断/补零对齐到 n_i(不报错)。
"""
B, S = input_ids.shape
H = inputs_embeds.size(-1)
device = inputs_embeds.device
dtype = inputs_embeds.dtype
# 找到哪些样本需要注入
mask = (input_ids == self.audio_token_id) # (B, S)
per_counts = mask.sum(dim=1) # (B,)
need_idx = (per_counts > 0).nonzero(as_tuple=False).squeeze(1) # (K,)
if need_idx.numel() == 0:
return inputs_embeds
# 只编码需要注入的那部分音频(K, ...)
audios_sub = audios.index_select(0, need_idx)
feats_sub = self.audio_encoder(audios_sub) # (K, Sa, H)
# 写回:逐样本替换(写回操作本身几乎不耗时)
feats_sub = feats_sub.to(device=device, dtype=dtype)
Sa = feats_sub.size(1)
# 逐个样本注入
for k in range(need_idx.numel()):
i = int(need_idx[k].item())
n_i = int(per_counts[i].item())
if n_i <= 0:
continue
feat_i = feats_sub[k] # (Sa, H)
# 对齐到该样本的 audio token 数 n_i
if Sa < n_i:
pad = feat_i.new_zeros((n_i - Sa, H))
feat_i_use = torch.cat([feat_i, pad], dim=0)
elif Sa > n_i:
feat_i_use = feat_i[:n_i]
else:
feat_i_use = feat_i
pos_i = mask[i].nonzero(as_tuple=False).squeeze(1) # (n_i,)
# 写回 embeddings
inputs_embeds[i, pos_i, :] = feat_i_use
return inputs_embeds
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
audios: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
if input_ids is None:
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
inputs_embeds = self.model.embed_tokens(input_ids)
# 只在首步(past_len==0)注入,避免 generation 后续重复 encode
past_len = self._cache_seq_len(past_key_values)
if audios is not None and input_ids is not None and past_len == 0:
inputs_embeds = self._inject_audio_embeddings_batch_encode_then_loop_scatter(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
audios=audios,
)
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = outputs[0]
# logits(避免重复算 lm_head)
if isinstance(logits_to_keep, int) and logits_to_keep > 0:
hidden_for_logits = hidden_states[:, -logits_to_keep:, :]
elif isinstance(logits_to_keep, torch.Tensor):
hidden_for_logits = hidden_states[:, logits_to_keep, :]
else:
hidden_for_logits = hidden_states
logits = self.lm_head(hidden_for_logits)
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
past_len = self._cache_seq_len(past_key_values)
if past_len > 0:
input_ids = input_ids[:, -1:]
model_inputs = {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
# audios 透传;forward 内 past_len==0 才注入,所以后续 step 不会重复编码
"audios": kwargs.get("audios", None),
}
if inputs_embeds is not None and past_key_values is None:
model_inputs["inputs_embeds"] = inputs_embeds
del model_inputs["input_ids"]
return model_inputs
__all__ = ["ArkasrForConditionalGeneration", "AudioMLPAdapter"]