Instructions to use RumiLabs/MOSS-Audio-4B-Thinking-MLX-4bit with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use RumiLabs/MOSS-Audio-4B-Thinking-MLX-4bit with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir MOSS-Audio-4B-Thinking-MLX-4bit RumiLabs/MOSS-Audio-4B-Thinking-MLX-4bit
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
| """MLX-native MossAudioEncoder. | |
| Direct port of src/modeling_moss_audio.py:36-155 (MossAudioEncoder). | |
| Adapted from ml-explore/mlx-examples/whisper/mlx_whisper/whisper.py with: | |
| - 3Γ Conv2d stride-2 stem (instead of Whisper's 2Γ Conv1d) | |
| - Pre-existing HF Whisper attribute names (q_proj/k_proj/v_proj/out_proj, fc1/fc2, | |
| self_attn_layer_norm/final_layer_norm) so weight remap is near-identity | |
| - DeepStack taps: capture hidden state AFTER layers in deepstack_layer_indexes | |
| - feature_lens-based padding mask | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from dataclasses import dataclass, field | |
| from typing import List, Optional, Tuple | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| # ---- helpers ---------------------------------------------------------- | |
| def sinusoids(length: int, channels: int, max_timescale: float = 10000.0) -> mx.array: | |
| """Whisper-style sinusoidal position embeddings. Matches mlx-examples whisper.""" | |
| assert channels % 2 == 0 | |
| log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) | |
| inv_timescales = mx.exp(-log_timescale_increment * mx.arange(channels // 2)) | |
| scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :] | |
| return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1) | |
| # ---- attention ------------------------------------------------------ | |
| class WhisperAttention(nn.Module): | |
| """HF-Whisper-style self-attention. Layer-scaling convention (`1/sqrt(head_dim)` | |
| applied to Q, not split between Q and K like mlx-examples does). | |
| Attribute names match HF so weight remap is identity: q_proj/k_proj/v_proj/out_proj. | |
| """ | |
| def __init__(self, d_model: int, n_heads: int): | |
| super().__init__() | |
| self.n_heads = n_heads | |
| self.head_dim = d_model // n_heads | |
| assert d_model == self.head_dim * n_heads | |
| # HF Whisper: q/v/out have bias; k does not | |
| self.q_proj = nn.Linear(d_model, d_model, bias=True) | |
| self.k_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.v_proj = nn.Linear(d_model, d_model, bias=True) | |
| self.out_proj = nn.Linear(d_model, d_model, bias=True) | |
| def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: | |
| B, T, D = x.shape | |
| q = self.q_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3) | |
| k = self.k_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3) | |
| v = self.v_proj(x).reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3) | |
| scale = self.head_dim ** -0.5 | |
| attn = (q * scale) @ k.transpose(0, 1, 3, 2) # (B, H, T, T) | |
| if mask is not None: | |
| attn = attn + mask | |
| w = mx.softmax(attn, axis=-1, precise=True) | |
| out = (w @ v).transpose(0, 2, 1, 3).reshape(B, T, D) | |
| return self.out_proj(out) | |
| # ---- encoder layer -------------------------------------------------- | |
| class WhisperEncoderBlock(nn.Module): | |
| """Pre-LN Whisper encoder block. Matches transformers.WhisperEncoderLayer.""" | |
| def __init__(self, d_model: int, n_heads: int, ffn_dim: int): | |
| super().__init__() | |
| self.self_attn = WhisperAttention(d_model, n_heads) | |
| self.self_attn_layer_norm = nn.LayerNorm(d_model) | |
| self.fc1 = nn.Linear(d_model, ffn_dim) | |
| self.fc2 = nn.Linear(ffn_dim, d_model) | |
| self.final_layer_norm = nn.LayerNorm(d_model) | |
| def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: | |
| h = self.self_attn_layer_norm(x) | |
| x = x + self.self_attn(h, mask=mask) | |
| h = self.final_layer_norm(x) | |
| x = x + self.fc2(nn.gelu(self.fc1(h))) | |
| return x | |
| # ---- encoder -------------------------------------------------------- | |
| class EncoderConfig: | |
| num_mel_bins: int = 128 | |
| downsample_hidden_size: int = 480 | |
| d_model: int = 1280 | |
| n_heads: int = 20 | |
| ffn_dim: int = 5120 | |
| n_layers: int = 32 | |
| max_source_positions: int = 1500 | |
| layer_norm_eps: float = 1e-5 | |
| output_dim: int = 1280 | |
| deepstack_layer_indexes: List[int] = field(default_factory=lambda: [8, 16, 24]) | |
| class MossAudioEncoderMLX(nn.Module): | |
| def __init__(self, cfg: EncoderConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| # Conv2d stem: 1 β 480 β 480 β 480, each stride-2 | |
| # MLX Conv2d expects NHWC, weight shape (OC, kH, kW, IC) | |
| self.conv1 = nn.Conv2d(1, cfg.downsample_hidden_size, kernel_size=3, stride=2, padding=1) | |
| self.conv2 = nn.Conv2d(cfg.downsample_hidden_size, cfg.downsample_hidden_size, kernel_size=3, stride=2, padding=1) | |
| self.conv3 = nn.Conv2d(cfg.downsample_hidden_size, cfg.downsample_hidden_size, kernel_size=3, stride=2, padding=1) | |
| # After 3Γ stride-2 on mel-axis (128β64β32β16): flat dim = 480*16 = 7680 | |
| self.stem_proj = nn.Linear(cfg.downsample_hidden_size * 16, cfg.d_model) | |
| # Precomputed sinusoids, will be sliced | |
| self._positions = sinusoids(cfg.max_source_positions, cfg.d_model) | |
| self.layers = [ | |
| WhisperEncoderBlock(cfg.d_model, cfg.n_heads, cfg.ffn_dim) | |
| for _ in range(cfg.n_layers) | |
| ] | |
| self.layer_norm = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps) | |
| # MOSS has optional out_proj; for 4B output_dim==d_model, so it's Identity in PyTorch | |
| # We skip it entirely (equivalent). | |
| assert cfg.output_dim == cfg.d_model, "non-identity out_proj not yet implemented" | |
| self._deepstack_set = set(cfg.deepstack_layer_indexes) | |
| def _compute_downsampled_length(self, L: int) -> int: | |
| """3Γ stride-2 conv output length: ceil((((L-1)//2+1)-1)//2+1 ... )""" | |
| def step(n): return (n - 1) // 2 + 1 | |
| return step(step(step(L))) | |
| def __call__( | |
| self, | |
| input_features: mx.array, # (B, n_mels, T) bf16 mel spectrogram | |
| feature_lens: Optional[mx.array] = None, | |
| return_deepstack: bool = True, | |
| ) -> Tuple[mx.array, Optional[List[mx.array]]]: | |
| if input_features.ndim == 2: | |
| input_features = input_features[None] | |
| B, n_mels, T = input_features.shape | |
| if feature_lens is None: | |
| feature_lens = mx.full((B,), T, dtype=mx.int32) | |
| # (B, n_mels, T) β (B, n_mels, T, 1) [NHWC with channels-last = 1 input channel] | |
| # But MLX Conv2d expects input shape (B, H, W, C_in). We map: | |
| # H = n_mels (128), W = T (frames), C_in = 1 | |
| x = input_features[..., None] # (B, n_mels, T, 1) | |
| x = nn.gelu(self.conv1(x)) # (B, 64, T/2, 480) | |
| x = nn.gelu(self.conv2(x)) # (B, 32, T/4, 480) | |
| x = nn.gelu(self.conv3(x)) # (B, 16, T/8, 480) | |
| # PyTorch reference: (B, C, F, T) β permute(0,3,1,2) β (B, T, C, F) β flatten β (B, T, C*F) | |
| # MLX is (B, F, T, C) post-conv. Need transpose to (B, T, C, F) to match PT's flatten order. | |
| B_, H_, W_, C_ = x.shape # H_=F, W_=T, C_=C | |
| x = x.transpose(0, 2, 3, 1).reshape(B_, W_, C_ * H_) # (B, T, C*F) | |
| x = self.stem_proj(x) # (B, T', d_model) | |
| # Trim to actual downsampled length (in case input was padded) | |
| max_len = self._compute_downsampled_length(int(feature_lens.max().item())) | |
| if x.shape[1] > max_len: | |
| x = x[:, :max_len, :] | |
| # Add sinusoidal positions | |
| seq_len = x.shape[1] | |
| pos = self._positions[:seq_len].astype(x.dtype) | |
| x = x + pos | |
| # Build attention mask: (B, 1, 1, seq_len) additive | |
| # padding_mask[b, t] = True if t >= downsampled_len[b] (this is where we mask out) | |
| dsl = mx.stack([ | |
| mx.array(self._compute_downsampled_length(int(feature_lens[b].item())), dtype=mx.int32) | |
| for b in range(B) | |
| ]) # (B,) | |
| ar = mx.arange(seq_len, dtype=mx.int32) | |
| padding = ar[None, :] >= dsl[:, None] # (B, seq_len) bool | |
| neg_inf = mx.array(-1e9, dtype=x.dtype) | |
| mask = mx.where(padding, neg_inf, mx.array(0.0, dtype=x.dtype)) | |
| mask = mask[:, None, None, :] # (B, 1, 1, seq_len) | |
| deepstack: List[mx.array] = [] | |
| for layer_idx, layer in enumerate(self.layers): | |
| x = layer(x, mask=mask) | |
| if return_deepstack and layer_idx in self._deepstack_set: | |
| # Apply the final layer_norm snapshot at this point, per MOSS's output_deepstack_hidden_states | |
| # Actually, MOSS captures x BEFORE the final layer_norm β matches what PyTorch does. | |
| deepstack.append(x) | |
| x = self.layer_norm(x) | |
| return x, (deepstack if return_deepstack else None) | |
| # ---- GatedMLP (for audio_adapter + deepstack_audio_merger_list) ---- | |
| class GatedMLP(nn.Module): | |
| """MOSS's GatedMLP: down(silu(gate(x)) * up(x)). SwiGLU convention. | |
| Matches MOSS/src/modeling_moss_audio.py:155-164. | |
| All linears are bias=False. | |
| """ | |
| def __init__(self, input_size: int, hidden_size: int, output_size: int): | |
| super().__init__() | |
| self.gate_proj = nn.Linear(input_size, hidden_size, bias=False) | |
| self.up_proj = nn.Linear(input_size, hidden_size, bias=False) | |
| self.down_proj = nn.Linear(hidden_size, output_size, bias=False) | |
| def __call__(self, x: mx.array) -> mx.array: | |
| return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) | |
| __all__ = ["sinusoids", "WhisperAttention", "WhisperEncoderBlock", | |
| "EncoderConfig", "MossAudioEncoderMLX", "GatedMLP"] | |