File size: 6,762 Bytes
ba7df37 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | """NVFP4 PTQ on Qwen3-ASR-1.7B LLM via nvidia-modelopt.
Calibration data: 168 English VITW-2M samples (same set used for GPTQ),
forwarded as input_embeddings through the HF model so the calibration
sees the same scattered-audio-embeds activation distribution as real ASR
inference.
Output: NVFP4 quantized PyTorch state_dict + a TensorRT-LLM exportable
checkpoint at models/nvfp4/.
"""
import argparse
import io
import os
import re
import sys
from pathlib import Path
import numpy as np
import torch
sys.path.insert(0, "Mega-ASR-src/src")
import onnxruntime as ort
import soundfile as sf
from datasets import load_dataset, Audio
from transformers import AutoModelForCausalLM, AutoTokenizer
def is_english(s):
if not s: return False
return bool(re.search(r"[a-zA-Z]", s)) and not bool(re.search(r"[一-鿿-ヿ가-]", s))
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", default="models/mega-asr-llm-qwen3", type=Path)
ap.add_argument("--encoder", default="models/mega-asr-export/audio_encoder_fp32.onnx", type=Path)
ap.add_argument("--qwen-asr-dir", default="models/mega-asr/Qwen3-ASR-1.7B", type=Path)
ap.add_argument("--per-split", type=int, default=24)
ap.add_argument("--max-seq", type=int, default=512)
ap.add_argument("--cfg", default="NVFP4_AWQ_LITE_CFG", choices=[
"NVFP4_DEFAULT_CFG", "NVFP4_AWQ_LITE_CFG", "NVFP4_AWQ_CLIP_CFG",
"NVFP4_AFFINE_KV_CFG", "MXFP4_DEFAULT_CFG",
])
ap.add_argument("--out", default="models/nvfp4/mega-asr-llm", type=Path)
args = ap.parse_args()
# Load HF token
env = Path(".env")
for line in env.read_text().splitlines():
if line.startswith("HF_TOKEN"):
v = line.split("=", 1)[1].strip().strip('"').strip("'")
os.environ["HF_TOKEN"] = v
break
print(f"Loading LLM from {args.model} ...")
model = AutoModelForCausalLM.from_pretrained(
str(args.model), torch_dtype=torch.bfloat16, device_map="cuda",
)
model.eval()
for p in model.parameters():
p.requires_grad = False
tok = AutoTokenizer.from_pretrained(str(args.qwen_asr_dir))
embed_w = model.model.embed_tokens.weight # (151936, 2048) on cuda
print(f"Model loaded: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B params")
# Load encoder + processor (for mel + audio_embeds during calibration)
from transformers import AutoFeatureExtractor
feat = AutoFeatureExtractor.from_pretrained(str(args.qwen_asr_dir))
enc = ort.InferenceSession(
str(args.encoder),
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
# Build prompt template
AUDIO_PAD = 151676
def build_prompt_ids(n_audio):
prompt = (
"<|im_start|>system\n<|im_end|>\n"
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\n"
"<|im_start|>assistant\n"
"language English<asr_text>"
)
ids = tok.encode(prompt, add_special_tokens=False)
pos = ids.index(AUDIO_PAD)
return ids[:pos] + [AUDIO_PAD] * n_audio + ids[pos + 1:]
# Stream calibration data: 168 English VITW samples → mel → audio embeds → scatter
configs = ["noise", "far_field", "obstructed", "distortion", "recording", "echo", "dropout"]
calib_batches = []
for cfg in configs:
print(f"streaming {cfg} ...", flush=True)
try:
ds = load_dataset("zhifeixie/Voices-in-the-Wild-2M", split=cfg, streaming=True)
ds = ds.cast_column("audio", Audio(decode=False))
except Exception as e:
print(f" skip {cfg}: {e}"); continue
n = 0
for ex in ds:
audio = ex.get("audio")
if not audio or not audio.get("bytes"): continue
text = ex.get("text", "") or ex.get("answer", "")
if not is_english(text): continue
try:
arr, sr = sf.read(io.BytesIO(audio["bytes"]))
if arr.ndim > 1: arr = arr.mean(axis=1)
if sr != 16000:
import librosa
arr = librosa.resample(arr.astype(np.float32), orig_sr=sr, target_sr=16000)
f = feat(arr, sampling_rate=16000, return_tensors="np", return_attention_mask=False)
mel = f["input_features"]
T_mel = mel.shape[-1]
if T_mel > 3000: mel = mel[..., :3000]; T_mel = 3000
mel = np.pad(mel, ((0, 0), (0, 0), (0, 3000 - T_mel)), constant_values=0).astype(np.float32)
ae = enc.run(["audio_embeds"], {"mel": mel})[0]
real_chunks = (T_mel + 99) // 100
last_chunk = T_mel - (real_chunks - 1) * 100
real_frames = (real_chunks - 1) * 13 + (last_chunk + 7) // 8
ae = torch.from_numpy(ae[0, :real_frames]).to("cuda").to(torch.bfloat16)
prompt_ids = build_prompt_ids(real_frames)
L = len(prompt_ids)
if L > args.max_seq: continue
ids_t = torch.tensor(prompt_ids, dtype=torch.long, device="cuda")
token_embeds = embed_w[ids_t].clone() # (L, 2048)
mask = (ids_t == AUDIO_PAD)
token_embeds[mask] = ae
inputs_embeds = token_embeds.unsqueeze(0) # (1, L, 2048)
attn = torch.ones((1, L), dtype=torch.long, device="cuda")
pos = torch.arange(L, device="cuda").unsqueeze(0)
calib_batches.append({"inputs_embeds": inputs_embeds, "attention_mask": attn, "position_ids": pos})
n += 1
except Exception as e:
print(f" skip ex: {e}"); continue
if n >= args.per_split: break
print(f" collected {n} from {cfg}", flush=True)
print(f"\nTotal calibration batches: {len(calib_batches)}")
del enc
# Modelopt NVFP4 PTQ
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization import config as cfg_mod
quant_cfg = getattr(mtq, args.cfg)
print(f"\nQuantizing with {args.cfg} ...")
def forward_loop(m):
for i, b in enumerate(calib_batches):
with torch.no_grad():
m(**b)
if (i + 1) % 20 == 0:
print(f" calib {i + 1}/{len(calib_batches)}", flush=True)
return None
mtq.quantize(model, quant_cfg, forward_loop)
print("Quantization done.")
# Save HF-compatible state_dict + config
args.out.mkdir(parents=True, exist_ok=True)
print(f"Saving to {args.out} ...")
model.save_pretrained(str(args.out), safe_serialization=True)
tok.save_pretrained(str(args.out))
print("Done.")
if __name__ == "__main__":
main()
|