mega-asr-nvfp4 / nvfp4_quantize.py
Reza2kn's picture
NVFP4 AWQ-Lite Qwen3-ASR-1.7B (91.4% on VITW)
ba7df37 verified
"""NVFP4 PTQ on Qwen3-ASR-1.7B LLM via nvidia-modelopt.
Calibration data: 168 English VITW-2M samples (same set used for GPTQ),
forwarded as input_embeddings through the HF model so the calibration
sees the same scattered-audio-embeds activation distribution as real ASR
inference.
Output: NVFP4 quantized PyTorch state_dict + a TensorRT-LLM exportable
checkpoint at models/nvfp4/.
"""
import argparse
import io
import os
import re
import sys
from pathlib import Path
import numpy as np
import torch
sys.path.insert(0, "Mega-ASR-src/src")
import onnxruntime as ort
import soundfile as sf
from datasets import load_dataset, Audio
from transformers import AutoModelForCausalLM, AutoTokenizer
def is_english(s):
if not s: return False
return bool(re.search(r"[a-zA-Z]", s)) and not bool(re.search(r"[一-鿿぀-ヿ가-힯]", s))
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", default="models/mega-asr-llm-qwen3", type=Path)
ap.add_argument("--encoder", default="models/mega-asr-export/audio_encoder_fp32.onnx", type=Path)
ap.add_argument("--qwen-asr-dir", default="models/mega-asr/Qwen3-ASR-1.7B", type=Path)
ap.add_argument("--per-split", type=int, default=24)
ap.add_argument("--max-seq", type=int, default=512)
ap.add_argument("--cfg", default="NVFP4_AWQ_LITE_CFG", choices=[
"NVFP4_DEFAULT_CFG", "NVFP4_AWQ_LITE_CFG", "NVFP4_AWQ_CLIP_CFG",
"NVFP4_AFFINE_KV_CFG", "MXFP4_DEFAULT_CFG",
])
ap.add_argument("--out", default="models/nvfp4/mega-asr-llm", type=Path)
args = ap.parse_args()
# Load HF token
env = Path(".env")
for line in env.read_text().splitlines():
if line.startswith("HF_TOKEN"):
v = line.split("=", 1)[1].strip().strip('"').strip("'")
os.environ["HF_TOKEN"] = v
break
print(f"Loading LLM from {args.model} ...")
model = AutoModelForCausalLM.from_pretrained(
str(args.model), torch_dtype=torch.bfloat16, device_map="cuda",
)
model.eval()
for p in model.parameters():
p.requires_grad = False
tok = AutoTokenizer.from_pretrained(str(args.qwen_asr_dir))
embed_w = model.model.embed_tokens.weight # (151936, 2048) on cuda
print(f"Model loaded: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B params")
# Load encoder + processor (for mel + audio_embeds during calibration)
from transformers import AutoFeatureExtractor
feat = AutoFeatureExtractor.from_pretrained(str(args.qwen_asr_dir))
enc = ort.InferenceSession(
str(args.encoder),
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
# Build prompt template
AUDIO_PAD = 151676
def build_prompt_ids(n_audio):
prompt = (
"<|im_start|>system\n<|im_end|>\n"
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\n"
"<|im_start|>assistant\n"
"language English<asr_text>"
)
ids = tok.encode(prompt, add_special_tokens=False)
pos = ids.index(AUDIO_PAD)
return ids[:pos] + [AUDIO_PAD] * n_audio + ids[pos + 1:]
# Stream calibration data: 168 English VITW samples → mel → audio embeds → scatter
configs = ["noise", "far_field", "obstructed", "distortion", "recording", "echo", "dropout"]
calib_batches = []
for cfg in configs:
print(f"streaming {cfg} ...", flush=True)
try:
ds = load_dataset("zhifeixie/Voices-in-the-Wild-2M", split=cfg, streaming=True)
ds = ds.cast_column("audio", Audio(decode=False))
except Exception as e:
print(f" skip {cfg}: {e}"); continue
n = 0
for ex in ds:
audio = ex.get("audio")
if not audio or not audio.get("bytes"): continue
text = ex.get("text", "") or ex.get("answer", "")
if not is_english(text): continue
try:
arr, sr = sf.read(io.BytesIO(audio["bytes"]))
if arr.ndim > 1: arr = arr.mean(axis=1)
if sr != 16000:
import librosa
arr = librosa.resample(arr.astype(np.float32), orig_sr=sr, target_sr=16000)
f = feat(arr, sampling_rate=16000, return_tensors="np", return_attention_mask=False)
mel = f["input_features"]
T_mel = mel.shape[-1]
if T_mel > 3000: mel = mel[..., :3000]; T_mel = 3000
mel = np.pad(mel, ((0, 0), (0, 0), (0, 3000 - T_mel)), constant_values=0).astype(np.float32)
ae = enc.run(["audio_embeds"], {"mel": mel})[0]
real_chunks = (T_mel + 99) // 100
last_chunk = T_mel - (real_chunks - 1) * 100
real_frames = (real_chunks - 1) * 13 + (last_chunk + 7) // 8
ae = torch.from_numpy(ae[0, :real_frames]).to("cuda").to(torch.bfloat16)
prompt_ids = build_prompt_ids(real_frames)
L = len(prompt_ids)
if L > args.max_seq: continue
ids_t = torch.tensor(prompt_ids, dtype=torch.long, device="cuda")
token_embeds = embed_w[ids_t].clone() # (L, 2048)
mask = (ids_t == AUDIO_PAD)
token_embeds[mask] = ae
inputs_embeds = token_embeds.unsqueeze(0) # (1, L, 2048)
attn = torch.ones((1, L), dtype=torch.long, device="cuda")
pos = torch.arange(L, device="cuda").unsqueeze(0)
calib_batches.append({"inputs_embeds": inputs_embeds, "attention_mask": attn, "position_ids": pos})
n += 1
except Exception as e:
print(f" skip ex: {e}"); continue
if n >= args.per_split: break
print(f" collected {n} from {cfg}", flush=True)
print(f"\nTotal calibration batches: {len(calib_batches)}")
del enc
# Modelopt NVFP4 PTQ
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization import config as cfg_mod
quant_cfg = getattr(mtq, args.cfg)
print(f"\nQuantizing with {args.cfg} ...")
def forward_loop(m):
for i, b in enumerate(calib_batches):
with torch.no_grad():
m(**b)
if (i + 1) % 20 == 0:
print(f" calib {i + 1}/{len(calib_batches)}", flush=True)
return None
mtq.quantize(model, quant_cfg, forward_loop)
print("Quantization done.")
# Save HF-compatible state_dict + config
args.out.mkdir(parents=True, exist_ok=True)
print(f"Saving to {args.out} ...")
model.save_pretrained(str(args.out), safe_serialization=True)
tok.save_pretrained(str(args.out))
print("Done.")
if __name__ == "__main__":
main()