| """End-to-end NVFP4 quantize + bench on the 8 VITW examples. |
| |
| 1. Load Qwen3 1.7B LLM (bf16). |
| 2. Build 168 calibration batches (audio-scattered inputs_embeds). |
| 3. modelopt PTQ with the chosen config (default: NVFP4_AWQ_LITE_CFG). |
| 4. Save to models/nvfp4/. |
| 5. Run the 8 VITW examples through (ONNX encoder + scattered inputs_embeds + |
| quantized LLM) and report WER + agreement. |
| """ |
| import argparse |
| import io |
| import os |
| import re |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
| sys.path.insert(0, "Mega-ASR-src/src") |
| import onnxruntime as ort |
| import soundfile as sf |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoFeatureExtractor |
|
|
|
|
| REFERENCES = { |
| "noise": "I usually take the quieter road home because the main street gets crowded after work.", |
| "far_field": "Please remind me to print the forms before we leave for the appointment tomorrow.", |
| "obstructed": "I forgot my charger at home, so I need to find an outlet before the meeting starts.", |
| "distortion": "The new coffee machine is simple, but everyone keeps forgetting where the filters are stored.", |
| "recording": "Can you check whether the train still stops at the downtown station after eight tonight?", |
| "echo": "I need to return these shoes because the size feels fine standing up but terrible while walking.", |
| "dropout": "My aunt is learning video calls, and she gets excited whenever the picture actually works.", |
| "mixed": "My sister is bringing dinner over later, so we do not need to cook tonight.", |
| } |
| _NORM = re.compile(r"[^a-z0-9\s]") |
|
|
|
|
| def normalize(t): |
| if "<asr_text>" in t: |
| t = t.split("<asr_text>", 1)[1] |
| return re.sub(r"\s+", " ", _NORM.sub(" ", t.lower())).strip() |
|
|
|
|
| def wer(ref, hyp): |
| r = ref.split(); h = hyp.split() |
| if not r: return (1.0 if h else 0.0, len(h), 0) |
| d = np.zeros((len(r) + 1, len(h) + 1), dtype=np.int32) |
| for i in range(len(r) + 1): d[i, 0] = i |
| for j in range(len(h) + 1): d[0, j] = j |
| for i in range(1, len(r) + 1): |
| for j in range(1, len(h) + 1): |
| d[i, j] = min(d[i-1, j] + 1, d[i, j-1] + 1, |
| d[i-1, j-1] + (0 if r[i-1] == h[j-1] else 1)) |
| return d[len(r), len(h)] / max(len(r), 1), int(d[len(r), len(h)]), len(r) |
|
|
|
|
| def col(p, s): |
| if p >= 70: return f"\033[92m{s}\033[0m" |
| if p >= 50: return f"\033[93m{s}\033[0m" |
| if p >= 25: return f"\033[33m{s}\033[0m" |
| return f"\033[91m{s}\033[0m" |
|
|
|
|
| def is_english(s): |
| if not s: return False |
| return bool(re.search(r"[a-zA-Z]", s)) and not bool(re.search(r"[一-鿿-ヿ가-]", s)) |
|
|
|
|
| def encode_audio(enc, audio, sr, feat): |
| if audio.ndim > 1: audio = audio.mean(axis=1) |
| if sr != 16000: |
| import librosa |
| audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=16000) |
| f = feat(audio, sampling_rate=16000, return_tensors="np", return_attention_mask=False) |
| mel = f["input_features"] |
| T_mel = mel.shape[-1] |
| if T_mel > 3000: mel = mel[..., :3000]; T_mel = 3000 |
| mel = np.pad(mel, ((0, 0), (0, 0), (0, 3000 - T_mel)), constant_values=0).astype(np.float32) |
| ae = enc.run(["audio_embeds"], {"mel": mel})[0] |
| rc = (T_mel + 99) // 100 |
| lc = T_mel - (rc - 1) * 100 |
| rf = (rc - 1) * 13 + (lc + 7) // 8 |
| return torch.from_numpy(ae[0, :rf]).to("cuda").to(torch.bfloat16) |
|
|
|
|
| def build_prompt_ids(tok, n_audio, AUDIO_PAD=151676): |
| prompt = ( |
| "<|im_start|>system\n<|im_end|>\n" |
| "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\n" |
| "<|im_start|>assistant\n" |
| "language English<asr_text>" |
| ) |
| ids = tok.encode(prompt, add_special_tokens=False) |
| pos = ids.index(AUDIO_PAD) |
| return ids[:pos] + [AUDIO_PAD] * n_audio + ids[pos + 1:] |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--model", default="models/mega-asr-llm-qwen3", type=Path) |
| ap.add_argument("--encoder", default="models/mega-asr-export/audio_encoder_fp32.onnx", type=Path) |
| ap.add_argument("--qwen-asr-dir", default="models/mega-asr/Qwen3-ASR-1.7B", type=Path) |
| ap.add_argument("--per-split", type=int, default=24) |
| ap.add_argument("--max-seq", type=int, default=512) |
| ap.add_argument("--max-new-tokens", type=int, default=80) |
| ap.add_argument("--cfg", default="NVFP4_AWQ_LITE_CFG") |
| ap.add_argument("--examples-dir", default="vitw-examples", type=Path) |
| ap.add_argument("--skip-quant", action="store_true", |
| help="Bench the bf16 model without quantization (sanity baseline)") |
| ap.add_argument("--out", default="models/nvfp4/mega-asr-llm", type=Path) |
| args = ap.parse_args() |
|
|
| env = Path(".env") |
| for line in env.read_text().splitlines(): |
| if line.startswith("HF_TOKEN"): |
| v = line.split("=", 1)[1].strip().strip('"').strip("'") |
| os.environ["HF_TOKEN"] = v |
| break |
|
|
| print(f"Loading LLM from {args.model} ...") |
| model = AutoModelForCausalLM.from_pretrained( |
| str(args.model), torch_dtype=torch.bfloat16, device_map="cuda", |
| ) |
| model.eval() |
| for p in model.parameters(): p.requires_grad = False |
| tok = AutoTokenizer.from_pretrained(str(args.qwen_asr_dir)) |
| feat = AutoFeatureExtractor.from_pretrained(str(args.qwen_asr_dir)) |
| embed_w = model.model.embed_tokens.weight |
| enc = ort.InferenceSession(str(args.encoder), |
| providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) |
| AUDIO_PAD = 151676; EOS = 151645 |
|
|
| if not args.skip_quant: |
| from datasets import load_dataset, Audio |
| configs = ["noise", "far_field", "obstructed", "distortion", "recording", "echo", "dropout"] |
| calib_batches = [] |
| for cfg in configs: |
| print(f"streaming {cfg} ...", flush=True) |
| try: |
| ds = load_dataset("zhifeixie/Voices-in-the-Wild-2M", split=cfg, streaming=True) |
| ds = ds.cast_column("audio", Audio(decode=False)) |
| except Exception as e: |
| print(f" skip: {e}"); continue |
| n = 0 |
| for ex in ds: |
| a = ex.get("audio") |
| if not a or not a.get("bytes"): continue |
| text = ex.get("text", "") or ex.get("answer", "") |
| if not is_english(text): continue |
| try: |
| arr, sr = sf.read(io.BytesIO(a["bytes"])) |
| ae = encode_audio(enc, arr, sr, feat) |
| ids = build_prompt_ids(tok, ae.shape[0]) |
| if len(ids) > args.max_seq: continue |
| ids_t = torch.tensor(ids, dtype=torch.long, device="cuda") |
| te = embed_w[ids_t].clone() |
| te[ids_t == AUDIO_PAD] = ae |
| calib_batches.append({ |
| "inputs_embeds": te.unsqueeze(0), |
| "attention_mask": torch.ones((1, len(ids)), dtype=torch.long, device="cuda"), |
| "position_ids": torch.arange(len(ids), device="cuda").unsqueeze(0), |
| }) |
| n += 1 |
| except Exception as e: |
| print(f" skip ex: {e}"); continue |
| if n >= args.per_split: break |
| print(f" collected {n} from {cfg}", flush=True) |
| print(f"Total: {len(calib_batches)} calibration batches") |
|
|
| import modelopt.torch.quantization as mtq |
| quant_cfg = getattr(mtq, args.cfg) |
| print(f"\nQuantizing with {args.cfg} ...") |
| def forward_loop(m): |
| for i, b in enumerate(calib_batches): |
| with torch.no_grad(): |
| m(**b) |
| if (i + 1) % 20 == 0: |
| print(f" calib {i + 1}/{len(calib_batches)}", flush=True) |
| mtq.quantize(model, quant_cfg, forward_loop) |
| print("Quantization done.") |
| args.out.mkdir(parents=True, exist_ok=True) |
| model.save_pretrained(str(args.out), safe_serialization=True) |
| tok.save_pretrained(str(args.out)) |
| print(f"Saved → {args.out}") |
|
|
| |
| print(f"\n=== Bench on {args.examples_dir} ===") |
| total_wer = 0.0; total_edits = 0; total_words = 0; n = 0 |
| for name in sorted(REFERENCES): |
| wav_path = args.examples_dir / f"{name}.wav" |
| if not wav_path.exists(): |
| print(f"skip {name} (no wav)"); continue |
| audio, sr = sf.read(str(wav_path)) |
| ae = encode_audio(enc, audio, sr, feat) |
| ids = build_prompt_ids(tok, ae.shape[0]) |
| L = len(ids) |
| ids_t = torch.tensor(ids, dtype=torch.long, device="cuda") |
| te = embed_w[ids_t].clone() |
| te[ids_t == AUDIO_PAD] = ae |
| inputs_embeds = te.unsqueeze(0) |
| attn = torch.ones((1, L), dtype=torch.long, device="cuda") |
| pos = torch.arange(L, device="cuda").unsqueeze(0) |
|
|
| |
| with torch.no_grad(): |
| out = model(inputs_embeds=inputs_embeds, attention_mask=attn, |
| position_ids=pos, use_cache=True, return_dict=True) |
| past = out.past_key_values |
| logits = out.logits[0, -1, :].float() |
| nid = int(logits.argmax().item()) |
| gen = [nid] |
| cur = L |
| |
| for _ in range(args.max_new_tokens - 1): |
| if nid == EOS: break |
| tok_e = embed_w[nid:nid + 1].unsqueeze(0) |
| attn = torch.cat([attn, torch.ones((1, 1), dtype=torch.long, device="cuda")], dim=1) |
| pos = torch.tensor([[cur]], device="cuda") |
| with torch.no_grad(): |
| out = model(inputs_embeds=tok_e, attention_mask=attn, |
| position_ids=pos, past_key_values=past, use_cache=True, return_dict=True) |
| past = out.past_key_values |
| logits = out.logits[0, -1, :].float() |
| nid = int(logits.argmax().item()) |
| gen.append(nid); cur += 1 |
| if gen and gen[-1] == EOS: gen = gen[:-1] |
| hyp_text = tok.decode(gen, skip_special_tokens=True) |
|
|
| ref = normalize(REFERENCES[name]) |
| hyp = normalize(hyp_text) |
| w, ed, words = wer(ref, hyp) |
| agree = max(0.0, 1.0 - w) * 100 |
| total_wer += w; total_edits += ed; total_words += words; n += 1 |
| print(f"\n[{col(agree, name.ljust(10))}] WER={w*100:5.1f}% agree={col(agree, f'{agree:5.1f}%')}") |
| print(f" REF: {ref}") |
| print(f" HYP: {hyp}") |
|
|
| avg = (1 - total_wer / n) * 100 if n else 0 |
| print(f"\n{col(avg, f'=== AVERAGE: agreement {avg:.1f}% WER {total_edits/total_words*100:.1f}% ({total_edits}/{total_words}) ===')}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|