File size: 6,762 Bytes
ba7df37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""NVFP4 PTQ on Qwen3-ASR-1.7B LLM via nvidia-modelopt.

Calibration data: 168 English VITW-2M samples (same set used for GPTQ),
forwarded as input_embeddings through the HF model so the calibration
sees the same scattered-audio-embeds activation distribution as real ASR
inference.

Output: NVFP4 quantized PyTorch state_dict + a TensorRT-LLM exportable
checkpoint at models/nvfp4/.
"""
import argparse
import io
import os
import re
import sys
from pathlib import Path

import numpy as np
import torch

sys.path.insert(0, "Mega-ASR-src/src")
import onnxruntime as ort
import soundfile as sf
from datasets import load_dataset, Audio
from transformers import AutoModelForCausalLM, AutoTokenizer


def is_english(s):
    if not s: return False
    return bool(re.search(r"[a-zA-Z]", s)) and not bool(re.search(r"[一-鿿぀-ヿ가-힯]", s))


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", default="models/mega-asr-llm-qwen3", type=Path)
    ap.add_argument("--encoder", default="models/mega-asr-export/audio_encoder_fp32.onnx", type=Path)
    ap.add_argument("--qwen-asr-dir", default="models/mega-asr/Qwen3-ASR-1.7B", type=Path)
    ap.add_argument("--per-split", type=int, default=24)
    ap.add_argument("--max-seq", type=int, default=512)
    ap.add_argument("--cfg", default="NVFP4_AWQ_LITE_CFG", choices=[
        "NVFP4_DEFAULT_CFG", "NVFP4_AWQ_LITE_CFG", "NVFP4_AWQ_CLIP_CFG",
        "NVFP4_AFFINE_KV_CFG", "MXFP4_DEFAULT_CFG",
    ])
    ap.add_argument("--out", default="models/nvfp4/mega-asr-llm", type=Path)
    args = ap.parse_args()

    # Load HF token
    env = Path(".env")
    for line in env.read_text().splitlines():
        if line.startswith("HF_TOKEN"):
            v = line.split("=", 1)[1].strip().strip('"').strip("'")
            os.environ["HF_TOKEN"] = v
            break

    print(f"Loading LLM from {args.model} ...")
    model = AutoModelForCausalLM.from_pretrained(
        str(args.model), torch_dtype=torch.bfloat16, device_map="cuda",
    )
    model.eval()
    for p in model.parameters():
        p.requires_grad = False
    tok = AutoTokenizer.from_pretrained(str(args.qwen_asr_dir))
    embed_w = model.model.embed_tokens.weight  # (151936, 2048) on cuda
    print(f"Model loaded: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B params")

    # Load encoder + processor (for mel + audio_embeds during calibration)
    from transformers import AutoFeatureExtractor
    feat = AutoFeatureExtractor.from_pretrained(str(args.qwen_asr_dir))
    enc = ort.InferenceSession(
        str(args.encoder),
        providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
    )

    # Build prompt template
    AUDIO_PAD = 151676
    def build_prompt_ids(n_audio):
        prompt = (
            "<|im_start|>system\n<|im_end|>\n"
            "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\n"
            "<|im_start|>assistant\n"
            "language English<asr_text>"
        )
        ids = tok.encode(prompt, add_special_tokens=False)
        pos = ids.index(AUDIO_PAD)
        return ids[:pos] + [AUDIO_PAD] * n_audio + ids[pos + 1:]

    # Stream calibration data: 168 English VITW samples → mel → audio embeds → scatter
    configs = ["noise", "far_field", "obstructed", "distortion", "recording", "echo", "dropout"]
    calib_batches = []
    for cfg in configs:
        print(f"streaming {cfg} ...", flush=True)
        try:
            ds = load_dataset("zhifeixie/Voices-in-the-Wild-2M", split=cfg, streaming=True)
            ds = ds.cast_column("audio", Audio(decode=False))
        except Exception as e:
            print(f"  skip {cfg}: {e}"); continue
        n = 0
        for ex in ds:
            audio = ex.get("audio")
            if not audio or not audio.get("bytes"): continue
            text = ex.get("text", "") or ex.get("answer", "")
            if not is_english(text): continue
            try:
                arr, sr = sf.read(io.BytesIO(audio["bytes"]))
                if arr.ndim > 1: arr = arr.mean(axis=1)
                if sr != 16000:
                    import librosa
                    arr = librosa.resample(arr.astype(np.float32), orig_sr=sr, target_sr=16000)
                f = feat(arr, sampling_rate=16000, return_tensors="np", return_attention_mask=False)
                mel = f["input_features"]
                T_mel = mel.shape[-1]
                if T_mel > 3000: mel = mel[..., :3000]; T_mel = 3000
                mel = np.pad(mel, ((0, 0), (0, 0), (0, 3000 - T_mel)), constant_values=0).astype(np.float32)
                ae = enc.run(["audio_embeds"], {"mel": mel})[0]
                real_chunks = (T_mel + 99) // 100
                last_chunk = T_mel - (real_chunks - 1) * 100
                real_frames = (real_chunks - 1) * 13 + (last_chunk + 7) // 8
                ae = torch.from_numpy(ae[0, :real_frames]).to("cuda").to(torch.bfloat16)

                prompt_ids = build_prompt_ids(real_frames)
                L = len(prompt_ids)
                if L > args.max_seq: continue
                ids_t = torch.tensor(prompt_ids, dtype=torch.long, device="cuda")
                token_embeds = embed_w[ids_t].clone()  # (L, 2048)
                mask = (ids_t == AUDIO_PAD)
                token_embeds[mask] = ae
                inputs_embeds = token_embeds.unsqueeze(0)  # (1, L, 2048)
                attn = torch.ones((1, L), dtype=torch.long, device="cuda")
                pos = torch.arange(L, device="cuda").unsqueeze(0)
                calib_batches.append({"inputs_embeds": inputs_embeds, "attention_mask": attn, "position_ids": pos})
                n += 1
            except Exception as e:
                print(f"  skip ex: {e}"); continue
            if n >= args.per_split: break
        print(f"  collected {n} from {cfg}", flush=True)
    print(f"\nTotal calibration batches: {len(calib_batches)}")
    del enc

    # Modelopt NVFP4 PTQ
    import modelopt.torch.quantization as mtq
    from modelopt.torch.quantization import config as cfg_mod

    quant_cfg = getattr(mtq, args.cfg)
    print(f"\nQuantizing with {args.cfg} ...")

    def forward_loop(m):
        for i, b in enumerate(calib_batches):
            with torch.no_grad():
                m(**b)
            if (i + 1) % 20 == 0:
                print(f"  calib {i + 1}/{len(calib_batches)}", flush=True)
        return None

    mtq.quantize(model, quant_cfg, forward_loop)
    print("Quantization done.")

    # Save HF-compatible state_dict + config
    args.out.mkdir(parents=True, exist_ok=True)
    print(f"Saving to {args.out} ...")
    model.save_pretrained(str(args.out), safe_serialization=True)
    tok.save_pretrained(str(args.out))
    print("Done.")


if __name__ == "__main__":
    main()