"""MLX-native MossAudioEncoder. Direct port of src/modeling_moss_audio.py:36-155 (MossAudioEncoder). Adapted from ml-explore/mlx-examples/whisper/mlx_whisper/whisper.py with: - 3× Conv2d stride-2 stem (instead of Whisper's 2× Conv1d) - Pre-existing HF Whisper attribute names (q_proj/k_proj/v_proj/out_proj, fc1/fc2, self_attn_layer_norm/final_layer_norm) so weight remap is near-identity - DeepStack taps: capture hidden state AFTER layers in deepstack_layer_indexes - feature_lens-based padding mask """ from __future__ import annotations import math from dataclasses import dataclass, field from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn # ---- helpers ---------------------------------------------------------- def sinusoids(length: int, channels: int, max_timescale: float = 10000.0) -> mx.array: """Whisper-style sinusoidal position embeddings. Matches mlx-examples whisper.""" assert channels % 2 == 0 log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) inv_timescales = mx.exp(-log_timescale_increment * mx.arange(channels // 2)) scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :] return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1) # ---- attention ------------------------------------------------------ class WhisperAttention(nn.Module): """HF-Whisper-style self-attention. Layer-scaling convention (`1/sqrt(head_dim)` applied to Q, not split between Q and K like mlx-examples does). Attribute names match HF so weight remap is identity: q_proj/k_proj/v_proj/out_proj. """ def __init__(self, d_model: int, n_heads: int): super().__init__() self.n_heads = n_heads self.head_dim = d_model // n_heads assert d_model == self.head_dim * n_heads # HF Whisper: q/v/out have bias; k does not self.q_proj = nn.Linear(d_model, d_model, bias=True) self.k_proj = nn.Linear(d_model, d_model, bias=False) self.v_proj = nn.Linear(d_model, d_model, bias=True) self.out_proj = nn.Linear(d_model, d_model, bias=True) def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: B, T, D = x.shape q = self.q_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3) k = self.k_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3) v = self.v_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3) scale = self.head_dim ** -0.5 attn = (q * scale) @ k.transpose(0, 1, 3, 2) # (B, H, T, T) if mask is not None: attn = attn + mask w = mx.softmax(attn, axis=-1, precise=True) out = (w @ v).transpose(0, 2, 1, 3).reshape(B, T, D) return self.out_proj(out) # ---- encoder layer -------------------------------------------------- class WhisperEncoderBlock(nn.Module): """Pre-LN Whisper encoder block. Matches transformers.WhisperEncoderLayer.""" def __init__(self, d_model: int, n_heads: int, ffn_dim: int): super().__init__() self.self_attn = WhisperAttention(d_model, n_heads) self.self_attn_layer_norm = nn.LayerNorm(d_model) self.fc1 = nn.Linear(d_model, ffn_dim) self.fc2 = nn.Linear(ffn_dim, d_model) self.final_layer_norm = nn.LayerNorm(d_model) def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: h = self.self_attn_layer_norm(x) x = x + self.self_attn(h, mask=mask) h = self.final_layer_norm(x) x = x + self.fc2(nn.gelu(self.fc1(h))) return x # ---- encoder -------------------------------------------------------- @dataclass class EncoderConfig: num_mel_bins: int = 128 downsample_hidden_size: int = 480 d_model: int = 1280 n_heads: int = 20 ffn_dim: int = 5120 n_layers: int = 32 max_source_positions: int = 1500 layer_norm_eps: float = 1e-5 output_dim: int = 1280 deepstack_layer_indexes: List[int] = field(default_factory=lambda: [8, 16, 24]) class MossAudioEncoderMLX(nn.Module): def __init__(self, cfg: EncoderConfig): super().__init__() self.cfg = cfg # Conv2d stem: 1 → 480 → 480 → 480, each stride-2 # MLX Conv2d expects NHWC, weight shape (OC, kH, kW, IC) self.conv1 = nn.Conv2d(1, cfg.downsample_hidden_size, kernel_size=3, stride=2, padding=1) self.conv2 = nn.Conv2d(cfg.downsample_hidden_size, cfg.downsample_hidden_size, kernel_size=3, stride=2, padding=1) self.conv3 = nn.Conv2d(cfg.downsample_hidden_size, cfg.downsample_hidden_size, kernel_size=3, stride=2, padding=1) # After 3× stride-2 on mel-axis (128→64→32→16): flat dim = 480*16 = 7680 self.stem_proj = nn.Linear(cfg.downsample_hidden_size * 16, cfg.d_model) # Precomputed sinusoids, will be sliced self._positions = sinusoids(cfg.max_source_positions, cfg.d_model) self.layers = [ WhisperEncoderBlock(cfg.d_model, cfg.n_heads, cfg.ffn_dim) for _ in range(cfg.n_layers) ] self.layer_norm = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) # MOSS has optional out_proj; for 4B output_dim==d_model, so it's Identity in PyTorch # We skip it entirely (equivalent). assert cfg.output_dim == cfg.d_model, "non-identity out_proj not yet implemented" self._deepstack_set = set(cfg.deepstack_layer_indexes) def _compute_downsampled_length(self, L: int) -> int: """3× stride-2 conv output length: ceil((((L-1)//2+1)-1)//2+1 ... )""" def step(n): return (n - 1) // 2 + 1 return step(step(step(L))) def __call__( self, input_features: mx.array, # (B, n_mels, T) bf16 mel spectrogram feature_lens: Optional[mx.array] = None, return_deepstack: bool = True, ) -> Tuple[mx.array, Optional[List[mx.array]]]: if input_features.ndim == 2: input_features = input_features[None] B, n_mels, T = input_features.shape if feature_lens is None: feature_lens = mx.full((B,), T, dtype=mx.int32) # (B, n_mels, T) → (B, n_mels, T, 1) [NHWC with channels-last = 1 input channel] # But MLX Conv2d expects input shape (B, H, W, C_in). We map: # H = n_mels (128), W = T (frames), C_in = 1 x = input_features[..., None] # (B, n_mels, T, 1) x = nn.gelu(self.conv1(x)) # (B, 64, T/2, 480) x = nn.gelu(self.conv2(x)) # (B, 32, T/4, 480) x = nn.gelu(self.conv3(x)) # (B, 16, T/8, 480) # PyTorch reference: (B, C, F, T) → permute(0,3,1,2) → (B, T, C, F) → flatten → (B, T, C*F) # MLX is (B, F, T, C) post-conv. Need transpose to (B, T, C, F) to match PT's flatten order. B_, H_, W_, C_ = x.shape # H_=F, W_=T, C_=C x = x.transpose(0, 2, 3, 1).reshape(B_, W_, C_ * H_) # (B, T, C*F) x = self.stem_proj(x) # (B, T', d_model) # Trim to actual downsampled length (in case input was padded) max_len = self._compute_downsampled_length(int(feature_lens.max().item())) if x.shape[1] > max_len: x = x[:, :max_len, :] # Add sinusoidal positions seq_len = x.shape[1] pos = self._positions[:seq_len].astype(x.dtype) x = x + pos # Build attention mask: (B, 1, 1, seq_len) additive # padding_mask[b, t] = True if t >= downsampled_len[b] (this is where we mask out) dsl = mx.stack([ mx.array(self._compute_downsampled_length(int(feature_lens[b].item())), dtype=mx.int32) for b in range(B) ]) # (B,) ar = mx.arange(seq_len, dtype=mx.int32) padding = ar[None, :] >= dsl[:, None] # (B, seq_len) bool neg_inf = mx.array(-1e9, dtype=x.dtype) mask = mx.where(padding, neg_inf, mx.array(0.0, dtype=x.dtype)) mask = mask[:, None, None, :] # (B, 1, 1, seq_len) deepstack: List[mx.array] = [] for layer_idx, layer in enumerate(self.layers): x = layer(x, mask=mask) if return_deepstack and layer_idx in self._deepstack_set: # Apply the final layer_norm snapshot at this point, per MOSS's output_deepstack_hidden_states # Actually, MOSS captures x BEFORE the final layer_norm — matches what PyTorch does. deepstack.append(x) x = self.layer_norm(x) return x, (deepstack if return_deepstack else None) # ---- GatedMLP (for audio_adapter + deepstack_audio_merger_list) ---- class GatedMLP(nn.Module): """MOSS's GatedMLP: down(silu(gate(x)) * up(x)). SwiGLU convention. Matches MOSS/src/modeling_moss_audio.py:155-164. All linears are bias=False. """ def __init__(self, input_size: int, hidden_size: int, output_size: int): super().__init__() self.gate_proj = nn.Linear(input_size, hidden_size, bias=False) self.up_proj = nn.Linear(input_size, hidden_size, bias=False) self.down_proj = nn.Linear(hidden_size, output_size, bias=False) def __call__(self, x: mx.array) -> mx.array: return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) __all__ = ["sinusoids", "WhisperAttention", "WhisperEncoderBlock", "EncoderConfig", "MossAudioEncoderMLX", "GatedMLP"]