File size: 11,904 Bytes
04e43d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
"""
MiMo-V2.5-ASR -> FP8 e4m3fn  (per-channel weight quant, dynamic activation quant)

Quantize entrypoint loads MiMoAudioForCausalLM directly (no audio tokenizer / no
flash-attn needed -- the LLM is pure Qwen2). Verify/load paths still go through the
full MimoAudio stack and DO require flash-attn + the audio tokenizer.
"""

import os
import sys
import json
import shutil
import argparse
from pathlib import Path

import torch
import torch.nn as nn
from safetensors.torch import save_file, safe_open

REPO_ROOT = Path(__file__).resolve().parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

# ----- constants -----
FP8_DTYPE   = torch.float8_e4m3fn
FP8_MAX     = torch.finfo(FP8_DTYPE).max   # 448.0
SCALE_DTYPE = torch.float32
SKIP_TYPES  = (nn.Embedding, nn.LayerNorm, nn.GroupNorm,
               nn.BatchNorm1d, nn.BatchNorm2d, nn.RMSNorm)

CONFIG_FILES = [
    "config.json", "tokenizer_config.json", "tokenizer.json",
    "special_tokens_map.json", "generation_config.json",
    "added_tokens.json", "merges.txt", "vocab.json", "chat_template.jinja",
]

SPECIAL_TOKENS = ["<|sosp|>", "<|eosp|>", "<|empty|>", "<|Human|>",
                  "<|SpeechLM|>", "<|sostm|>", "<|eostm|>", "<|eot|>"]


# ----- weight quantization -----
def quantize_weight_per_channel(weight: torch.Tensor):
    """Per output-channel absmax scaling. weight: [out, in]"""
    w     = weight.float()
    amax  = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-12)
    scale = (amax / FP8_MAX).to(SCALE_DTYPE)
    w_fp8 = (w / scale).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE)
    return w_fp8, scale


class FP8Linear(nn.Module):
    """FP8 e4m3fn weights + per-channel scales; dynamic per-tensor activation quant."""

    def __init__(self, linear: nn.Linear):
        super().__init__()
        with torch.no_grad():
            w_fp8, w_scale = quantize_weight_per_channel(linear.weight)
        self.register_buffer("weight_fp8",   w_fp8.contiguous())   # [out, in]
        self.register_buffer("weight_scale", w_scale.squeeze(1))   # [out]
        if linear.bias is not None:
            self.register_buffer("bias", linear.bias.detach().clone())
        else:
            self.bias = None
        self.in_features  = linear.in_features
        self.out_features = linear.out_features

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        leading = x.shape[:-1]
        x2d     = x.reshape(-1, self.in_features)
        x_scale = (x2d.float().abs().max().clamp(min=1e-12) / FP8_MAX).to(SCALE_DTYPE)
        x_fp8   = (x2d.float() / x_scale).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE)
        w_scale_scalar = self.weight_scale.max().to(SCALE_DTYPE)
        out = torch._scaled_mm(
            x_fp8, self.weight_fp8.t(),
            scale_a=x_scale, scale_b=w_scale_scalar,
            out_dtype=torch.bfloat16, use_fast_accum=True,
        )
        correction = (self.weight_scale / w_scale_scalar).to(torch.bfloat16)
        out = out * correction.unsqueeze(0)
        if self.bias is not None:
            out = out + self.bias.to(out.dtype)
        return out.reshape(*leading, self.out_features)

    def extra_repr(self):
        return f"in={self.in_features}, out={self.out_features}, fp8=e4m3fn"


# ----- model walk -----
def quantize_model(model: nn.Module, verbose: bool = True):
    stats = {"converted": 0, "skipped": 0, "bytes_before": 0, "bytes_after": 0}

    def _walk(parent, prefix=""):
        for name, module in list(parent.named_children()):
            full = f"{prefix}.{name}" if prefix else name
            if isinstance(module, nn.Linear) and not isinstance(module, SKIP_TYPES):
                b_before = module.weight.numel() * module.weight.element_size()
                if module.bias is not None:
                    b_before += module.bias.numel() * module.bias.element_size()
                fp8mod = FP8Linear(module)
                b_after = fp8mod.weight_fp8.numel() + fp8mod.weight_scale.numel() * 4
                if fp8mod.bias is not None:
                    b_after += fp8mod.bias.numel() * fp8mod.bias.element_size()
                setattr(parent, name, fp8mod)
                stats["converted"]    += 1
                stats["bytes_before"] += b_before
                stats["bytes_after"]  += b_after
                if verbose:
                    print(f"  [FP8] {full:<70} {b_before/max(b_after,1):.1f}x")
            elif isinstance(module, SKIP_TYPES):
                stats["skipped"] += 1
            else:
                _walk(module, full)

    _walk(model)
    return model, stats


