"""NVFP4 PTQ on Qwen3-ASR-1.7B LLM via nvidia-modelopt. Calibration data: 168 English VITW-2M samples (same set used for GPTQ), forwarded as input_embeddings through the HF model so the calibration sees the same scattered-audio-embeds activation distribution as real ASR inference. Output: NVFP4 quantized PyTorch state_dict + a TensorRT-LLM exportable checkpoint at models/nvfp4/. """ import argparse import io import os import re import sys from pathlib import Path import numpy as np import torch sys.path.insert(0, "Mega-ASR-src/src") import onnxruntime as ort import soundfile as sf from datasets import load_dataset, Audio from transformers import AutoModelForCausalLM, AutoTokenizer def is_english(s): if not s: return False return bool(re.search(r"[a-zA-Z]", s)) and not bool(re.search(r"[一-鿿぀-ヿ가-힯]", s)) def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", default="models/mega-asr-llm-qwen3", type=Path) ap.add_argument("--encoder", default="models/mega-asr-export/audio_encoder_fp32.onnx", type=Path) ap.add_argument("--qwen-asr-dir", default="models/mega-asr/Qwen3-ASR-1.7B", type=Path) ap.add_argument("--per-split", type=int, default=24) ap.add_argument("--max-seq", type=int, default=512) ap.add_argument("--cfg", default="NVFP4_AWQ_LITE_CFG", choices=[ "NVFP4_DEFAULT_CFG", "NVFP4_AWQ_LITE_CFG", "NVFP4_AWQ_CLIP_CFG", "NVFP4_AFFINE_KV_CFG", "MXFP4_DEFAULT_CFG", ]) ap.add_argument("--out", default="models/nvfp4/mega-asr-llm", type=Path) args = ap.parse_args() # Load HF token env = Path(".env") for line in env.read_text().splitlines(): if line.startswith("HF_TOKEN"): v = line.split("=", 1)[1].strip().strip('"').strip("'") os.environ["HF_TOKEN"] = v break print(f"Loading LLM from {args.model} ...") model = AutoModelForCausalLM.from_pretrained( str(args.model), torch_dtype=torch.bfloat16, device_map="cuda", ) model.eval() for p in model.parameters(): p.requires_grad = False tok = AutoTokenizer.from_pretrained(str(args.qwen_asr_dir)) embed_w = model.model.embed_tokens.weight # (151936, 2048) on cuda print(f"Model loaded: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B params") # Load encoder + processor (for mel + audio_embeds during calibration) from transformers import AutoFeatureExtractor feat = AutoFeatureExtractor.from_pretrained(str(args.qwen_asr_dir)) enc = ort.InferenceSession( str(args.encoder), providers=["CUDAExecutionProvider", "CPUExecutionProvider"], ) # Build prompt template AUDIO_PAD = 151676 def build_prompt_ids(n_audio): prompt = ( "<|im_start|>system\n<|im_end|>\n" "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\n" "<|im_start|>assistant\n" "language English" ) ids = tok.encode(prompt, add_special_tokens=False) pos = ids.index(AUDIO_PAD) return ids[:pos] + [AUDIO_PAD] * n_audio + ids[pos + 1:] # Stream calibration data: 168 English VITW samples → mel → audio embeds → scatter configs = ["noise", "far_field", "obstructed", "distortion", "recording", "echo", "dropout"] calib_batches = [] for cfg in configs: print(f"streaming {cfg} ...", flush=True) try: ds = load_dataset("zhifeixie/Voices-in-the-Wild-2M", split=cfg, streaming=True) ds = ds.cast_column("audio", Audio(decode=False)) except Exception as e: print(f" skip {cfg}: {e}"); continue n = 0 for ex in ds: audio = ex.get("audio") if not audio or not audio.get("bytes"): continue text = ex.get("text", "") or ex.get("answer", "") if not is_english(text): continue try: arr, sr = sf.read(io.BytesIO(audio["bytes"])) if arr.ndim > 1: arr = arr.mean(axis=1) if sr != 16000: import librosa arr = librosa.resample(arr.astype(np.float32), orig_sr=sr, target_sr=16000) f = feat(arr, sampling_rate=16000, return_tensors="np", return_attention_mask=False) mel = f["input_features"] T_mel = mel.shape[-1] if T_mel > 3000: mel = mel[..., :3000]; T_mel = 3000 mel = np.pad(mel, ((0, 0), (0, 0), (0, 3000 - T_mel)), constant_values=0).astype(np.float32) ae = enc.run(["audio_embeds"], {"mel": mel})[0] real_chunks = (T_mel + 99) // 100 last_chunk = T_mel - (real_chunks - 1) * 100 real_frames = (real_chunks - 1) * 13 + (last_chunk + 7) // 8 ae = torch.from_numpy(ae[0, :real_frames]).to("cuda").to(torch.bfloat16) prompt_ids = build_prompt_ids(real_frames) L = len(prompt_ids) if L > args.max_seq: continue ids_t = torch.tensor(prompt_ids, dtype=torch.long, device="cuda") token_embeds = embed_w[ids_t].clone() # (L, 2048) mask = (ids_t == AUDIO_PAD) token_embeds[mask] = ae inputs_embeds = token_embeds.unsqueeze(0) # (1, L, 2048) attn = torch.ones((1, L), dtype=torch.long, device="cuda") pos = torch.arange(L, device="cuda").unsqueeze(0) calib_batches.append({"inputs_embeds": inputs_embeds, "attention_mask": attn, "position_ids": pos}) n += 1 except Exception as e: print(f" skip ex: {e}"); continue if n >= args.per_split: break print(f" collected {n} from {cfg}", flush=True) print(f"\nTotal calibration batches: {len(calib_batches)}") del enc # Modelopt NVFP4 PTQ import modelopt.torch.quantization as mtq from modelopt.torch.quantization import config as cfg_mod quant_cfg = getattr(mtq, args.cfg) print(f"\nQuantizing with {args.cfg} ...") def forward_loop(m): for i, b in enumerate(calib_batches): with torch.no_grad(): m(**b) if (i + 1) % 20 == 0: print(f" calib {i + 1}/{len(calib_batches)}", flush=True) return None mtq.quantize(model, quant_cfg, forward_loop) print("Quantization done.") # Save HF-compatible state_dict + config args.out.mkdir(parents=True, exist_ok=True) print(f"Saving to {args.out} ...") model.save_pretrained(str(args.out), safe_serialization=True) tok.save_pretrained(str(args.out)) print("Done.") if __name__ == "__main__": main()