| """ |
| MiMo-V2.5-ASR -> FP8 e4m3fn (per-channel weight quant, dynamic activation quant) |
| |
| Quantize entrypoint loads MiMoAudioForCausalLM directly (no audio tokenizer / no |
| flash-attn needed -- the LLM is pure Qwen2). Verify/load paths still go through the |
| full MimoAudio stack and DO require flash-attn + the audio tokenizer. |
| """ |
|
|
| import os |
| import sys |
| import json |
| import shutil |
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| from safetensors.torch import save_file, safe_open |
|
|
| REPO_ROOT = Path(__file__).resolve().parent |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| |
| FP8_DTYPE = torch.float8_e4m3fn |
| FP8_MAX = torch.finfo(FP8_DTYPE).max |
| SCALE_DTYPE = torch.float32 |
| SKIP_TYPES = (nn.Embedding, nn.LayerNorm, nn.GroupNorm, |
| nn.BatchNorm1d, nn.BatchNorm2d, nn.RMSNorm) |
|
|
| CONFIG_FILES = [ |
| "config.json", "tokenizer_config.json", "tokenizer.json", |
| "special_tokens_map.json", "generation_config.json", |
| "added_tokens.json", "merges.txt", "vocab.json", "chat_template.jinja", |
| ] |
|
|
| SPECIAL_TOKENS = ["<|sosp|>", "<|eosp|>", "<|empty|>", "<|Human|>", |
| "<|SpeechLM|>", "<|sostm|>", "<|eostm|>", "<|eot|>"] |
|
|
|
|
| |
| def quantize_weight_per_channel(weight: torch.Tensor): |
| """Per output-channel absmax scaling. weight: [out, in]""" |
| w = weight.float() |
| amax = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-12) |
| scale = (amax / FP8_MAX).to(SCALE_DTYPE) |
| w_fp8 = (w / scale).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE) |
| return w_fp8, scale |
|
|
|
|
| class FP8Linear(nn.Module): |
| """FP8 e4m3fn weights + per-channel scales; dynamic per-tensor activation quant.""" |
|
|
| def __init__(self, linear: nn.Linear): |
| super().__init__() |
| with torch.no_grad(): |
| w_fp8, w_scale = quantize_weight_per_channel(linear.weight) |
| self.register_buffer("weight_fp8", w_fp8.contiguous()) |
| self.register_buffer("weight_scale", w_scale.squeeze(1)) |
| if linear.bias is not None: |
| self.register_buffer("bias", linear.bias.detach().clone()) |
| else: |
| self.bias = None |
| self.in_features = linear.in_features |
| self.out_features = linear.out_features |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| leading = x.shape[:-1] |
| x2d = x.reshape(-1, self.in_features) |
| x_scale = (x2d.float().abs().max().clamp(min=1e-12) / FP8_MAX).to(SCALE_DTYPE) |
| x_fp8 = (x2d.float() / x_scale).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE) |
| w_scale_scalar = self.weight_scale.max().to(SCALE_DTYPE) |
| out = torch._scaled_mm( |
| x_fp8, self.weight_fp8.t(), |
| scale_a=x_scale, scale_b=w_scale_scalar, |
| out_dtype=torch.bfloat16, use_fast_accum=True, |
| ) |
| correction = (self.weight_scale / w_scale_scalar).to(torch.bfloat16) |
| out = out * correction.unsqueeze(0) |
| if self.bias is not None: |
| out = out + self.bias.to(out.dtype) |
| return out.reshape(*leading, self.out_features) |
|
|
| def extra_repr(self): |
| return f"in={self.in_features}, out={self.out_features}, fp8=e4m3fn" |
|
|
|
|
| |
| def quantize_model(model: nn.Module, verbose: bool = True): |
| stats = {"converted": 0, "skipped": 0, "bytes_before": 0, "bytes_after": 0} |
|
|
| def _walk(parent, prefix=""): |
| for name, module in list(parent.named_children()): |
| full = f"{prefix}.{name}" if prefix else name |
| if isinstance(module, nn.Linear) and not isinstance(module, SKIP_TYPES): |
| b_before = module.weight.numel() * module.weight.element_size() |
| if module.bias is not None: |
| b_before += module.bias.numel() * module.bias.element_size() |
| fp8mod = FP8Linear(module) |
| b_after = fp8mod.weight_fp8.numel() + fp8mod.weight_scale.numel() * 4 |
| if fp8mod.bias is not None: |
| b_after += fp8mod.bias.numel() * fp8mod.bias.element_size() |
| setattr(parent, name, fp8mod) |
| stats["converted"] += 1 |
| stats["bytes_before"] += b_before |
| stats["bytes_after"] += b_after |
| if verbose: |
| print(f" [FP8] {full:<70} {b_before/max(b_after,1):.1f}x") |
| elif isinstance(module, SKIP_TYPES): |
| stats["skipped"] += 1 |
| else: |
| _walk(module, full) |
|
|
| _walk(model) |
| return model, stats |
|
|
|
|
| |
| def save_fp8(model, out_dir: Path, stats: dict, model_path: Path): |
| out_dir.mkdir(parents=True, exist_ok=True) |
| state = {k: v.contiguous().cpu() for k, v in model.state_dict().items()} |
| st_path = out_dir / "model.safetensors" |
| save_file(state, str(st_path), metadata={"format": "pt"}) |
|
|
| copied = [] |
| for cfg in CONFIG_FILES: |
| src = model_path / cfg |
| if src.exists(): |
| shutil.copy2(src, out_dir / cfg) |
| copied.append(cfg) |
| if copied: |
| print(f" Copied config: {', '.join(copied)}") |
|
|
| gb_before = stats["bytes_before"] / 1e9 |
| gb_after = stats["bytes_after"] / 1e9 |
| ratio = round(stats["bytes_before"] / max(stats["bytes_after"], 1), 3) |
| meta = { |
| "dtype": "float8_e4m3fn", "weight_scaling": "per_channel_absmax", |
| "activation_scaling": "dynamic_per_tensor", "matmul_op": "torch._scaled_mm", |
| "output_dtype": "bfloat16", "converted_layers": stats["converted"], |
| "skipped_layers": stats["skipped"], "weight_gb_before": round(gb_before, 3), |
| "weight_gb_after": round(gb_after, 3), "compression_ratio": ratio, |
| } |
| with open(out_dir / "fp8_meta.json", "w") as f: |
| json.dump(meta, f, indent=2) |
|
|
| actual_gb = st_path.stat().st_size / 1e9 |
| print(f"\nOK {st_path} ({actual_gb:.2f} GB on disk)") |
| print(f" weight bytes: {gb_before:.2f} GB -> {gb_after:.2f} GB ({ratio}x)") |
| print(f" {stats['converted']} layers converted, {stats['skipped']} skipped") |
|
|
|
|
| def _build_args_and_tokenizer(model_path: str): |
| from transformers import AutoTokenizer |
| from src.mimo_audio.modeling_mimo_audio import MiMoAudioArguments |
| tok = AutoTokenizer.from_pretrained(model_path) |
| for t in SPECIAL_TOKENS: |
| if t not in tok.get_vocab(): |
| tok.add_tokens([t], special_tokens=True) |
| gid = lambda t: tok.convert_tokens_to_ids(t) |
| args = MiMoAudioArguments( |
| model_name_or_path=model_path, |
| sosp_idx=gid("<|sosp|>"), eosp_idx=gid("<|eosp|>"), |
| empty_idx=gid("<|empty|>"), sostm_idx=gid("<|sostm|>"), |
| eostm_idx=gid("<|eostm|>"), eot_idx=gid("<|eot|>"), |
| ) |
| return args, tok |
|
|
|
|
| |
| def run_quantize(args): |
| from src.mimo_audio.modeling_mimo_audio import MiMoAudioForCausalLM |
|
|
| print(f"Loading MiMoAudioForCausalLM from {args.model_path} on {args.device} ...") |
| model_args, _ = _build_args_and_tokenizer(args.model_path) |
| model = MiMoAudioForCausalLM.from_pretrained( |
| args.model_path, |
| args=model_args, |
| torch_dtype=torch.bfloat16, |
| device_map={"": args.device}, |
| attn_implementation="sdpa", |
| ) |
| model.eval() |
| print("OK loaded\n") |
|
|
| print("Quantizing to FP8 e4m3fn ...") |
| with torch.no_grad(): |
| model, stats = quantize_model(model, verbose=not args.quiet) |
| save_fp8(model, Path(args.out_dir), stats, Path(args.model_path)) |
| print("\nDone.") |
|
|
|
|
| |
| def load_fp8_model(fp8_dir: str, tokenizer_path: str, repo_root: str, device: str = "cuda"): |
| """ |
| Load the FP8 checkpoint for inference, returning a MimoAudio wrapper exposing .asr_sft(). |
| |
| Strategy: instantiate the real architecture via from_pretrained on the ORIGINAL repo |
| weights is NOT required -- instead we build the bf16 architecture from config (correct |
| rotary init), replace Linears with FP8Linear shells, then load the FP8 state dict. |
| |
| repo_root must be the cloned MiMo-V2.5-ASR repo (contains src/). |
| fp8_dir must contain model.safetensors + config/tokenizer files. |
| """ |
| rr = Path(repo_root).resolve() |
| if str(rr) not in sys.path: |
| sys.path.insert(0, str(rr)) |
|
|
| from src.mimo_audio.mimo_audio import MimoAudio |
| from src.mimo_audio.modeling_mimo_audio import MiMoAudioForCausalLM |
| from src.mimo_audio_tokenizer import MiMoAudioTokenizer |
| from transformers import AutoTokenizer, AutoConfig, GenerationConfig |
|
|
| fp8_dir = Path(fp8_dir) |
| model_args, tokenizer = _build_args_and_tokenizer(str(fp8_dir)) |
|
|
| |
| print("Building architecture (config init) ...") |
| cfg = AutoConfig.from_pretrained(str(fp8_dir)) |
| model = MiMoAudioForCausalLM(cfg, model_args).to(torch.bfloat16) |
|
|
| print("Installing FP8 modules ...") |
| with torch.no_grad(): |
| quantize_model(model, verbose=False) |
| model = model.to(device) |
|
|
| print("Loading FP8 weights ...") |
| state = {} |
| with safe_open(str(fp8_dir / "model.safetensors"), framework="pt", device=device) as f: |
| for key in f.keys(): |
| state[key] = f.get_tensor(key) |
| model.load_state_dict(state, strict=True) |
| model.eval() |
|
|
| |
| mimo = object.__new__(MimoAudio) |
| mimo.device = device |
| mimo.path = str(fp8_dir) |
| mimo.mimo_audio_tokenizer_path = tokenizer_path |
| mimo.tokenizer = tokenizer |
| mimo.padding_idx = int(tokenizer.pad_token_id) |
| mimo.sosp_idx = model_args.sosp_idx |
| mimo.eosp_idx = model_args.eosp_idx |
| mimo.empty_token = model_args.empty_idx |
| mimo.sostm_idx = model_args.sostm_idx |
| mimo.eostm_idx = model_args.eostm_idx |
| mimo.eot_idx = model_args.eot_idx |
| mimo.im_start_idx = tokenizer.convert_tokens_to_ids("<|im_start|>") |
| mimo.im_end_idx = tokenizer.convert_tokens_to_ids("<|im_end|>") |
| mimo.model = model |
| mimo.group_size = model.config.group_size |
| mimo.audio_channels = model.config.audio_channels |
| mimo.delay_pattern = model.config.delay_pattern |
| mimo.vocab_size = model.config.vocab_size |
| mimo.speech_zeroemb_idx = model.speech_empty_ids |
|
|
| from src.mimo_audio.modeling_mimo_audio import MiMoSampler |
| mimo.default_global_sampler = MiMoSampler(do_sample=True, temperature=0.6, top_k=50, top_p=0.95) |
| mimo.default_local_sampler = MiMoSampler(do_sample=True, temperature=0.9, top_k=50, top_p=0.95) |
| mimo.task_sampler_configs = { |
| "asr": {"global": MiMoSampler(do_sample=False, temperature=1.0, top_p=1.0), |
| "local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)}, |
| } |
| mimo.generate_kwargs = { |
| "max_length": 8192, |
| "eos_token_id": tokenizer.eos_token_id, |
| "pad_token_id": tokenizer.pad_token_id, |
| } |
|
|
| mimo.mimo_audio_tokenizer = MiMoAudioTokenizer.from_pretrained(tokenizer_path) |
| mimo.mimo_audio_tokenizer.eval().bfloat16().to(device) |
| from torchaudio.transforms import MelSpectrogram |
| tcfg = mimo.mimo_audio_tokenizer.config |
| mimo.mel_transform = MelSpectrogram( |
| sample_rate=tcfg.sampling_rate, n_fft=tcfg.nfft, hop_length=tcfg.hop_length, |
| win_length=tcfg.window_size, f_min=tcfg.fmin, f_max=tcfg.fmax, |
| n_mels=tcfg.n_mels, power=1.0, center=True, |
| ).to(device) |
| print("FP8 model ready\n") |
| return mimo |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--model-path", required=True) |
| ap.add_argument("--out-dir", default="./MiMo-V2.5-ASR-FP8") |
| ap.add_argument("--device", default="cuda") |
| ap.add_argument("--quiet", action="store_true") |
| args = ap.parse_args() |
| run_quantize(args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|