# ----- save -----
def save_fp8(model, out_dir: Path, stats: dict, model_path: Path):
    out_dir.mkdir(parents=True, exist_ok=True)
    state = {k: v.contiguous().cpu() for k, v in model.state_dict().items()}
    st_path = out_dir / "model.safetensors"
    save_file(state, str(st_path), metadata={"format": "pt"})

    copied = []
    for cfg in CONFIG_FILES:
        src = model_path / cfg
        if src.exists():
            shutil.copy2(src, out_dir / cfg)
            copied.append(cfg)
    if copied:
        print(f"  Copied config: {', '.join(copied)}")

    gb_before = stats["bytes_before"] / 1e9
    gb_after  = stats["bytes_after"]  / 1e9
    ratio     = round(stats["bytes_before"] / max(stats["bytes_after"], 1), 3)
    meta = {
        "dtype": "float8_e4m3fn", "weight_scaling": "per_channel_absmax",
        "activation_scaling": "dynamic_per_tensor", "matmul_op": "torch._scaled_mm",
        "output_dtype": "bfloat16", "converted_layers": stats["converted"],
        "skipped_layers": stats["skipped"], "weight_gb_before": round(gb_before, 3),
        "weight_gb_after": round(gb_after, 3), "compression_ratio": ratio,
    }
    with open(out_dir / "fp8_meta.json", "w") as f:
        json.dump(meta, f, indent=2)

    actual_gb = st_path.stat().st_size / 1e9
    print(f"\nOK {st_path}  ({actual_gb:.2f} GB on disk)")
    print(f"   weight bytes: {gb_before:.2f} GB -> {gb_after:.2f} GB ({ratio}x)")
    print(f"   {stats['converted']} layers converted, {stats['skipped']} skipped")


def _build_args_and_tokenizer(model_path: str):
    from transformers import AutoTokenizer
    from src.mimo_audio.modeling_mimo_audio import MiMoAudioArguments
    tok = AutoTokenizer.from_pretrained(model_path)
    for t in SPECIAL_TOKENS:
        if t not in tok.get_vocab():
            tok.add_tokens([t], special_tokens=True)
    gid = lambda t: tok.convert_tokens_to_ids(t)
    args = MiMoAudioArguments(
        model_name_or_path=model_path,
        sosp_idx=gid("<|sosp|>"), eosp_idx=gid("<|eosp|>"),
        empty_idx=gid("<|empty|>"), sostm_idx=gid("<|sostm|>"),
        eostm_idx=gid("<|eostm|>"), eot_idx=gid("<|eot|>"),
    )
    return args, tok


# ----- quantize entrypoint (direct LLM load, no audio tokenizer / flash-attn) -----
def run_quantize(args):
    from src.mimo_audio.modeling_mimo_audio import MiMoAudioForCausalLM

    print(f"Loading MiMoAudioForCausalLM from {args.model_path} on {args.device} ...")
    model_args, _ = _build_args_and_tokenizer(args.model_path)
    model = MiMoAudioForCausalLM.from_pretrained(
        args.model_path,
        args=model_args,
        torch_dtype=torch.bfloat16,
        device_map={"": args.device},
        attn_implementation="sdpa",
    )
    model.eval()
    print("OK loaded\n")

    print("Quantizing to FP8 e4m3fn ...")
    with torch.no_grad():
        model, stats = quantize_model(model, verbose=not args.quiet)
    save_fp8(model, Path(args.out_dir), stats, Path(args.model_path))
    print("\nDone.")


