"""Standalone MOSS-Audio-{4B,8B}-Thinking MLX inference. Usage: python inference.py --audio path/to/clip.wav [--max-tokens 2048] Both 4B INT4 and 8B hybrid bundles work with this script. Audio-path dtype is inferred from the saved adapter weights (`scales` key => INT4). """ from __future__ import annotations import argparse, sys, time from pathlib import Path HERE = Path(__file__).resolve().parent sys.path.insert(0, str(HERE / "scripts")) import librosa import mlx.core as mx import numpy as np from mlx_lm import load as mlx_load from mlx_lm.generate import generate_step from mlx_lm.sample_utils import make_sampler, make_logits_processors from moss_audio_mlx_bridge_v3 import ( load_mlx_audio_path, build_mel_spectrogram, run_mlx_audio_pipeline, install_deepstack_hooks, ) def main(): p = argparse.ArgumentParser() p.add_argument("--audio", required=True, help="Path to input .wav (16 kHz, mono)") p.add_argument("--max-tokens", type=int, default=2048) p.add_argument("--repetition-penalty", type=float, default=1.02, help="1.02 kills decode-loops without over-penalizing descriptions") args = p.parse_args() ad_w = mx.load(str(HERE / "mlx_audio/audio_adapter.safetensors")) if "down_proj.scales" in ad_w: llm_hidden = ad_w["down_proj.scales"].shape[0] int4_audio = True else: llm_hidden = ad_w["down_proj.weight"].shape[0] int4_audio = False size_tag = "4B" if llm_hidden == 2560 else "8B" print(f"[detect] {size_tag} bundle, audio int4={int4_audio}") print(f"[load] LLM from {HERE / 'mlx_llm'}") t0 = time.perf_counter() mlx_model, mlx_tokenizer = mlx_load(str(HERE / "mlx_llm")) print(f"[load] LLM: {time.perf_counter()-t0:.1f}s") t0 = time.perf_counter() encoder, adapter, mergers = load_mlx_audio_path(HERE / "mlx_audio", int4=int4_audio) print(f"[load] audio path: {time.perf_counter()-t0:.1f}s") y, _ = librosa.load(args.audio, sr=16000, mono=True) y = y.astype(np.float32) print(f"[audio] {args.audio} ({len(y)/16000:.1f}s)") # Pure-MLX mel + input_ids (no torch). mel, lens, input_ids_mx, audio_token_id = build_mel_spectrogram(y, mlx_tokenizer) primary, ds_embeds = run_mlx_audio_pipeline(encoder, adapter, mergers, mel, lens) primary = primary.astype(mx.bfloat16) ds_embeds = [d.astype(mx.bfloat16) for d in ds_embeds] mx.eval(primary, *ds_embeds) del encoder, adapter, mergers, mel, lens import gc; gc.collect(); mx.clear_cache() audio_mask = input_ids_mx == audio_token_id audio_positions = np.where(np.array(audio_mask[0]))[0] text_embeds = mlx_model.model.embed_tokens(input_ids_mx) text_np = np.array(text_embeds.astype(mx.float32)) primary_np = np.array(primary.astype(mx.float32)) text_np[0, audio_positions, :] = primary_np[0, :, :] merged = mx.array(text_np).astype(mx.bfloat16) ds_flat = [d[0] for d in ds_embeds] install_deepstack_hooks(mlx_model, ds_flat, audio_positions) sampler = make_sampler(temp=1.0, top_p=1.0, top_k=50) logits_processors = make_logits_processors( repetition_penalty=args.repetition_penalty, repetition_context_size=20 ) if args.repetition_penalty else None gen_kwargs = dict( prompt=input_ids_mx[0], model=mlx_model, input_embeddings=merged[0], max_tokens=args.max_tokens, sampler=sampler, ) if logits_processors: gen_kwargs["logits_processors"] = logits_processors t0 = time.perf_counter() generated = [] for tok, _ in generate_step(**gen_kwargs): generated.append(int(tok)) if tok == mlx_tokenizer.eos_token_id: break elapsed = time.perf_counter() - t0 print(f"[gen] {len(generated)} tokens in {elapsed:.2f}s ({len(generated)/elapsed:.1f} t/s)") print(f"\n=== OUTPUT ===\n{mlx_tokenizer.decode(generated)}\n=== END ===") if __name__ == "__main__": main()