mega-asr-nvfp4 / inference_bench.py
Reza2kn's picture
NVFP4 AWQ-Lite Qwen3-ASR-1.7B (91.4% on VITW)
1f1078e verified
"""End-to-end NVFP4 quantize + bench on the 8 VITW examples.
1. Load Qwen3 1.7B LLM (bf16).
2. Build 168 calibration batches (audio-scattered inputs_embeds).
3. modelopt PTQ with the chosen config (default: NVFP4_AWQ_LITE_CFG).
4. Save to models/nvfp4/.
5. Run the 8 VITW examples through (ONNX encoder + scattered inputs_embeds +
quantized LLM) and report WER + agreement.
"""
import argparse
import io
import os
import re
import sys
from pathlib import Path
import numpy as np
import torch
sys.path.insert(0, "Mega-ASR-src/src")
import onnxruntime as ort
import soundfile as sf
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoFeatureExtractor
REFERENCES = {
"noise": "I usually take the quieter road home because the main street gets crowded after work.",
"far_field": "Please remind me to print the forms before we leave for the appointment tomorrow.",
"obstructed": "I forgot my charger at home, so I need to find an outlet before the meeting starts.",
"distortion": "The new coffee machine is simple, but everyone keeps forgetting where the filters are stored.",
"recording": "Can you check whether the train still stops at the downtown station after eight tonight?",
"echo": "I need to return these shoes because the size feels fine standing up but terrible while walking.",
"dropout": "My aunt is learning video calls, and she gets excited whenever the picture actually works.",
"mixed": "My sister is bringing dinner over later, so we do not need to cook tonight.",
}
_NORM = re.compile(r"[^a-z0-9\s]")
def normalize(t):
if "<asr_text>" in t:
t = t.split("<asr_text>", 1)[1]
return re.sub(r"\s+", " ", _NORM.sub(" ", t.lower())).strip()
def wer(ref, hyp):
r = ref.split(); h = hyp.split()
if not r: return (1.0 if h else 0.0, len(h), 0)
d = np.zeros((len(r) + 1, len(h) + 1), dtype=np.int32)
for i in range(len(r) + 1): d[i, 0] = i
for j in range(len(h) + 1): d[0, j] = j
for i in range(1, len(r) + 1):
for j in range(1, len(h) + 1):
d[i, j] = min(d[i-1, j] + 1, d[i, j-1] + 1,
d[i-1, j-1] + (0 if r[i-1] == h[j-1] else 1))
return d[len(r), len(h)] / max(len(r), 1), int(d[len(r), len(h)]), len(r)
def col(p, s):
if p >= 70: return f"\033[92m{s}\033[0m"
if p >= 50: return f"\033[93m{s}\033[0m"
if p >= 25: return f"\033[33m{s}\033[0m"
return f"\033[91m{s}\033[0m"
def is_english(s):
if not s: return False
return bool(re.search(r"[a-zA-Z]", s)) and not bool(re.search(r"[一-鿿぀-ヿ가-힯]", s))
def encode_audio(enc, audio, sr, feat):
if audio.ndim > 1: audio = audio.mean(axis=1)
if sr != 16000:
import librosa
audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=16000)
f = feat(audio, sampling_rate=16000, return_tensors="np", return_attention_mask=False)
mel = f["input_features"]
T_mel = mel.shape[-1]
if T_mel > 3000: mel = mel[..., :3000]; T_mel = 3000
mel = np.pad(mel, ((0, 0), (0, 0), (0, 3000 - T_mel)), constant_values=0).astype(np.float32)
ae = enc.run(["audio_embeds"], {"mel": mel})[0]
rc = (T_mel + 99) // 100
lc = T_mel - (rc - 1) * 100
rf = (rc - 1) * 13 + (lc + 7) // 8
return torch.from_numpy(ae[0, :rf]).to("cuda").to(torch.bfloat16)
def build_prompt_ids(tok, n_audio, AUDIO_PAD=151676):
prompt = (
"<|im_start|>system\n<|im_end|>\n"
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\n"
"<|im_start|>assistant\n"
"language English<asr_text>"
)
ids = tok.encode(prompt, add_special_tokens=False)
pos = ids.index(AUDIO_PAD)
return ids[:pos] + [AUDIO_PAD] * n_audio + ids[pos + 1:]
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", default="models/mega-asr-llm-qwen3", type=Path)
ap.add_argument("--encoder", default="models/mega-asr-export/audio_encoder_fp32.onnx", type=Path)
ap.add_argument("--qwen-asr-dir", default="models/mega-asr/Qwen3-ASR-1.7B", type=Path)
ap.add_argument("--per-split", type=int, default=24)
ap.add_argument("--max-seq", type=int, default=512)
ap.add_argument("--max-new-tokens", type=int, default=80)
ap.add_argument("--cfg", default="NVFP4_AWQ_LITE_CFG")
ap.add_argument("--examples-dir", default="vitw-examples", type=Path)
ap.add_argument("--skip-quant", action="store_true",
help="Bench the bf16 model without quantization (sanity baseline)")
ap.add_argument("--out", default="models/nvfp4/mega-asr-llm", type=Path)
args = ap.parse_args()
env = Path(".env")
for line in env.read_text().splitlines():
if line.startswith("HF_TOKEN"):
v = line.split("=", 1)[1].strip().strip('"').strip("'")
os.environ["HF_TOKEN"] = v
break
print(f"Loading LLM from {args.model} ...")
model = AutoModelForCausalLM.from_pretrained(
str(args.model), torch_dtype=torch.bfloat16, device_map="cuda",
)
model.eval()
for p in model.parameters(): p.requires_grad = False
tok = AutoTokenizer.from_pretrained(str(args.qwen_asr_dir))
feat = AutoFeatureExtractor.from_pretrained(str(args.qwen_asr_dir))
embed_w = model.model.embed_tokens.weight
enc = ort.InferenceSession(str(args.encoder),
providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
AUDIO_PAD = 151676; EOS = 151645
if not args.skip_quant:
from datasets import load_dataset, Audio
configs = ["noise", "far_field", "obstructed", "distortion", "recording", "echo", "dropout"]
calib_batches = []
for cfg in configs:
print(f"streaming {cfg} ...", flush=True)
try:
ds = load_dataset("zhifeixie/Voices-in-the-Wild-2M", split=cfg, streaming=True)
ds = ds.cast_column("audio", Audio(decode=False))
except Exception as e:
print(f" skip: {e}"); continue
n = 0
for ex in ds:
a = ex.get("audio")
if not a or not a.get("bytes"): continue
text = ex.get("text", "") or ex.get("answer", "")
if not is_english(text): continue
try:
arr, sr = sf.read(io.BytesIO(a["bytes"]))
ae = encode_audio(enc, arr, sr, feat)
ids = build_prompt_ids(tok, ae.shape[0])
if len(ids) > args.max_seq: continue
ids_t = torch.tensor(ids, dtype=torch.long, device="cuda")
te = embed_w[ids_t].clone()
te[ids_t == AUDIO_PAD] = ae
calib_batches.append({
"inputs_embeds": te.unsqueeze(0),
"attention_mask": torch.ones((1, len(ids)), dtype=torch.long, device="cuda"),
"position_ids": torch.arange(len(ids), device="cuda").unsqueeze(0),
})
n += 1
except Exception as e:
print(f" skip ex: {e}"); continue
if n >= args.per_split: break
print(f" collected {n} from {cfg}", flush=True)
print(f"Total: {len(calib_batches)} calibration batches")
import modelopt.torch.quantization as mtq
quant_cfg = getattr(mtq, args.cfg)
print(f"\nQuantizing with {args.cfg} ...")
def forward_loop(m):
for i, b in enumerate(calib_batches):
with torch.no_grad():
m(**b)
if (i + 1) % 20 == 0:
print(f" calib {i + 1}/{len(calib_batches)}", flush=True)
mtq.quantize(model, quant_cfg, forward_loop)
print("Quantization done.")
args.out.mkdir(parents=True, exist_ok=True)
model.save_pretrained(str(args.out), safe_serialization=True)
tok.save_pretrained(str(args.out))
print(f"Saved → {args.out}")
# Bench
print(f"\n=== Bench on {args.examples_dir} ===")
total_wer = 0.0; total_edits = 0; total_words = 0; n = 0
for name in sorted(REFERENCES):
wav_path = args.examples_dir / f"{name}.wav"
if not wav_path.exists():
print(f"skip {name} (no wav)"); continue
audio, sr = sf.read(str(wav_path))
ae = encode_audio(enc, audio, sr, feat)
ids = build_prompt_ids(tok, ae.shape[0])
L = len(ids)
ids_t = torch.tensor(ids, dtype=torch.long, device="cuda")
te = embed_w[ids_t].clone()
te[ids_t == AUDIO_PAD] = ae
inputs_embeds = te.unsqueeze(0)
attn = torch.ones((1, L), dtype=torch.long, device="cuda")
pos = torch.arange(L, device="cuda").unsqueeze(0)
# Prefill
with torch.no_grad():
out = model(inputs_embeds=inputs_embeds, attention_mask=attn,
position_ids=pos, use_cache=True, return_dict=True)
past = out.past_key_values
logits = out.logits[0, -1, :].float()
nid = int(logits.argmax().item())
gen = [nid]
cur = L
# Decode
for _ in range(args.max_new_tokens - 1):
if nid == EOS: break
tok_e = embed_w[nid:nid + 1].unsqueeze(0) # (1, 1, 2048)
attn = torch.cat([attn, torch.ones((1, 1), dtype=torch.long, device="cuda")], dim=1)
pos = torch.tensor([[cur]], device="cuda")
with torch.no_grad():
out = model(inputs_embeds=tok_e, attention_mask=attn,
position_ids=pos, past_key_values=past, use_cache=True, return_dict=True)
past = out.past_key_values
logits = out.logits[0, -1, :].float()
nid = int(logits.argmax().item())
gen.append(nid); cur += 1
if gen and gen[-1] == EOS: gen = gen[:-1]
hyp_text = tok.decode(gen, skip_special_tokens=True)
ref = normalize(REFERENCES[name])
hyp = normalize(hyp_text)
w, ed, words = wer(ref, hyp)
agree = max(0.0, 1.0 - w) * 100
total_wer += w; total_edits += ed; total_words += words; n += 1
print(f"\n[{col(agree, name.ljust(10))}] WER={w*100:5.1f}% agree={col(agree, f'{agree:5.1f}%')}")
print(f" REF: {ref}")
print(f" HYP: {hyp}")
avg = (1 - total_wer / n) * 100 if n else 0
print(f"\n{col(avg, f'=== AVERAGE: agreement {avg:.1f}% WER {total_edits/total_words*100:.1f}% ({total_edits}/{total_words}) ===')}")
if __name__ == "__main__":
main()