Instructions to use fredchu/MOSS-Audio-8B-Instruct-MLX with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use fredchu/MOSS-Audio-8B-Instruct-MLX with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir MOSS-Audio-8B-Instruct-MLX fredchu/MOSS-Audio-8B-Instruct-MLX
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
| """Standalone MOSS-Audio-{4B,8B}-Thinking MLX inference. | |
| Usage: | |
| python inference.py --audio path/to/clip.wav [--max-tokens 2048] | |
| Both 4B INT4 and 8B hybrid bundles work with this script. Audio-path | |
| dtype is inferred from the saved adapter weights (`scales` key => INT4). | |
| """ | |
| from __future__ import annotations | |
| import argparse, sys, time | |
| from pathlib import Path | |
| HERE = Path(__file__).resolve().parent | |
| sys.path.insert(0, str(HERE / "scripts")) | |
| import librosa | |
| import mlx.core as mx | |
| import numpy as np | |
| from mlx_lm import load as mlx_load | |
| from mlx_lm.generate import generate_step | |
| from mlx_lm.sample_utils import make_sampler, make_logits_processors | |
| from moss_audio_mlx_bridge_v3 import ( | |
| load_mlx_audio_path, | |
| build_mel_spectrogram, | |
| run_mlx_audio_pipeline, | |
| install_deepstack_hooks, | |
| ) | |
| def main(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--audio", required=True, help="Path to input .wav (16 kHz, mono)") | |
| p.add_argument("--max-tokens", type=int, default=2048) | |
| p.add_argument("--repetition-penalty", type=float, default=1.02, | |
| help="1.02 kills decode-loops without over-penalizing descriptions") | |
| args = p.parse_args() | |
| ad_w = mx.load(str(HERE / "mlx_audio/audio_adapter.safetensors")) | |
| if "down_proj.scales" in ad_w: | |
| llm_hidden = ad_w["down_proj.scales"].shape[0] | |
| int4_audio = True | |
| else: | |
| llm_hidden = ad_w["down_proj.weight"].shape[0] | |
| int4_audio = False | |
| size_tag = "4B" if llm_hidden == 2560 else "8B" | |
| print(f"[detect] {size_tag} bundle, audio int4={int4_audio}") | |
| print(f"[load] LLM from {HERE / 'mlx_llm'}") | |
| t0 = time.perf_counter() | |
| mlx_model, mlx_tokenizer = mlx_load(str(HERE / "mlx_llm")) | |
| print(f"[load] LLM: {time.perf_counter()-t0:.1f}s") | |
| t0 = time.perf_counter() | |
| encoder, adapter, mergers = load_mlx_audio_path(HERE / "mlx_audio", int4=int4_audio) | |
| print(f"[load] audio path: {time.perf_counter()-t0:.1f}s") | |
| y, _ = librosa.load(args.audio, sr=16000, mono=True) | |
| y = y.astype(np.float32) | |
| print(f"[audio] {args.audio} ({len(y)/16000:.1f}s)") | |
| # Pure-MLX mel + input_ids (no torch). | |
| mel, lens, input_ids_mx, audio_token_id = build_mel_spectrogram(y, mlx_tokenizer) | |
| primary, ds_embeds = run_mlx_audio_pipeline(encoder, adapter, mergers, mel, lens) | |
| primary = primary.astype(mx.bfloat16) | |
| ds_embeds = [d.astype(mx.bfloat16) for d in ds_embeds] | |
| mx.eval(primary, *ds_embeds) | |
| del encoder, adapter, mergers, mel, lens | |
| import gc; gc.collect(); mx.clear_cache() | |
| audio_mask = input_ids_mx == audio_token_id | |
| audio_positions = np.where(np.array(audio_mask[0]))[0] | |
| text_embeds = mlx_model.model.embed_tokens(input_ids_mx) | |
| text_np = np.array(text_embeds.astype(mx.float32)) | |
| primary_np = np.array(primary.astype(mx.float32)) | |
| text_np[0, audio_positions, :] = primary_np[0, :, :] | |
| merged = mx.array(text_np).astype(mx.bfloat16) | |
| ds_flat = [d[0] for d in ds_embeds] | |
| install_deepstack_hooks(mlx_model, ds_flat, audio_positions) | |
| sampler = make_sampler(temp=1.0, top_p=1.0, top_k=50) | |
| logits_processors = make_logits_processors( | |
| repetition_penalty=args.repetition_penalty, repetition_context_size=20 | |
| ) if args.repetition_penalty else None | |
| gen_kwargs = dict( | |
| prompt=input_ids_mx[0], model=mlx_model, | |
| input_embeddings=merged[0], max_tokens=args.max_tokens, sampler=sampler, | |
| ) | |
| if logits_processors: | |
| gen_kwargs["logits_processors"] = logits_processors | |
| t0 = time.perf_counter() | |
| generated = [] | |
| for tok, _ in generate_step(**gen_kwargs): | |
| generated.append(int(tok)) | |
| if tok == mlx_tokenizer.eos_token_id: | |
| break | |
| elapsed = time.perf_counter() - t0 | |
| print(f"[gen] {len(generated)} tokens in {elapsed:.2f}s ({len(generated)/elapsed:.1f} t/s)") | |
| print(f"\n=== OUTPUT ===\n{mlx_tokenizer.decode(generated)}\n=== END ===") | |
| if __name__ == "__main__": | |
| main() | |