from collections.abc import Callable from typing import Optional import numpy as np import torch import torch.nn.functional as F from torch import Tensor, nn from transformers import PreTrainedModel, Qwen2ForCausalLM from transformers.activations import ACT2FN from transformers.generation import GenerationMixin from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.utils import auto_docstring from .configuration_uas_audio import UASAudioConfig, UASAudioEncoderConfig, UASAudioEncoderOnlyConfig def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def _get_feat_extract_output_lengths(input_lengths): """ Computes the output length of the convolutional layers and the output length of the audio encoder """ input_lengths_leave = input_lengths % 100 feat_lengths = (input_lengths_leave - 1) // 2 + 1 output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 return output_lengths def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class SinusoidsPositionEmbedding(nn.Module): def __init__(self, length, channels, max_timescale=10000): super().__init__() if channels % 2 != 0: raise ValueError("SinusoidsPositionEmbedding needs even channels input") log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] self.register_buffer( "positional_embedding", torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), persistent=False, ) def forward(self, seqlen: int): return self.positional_embedding[:seqlen, :] class UASAudioAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config): super().__init__() self.embed_dim = config.d_model self.num_heads = config.encoder_attention_heads self.dropout = config.attention_dropout self.head_dim = self.embed_dim // self.num_heads self.num_key_value_groups = 1 # needed for eager attention self.config = config if (self.head_dim * self.num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {self.num_heads})." ) self.scaling = self.head_dim**-0.5 self.attention_dropout = 0.0 self.is_decoder = False self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) def forward( self, hidden_states: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" seq_length, _ = hidden_states.size() query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, _ = attention_interface( self, query_states, key_states, value_states, attention_mask=attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 cu_seq_lens_k=cu_seqlens, max_length_q=max_seqlen, max_length_k=max_seqlen, is_causal=False, **kwargs, ) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.out_proj(attn_output) return attn_output class UASAudioEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: UASAudioEncoderConfig): super().__init__() self.embed_dim = config.d_model self.self_attn = UASAudioAttention(config) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, attention_mask=attention_mask, **kwargs, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16: clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) return outputs class UASAudioEncoder(PreTrainedModel): config: UASAudioEncoderConfig main_input_name = "input_features" input_modalities = "audio" _no_split_modules = ["UASAudioEncoderLayer"] _supports_sdpa = True def __init__(self, config: UASAudioEncoderConfig): super().__init__(config) self.dropout = config.dropout embed_dim = config.d_model self.num_mel_bins = config.num_mel_bins self.max_source_positions = config.max_source_positions self.n_window = config.n_window self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim) self.layers = nn.ModuleList([UASAudioEncoderLayer(config) for _ in range(config.encoder_layers)]) self.ln_post = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1) self.conv2d2 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1) self.conv2d3 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1) self.conv_out = nn.Linear( config.downsample_hidden_size * ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2), config.d_model, bias=False, ) self.n_window_infer = self.config.n_window_infer self.conv_chunksize = self.config.conv_chunksize self.post_init() def _freeze_parameters(self): for param in self.parameters(): param.requires_grad = False self._requires_grad = False def get_input_embeddings(self) -> nn.Module: return self.conv1 def set_input_embeddings(self, value: nn.Module): self.conv1 = value def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` # NOTE: the created attention masl only approximates the ragged FA2 attention by # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between # blocks. Though it will not be a 100% match for FA2's `varlen` path if self.config._attn_implementation == "flash_attention_2": return None seq_length = inputs_tensor.shape[0] attention_mask = torch.full( [1, 1, seq_length, seq_length], torch.finfo(inputs_tensor.dtype).min, device=inputs_tensor.device, dtype=inputs_tensor.dtype, ) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 return attention_mask @auto_docstring def forward( self, input_features, feature_lens=None, aftercnn_lens=None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" feature_lens (`torch.LongTensor` of shape `(batch_size,)`): mel length aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`): mel length after cnn output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', False) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else getattr(self.config, 'output_hidden_states', False) ) aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() chunk_lengths = torch.tensor( [self.n_window * 2] * chunk_num.sum(), dtype=torch.long, device=feature_lens.device, ) tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) chunk_lengths[chunk_lengths == 0] = self.n_window * 2 chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) padded_mask_after_cnn = nn.utils.rnn.pad_sequence( [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn], batch_first=True, ) padded_feature = padded_feature.unsqueeze(1) # Split to chunk to avoid OOM during convolution padded_embeds = [] for chunk in padded_feature.split(self.conv_chunksize, dim=0): padded_embed = F.gelu(self.conv2d1(chunk)) padded_embed = F.gelu(self.conv2d2(padded_embed)) padded_embed = F.gelu(self.conv2d3(padded_embed)) padded_embeds.append(padded_embed) padded_embed = torch.cat(padded_embeds, dim=0) b, c, f, t = padded_embed.size() padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) positional_embedding = ( self.positional_embedding.positional_embedding[: padded_embed.shape[1], :] .unsqueeze(0) .to(padded_embed.dtype) ) padded_embed = padded_embed + positional_embedding hidden_states = padded_embed[padded_mask_after_cnn] cu_chunk_lens = [0] window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2)) for cnn_len in aftercnn_lens: cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) remainder = cnn_len % window_aftercnn if remainder != 0: cu_chunk_lens += [remainder] cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32) all_hidden_states = () if output_hidden_states else None if output_hidden_states: all_hidden_states = (hidden_states,) for layer_idx, encoder_layer in enumerate(self.layers): layer_outputs = encoder_layer( hidden_states, cu_seqlens, ) hidden_states = layer_outputs[0] if output_hidden_states: all_hidden_states += (hidden_states,) hidden_states = self.ln_post(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, ) def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): """ Pads a sequence of tensors to their maximum length on indicated `padding_side`. Then prepares a mask so that pad tokens are not attended to. """ max_len = tensor_len.max() dim = tensor_list[0].shape[0] padded_tensor = torch.full( size=(len(tensor_list), dim, max_len), fill_value=padding_value, dtype=self.dtype, device=tensor_list[0].device, ) batch_mask = torch.zeros( (len(tensor_len), max_len), dtype=torch.long, device=padded_tensor.device, ) for i, length in enumerate(tensor_len): batch_mask[i, :length] = 1 padded_tensor[i, :, :length] = tensor_list[i] feature_lens_after_cnn = (tensor_len - 1) // 2 + 1 max_len_after_cnn = feature_lens_after_cnn.max() batch_mask_after_cnn = torch.zeros( (len(tensor_len), max_len_after_cnn), dtype=torch.long, device=padded_tensor.device, ) for i, length in enumerate(feature_lens_after_cnn): batch_mask_after_cnn[i, :length] = 1 return ( padded_tensor, batch_mask.unsqueeze(1), batch_mask_after_cnn.bool(), ) class Adapter(nn.Module): def __init__( self, d_model: int, n_embd: int, ): super().__init__() self.audio_projector = torch.nn.Sequential( torch.nn.Linear(d_model, n_embd), torch.nn.GELU(), torch.nn.Linear(n_embd, n_embd) ) def forward(self, x: Tensor) -> Tensor: x = self.audio_projector(x) return x class UASAudioForCausalLM(PreTrainedModel, GenerationMixin): config_class = UASAudioConfig main_input_name = "input_ids" supports_gradient_checkpointing = True def __init__(self, config: UASAudioConfig): super().__init__(config) if isinstance(config.dtype, str): dtype = getattr(torch, config.dtype) else: dtype = config.dtype self.bf16 = dtype == torch.bfloat16 self.llm = Qwen2ForCausalLM(config.text_config) self.audio_encoder = UASAudioEncoder(config.audio_encoder_config) d_model = config.audio_encoder_config.d_model self.adapter = Adapter( d_model, config.text_config.hidden_size, ) self.audio_token = config.audio_token if self.bf16: self.audio_encoder = self.audio_encoder.bfloat16() self.adapter = self.adapter.bfloat16() self.post_init() def forward( self, input_ids=None, attention_mask=None, mels=None, mel_masks=None, past_key_values=None, **kwargs ): # If past_key_values are provided, we are in the generation phase and should not process audio inputs again if past_key_values is not None: outputs = self.llm( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, **kwargs ) else: # First get the text embeddings for the input_ids, then replace audio token positions with audio embeddings hidden_states = self.embedding_with_audio_tokens(input_ids, mels, mel_masks) outputs = self.llm( inputs_embeds=hidden_states, attention_mask=attention_mask, past_key_values=None, **kwargs ) return outputs def embedding_with_audio_tokens( self, input_ids, mels, mel_masks ): """ Get input embeddings for the LLM, replacing audio token positions with audio features from the audio encoder. """ hidden_states = self.embeddings(input_ids) if mels is None: return hidden_states audio_embeddings = self.audio_encoding(mels, mel_masks) # data -> feature audio_embeddings = self.adapter(audio_embeddings) audio_mask = input_ids == self.audio_token hidden_states[audio_mask] = audio_embeddings return hidden_states def audio_encoding( self, audio_features: torch.Tensor, audio_features_mask: torch.Tensor, output_hidden_states: bool = False ): """ Encode audio features into embeddings. Args: audio_features: Audio features tensor audio_features_mask: Audio features mask output_hidden_states: Whether to return hidden states from all encoder layers Returns: If output_hidden_states=False: audio_features_encoded tensor If output_hidden_states=True: BaseModelOutput with last_hidden_state and hidden_states """ feature_lens = audio_features_mask.sum(-1).long() # [batch_size] input_features = audio_features.permute(0, 2, 1)[audio_features_mask.bool()].permute(1, 0) audio_encoder_outputs = self.audio_encoder( input_features, feature_lens=feature_lens, output_hidden_states=output_hidden_states, return_dict=output_hidden_states, # Only return dict when we need hidden states ) if output_hidden_states: # When output_hidden_states=True, we get BaseModelOutput return audio_encoder_outputs else: # When output_hidden_states=False, we get tuple (hidden_states, ...) # Extract the first element (hidden_states tensor) for backward compatibility if isinstance(audio_encoder_outputs, tuple): return audio_encoder_outputs[0] return audio_encoder_outputs @property def embeddings(self): """Return the model's input embeddings - required for GenerationMixin""" return self.llm.model.embed_tokens def forward_with_detailed_outputs( self, input_ids=None, attention_mask=None, mels=None, mel_masks=None, past_key_values=None, output_hidden_states: bool = True, **kwargs ): """ Forward pass that returns detailed outputs including: - Audio encoder final output - Audio features after projector (adapter) - Text embedding features - Hidden states from each layer (separated for audio and text) Args: input_ids: Input token ids attention_mask: Attention mask mels: Audio mel features mel_masks: Audio mel masks past_key_values: Past key values for generation output_hidden_states: Whether to return hidden states from all layers **kwargs: Additional arguments Returns: dict containing: - audio_encoder_output: Final output from audio encoder - audio_features_after_adapter: Audio features after projector/adapter - text_embeddings: Text embedding features (before audio replacement) - audio_encoder_hidden_states: Tuple of hidden states from each audio encoder layer - llm_hidden_states: Tuple of hidden states from each LLM layer (mixed audio+text) - llm_hidden_states_text_only: Tuple of text-only hidden states from each LLM layer - llm_hidden_states_audio_only: Tuple of audio-only hidden states from each LLM layer - llm_outputs: Full LLM outputs (CausalLMOutputWithPast) """ # Get text embeddings (pure text, before audio replacement) # Save original text embeddings for return (will have audio parts removed) text_embeddings_pure = self.embeddings(input_ids) # Process audio if provided audio_encoder_output = None audio_features_after_adapter = None audio_encoder_hidden_states = None audio_mask = None # Create embeddings for LLM forward pass (may include audio features) input_embeddings_for_llm = text_embeddings_pure.clone() # Identify audio token positions (even if no audio is provided, audio tokens may exist in input_ids) audio_mask = input_ids == self.audio_token if mels is not None: # Get audio encoder outputs with hidden states audio_encoder_outputs = self.audio_encoding( mels, mel_masks, output_hidden_states=output_hidden_states ) if output_hidden_states: audio_encoder_output = audio_encoder_outputs.last_hidden_state audio_encoder_hidden_states = audio_encoder_outputs.hidden_states else: audio_encoder_output = audio_encoder_outputs audio_encoder_hidden_states = None # Apply adapter audio_features_after_adapter = self.adapter(audio_encoder_output) # Replace audio token positions with audio embeddings in the LLM input input_embeddings_for_llm[audio_mask] = audio_features_after_adapter # Remove audio parts from text_embeddings_pure (delete, not set to zero) # This ensures returned text_embeddings strictly contains no audio features if audio_mask.any(): # Process each batch separately since audio positions may differ batch_size = text_embeddings_pure.shape[0] text_embeddings_list = [] for i in range(batch_size): # Get text-only mask for this batch (inverse of audio_mask) text_mask = ~audio_mask[i] # shape: (seq_len,) # Extract only text embeddings text_emb = text_embeddings_pure[i][text_mask] # shape: (text_seq_len, hidden_size) text_embeddings_list.append(text_emb) # Pad sequences to the same length for batching # Use the maximum text sequence length across batches max_text_len = max(emb.shape[0] for emb in text_embeddings_list) if text_embeddings_list else 0 if max_text_len > 0: hidden_size = text_embeddings_pure.shape[2] text_embeddings_pure = torch.zeros( (batch_size, max_text_len, hidden_size), dtype=text_embeddings_pure.dtype, device=text_embeddings_pure.device ) for i, emb in enumerate(text_embeddings_list): text_len = emb.shape[0] text_embeddings_pure[i, :text_len] = emb else: # No text embeddings (all are audio tokens) hidden_size = text_embeddings_pure.shape[2] text_embeddings_pure = torch.zeros( (batch_size, 0, hidden_size), dtype=text_embeddings_pure.dtype, device=text_embeddings_pure.device ) # If no audio tokens, text_embeddings_pure remains unchanged # Forward through LLM if past_key_values is not None: # Incremental decoding llm_outputs = self.llm( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, output_hidden_states=output_hidden_states, return_dict=True, **kwargs ) else: # First step: use combined embeddings (may include audio features) llm_outputs = self.llm( inputs_embeds=input_embeddings_for_llm, attention_mask=attention_mask, past_key_values=None, output_hidden_states=output_hidden_states, return_dict=True, **kwargs ) # Extract LLM hidden states llm_hidden_states = llm_outputs.hidden_states if output_hidden_states else None # Separate audio and text hidden states if audio is present llm_hidden_states_text_only = None llm_hidden_states_audio_only = None if output_hidden_states and llm_hidden_states is not None and audio_mask is not None: # Separate each layer's hidden states into text and audio parts llm_hidden_states_text_only = tuple() llm_hidden_states_audio_only = tuple() batch_size = llm_hidden_states[0].shape[0] hidden_size = llm_hidden_states[0].shape[2] for layer_hidden_states in llm_hidden_states: # layer_hidden_states shape: (batch_size, seq_len, hidden_size) # audio_mask shape: (batch_size, seq_len) # Process text-only hidden states: delete audio positions, not set to zero text_hidden_list = [] audio_hidden_list = [] for i in range(batch_size): # Get text-only mask for this batch (inverse of audio_mask) text_mask = ~audio_mask[i] # shape: (seq_len,) audio_mask_i = audio_mask[i] # shape: (seq_len,) # Extract only text hidden states (delete audio positions) text_hidden_i = layer_hidden_states[i][text_mask] # shape: (text_seq_len, hidden_size) text_hidden_list.append(text_hidden_i) # Extract only audio hidden states (delete text positions) audio_hidden_i = layer_hidden_states[i][audio_mask_i] # shape: (audio_seq_len, hidden_size) audio_hidden_list.append(audio_hidden_i) # Pad sequences to the same length for batching # Use the maximum text sequence length across batches max_text_len = max(emb.shape[0] for emb in text_hidden_list) if text_hidden_list else 0 max_audio_len = max(emb.shape[0] for emb in audio_hidden_list) if audio_hidden_list else 0 if max_text_len > 0: text_hidden = torch.zeros( (batch_size, max_text_len, hidden_size), dtype=layer_hidden_states.dtype, device=layer_hidden_states.device ) for i, emb in enumerate(text_hidden_list): text_len = emb.shape[0] text_hidden[i, :text_len] = emb else: # No text hidden states (all are audio tokens) text_hidden = torch.zeros( (batch_size, 0, hidden_size), dtype=layer_hidden_states.dtype, device=layer_hidden_states.device ) if max_audio_len > 0: audio_hidden = torch.zeros( (batch_size, max_audio_len, hidden_size), dtype=layer_hidden_states.dtype, device=layer_hidden_states.device ) for i, emb in enumerate(audio_hidden_list): audio_len = emb.shape[0] audio_hidden[i, :audio_len] = emb else: # No audio hidden states (all are text tokens) audio_hidden = torch.zeros( (batch_size, 0, hidden_size), dtype=layer_hidden_states.dtype, device=layer_hidden_states.device ) llm_hidden_states_text_only += (text_hidden,) llm_hidden_states_audio_only += (audio_hidden,) return { "audio_encoder_output": audio_encoder_output, "audio_features_after_adapter": audio_features_after_adapter, "text_embeddings": text_embeddings_pure, # Return text embeddings with audio parts removed "audio_encoder_hidden_states": audio_encoder_hidden_states, "llm_hidden_states": llm_hidden_states, "llm_hidden_states_text_only": llm_hidden_states_text_only, "llm_hidden_states_audio_only": llm_hidden_states_audio_only, "llm_outputs": llm_outputs, } def generate( self, input_ids, attention_mask=None, mels=None, mel_masks=None, generation_config=None, **generate_kwargs ): """ New implementation of the generate method to support audio inputs. This method will: 1. Handle the initial processing of audio inputs; 2. Call the underlying LLM's generate method with the appropriate embeddings; 3. The incremental decoding will be handled by the LLM's generate method using past_key_values. """ # Process audio inputs and get combined embeddings for the initial step input_embeddings = self.embedding_with_audio_tokens(input_ids, mels, mel_masks) # Call the underlying LLM's generate method with inputs_embeds instead of input_ids # The LLM's generate method will handle the generation loop. # During incremental decoding, it will use past_key_values to avoid re-processing audio inputs. outputs = self.llm.generate( inputs_embeds=input_embeddings, attention_mask=attention_mask, generation_config=generation_config, use_cache=True, **generate_kwargs ) return outputs class UASAudioEncoderOnly(PreTrainedModel): """ UASAudio encoder-only model that contains only the audio encoder and adapter. Input: audio features Output: features processed by encoder and adapter """ config_class = UASAudioEncoderOnlyConfig main_input_name = "input_features" input_modalities = "audio" def __init__(self, config: UASAudioEncoderOnlyConfig): super().__init__(config) if isinstance(config.dtype, str): dtype = getattr(torch, config.dtype) else: dtype = getattr(config, "dtype", torch.bfloat16) self.bf16 = dtype == torch.bfloat16 self.audio_encoder = UASAudioEncoder(config.audio_encoder_config) d_model = config.audio_encoder_config.d_model self.adapter = Adapter( d_model, config.hidden_size, ) if self.bf16: self.audio_encoder = self.audio_encoder.bfloat16() self.adapter = self.adapter.bfloat16() self.post_init() def forward( self, input_features: torch.Tensor, feature_lens: Optional[torch.Tensor] = None, **kwargs ): """ Forward pass through audio encoder and adapter. Args: input_features: Audio features tensor of shape (seq_len, num_mel_bins) or (batch, seq_len, num_mel_bins) feature_lens: Optional tensor of shape (batch_size,) indicating the length of each sequence **kwargs: Additional arguments passed to audio encoder Returns: torch.Tensor: Features processed by encoder and adapter """ # Encode audio features audio_features_encoded = self.audio_encoder( input_features, feature_lens=feature_lens, **kwargs ) # Handle tuple output from encoder (backward compatibility) if isinstance(audio_features_encoded, tuple): audio_features_encoded = audio_features_encoded[0] elif hasattr(audio_features_encoded, "last_hidden_state"): audio_features_encoded = audio_features_encoded.last_hidden_state # Apply adapter (projector) output_features = self.adapter(audio_features_encoded) return output_features