""" MiMo-V2.5-ASR -> FP8 e4m3fn (per-channel weight quant, dynamic activation quant) Quantize entrypoint loads MiMoAudioForCausalLM directly (no audio tokenizer / no flash-attn needed -- the LLM is pure Qwen2). Verify/load paths still go through the full MimoAudio stack and DO require flash-attn + the audio tokenizer. """ import os import sys import json import shutil import argparse from pathlib import Path import torch import torch.nn as nn from safetensors.torch import save_file, safe_open REPO_ROOT = Path(__file__).resolve().parent if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) # ----- constants ----- FP8_DTYPE = torch.float8_e4m3fn FP8_MAX = torch.finfo(FP8_DTYPE).max # 448.0 SCALE_DTYPE = torch.float32 SKIP_TYPES = (nn.Embedding, nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.RMSNorm) CONFIG_FILES = [ "config.json", "tokenizer_config.json", "tokenizer.json", "special_tokens_map.json", "generation_config.json", "added_tokens.json", "merges.txt", "vocab.json", "chat_template.jinja", ] SPECIAL_TOKENS = ["<|sosp|>", "<|eosp|>", "<|empty|>", "<|Human|>", "<|SpeechLM|>", "<|sostm|>", "<|eostm|>", "<|eot|>"] # ----- weight quantization ----- def quantize_weight_per_channel(weight: torch.Tensor): """Per output-channel absmax scaling. weight: [out, in]""" w = weight.float() amax = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-12) scale = (amax / FP8_MAX).to(SCALE_DTYPE) w_fp8 = (w / scale).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE) return w_fp8, scale class FP8Linear(nn.Module): """FP8 e4m3fn weights + per-channel scales; dynamic per-tensor activation quant.""" def __init__(self, linear: nn.Linear): super().__init__() with torch.no_grad(): w_fp8, w_scale = quantize_weight_per_channel(linear.weight) self.register_buffer("weight_fp8", w_fp8.contiguous()) # [out, in] self.register_buffer("weight_scale", w_scale.squeeze(1)) # [out] if linear.bias is not None: self.register_buffer("bias", linear.bias.detach().clone()) else: self.bias = None self.in_features = linear.in_features self.out_features = linear.out_features def forward(self, x: torch.Tensor) -> torch.Tensor: leading = x.shape[:-1] x2d = x.reshape(-1, self.in_features) x_scale = (x2d.float().abs().max().clamp(min=1e-12) / FP8_MAX).to(SCALE_DTYPE) x_fp8 = (x2d.float() / x_scale).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE) w_scale_scalar = self.weight_scale.max().to(SCALE_DTYPE) out = torch._scaled_mm( x_fp8, self.weight_fp8.t(), scale_a=x_scale, scale_b=w_scale_scalar, out_dtype=torch.bfloat16, use_fast_accum=True, ) correction = (self.weight_scale / w_scale_scalar).to(torch.bfloat16) out = out * correction.unsqueeze(0) if self.bias is not None: out = out + self.bias.to(out.dtype) return out.reshape(*leading, self.out_features) def extra_repr(self): return f"in={self.in_features}, out={self.out_features}, fp8=e4m3fn" # ----- model walk ----- def quantize_model(model: nn.Module, verbose: bool = True): stats = {"converted": 0, "skipped": 0, "bytes_before": 0, "bytes_after": 0} def _walk(parent, prefix=""): for name, module in list(parent.named_children()): full = f"{prefix}.{name}" if prefix else name if isinstance(module, nn.Linear) and not isinstance(module, SKIP_TYPES): b_before = module.weight.numel() * module.weight.element_size() if module.bias is not None: b_before += module.bias.numel() * module.bias.element_size() fp8mod = FP8Linear(module) b_after = fp8mod.weight_fp8.numel() + fp8mod.weight_scale.numel() * 4 if fp8mod.bias is not None: b_after += fp8mod.bias.numel() * fp8mod.bias.element_size() setattr(parent, name, fp8mod) stats["converted"] += 1 stats["bytes_before"] += b_before stats["bytes_after"] += b_after if verbose: print(f" [FP8] {full:<70} {b_before/max(b_after,1):.1f}x") elif isinstance(module, SKIP_TYPES): stats["skipped"] += 1 else: _walk(module, full) _walk(model) return model, stats # ----- save ----- def save_fp8(model, out_dir: Path, stats: dict, model_path: Path): out_dir.mkdir(parents=True, exist_ok=True) state = {k: v.contiguous().cpu() for k, v in model.state_dict().items()} st_path = out_dir / "model.safetensors" save_file(state, str(st_path), metadata={"format": "pt"}) copied = [] for cfg in CONFIG_FILES: src = model_path / cfg if src.exists(): shutil.copy2(src, out_dir / cfg) copied.append(cfg) if copied: print(f" Copied config: {', '.join(copied)}") gb_before = stats["bytes_before"] / 1e9 gb_after = stats["bytes_after"] / 1e9 ratio = round(stats["bytes_before"] / max(stats["bytes_after"], 1), 3) meta = { "dtype": "float8_e4m3fn", "weight_scaling": "per_channel_absmax", "activation_scaling": "dynamic_per_tensor", "matmul_op": "torch._scaled_mm", "output_dtype": "bfloat16", "converted_layers": stats["converted"], "skipped_layers": stats["skipped"], "weight_gb_before": round(gb_before, 3), "weight_gb_after": round(gb_after, 3), "compression_ratio": ratio, } with open(out_dir / "fp8_meta.json", "w") as f: json.dump(meta, f, indent=2) actual_gb = st_path.stat().st_size / 1e9 print(f"\nOK {st_path} ({actual_gb:.2f} GB on disk)") print(f" weight bytes: {gb_before:.2f} GB -> {gb_after:.2f} GB ({ratio}x)") print(f" {stats['converted']} layers converted, {stats['skipped']} skipped") def _build_args_and_tokenizer(model_path: str): from transformers import AutoTokenizer from src.mimo_audio.modeling_mimo_audio import MiMoAudioArguments tok = AutoTokenizer.from_pretrained(model_path) for t in SPECIAL_TOKENS: if t not in tok.get_vocab(): tok.add_tokens([t], special_tokens=True) gid = lambda t: tok.convert_tokens_to_ids(t) args = MiMoAudioArguments( model_name_or_path=model_path, sosp_idx=gid("<|sosp|>"), eosp_idx=gid("<|eosp|>"), empty_idx=gid("<|empty|>"), sostm_idx=gid("<|sostm|>"), eostm_idx=gid("<|eostm|>"), eot_idx=gid("<|eot|>"), ) return args, tok # ----- quantize entrypoint (direct LLM load, no audio tokenizer / flash-attn) ----- def run_quantize(args): from src.mimo_audio.modeling_mimo_audio import MiMoAudioForCausalLM print(f"Loading MiMoAudioForCausalLM from {args.model_path} on {args.device} ...") model_args, _ = _build_args_and_tokenizer(args.model_path) model = MiMoAudioForCausalLM.from_pretrained( args.model_path, args=model_args, torch_dtype=torch.bfloat16, device_map={"": args.device}, attn_implementation="sdpa", ) model.eval() print("OK loaded\n") print("Quantizing to FP8 e4m3fn ...") with torch.no_grad(): model, stats = quantize_model(model, verbose=not args.quiet) save_fp8(model, Path(args.out_dir), stats, Path(args.model_path)) print("\nDone.") # ----- load FP8 model (for inference / verify) ----- def load_fp8_model(fp8_dir: str, tokenizer_path: str, repo_root: str, device: str = "cuda"): """ Load the FP8 checkpoint for inference, returning a MimoAudio wrapper exposing .asr_sft(). Strategy: instantiate the real architecture via from_pretrained on the ORIGINAL repo weights is NOT required -- instead we build the bf16 architecture from config (correct rotary init), replace Linears with FP8Linear shells, then load the FP8 state dict. repo_root must be the cloned MiMo-V2.5-ASR repo (contains src/). fp8_dir must contain model.safetensors + config/tokenizer files. """ rr = Path(repo_root).resolve() if str(rr) not in sys.path: sys.path.insert(0, str(rr)) from src.mimo_audio.mimo_audio import MimoAudio from src.mimo_audio.modeling_mimo_audio import MiMoAudioForCausalLM from src.mimo_audio_tokenizer import MiMoAudioTokenizer from transformers import AutoTokenizer, AutoConfig, GenerationConfig fp8_dir = Path(fp8_dir) model_args, tokenizer = _build_args_and_tokenizer(str(fp8_dir)) # Build architecture with real init (correct rotary inv_freq), no pretrained shards. print("Building architecture (config init) ...") cfg = AutoConfig.from_pretrained(str(fp8_dir)) model = MiMoAudioForCausalLM(cfg, model_args).to(torch.bfloat16) print("Installing FP8 modules ...") with torch.no_grad(): quantize_model(model, verbose=False) model = model.to(device) print("Loading FP8 weights ...") state = {} with safe_open(str(fp8_dir / "model.safetensors"), framework="pt", device=device) as f: for key in f.keys(): state[key] = f.get_tensor(key) model.load_state_dict(state, strict=True) model.eval() # Wrap in MimoAudio without re-running its __init__ (which would reload weights). mimo = object.__new__(MimoAudio) mimo.device = device mimo.path = str(fp8_dir) mimo.mimo_audio_tokenizer_path = tokenizer_path mimo.tokenizer = tokenizer mimo.padding_idx = int(tokenizer.pad_token_id) mimo.sosp_idx = model_args.sosp_idx mimo.eosp_idx = model_args.eosp_idx mimo.empty_token = model_args.empty_idx mimo.sostm_idx = model_args.sostm_idx mimo.eostm_idx = model_args.eostm_idx mimo.eot_idx = model_args.eot_idx mimo.im_start_idx = tokenizer.convert_tokens_to_ids("<|im_start|>") mimo.im_end_idx = tokenizer.convert_tokens_to_ids("<|im_end|>") mimo.model = model mimo.group_size = model.config.group_size mimo.audio_channels = model.config.audio_channels mimo.delay_pattern = model.config.delay_pattern mimo.vocab_size = model.config.vocab_size mimo.speech_zeroemb_idx = model.speech_empty_ids from src.mimo_audio.modeling_mimo_audio import MiMoSampler mimo.default_global_sampler = MiMoSampler(do_sample=True, temperature=0.6, top_k=50, top_p=0.95) mimo.default_local_sampler = MiMoSampler(do_sample=True, temperature=0.9, top_k=50, top_p=0.95) mimo.task_sampler_configs = { "asr": {"global": MiMoSampler(do_sample=False, temperature=1.0, top_p=1.0), "local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)}, } mimo.generate_kwargs = { "max_length": 8192, "eos_token_id": tokenizer.eos_token_id, "pad_token_id": tokenizer.pad_token_id, } mimo.mimo_audio_tokenizer = MiMoAudioTokenizer.from_pretrained(tokenizer_path) mimo.mimo_audio_tokenizer.eval().bfloat16().to(device) from torchaudio.transforms import MelSpectrogram tcfg = mimo.mimo_audio_tokenizer.config mimo.mel_transform = MelSpectrogram( sample_rate=tcfg.sampling_rate, n_fft=tcfg.nfft, hop_length=tcfg.hop_length, win_length=tcfg.window_size, f_min=tcfg.fmin, f_max=tcfg.fmax, n_mels=tcfg.n_mels, power=1.0, center=True, ).to(device) print("FP8 model ready\n") return mimo def main(): ap = argparse.ArgumentParser() ap.add_argument("--model-path", required=True) ap.add_argument("--out-dir", default="./MiMo-V2.5-ASR-FP8") ap.add_argument("--device", default="cuda") ap.add_argument("--quiet", action="store_true") args = ap.parse_args() run_quantize(args) if __name__ == "__main__": main()