ChristophSchuhmann's picture
Upload folder using huggingface_hub
ce6d303 verified
from typing import Optional, List, Union, Tuple, Any
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
from transformers.utils.auto_docstring import auto_docstring
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import GenerationMixin
from transformers.models.qwen3.modeling_qwen3 import Qwen3Model, Qwen3DecoderLayer
from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
from src.configuration_moss_audio import MossAudioEncoderConfig, MossAudioConfig
class SinusoidsPositionEmbedding(nn.Module):
def __init__(self, num_positions: int, embedding_dim: int):
super().__init__()
max_timescale = 10000.0
log_timescale_increment = math.log(max_timescale) / (embedding_dim // 2 - 1)
inv_timescales = torch.exp(
-log_timescale_increment * torch.arange(embedding_dim // 2).float()
)
self.register_buffer("inv_timescales", inv_timescales, persistent=False)
def forward(self, seq_len: int, device: torch.device):
scaled_time = torch.arange(
seq_len, device=device, dtype=self.inv_timescales.dtype
).unsqueeze(1) * self.inv_timescales.unsqueeze(0)
sin_emb = torch.sin(scaled_time)
cos_emb = torch.cos(scaled_time)
pos_emb = torch.cat([sin_emb, cos_emb], dim=1)
return pos_emb.unsqueeze(0)
class MossAudioEncoder(nn.Module):
"""Audio encoder with conv-stem downsampling and Whisper transformer layers."""
def __init__(self, config: MossAudioEncoderConfig):
super().__init__()
self.config = config
self.gelu = nn.GELU()
self.conv1 = nn.Conv2d(
1,
config.downsample_hidden_size,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
)
self.conv2 = nn.Conv2d(
config.downsample_hidden_size,
config.downsample_hidden_size,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
)
self.conv3 = nn.Conv2d(
config.downsample_hidden_size,
config.downsample_hidden_size,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
)
# 128 mel bins / 8 = 16 after 3 convs with stride=2
self.stem_proj = nn.Linear(config.downsample_hidden_size * 16, config.d_model)
self.embed_positions = SinusoidsPositionEmbedding(
config.max_source_positions, config.d_model
)
self.layers = nn.ModuleList(
[WhisperEncoderLayer(config) for _ in range(config.encoder_layers)]
)
self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
self.out_proj = (
nn.Linear(config.d_model, config.output_dim, bias=False)
if config.output_dim != config.d_model
else nn.Identity()
)
self.deepstack_encoder_layer_indexes = list(
config.deepstack_encoder_layer_indexes or []
)
self._deepstack_capture_map = {
layer_idx: capture_idx
for capture_idx, layer_idx in enumerate(self.deepstack_encoder_layer_indexes)
}
self.n_window = int(config.n_window)
self.chunk_frames = int(self.n_window * 2)
self.conv_chunksize = int(config.conv_chunksize)
@property
def dtype(self) -> torch.dtype:
return self.conv1.weight.dtype
@staticmethod
def _compute_downsampled_length(lengths: torch.Tensor) -> torch.Tensor:
def conv_out_len(L):
return (L - 1) // 2 + 1
return conv_out_len(conv_out_len(conv_out_len(lengths)))
def _encode_chunk_batch(
self,
input_features: torch.Tensor,
seq_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Encode a batch of (already padded) chunks through the conv stem and
transformer layers. Returns (last_hidden, ordered_deepstack_hidden_states).
"""
if input_features.dim() == 2:
input_features = input_features.unsqueeze(0)
downsampled_lengths = self._compute_downsampled_length(seq_lengths)
# [B, n_mels, T] -> [B, 1, n_mels, T]
x = input_features.unsqueeze(1)
x = self.gelu(self.conv1(x))
x = self.gelu(self.conv2(x))
x = self.gelu(self.conv3(x))
# [B, C, F, T] -> [B, T, C*F]
x = x.permute(0, 3, 1, 2).contiguous().flatten(2)
x = self.stem_proj(x)
max_len = int(downsampled_lengths.max().item())
if x.size(1) > max_len:
x = x[:, :max_len, :]
positions = self.embed_positions(x.shape[1], x.device)
x = x + positions.to(x.dtype)
padding_mask = (
torch.arange(x.size(1), device=x.device)[None, :] >= downsampled_lengths[:, None]
)
attention_mask = (1.0 - (~padding_mask).to(dtype=x.dtype)) * torch.finfo(x.dtype).min
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
deepstack_hidden_states: List[Optional[torch.Tensor]] = [None] * len(
self.deepstack_encoder_layer_indexes
)
for layer_idx, layer in enumerate(self.layers):
x = layer(
x,
attention_mask,
layer_head_mask=None,
output_attentions=False,
)[0]
capture_idx = self._deepstack_capture_map.get(layer_idx)
if capture_idx is not None:
deepstack_hidden_states[capture_idx] = x
x = self.layer_norm(x)
x = self.out_proj(x)
ordered_deepstack_hidden_states = [
h for h in deepstack_hidden_states if h is not None
]
if not isinstance(self.out_proj, nn.Identity):
ordered_deepstack_hidden_states = [
self.out_proj(h) for h in ordered_deepstack_hidden_states
]
return x, ordered_deepstack_hidden_states
def forward(
self,
input_features: torch.Tensor,
feature_lens: Optional[torch.Tensor] = None,
output_deepstack_hidden_states: bool = True,
) -> BaseModelOutputWithPast:
if input_features.dim() == 3:
if feature_lens is None:
feature_lens = torch.full(
(input_features.size(0),),
input_features.size(-1),
dtype=torch.long,
device=input_features.device,
)
else:
feature_lens = feature_lens.to(
device=input_features.device, dtype=torch.long
)
valid_chunks = [
input_features[i, :, : int(feature_lens[i].item())]
for i in range(int(input_features.shape[0]))
]
input_features = torch.cat(valid_chunks, dim=1)
elif input_features.dim() != 2:
raise ValueError(
f"Expected [n_mels, T] or [B, n_mels, T], got {tuple(input_features.shape)}."
)
if feature_lens is None:
feature_lens = torch.tensor(
[int(input_features.shape[1])],
device=input_features.device,
dtype=torch.long,
)
else:
feature_lens = feature_lens.to(
device=input_features.device, dtype=torch.long
)
chunk_frames = int(self.chunk_frames)
chunk_num = torch.ceil(
feature_lens.to(torch.float32) / float(chunk_frames)
).long()
chunk_lengths = torch.full(
(int(chunk_num.sum().item()),),
chunk_frames,
dtype=torch.long,
device=feature_lens.device,
)
tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
chunk_lengths[tail_chunk_index] = feature_lens % chunk_frames
chunk_lengths[chunk_lengths == 0] = chunk_frames
chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
padded_feature = nn.utils.rnn.pad_sequence(
chunk_list, batch_first=True
).transpose(1, 2)
feature_lens_after_cnn = self._compute_downsampled_length(chunk_lengths)
t_down_max = (
int(feature_lens_after_cnn.max().item())
if feature_lens_after_cnn.numel() > 0
else 0
)
padded_mask_after_cnn = nn.utils.rnn.pad_sequence(
[
torch.ones(int(L.item()), dtype=torch.bool, device=padded_feature.device)
for L in feature_lens_after_cnn
],
batch_first=True,
)
if padded_mask_after_cnn.shape[1] < t_down_max:
padded_mask_after_cnn = F.pad(
padded_mask_after_cnn,
(0, t_down_max - padded_mask_after_cnn.shape[1]),
value=False,
)
num_deepstack = len(self.deepstack_encoder_layer_indexes)
padded_embeds: List[torch.Tensor] = []
deepstack_padded_embeds: List[List[torch.Tensor]] = [
[] for _ in range(num_deepstack)
]
for feat_chunk, len_chunk in zip(
padded_feature.split(self.conv_chunksize, dim=0),
chunk_lengths.split(self.conv_chunksize, dim=0),
):
out, deepstack_outs = self._encode_chunk_batch(feat_chunk, len_chunk)
if out.shape[1] < t_down_max:
out = F.pad(out, (0, 0, 0, t_down_max - out.shape[1]))
padded_embeds.append(out)
if output_deepstack_hidden_states and num_deepstack > 0:
if len(deepstack_outs) != num_deepstack:
raise RuntimeError(
"Deepstack output count does not match configured layer indexes."
)
for capture_idx, ds in enumerate(deepstack_outs):
if ds.shape[1] < t_down_max:
ds = F.pad(ds, (0, 0, 0, t_down_max - ds.shape[1]))
deepstack_padded_embeds[capture_idx].append(ds)
if padded_embeds:
padded_embed = torch.cat(padded_embeds, dim=0)
else:
padded_embed = torch.empty(
(0, t_down_max, self.config.output_dim),
device=padded_feature.device,
)
valid_tokens = padded_embed[padded_mask_after_cnn] # [N_valid, D]
last_hidden_state = valid_tokens.unsqueeze(0) # [1, N_valid, D]
deepstack_states: Optional[Tuple[torch.Tensor, ...]] = None
if output_deepstack_hidden_states and num_deepstack > 0:
collected: List[torch.Tensor] = []
for chunks_list in deepstack_padded_embeds:
if chunks_list:
ds = torch.cat(chunks_list, dim=0)
collected.append(ds[padded_mask_after_cnn].unsqueeze(0))
else:
collected.append(
torch.empty(
(1, 0, self.config.output_dim),
device=padded_feature.device,
dtype=padded_embed.dtype,
)
)
deepstack_states = tuple(collected)
return BaseModelOutputWithPast(
last_hidden_state=last_hidden_state,
hidden_states=deepstack_states,
)
class GatedMLP(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.gate_proj = nn.Linear(input_size, hidden_size, bias=False)
self.up_proj = nn.Linear(input_size, hidden_size, bias=False)
self.down_proj = nn.Linear(hidden_size, output_size, bias=False)
self.act_fn = nn.SiLU()
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
@auto_docstring
class MossAudioPreTrainedModel(PreTrainedModel):
config_class = MossAudioConfig
config: MossAudioConfig
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen3DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {"hidden_states": Qwen3DecoderLayer}
class MossAudioModel(MossAudioPreTrainedModel, GenerationMixin):
config_class = MossAudioConfig
_tied_weights_keys: List[str] = []
def __init__(self, config: MossAudioConfig):
super().__init__(config)
self.audio_encoder = MossAudioEncoder(config.audio_config)
self.language_model = Qwen3Model(config.language_config)
self.audio_adapter = GatedMLP(
input_size=config.audio_config.output_dim,
hidden_size=config.adapter_hidden_size,
output_size=config.language_config.hidden_size,
)
deepstack_k = len(getattr(config.audio_config, "deepstack_encoder_layer_indexes", []) or [])
if config.deepstack_num_inject_layers is not None:
deepstack_k = min(deepstack_k, int(config.deepstack_num_inject_layers))
self.deepstack_audio_merger_list = nn.ModuleList(
[
GatedMLP(
input_size=config.audio_config.output_dim,
hidden_size=config.adapter_hidden_size,
output_size=config.language_config.hidden_size,
)
for _ in range(deepstack_k)
]
)
self.vocab_size = config.language_config.vocab_size
self.lm_head = nn.Linear(config.language_config.hidden_size, self.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_audio_features(self, input_features, feature_lens):
audio_outputs = self.audio_encoder(
input_features=input_features,
feature_lens=feature_lens,
output_deepstack_hidden_states=True,
)
deepstack = list(audio_outputs.hidden_states) if audio_outputs.hidden_states is not None else None
return audio_outputs.last_hidden_state, deepstack
def _apply_deepstack_to_hidden_states(
self,
hidden_states: torch.Tensor,
audio_input_mask: torch.Tensor,
deepstack_embeds: torch.Tensor,
) -> torch.Tensor:
audio_input_mask = audio_input_mask.to(hidden_states.device)
deepstack_embeds = deepstack_embeds.to(hidden_states.device, hidden_states.dtype)
flat = deepstack_embeds.reshape(-1, deepstack_embeds.shape[-1])
hs = hidden_states.clone()
hs[audio_input_mask] = hs[audio_input_mask] + flat
return hs
def _register_llm_deepstack_hooks(
self,
audio_input_mask: torch.Tensor,
deepstack_audio_embeds: List[torch.Tensor],
):
if deepstack_audio_embeds is None or len(deepstack_audio_embeds) == 0:
return []
layers = getattr(self.language_model, "layers", None)
if layers is None:
raise RuntimeError("Qwen3Model does not expose `.layers`; cannot register DeepStack hooks.")
num_inject = len(deepstack_audio_embeds)
handles = []
for layer_idx, layer in enumerate(layers):
if layer_idx >= num_inject:
break
def _make_llm_hook(k: int):
def _hook(_module, _inputs, _output):
if isinstance(_output, (tuple, list)):
hs = _output[0]
new_hs = self._apply_deepstack_to_hidden_states(
hs, audio_input_mask, deepstack_audio_embeds[k]
)
return (new_hs,) + tuple(_output[1:])
else:
return self._apply_deepstack_to_hidden_states(
_output, audio_input_mask, deepstack_audio_embeds[k]
)
return _hook
handles.append(layer.register_forward_hook(_make_llm_hook(layer_idx)))
return handles
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
audio_data: Optional[torch.FloatTensor] = None,
audio_data_seqlens: Optional[torch.Tensor] = None,
audio_input_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Any,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
hook_handles = []
if audio_data is not None:
if audio_input_mask is None:
raise ValueError("audio_input_mask is required when audio_data is provided.")
audio_embeds, deepstack = self.get_audio_features(audio_data, audio_data_seqlens)
audio_embeds = self.audio_adapter(audio_embeds)
audio_token_count = int(audio_input_mask.to(torch.int32).sum().item())
if audio_token_count != int(audio_embeds.shape[1]):
raise ValueError(
f"Audio token count mismatch: audio_input_mask has {audio_token_count} audio tokens, "
f"but audio_embeds has length {int(audio_embeds.shape[1])}."
)
mask_expanded = audio_input_mask.unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds = inputs_embeds.clone()
inputs_embeds.masked_scatter_(mask_expanded, audio_embeds)
if deepstack is not None and len(self.deepstack_audio_merger_list) > 0:
deepstack_audio_embeds = []
for i, x in enumerate(deepstack[: len(self.deepstack_audio_merger_list)]):
ds = self.deepstack_audio_merger_list[i](x)
if int(ds.shape[1]) != audio_token_count:
raise ValueError(
f"DeepStack audio seq_len mismatch at index {i}: "
f"expected {audio_token_count}, got {int(ds.shape[1])}."
)
deepstack_audio_embeds.append(ds)
try:
hook_handles = self._register_llm_deepstack_hooks(audio_input_mask, deepstack_audio_embeds)
except Exception:
for h in hook_handles:
h.remove()
raise
try:
outputs = self.language_model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
finally:
for h in hook_handles:
h.remove()
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.ignore_index)
shift_logits = shift_logits.view(-1, self.config.language_config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
**kwargs,
):
position_ids = kwargs.get("position_ids", None)
if cache_position is not None and cache_position[0] > 0:
input_ids = input_ids[:, -1:]
if position_ids is not None:
position_ids = position_ids[:, -1:]
audio_data = None
audio_input_mask = None
audio_data_seqlens = None
else:
audio_data = kwargs.get("audio_data", None)
audio_input_mask = kwargs.get("audio_input_mask", None)
audio_data_seqlens = kwargs.get("audio_data_seqlens", None)
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"position_ids": position_ids,
"audio_data": audio_data,
"audio_input_mask": audio_input_mask,
"audio_data_seqlens": audio_data_seqlens,
}
)
return model_inputs
__all__ = [
"MossAudioEncoderConfig",
"MossAudioConfig",
"MossAudioModel",
]