Automatic Speech Recognition
Transformers
Safetensors
PyTorch
arkasr
text-generation
speech
audio
ark-asr
custom_code
Instructions to use AutoArk-AI/ARK-ASR-0.6B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use AutoArk-AI/ARK-ASR-0.6B with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("automatic-speech-recognition", model="AutoArk-AI/ARK-ASR-0.6B", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("AutoArk-AI/ARK-ASR-0.6B", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 9,402 Bytes
05f7466 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 | from __future__ import annotations
from typing import Optional, List, Tuple, Union, Dict
import torch
from torch import Tensor, nn
from transformers import Qwen2ForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_arkasr import ArkasrConfig
from .modeling_audio import WhisperSpecialEncoder
class AudioMLPAdapter(nn.Module):
def __init__(self, config: ArkasrConfig):
super().__init__()
whisper_config = config.whisper_config
self.merge_factor = int(config.merge_factor)
# 音频编码器
self.whisper = WhisperSpecialEncoder(
whisper_config,
use_rope=getattr(config, "use_rope", False),
)
# 禁用 Whisper 自带 LayerNorm
self.whisper.layer_norm = nn.Identity()
self.layer_norm = nn.LayerNorm(whisper_config.hidden_size)
act_fn_map = {
"gelu": nn.GELU(),
"relu": nn.ReLU(),
"selu": nn.SELU(),
}
act = act_fn_map.get(getattr(config, "mlp_adapter_act", "gelu"), nn.GELU())
input_dim = whisper_config.hidden_size * self.merge_factor
output_dim = config.hidden_size
self.adapting = nn.Sequential(
nn.Linear(input_dim, output_dim * 2),
act,
nn.Linear(output_dim * 2, output_dim),
)
def forward(self, audios: Tensor) -> Tensor:
"""
Args:
audios: (B, mel, T) 或 (B, raw_len) —— 由 WhisperSpecialEncoder 决定
Returns:
adapted_features: (B, Seq_Audio, LLM_Hidden_Dim)
"""
bsz = audios.size(0)
encoded = self.whisper(audios)[0] # (B, T, D)
encoded = self.layer_norm(encoded)
seq_len = encoded.size(1)
if seq_len % self.merge_factor != 0:
target_len = (seq_len // self.merge_factor) * self.merge_factor
if target_len <= 0:
# 极短音频兜底:pad 到 merge_factor
target_len = self.merge_factor
if seq_len < target_len:
pad_len = target_len - seq_len
pad = encoded.new_zeros((bsz, pad_len, encoded.size(-1)))
encoded = torch.cat([encoded, pad], dim=1)
else:
encoded = encoded[:, :target_len, :]
encoded = encoded.reshape(bsz, -1, encoded.size(-1) * self.merge_factor)
adapted = self.adapting(encoded) # (B, T/k, hidden)
return adapted
class ArkasrForConditionalGeneration(Qwen2ForCausalLM):
config_class = ArkasrConfig
_no_split_modules = ["WhisperSpecialEncoder"]
def __init__(self, config: ArkasrConfig):
super().__init__(config)
self.audio_encoder = AudioMLPAdapter(config)
self.audio_token_id = getattr(config, "audio_token_id", None)
if self.audio_token_id is None:
raise ValueError("`audio_token_id` must be defined in config.")
@staticmethod
def _cache_seq_len(past_key_values) -> int:
if past_key_values is None:
return 0
if hasattr(past_key_values, "get_seq_length"):
try:
return int(past_key_values.get_seq_length())
except Exception:
return 0
try:
return int(past_key_values[0][0].shape[-2])
except Exception:
return 0
def _inject_audio_embeddings_batch_encode_then_loop_scatter(
self,
input_ids: torch.LongTensor, # (B, S)
inputs_embeds: torch.FloatTensor, # (B, S, H)
audios: Tensor, # (B, ...)
) -> torch.FloatTensor:
"""
先对「有 audio token 的样本」做一次 batch 音频编码,
然后 for-loop 把每个样本的 audio features 按 audio_token 位置写回 inputs_embeds。
好处:
- encoder 只跑一次(快)
- 写回按样本做,不会跨样本错位(稳)
- 碰到某行没有 audio_token:直接跳过(TTS 行无影响)
约束:
- 每条样本的 audio_token 数量 n_i 需要和 audio_encoder 输出的 Sa 对齐。
如果不对齐:这里采用截断/补零对齐到 n_i(不报错)。
"""
B, S = input_ids.shape
H = inputs_embeds.size(-1)
device = inputs_embeds.device
dtype = inputs_embeds.dtype
# 找到哪些样本需要注入
mask = (input_ids == self.audio_token_id) # (B, S)
per_counts = mask.sum(dim=1) # (B,)
need_idx = (per_counts > 0).nonzero(as_tuple=False).squeeze(1) # (K,)
if need_idx.numel() == 0:
return inputs_embeds
# 只编码需要注入的那部分音频(K, ...)
audios_sub = audios.index_select(0, need_idx)
feats_sub = self.audio_encoder(audios_sub) # (K, Sa, H)
# 写回:逐样本替换(写回操作本身几乎不耗时)
feats_sub = feats_sub.to(device=device, dtype=dtype)
Sa = feats_sub.size(1)
# 逐个样本注入
for k in range(need_idx.numel()):
i = int(need_idx[k].item())
n_i = int(per_counts[i].item())
if n_i <= 0:
continue
feat_i = feats_sub[k] # (Sa, H)
# 对齐到该样本的 audio token 数 n_i
if Sa < n_i:
pad = feat_i.new_zeros((n_i - Sa, H))
feat_i_use = torch.cat([feat_i, pad], dim=0)
elif Sa > n_i:
feat_i_use = feat_i[:n_i]
else:
feat_i_use = feat_i
pos_i = mask[i].nonzero(as_tuple=False).squeeze(1) # (n_i,)
# 写回 embeddings
inputs_embeds[i, pos_i, :] = feat_i_use
return inputs_embeds
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
audios: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
if input_ids is None:
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
inputs_embeds = self.model.embed_tokens(input_ids)
# 只在首步(past_len==0)注入,避免 generation 后续重复 encode
past_len = self._cache_seq_len(past_key_values)
if audios is not None and input_ids is not None and past_len == 0:
inputs_embeds = self._inject_audio_embeddings_batch_encode_then_loop_scatter(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
audios=audios,
)
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = outputs[0]
# logits(避免重复算 lm_head)
if isinstance(logits_to_keep, int) and logits_to_keep > 0:
hidden_for_logits = hidden_states[:, -logits_to_keep:, :]
elif isinstance(logits_to_keep, torch.Tensor):
hidden_for_logits = hidden_states[:, logits_to_keep, :]
else:
hidden_for_logits = hidden_states
logits = self.lm_head(hidden_for_logits)
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
past_len = self._cache_seq_len(past_key_values)
if past_len > 0:
input_ids = input_ids[:, -1:]
model_inputs = {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
# audios 透传;forward 内 past_len==0 才注入,所以后续 step 不会重复编码
"audios": kwargs.get("audios", None),
}
if inputs_embeds is not None and past_key_values is None:
model_inputs["inputs_embeds"] = inputs_embeds
del model_inputs["input_ids"]
return model_inputs
__all__ = ["ArkasrForConditionalGeneration", "AudioMLPAdapter"] |