fredchu's picture
Upload folder using huggingface_hub
59237c3 verified
Raw
History Blame Contribute Delete
3.96 kB
"""Standalone MOSS-Audio-{4B,8B}-Thinking MLX inference.
Usage:
python inference.py --audio path/to/clip.wav [--max-tokens 2048]
Both 4B INT4 and 8B hybrid bundles work with this script. Audio-path
dtype is inferred from the saved adapter weights (`scales` key => INT4).
"""
from __future__ import annotations
import argparse, sys, time
from pathlib import Path
HERE = Path(__file__).resolve().parent
sys.path.insert(0, str(HERE / "scripts"))
import librosa
import mlx.core as mx
import numpy as np
from mlx_lm import load as mlx_load
from mlx_lm.generate import generate_step
from mlx_lm.sample_utils import make_sampler, make_logits_processors
from moss_audio_mlx_bridge_v3 import (
load_mlx_audio_path,
build_mel_spectrogram,
run_mlx_audio_pipeline,
install_deepstack_hooks,
)
def main():
p = argparse.ArgumentParser()
p.add_argument("--audio", required=True, help="Path to input .wav (16 kHz, mono)")
p.add_argument("--max-tokens", type=int, default=2048)
p.add_argument("--repetition-penalty", type=float, default=1.02,
help="1.02 kills decode-loops without over-penalizing descriptions")
args = p.parse_args()
ad_w = mx.load(str(HERE / "mlx_audio/audio_adapter.safetensors"))
if "down_proj.scales" in ad_w:
llm_hidden = ad_w["down_proj.scales"].shape[0]
int4_audio = True
else:
llm_hidden = ad_w["down_proj.weight"].shape[0]
int4_audio = False
size_tag = "4B" if llm_hidden == 2560 else "8B"
print(f"[detect] {size_tag} bundle, audio int4={int4_audio}")
print(f"[load] LLM from {HERE / 'mlx_llm'}")
t0 = time.perf_counter()
mlx_model, mlx_tokenizer = mlx_load(str(HERE / "mlx_llm"))
print(f"[load] LLM: {time.perf_counter()-t0:.1f}s")
t0 = time.perf_counter()
encoder, adapter, mergers = load_mlx_audio_path(HERE / "mlx_audio", int4=int4_audio)
print(f"[load] audio path: {time.perf_counter()-t0:.1f}s")
y, _ = librosa.load(args.audio, sr=16000, mono=True)
y = y.astype(np.float32)
print(f"[audio] {args.audio} ({len(y)/16000:.1f}s)")
# Pure-MLX mel + input_ids (no torch).
mel, lens, input_ids_mx, audio_token_id = build_mel_spectrogram(y, mlx_tokenizer)
primary, ds_embeds = run_mlx_audio_pipeline(encoder, adapter, mergers, mel, lens)
primary = primary.astype(mx.bfloat16)
ds_embeds = [d.astype(mx.bfloat16) for d in ds_embeds]
mx.eval(primary, *ds_embeds)
del encoder, adapter, mergers, mel, lens
import gc; gc.collect(); mx.clear_cache()
audio_mask = input_ids_mx == audio_token_id
audio_positions = np.where(np.array(audio_mask[0]))[0]
text_embeds = mlx_model.model.embed_tokens(input_ids_mx)
text_np = np.array(text_embeds.astype(mx.float32))
primary_np = np.array(primary.astype(mx.float32))
text_np[0, audio_positions, :] = primary_np[0, :, :]
merged = mx.array(text_np).astype(mx.bfloat16)
ds_flat = [d[0] for d in ds_embeds]
install_deepstack_hooks(mlx_model, ds_flat, audio_positions)
sampler = make_sampler(temp=1.0, top_p=1.0, top_k=50)
logits_processors = make_logits_processors(
repetition_penalty=args.repetition_penalty, repetition_context_size=20
) if args.repetition_penalty else None
gen_kwargs = dict(
prompt=input_ids_mx[0], model=mlx_model,
input_embeddings=merged[0], max_tokens=args.max_tokens, sampler=sampler,
)
if logits_processors:
gen_kwargs["logits_processors"] = logits_processors
t0 = time.perf_counter()
generated = []
for tok, _ in generate_step(**gen_kwargs):
generated.append(int(tok))
if tok == mlx_tokenizer.eos_token_id:
break
elapsed = time.perf_counter() - t0
print(f"[gen] {len(generated)} tokens in {elapsed:.2f}s ({len(generated)/elapsed:.1f} t/s)")
print(f"\n=== OUTPUT ===\n{mlx_tokenizer.decode(generated)}\n=== END ===")
if __name__ == "__main__":
main()