Automatic Speech Recognition
Transformers
Safetensors
PyTorch
arkasr
text-generation
speech
audio
ark-asr
custom_code
Instructions to use AutoArk-AI/ARK-ASR-0.6B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use AutoArk-AI/ARK-ASR-0.6B with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("automatic-speech-recognition", model="AutoArk-AI/ARK-ASR-0.6B", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("AutoArk-AI/ARK-ASR-0.6B", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| from typing import Optional, List, Tuple, Union, Dict | |
| import torch | |
| from torch import Tensor, nn | |
| from transformers import Qwen2ForCausalLM | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from .configuration_arkasr import ArkasrConfig | |
| from .modeling_audio import WhisperSpecialEncoder | |
| class AudioMLPAdapter(nn.Module): | |
| def __init__(self, config: ArkasrConfig): | |
| super().__init__() | |
| whisper_config = config.whisper_config | |
| self.merge_factor = int(config.merge_factor) | |
| # 音频编码器 | |
| self.whisper = WhisperSpecialEncoder( | |
| whisper_config, | |
| use_rope=getattr(config, "use_rope", False), | |
| ) | |
| # 禁用 Whisper 自带 LayerNorm | |
| self.whisper.layer_norm = nn.Identity() | |
| self.layer_norm = nn.LayerNorm(whisper_config.hidden_size) | |
| act_fn_map = { | |
| "gelu": nn.GELU(), | |
| "relu": nn.ReLU(), | |
| "selu": nn.SELU(), | |
| } | |
| act = act_fn_map.get(getattr(config, "mlp_adapter_act", "gelu"), nn.GELU()) | |
| input_dim = whisper_config.hidden_size * self.merge_factor | |
| output_dim = config.hidden_size | |
| self.adapting = nn.Sequential( | |
| nn.Linear(input_dim, output_dim * 2), | |
| act, | |
| nn.Linear(output_dim * 2, output_dim), | |
| ) | |
| def forward(self, audios: Tensor) -> Tensor: | |
| """ | |
| Args: | |
| audios: (B, mel, T) 或 (B, raw_len) —— 由 WhisperSpecialEncoder 决定 | |
| Returns: | |
| adapted_features: (B, Seq_Audio, LLM_Hidden_Dim) | |
| """ | |
| bsz = audios.size(0) | |
| encoded = self.whisper(audios)[0] # (B, T, D) | |
| encoded = self.layer_norm(encoded) | |
| seq_len = encoded.size(1) | |
| if seq_len % self.merge_factor != 0: | |
| target_len = (seq_len // self.merge_factor) * self.merge_factor | |
| if target_len <= 0: | |
| # 极短音频兜底:pad 到 merge_factor | |
| target_len = self.merge_factor | |
| if seq_len < target_len: | |
| pad_len = target_len - seq_len | |
| pad = encoded.new_zeros((bsz, pad_len, encoded.size(-1))) | |
| encoded = torch.cat([encoded, pad], dim=1) | |
| else: | |
| encoded = encoded[:, :target_len, :] | |
| encoded = encoded.reshape(bsz, -1, encoded.size(-1) * self.merge_factor) | |
| adapted = self.adapting(encoded) # (B, T/k, hidden) | |
| return adapted | |
| class ArkasrForConditionalGeneration(Qwen2ForCausalLM): | |
| config_class = ArkasrConfig | |
| _no_split_modules = ["WhisperSpecialEncoder"] | |
| def __init__(self, config: ArkasrConfig): | |
| super().__init__(config) | |
| self.audio_encoder = AudioMLPAdapter(config) | |
| self.audio_token_id = getattr(config, "audio_token_id", None) | |
| if self.audio_token_id is None: | |
| raise ValueError("`audio_token_id` must be defined in config.") | |
| def _cache_seq_len(past_key_values) -> int: | |
| if past_key_values is None: | |
| return 0 | |
| if hasattr(past_key_values, "get_seq_length"): | |
| try: | |
| return int(past_key_values.get_seq_length()) | |
| except Exception: | |
| return 0 | |
| try: | |
| return int(past_key_values[0][0].shape[-2]) | |
| except Exception: | |
| return 0 | |
| def _inject_audio_embeddings_batch_encode_then_loop_scatter( | |
| self, | |
| input_ids: torch.LongTensor, # (B, S) | |
| inputs_embeds: torch.FloatTensor, # (B, S, H) | |
| audios: Tensor, # (B, ...) | |
| ) -> torch.FloatTensor: | |
| """ | |
| 先对「有 audio token 的样本」做一次 batch 音频编码, | |
| 然后 for-loop 把每个样本的 audio features 按 audio_token 位置写回 inputs_embeds。 | |
| 好处: | |
| - encoder 只跑一次(快) | |
| - 写回按样本做,不会跨样本错位(稳) | |
| - 碰到某行没有 audio_token:直接跳过(TTS 行无影响) | |
| 约束: | |
| - 每条样本的 audio_token 数量 n_i 需要和 audio_encoder 输出的 Sa 对齐。 | |
| 如果不对齐:这里采用截断/补零对齐到 n_i(不报错)。 | |
| """ | |
| B, S = input_ids.shape | |
| H = inputs_embeds.size(-1) | |
| device = inputs_embeds.device | |
| dtype = inputs_embeds.dtype | |
| # 找到哪些样本需要注入 | |
| mask = (input_ids == self.audio_token_id) # (B, S) | |
| per_counts = mask.sum(dim=1) # (B,) | |
| need_idx = (per_counts > 0).nonzero(as_tuple=False).squeeze(1) # (K,) | |
| if need_idx.numel() == 0: | |
| return inputs_embeds | |
| # 只编码需要注入的那部分音频(K, ...) | |
| audios_sub = audios.index_select(0, need_idx) | |
| feats_sub = self.audio_encoder(audios_sub) # (K, Sa, H) | |
| # 写回:逐样本替换(写回操作本身几乎不耗时) | |
| feats_sub = feats_sub.to(device=device, dtype=dtype) | |
| Sa = feats_sub.size(1) | |
| # 逐个样本注入 | |
| for k in range(need_idx.numel()): | |
| i = int(need_idx[k].item()) | |
| n_i = int(per_counts[i].item()) | |
| if n_i <= 0: | |
| continue | |
| feat_i = feats_sub[k] # (Sa, H) | |
| # 对齐到该样本的 audio token 数 n_i | |
| if Sa < n_i: | |
| pad = feat_i.new_zeros((n_i - Sa, H)) | |
| feat_i_use = torch.cat([feat_i, pad], dim=0) | |
| elif Sa > n_i: | |
| feat_i_use = feat_i[:n_i] | |
| else: | |
| feat_i_use = feat_i | |
| pos_i = mask[i].nonzero(as_tuple=False).squeeze(1) # (n_i,) | |
| # 写回 embeddings | |
| inputs_embeds[i, pos_i, :] = feat_i_use | |
| return inputs_embeds | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| audios: Optional[Tensor] = None, | |
| attention_mask: Optional[Tensor] = None, | |
| position_ids: Optional[Tensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| logits_to_keep: int | torch.Tensor = 0, | |
| **kwargs, | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| if inputs_embeds is None: | |
| if input_ids is None: | |
| raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.") | |
| inputs_embeds = self.model.embed_tokens(input_ids) | |
| # 只在首步(past_len==0)注入,避免 generation 后续重复 encode | |
| past_len = self._cache_seq_len(past_key_values) | |
| if audios is not None and input_ids is not None and past_len == 0: | |
| inputs_embeds = self._inject_audio_embeddings_batch_encode_then_loop_scatter( | |
| input_ids=input_ids, | |
| inputs_embeds=inputs_embeds, | |
| audios=audios, | |
| ) | |
| outputs = self.model( | |
| input_ids=None, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| ) | |
| hidden_states = outputs[0] | |
| # logits(避免重复算 lm_head) | |
| if isinstance(logits_to_keep, int) and logits_to_keep > 0: | |
| hidden_for_logits = hidden_states[:, -logits_to_keep:, :] | |
| elif isinstance(logits_to_keep, torch.Tensor): | |
| hidden_for_logits = hidden_states[:, logits_to_keep, :] | |
| else: | |
| hidden_for_logits = hidden_states | |
| logits = self.lm_head(hidden_for_logits) | |
| loss = None | |
| if labels is not None: | |
| loss = self.loss_function( | |
| logits=logits, | |
| labels=labels, | |
| vocab_size=self.config.vocab_size, | |
| **kwargs, | |
| ) | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids, | |
| past_key_values=None, | |
| attention_mask=None, | |
| inputs_embeds=None, | |
| **kwargs, | |
| ): | |
| past_len = self._cache_seq_len(past_key_values) | |
| if past_len > 0: | |
| input_ids = input_ids[:, -1:] | |
| model_inputs = { | |
| "input_ids": input_ids, | |
| "past_key_values": past_key_values, | |
| "use_cache": kwargs.get("use_cache"), | |
| "attention_mask": attention_mask, | |
| # audios 透传;forward 内 past_len==0 才注入,所以后续 step 不会重复编码 | |
| "audios": kwargs.get("audios", None), | |
| } | |
| if inputs_embeds is not None and past_key_values is None: | |
| model_inputs["inputs_embeds"] = inputs_embeds | |
| del model_inputs["input_ids"] | |
| return model_inputs | |
| __all__ = ["ArkasrForConditionalGeneration", "AudioMLPAdapter"] |