MOSS-Audio-4B-Thinking-MLX-4bit / scripts /moss_audio_encoder_mlx.py
leo-rumilabs's picture
Upload folder using huggingface_hub
e7662d1 verified
"""MLX-native MossAudioEncoder.
Direct port of src/modeling_moss_audio.py:36-155 (MossAudioEncoder).
Adapted from ml-explore/mlx-examples/whisper/mlx_whisper/whisper.py with:
- 3Γ— Conv2d stride-2 stem (instead of Whisper's 2Γ— Conv1d)
- Pre-existing HF Whisper attribute names (q_proj/k_proj/v_proj/out_proj, fc1/fc2,
self_attn_layer_norm/final_layer_norm) so weight remap is near-identity
- DeepStack taps: capture hidden state AFTER layers in deepstack_layer_indexes
- feature_lens-based padding mask
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
# ---- helpers ----------------------------------------------------------
def sinusoids(length: int, channels: int, max_timescale: float = 10000.0) -> mx.array:
"""Whisper-style sinusoidal position embeddings. Matches mlx-examples whisper."""
assert channels % 2 == 0
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
inv_timescales = mx.exp(-log_timescale_increment * mx.arange(channels // 2))
scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :]
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
# ---- attention ------------------------------------------------------
class WhisperAttention(nn.Module):
"""HF-Whisper-style self-attention. Layer-scaling convention (`1/sqrt(head_dim)`
applied to Q, not split between Q and K like mlx-examples does).
Attribute names match HF so weight remap is identity: q_proj/k_proj/v_proj/out_proj.
"""
def __init__(self, d_model: int, n_heads: int):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
assert d_model == self.head_dim * n_heads
# HF Whisper: q/v/out have bias; k does not
self.q_proj = nn.Linear(d_model, d_model, bias=True)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=True)
self.out_proj = nn.Linear(d_model, d_model, bias=True)
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
B, T, D = x.shape
q = self.q_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
k = self.k_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
scale = self.head_dim ** -0.5
attn = (q * scale) @ k.transpose(0, 1, 3, 2) # (B, H, T, T)
if mask is not None:
attn = attn + mask
w = mx.softmax(attn, axis=-1, precise=True)
out = (w @ v).transpose(0, 2, 1, 3).reshape(B, T, D)
return self.out_proj(out)
# ---- encoder layer --------------------------------------------------
class WhisperEncoderBlock(nn.Module):
"""Pre-LN Whisper encoder block. Matches transformers.WhisperEncoderLayer."""
def __init__(self, d_model: int, n_heads: int, ffn_dim: int):
super().__init__()
self.self_attn = WhisperAttention(d_model, n_heads)
self.self_attn_layer_norm = nn.LayerNorm(d_model)
self.fc1 = nn.Linear(d_model, ffn_dim)
self.fc2 = nn.Linear(ffn_dim, d_model)
self.final_layer_norm = nn.LayerNorm(d_model)
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
h = self.self_attn_layer_norm(x)
x = x + self.self_attn(h, mask=mask)
h = self.final_layer_norm(x)
x = x + self.fc2(nn.gelu(self.fc1(h)))
return x
# ---- encoder --------------------------------------------------------
@dataclass
class EncoderConfig:
num_mel_bins: int = 128
downsample_hidden_size: int = 480
d_model: int = 1280
n_heads: int = 20
ffn_dim: int = 5120
n_layers: int = 32
max_source_positions: int = 1500
layer_norm_eps: float = 1e-5
output_dim: int = 1280
deepstack_layer_indexes: List[int] = field(default_factory=lambda: [8, 16, 24])
class MossAudioEncoderMLX(nn.Module):
def __init__(self, cfg: EncoderConfig):
super().__init__()
self.cfg = cfg
# Conv2d stem: 1 β†’ 480 β†’ 480 β†’ 480, each stride-2
# MLX Conv2d expects NHWC, weight shape (OC, kH, kW, IC)
self.conv1 = nn.Conv2d(1, cfg.downsample_hidden_size, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(cfg.downsample_hidden_size, cfg.downsample_hidden_size, kernel_size=3, stride=2, padding=1)
self.conv3 = nn.Conv2d(cfg.downsample_hidden_size, cfg.downsample_hidden_size, kernel_size=3, stride=2, padding=1)
# After 3Γ— stride-2 on mel-axis (128β†’64β†’32β†’16): flat dim = 480*16 = 7680
self.stem_proj = nn.Linear(cfg.downsample_hidden_size * 16, cfg.d_model)
# Precomputed sinusoids, will be sliced
self._positions = sinusoids(cfg.max_source_positions, cfg.d_model)
self.layers = [
WhisperEncoderBlock(cfg.d_model, cfg.n_heads, cfg.ffn_dim)
for _ in range(cfg.n_layers)
]
self.layer_norm = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
# MOSS has optional out_proj; for 4B output_dim==d_model, so it's Identity in PyTorch
# We skip it entirely (equivalent).
assert cfg.output_dim == cfg.d_model, "non-identity out_proj not yet implemented"
self._deepstack_set = set(cfg.deepstack_layer_indexes)
def _compute_downsampled_length(self, L: int) -> int:
"""3Γ— stride-2 conv output length: ceil((((L-1)//2+1)-1)//2+1 ... )"""
def step(n): return (n - 1) // 2 + 1
return step(step(step(L)))
def __call__(
self,
input_features: mx.array, # (B, n_mels, T) bf16 mel spectrogram
feature_lens: Optional[mx.array] = None,
return_deepstack: bool = True,
) -> Tuple[mx.array, Optional[List[mx.array]]]:
if input_features.ndim == 2:
input_features = input_features[None]
B, n_mels, T = input_features.shape
if feature_lens is None:
feature_lens = mx.full((B,), T, dtype=mx.int32)
# (B, n_mels, T) β†’ (B, n_mels, T, 1) [NHWC with channels-last = 1 input channel]
# But MLX Conv2d expects input shape (B, H, W, C_in). We map:
# H = n_mels (128), W = T (frames), C_in = 1
x = input_features[..., None] # (B, n_mels, T, 1)
x = nn.gelu(self.conv1(x)) # (B, 64, T/2, 480)
x = nn.gelu(self.conv2(x)) # (B, 32, T/4, 480)
x = nn.gelu(self.conv3(x)) # (B, 16, T/8, 480)
# PyTorch reference: (B, C, F, T) β†’ permute(0,3,1,2) β†’ (B, T, C, F) β†’ flatten β†’ (B, T, C*F)
# MLX is (B, F, T, C) post-conv. Need transpose to (B, T, C, F) to match PT's flatten order.
B_, H_, W_, C_ = x.shape # H_=F, W_=T, C_=C
x = x.transpose(0, 2, 3, 1).reshape(B_, W_, C_ * H_) # (B, T, C*F)
x = self.stem_proj(x) # (B, T', d_model)
# Trim to actual downsampled length (in case input was padded)
max_len = self._compute_downsampled_length(int(feature_lens.max().item()))
if x.shape[1] > max_len:
x = x[:, :max_len, :]
# Add sinusoidal positions
seq_len = x.shape[1]
pos = self._positions[:seq_len].astype(x.dtype)
x = x + pos
# Build attention mask: (B, 1, 1, seq_len) additive
# padding_mask[b, t] = True if t >= downsampled_len[b] (this is where we mask out)
dsl = mx.stack([
mx.array(self._compute_downsampled_length(int(feature_lens[b].item())), dtype=mx.int32)
for b in range(B)
]) # (B,)
ar = mx.arange(seq_len, dtype=mx.int32)
padding = ar[None, :] >= dsl[:, None] # (B, seq_len) bool
neg_inf = mx.array(-1e9, dtype=x.dtype)
mask = mx.where(padding, neg_inf, mx.array(0.0, dtype=x.dtype))
mask = mask[:, None, None, :] # (B, 1, 1, seq_len)
deepstack: List[mx.array] = []
for layer_idx, layer in enumerate(self.layers):
x = layer(x, mask=mask)
if return_deepstack and layer_idx in self._deepstack_set:
# Apply the final layer_norm snapshot at this point, per MOSS's output_deepstack_hidden_states
# Actually, MOSS captures x BEFORE the final layer_norm β€” matches what PyTorch does.
deepstack.append(x)
x = self.layer_norm(x)
return x, (deepstack if return_deepstack else None)
# ---- GatedMLP (for audio_adapter + deepstack_audio_merger_list) ----
class GatedMLP(nn.Module):
"""MOSS's GatedMLP: down(silu(gate(x)) * up(x)). SwiGLU convention.
Matches MOSS/src/modeling_moss_audio.py:155-164.
All linears are bias=False.
"""
def __init__(self, input_size: int, hidden_size: int, output_size: int):
super().__init__()
self.gate_proj = nn.Linear(input_size, hidden_size, bias=False)
self.up_proj = nn.Linear(input_size, hidden_size, bias=False)
self.down_proj = nn.Linear(hidden_size, output_size, bias=False)
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
__all__ = ["sinusoids", "WhisperAttention", "WhisperEncoderBlock",
"EncoderConfig", "MossAudioEncoderMLX", "GatedMLP"]