GPA-v1.5 / modeling_arkasr.py
chua's picture
Upload GPA v1.5 model package
7e53c9a
from __future__ import annotations
from typing import Optional, List, Tuple, Union, Dict
import torch
from torch import Tensor, nn
from transformers import Qwen2ForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_arkasr import ArkasrConfig
from .modeling_audio import WhisperSpecialEncoder
class AudioMLPAdapter(nn.Module):
def __init__(self, config: ArkasrConfig):
super().__init__()
whisper_config = config.whisper_config
self.merge_factor = int(config.merge_factor)
# Audio encoder
self.whisper = WhisperSpecialEncoder(
whisper_config,
use_rope=getattr(config, "use_rope", False),
)
# Disable Whisper's built-in LayerNorm.
self.whisper.layer_norm = nn.Identity()
self.layer_norm = nn.LayerNorm(whisper_config.hidden_size)
act_fn_map = {
"gelu": nn.GELU(),
"relu": nn.ReLU(),
"selu": nn.SELU(),
}
act = act_fn_map.get(getattr(config, "mlp_adapter_act", "gelu"), nn.GELU())
input_dim = whisper_config.hidden_size * self.merge_factor
output_dim = config.hidden_size
self.adapting = nn.Sequential(
nn.Linear(input_dim, output_dim * 2),
act,
nn.Linear(output_dim * 2, output_dim),
)
def forward(self, audios: Tensor) -> Tensor:
"""
Args:
audios: (B, mel, T) or (B, raw_len), depending on WhisperSpecialEncoder.
Returns:
adapted_features: (B, Seq_Audio, LLM_Hidden_Dim)
"""
bsz = audios.size(0)
encoded = self.whisper(audios)[0] # (B, T, D)
encoded = self.layer_norm(encoded)
seq_len = encoded.size(1)
if seq_len % self.merge_factor != 0:
target_len = (seq_len // self.merge_factor) * self.merge_factor
if target_len <= 0:
# Guard for extremely short audio: pad to merge_factor.
target_len = self.merge_factor
if seq_len < target_len:
pad_len = target_len - seq_len
pad = encoded.new_zeros((bsz, pad_len, encoded.size(-1)))
encoded = torch.cat([encoded, pad], dim=1)
else:
encoded = encoded[:, :target_len, :]
encoded = encoded.reshape(bsz, -1, encoded.size(-1) * self.merge_factor)
adapted = self.adapting(encoded) # (B, T/k, hidden)
return adapted
class ArkasrForConditionalGeneration(Qwen2ForCausalLM):
config_class = ArkasrConfig
_no_split_modules = ["WhisperSpecialEncoder"]
def __init__(self, config: ArkasrConfig):
super().__init__(config)
self.audio_encoder = AudioMLPAdapter(config)
self.audio_token_id = getattr(config, "audio_token_id", None)
if self.audio_token_id is None:
raise ValueError("`audio_token_id` must be defined in config.")
@staticmethod
def _cache_seq_len(past_key_values) -> int:
if past_key_values is None:
return 0
if hasattr(past_key_values, "get_seq_length"):
try:
return int(past_key_values.get_seq_length())
except Exception:
return 0
try:
return int(past_key_values[0][0].shape[-2])
except Exception:
return 0
def _inject_audio_embeddings_batch_encode_then_loop_scatter(
self,
input_ids: torch.LongTensor, # (B, S)
inputs_embeds: torch.FloatTensor, # (B, S, H)
audios: Tensor, # (B, ...)
) -> torch.FloatTensor:
"""
First run one batched audio encoding pass for samples that contain audio tokens,
then scatter each sample's audio features back into inputs_embeds at the
corresponding audio_token positions.
Benefits:
- The encoder runs only once.
- Scatter is performed per sample, so features cannot drift across samples.
- Rows without audio_token are skipped directly, which keeps TTS-only rows unaffected.
Constraint:
- The number of audio tokens n_i in each sample should align with Sa from the
audio encoder output. If they do not align, this path truncates or zero-pads
to n_i instead of raising an error.
"""
B, S = input_ids.shape
H = inputs_embeds.size(-1)
device = inputs_embeds.device
dtype = inputs_embeds.dtype
# Find the samples that require audio injection.
mask = (input_ids == self.audio_token_id) # (B, S)
per_counts = mask.sum(dim=1) # (B,)
need_idx = (per_counts > 0).nonzero(as_tuple=False).squeeze(1) # (K,)
if need_idx.numel() == 0:
return inputs_embeds
# Encode only the subset of audio that needs injection. (K, ...)
audios_sub = audios.index_select(0, need_idx)
feats_sub = self.audio_encoder(audios_sub) # (K, Sa, H)
# Scatter back per sample; the write-back itself is negligible.
feats_sub = feats_sub.to(device=device, dtype=dtype)
Sa = feats_sub.size(1)
# Inject one sample at a time.
for k in range(need_idx.numel()):
i = int(need_idx[k].item())
n_i = int(per_counts[i].item())
if n_i <= 0:
continue
feat_i = feats_sub[k] # (Sa, H)
# Align to the number of audio tokens n_i for this sample.
if Sa < n_i:
pad = feat_i.new_zeros((n_i - Sa, H))
feat_i_use = torch.cat([feat_i, pad], dim=0)
elif Sa > n_i:
feat_i_use = feat_i[:n_i]
else:
feat_i_use = feat_i
pos_i = mask[i].nonzero(as_tuple=False).squeeze(1) # (n_i,)
# Write features back into embeddings.
inputs_embeds[i, pos_i, :] = feat_i_use
return inputs_embeds
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
audios: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
if input_ids is None:
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
inputs_embeds = self.model.embed_tokens(input_ids)
# Inject only on the first step (past_len == 0) to avoid re-encoding during generation.
past_len = self._cache_seq_len(past_key_values)
if audios is not None and input_ids is not None and past_len == 0:
inputs_embeds = self._inject_audio_embeddings_batch_encode_then_loop_scatter(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
audios=audios,
)
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = outputs[0]
# Restrict logits computation when possible to avoid redundant lm_head work.
if isinstance(logits_to_keep, int) and logits_to_keep > 0:
hidden_for_logits = hidden_states[:, -logits_to_keep:, :]
elif isinstance(logits_to_keep, torch.Tensor):
hidden_for_logits = hidden_states[:, logits_to_keep, :]
else:
hidden_for_logits = hidden_states
logits = self.lm_head(hidden_for_logits)
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
past_len = self._cache_seq_len(past_key_values)
if past_len > 0:
input_ids = input_ids[:, -1:]
model_inputs = {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
# Pass audios through. Injection happens only when past_len == 0 in forward,
# so later generation steps do not re-encode.
"audios": kwargs.get("audios", None),
}
if inputs_embeds is not None and past_key_values is None:
model_inputs["inputs_embeds"] = inputs_embeds
del model_inputs["input_ids"]
return model_inputs
__all__ = ["ArkasrForConditionalGeneration", "AudioMLPAdapter"]