"""End-to-end NVFP4 quantize + bench on the 8 VITW examples. 1. Load Qwen3 1.7B LLM (bf16). 2. Build 168 calibration batches (audio-scattered inputs_embeds). 3. modelopt PTQ with the chosen config (default: NVFP4_AWQ_LITE_CFG). 4. Save to models/nvfp4/. 5. Run the 8 VITW examples through (ONNX encoder + scattered inputs_embeds + quantized LLM) and report WER + agreement. """ import argparse import io import os import re import sys from pathlib import Path import numpy as np import torch sys.path.insert(0, "Mega-ASR-src/src") import onnxruntime as ort import soundfile as sf from transformers import AutoModelForCausalLM, AutoTokenizer, AutoFeatureExtractor REFERENCES = { "noise": "I usually take the quieter road home because the main street gets crowded after work.", "far_field": "Please remind me to print the forms before we leave for the appointment tomorrow.", "obstructed": "I forgot my charger at home, so I need to find an outlet before the meeting starts.", "distortion": "The new coffee machine is simple, but everyone keeps forgetting where the filters are stored.", "recording": "Can you check whether the train still stops at the downtown station after eight tonight?", "echo": "I need to return these shoes because the size feels fine standing up but terrible while walking.", "dropout": "My aunt is learning video calls, and she gets excited whenever the picture actually works.", "mixed": "My sister is bringing dinner over later, so we do not need to cook tonight.", } _NORM = re.compile(r"[^a-z0-9\s]") def normalize(t): if "" in t: t = t.split("", 1)[1] return re.sub(r"\s+", " ", _NORM.sub(" ", t.lower())).strip() def wer(ref, hyp): r = ref.split(); h = hyp.split() if not r: return (1.0 if h else 0.0, len(h), 0) d = np.zeros((len(r) + 1, len(h) + 1), dtype=np.int32) for i in range(len(r) + 1): d[i, 0] = i for j in range(len(h) + 1): d[0, j] = j for i in range(1, len(r) + 1): for j in range(1, len(h) + 1): d[i, j] = min(d[i-1, j] + 1, d[i, j-1] + 1, d[i-1, j-1] + (0 if r[i-1] == h[j-1] else 1)) return d[len(r), len(h)] / max(len(r), 1), int(d[len(r), len(h)]), len(r) def col(p, s): if p >= 70: return f"\033[92m{s}\033[0m" if p >= 50: return f"\033[93m{s}\033[0m" if p >= 25: return f"\033[33m{s}\033[0m" return f"\033[91m{s}\033[0m" def is_english(s): if not s: return False return bool(re.search(r"[a-zA-Z]", s)) and not bool(re.search(r"[一-鿿぀-ヿ가-힯]", s)) def encode_audio(enc, audio, sr, feat): if audio.ndim > 1: audio = audio.mean(axis=1) if sr != 16000: import librosa audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=16000) f = feat(audio, sampling_rate=16000, return_tensors="np", return_attention_mask=False) mel = f["input_features"] T_mel = mel.shape[-1] if T_mel > 3000: mel = mel[..., :3000]; T_mel = 3000 mel = np.pad(mel, ((0, 0), (0, 0), (0, 3000 - T_mel)), constant_values=0).astype(np.float32) ae = enc.run(["audio_embeds"], {"mel": mel})[0] rc = (T_mel + 99) // 100 lc = T_mel - (rc - 1) * 100 rf = (rc - 1) * 13 + (lc + 7) // 8 return torch.from_numpy(ae[0, :rf]).to("cuda").to(torch.bfloat16) def build_prompt_ids(tok, n_audio, AUDIO_PAD=151676): prompt = ( "<|im_start|>system\n<|im_end|>\n" "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\n" "<|im_start|>assistant\n" "language English" ) ids = tok.encode(prompt, add_special_tokens=False) pos = ids.index(AUDIO_PAD) return ids[:pos] + [AUDIO_PAD] * n_audio + ids[pos + 1:] def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", default="models/mega-asr-llm-qwen3", type=Path) ap.add_argument("--encoder", default="models/mega-asr-export/audio_encoder_fp32.onnx", type=Path) ap.add_argument("--qwen-asr-dir", default="models/mega-asr/Qwen3-ASR-1.7B", type=Path) ap.add_argument("--per-split", type=int, default=24) ap.add_argument("--max-seq", type=int, default=512) ap.add_argument("--max-new-tokens", type=int, default=80) ap.add_argument("--cfg", default="NVFP4_AWQ_LITE_CFG") ap.add_argument("--examples-dir", default="vitw-examples", type=Path) ap.add_argument("--skip-quant", action="store_true", help="Bench the bf16 model without quantization (sanity baseline)") ap.add_argument("--out", default="models/nvfp4/mega-asr-llm", type=Path) args = ap.parse_args() env = Path(".env") for line in env.read_text().splitlines(): if line.startswith("HF_TOKEN"): v = line.split("=", 1)[1].strip().strip('"').strip("'") os.environ["HF_TOKEN"] = v break print(f"Loading LLM from {args.model} ...") model = AutoModelForCausalLM.from_pretrained( str(args.model), torch_dtype=torch.bfloat16, device_map="cuda", ) model.eval() for p in model.parameters(): p.requires_grad = False tok = AutoTokenizer.from_pretrained(str(args.qwen_asr_dir)) feat = AutoFeatureExtractor.from_pretrained(str(args.qwen_asr_dir)) embed_w = model.model.embed_tokens.weight enc = ort.InferenceSession(str(args.encoder), providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) AUDIO_PAD = 151676; EOS = 151645 if not args.skip_quant: from datasets import load_dataset, Audio configs = ["noise", "far_field", "obstructed", "distortion", "recording", "echo", "dropout"] calib_batches = [] for cfg in configs: print(f"streaming {cfg} ...", flush=True) try: ds = load_dataset("zhifeixie/Voices-in-the-Wild-2M", split=cfg, streaming=True) ds = ds.cast_column("audio", Audio(decode=False)) except Exception as e: print(f" skip: {e}"); continue n = 0 for ex in ds: a = ex.get("audio") if not a or not a.get("bytes"): continue text = ex.get("text", "") or ex.get("answer", "") if not is_english(text): continue try: arr, sr = sf.read(io.BytesIO(a["bytes"])) ae = encode_audio(enc, arr, sr, feat) ids = build_prompt_ids(tok, ae.shape[0]) if len(ids) > args.max_seq: continue ids_t = torch.tensor(ids, dtype=torch.long, device="cuda") te = embed_w[ids_t].clone() te[ids_t == AUDIO_PAD] = ae calib_batches.append({ "inputs_embeds": te.unsqueeze(0), "attention_mask": torch.ones((1, len(ids)), dtype=torch.long, device="cuda"), "position_ids": torch.arange(len(ids), device="cuda").unsqueeze(0), }) n += 1 except Exception as e: print(f" skip ex: {e}"); continue if n >= args.per_split: break print(f" collected {n} from {cfg}", flush=True) print(f"Total: {len(calib_batches)} calibration batches") import modelopt.torch.quantization as mtq quant_cfg = getattr(mtq, args.cfg) print(f"\nQuantizing with {args.cfg} ...") def forward_loop(m): for i, b in enumerate(calib_batches): with torch.no_grad(): m(**b) if (i + 1) % 20 == 0: print(f" calib {i + 1}/{len(calib_batches)}", flush=True) mtq.quantize(model, quant_cfg, forward_loop) print("Quantization done.") args.out.mkdir(parents=True, exist_ok=True) model.save_pretrained(str(args.out), safe_serialization=True) tok.save_pretrained(str(args.out)) print(f"Saved → {args.out}") # Bench print(f"\n=== Bench on {args.examples_dir} ===") total_wer = 0.0; total_edits = 0; total_words = 0; n = 0 for name in sorted(REFERENCES): wav_path = args.examples_dir / f"{name}.wav" if not wav_path.exists(): print(f"skip {name} (no wav)"); continue audio, sr = sf.read(str(wav_path)) ae = encode_audio(enc, audio, sr, feat) ids = build_prompt_ids(tok, ae.shape[0]) L = len(ids) ids_t = torch.tensor(ids, dtype=torch.long, device="cuda") te = embed_w[ids_t].clone() te[ids_t == AUDIO_PAD] = ae inputs_embeds = te.unsqueeze(0) attn = torch.ones((1, L), dtype=torch.long, device="cuda") pos = torch.arange(L, device="cuda").unsqueeze(0) # Prefill with torch.no_grad(): out = model(inputs_embeds=inputs_embeds, attention_mask=attn, position_ids=pos, use_cache=True, return_dict=True) past = out.past_key_values logits = out.logits[0, -1, :].float() nid = int(logits.argmax().item()) gen = [nid] cur = L # Decode for _ in range(args.max_new_tokens - 1): if nid == EOS: break tok_e = embed_w[nid:nid + 1].unsqueeze(0) # (1, 1, 2048) attn = torch.cat([attn, torch.ones((1, 1), dtype=torch.long, device="cuda")], dim=1) pos = torch.tensor([[cur]], device="cuda") with torch.no_grad(): out = model(inputs_embeds=tok_e, attention_mask=attn, position_ids=pos, past_key_values=past, use_cache=True, return_dict=True) past = out.past_key_values logits = out.logits[0, -1, :].float() nid = int(logits.argmax().item()) gen.append(nid); cur += 1 if gen and gen[-1] == EOS: gen = gen[:-1] hyp_text = tok.decode(gen, skip_special_tokens=True) ref = normalize(REFERENCES[name]) hyp = normalize(hyp_text) w, ed, words = wer(ref, hyp) agree = max(0.0, 1.0 - w) * 100 total_wer += w; total_edits += ed; total_words += words; n += 1 print(f"\n[{col(agree, name.ljust(10))}] WER={w*100:5.1f}% agree={col(agree, f'{agree:5.1f}%')}") print(f" REF: {ref}") print(f" HYP: {hyp}") avg = (1 - total_wer / n) * 100 if n else 0 print(f"\n{col(avg, f'=== AVERAGE: agreement {avg:.1f}% WER {total_edits/total_words*100:.1f}% ({total_edits}/{total_words}) ===')}") if __name__ == "__main__": main()