# ----- load FP8 model (for inference / verify) -----
def load_fp8_model(fp8_dir: str, tokenizer_path: str, repo_root: str, device: str = "cuda"):
    """
    Load the FP8 checkpoint for inference, returning a MimoAudio wrapper exposing .asr_sft().

    Strategy: instantiate the real architecture via from_pretrained on the ORIGINAL repo
    weights is NOT required -- instead we build the bf16 architecture from config (correct
    rotary init), replace Linears with FP8Linear shells, then load the FP8 state dict.

    repo_root must be the cloned MiMo-V2.5-ASR repo (contains src/).
    fp8_dir must contain model.safetensors + config/tokenizer files.
    """
    rr = Path(repo_root).resolve()
    if str(rr) not in sys.path:
        sys.path.insert(0, str(rr))

    from src.mimo_audio.mimo_audio import MimoAudio
    from src.mimo_audio.modeling_mimo_audio import MiMoAudioForCausalLM
    from src.mimo_audio_tokenizer import MiMoAudioTokenizer
    from transformers import AutoTokenizer, AutoConfig, GenerationConfig

    fp8_dir = Path(fp8_dir)
    model_args, tokenizer = _build_args_and_tokenizer(str(fp8_dir))

    # Build architecture with real init (correct rotary inv_freq), no pretrained shards.
    print("Building architecture (config init) ...")
    cfg = AutoConfig.from_pretrained(str(fp8_dir))
    model = MiMoAudioForCausalLM(cfg, model_args).to(torch.bfloat16)

    print("Installing FP8 modules ...")
    with torch.no_grad():
        quantize_model(model, verbose=False)
    model = model.to(device)

    print("Loading FP8 weights ...")
    state = {}
    with safe_open(str(fp8_dir / "model.safetensors"), framework="pt", device=device) as f:
        for key in f.keys():
            state[key] = f.get_tensor(key)
    model.load_state_dict(state, strict=True)
    model.eval()

    # Wrap in MimoAudio without re-running its __init__ (which would reload weights).
    mimo = object.__new__(MimoAudio)
    mimo.device = device
    mimo.path = str(fp8_dir)
    mimo.mimo_audio_tokenizer_path = tokenizer_path
    mimo.tokenizer = tokenizer
    mimo.padding_idx = int(tokenizer.pad_token_id)
    mimo.sosp_idx = model_args.sosp_idx
    mimo.eosp_idx = model_args.eosp_idx
    mimo.empty_token = model_args.empty_idx
    mimo.sostm_idx = model_args.sostm_idx
    mimo.eostm_idx = model_args.eostm_idx
    mimo.eot_idx = model_args.eot_idx
    mimo.im_start_idx = tokenizer.convert_tokens_to_ids("<|im_start|>")
    mimo.im_end_idx = tokenizer.convert_tokens_to_ids("<|im_end|>")
    mimo.model = model
    mimo.group_size = model.config.group_size
    mimo.audio_channels = model.config.audio_channels
    mimo.delay_pattern = model.config.delay_pattern
    mimo.vocab_size = model.config.vocab_size
    mimo.speech_zeroemb_idx = model.speech_empty_ids

    from src.mimo_audio.modeling_mimo_audio import MiMoSampler
    mimo.default_global_sampler = MiMoSampler(do_sample=True, temperature=0.6, top_k=50, top_p=0.95)
    mimo.default_local_sampler = MiMoSampler(do_sample=True, temperature=0.9, top_k=50, top_p=0.95)
    mimo.task_sampler_configs = {
        "asr": {"global": MiMoSampler(do_sample=False, temperature=1.0, top_p=1.0),
                "local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)},
    }
    mimo.generate_kwargs = {
        "max_length": 8192,
        "eos_token_id": tokenizer.eos_token_id,
        "pad_token_id": tokenizer.pad_token_id,
    }

    mimo.mimo_audio_tokenizer = MiMoAudioTokenizer.from_pretrained(tokenizer_path)
    mimo.mimo_audio_tokenizer.eval().bfloat16().to(device)
    from torchaudio.transforms import MelSpectrogram
    tcfg = mimo.mimo_audio_tokenizer.config
    mimo.mel_transform = MelSpectrogram(
        sample_rate=tcfg.sampling_rate, n_fft=tcfg.nfft, hop_length=tcfg.hop_length,
        win_length=tcfg.window_size, f_min=tcfg.fmin, f_max=tcfg.fmax,
        n_mels=tcfg.n_mels, power=1.0, center=True,
    ).to(device)
    print("FP8 model ready\n")
    return mimo


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model-path",     required=True)
    ap.add_argument("--out-dir",        default="./MiMo-V2.5-ASR-FP8")
    ap.add_argument("--device",         default="cuda")
    ap.add_argument("--quiet",          action="store_true")
    args = ap.parse_args()
    run_quantize(args)


if __name__ == "__main__":
    main()