""" Streaming CPU FP8 quantizer for MiMo-V2.5-ASR. Avoids loading the full model into RAM and never uses CUDA (works around the lack of sm_120 / Blackwell kernels in torch 2.6+cu124). Reads the original safetensors shards, quantizes every nn.Linear weight to float8_e4m3fn (per-out-channel absmax), leaves embeddings / norms / biases in bf16, and writes a single model.safetensors whose keys match the FP8Linear loader (`*.weight_fp8`, `*.weight_scale`). """ import sys, json, shutil, argparse from pathlib import Path import torch import torch.nn as nn from safetensors.torch import save_file, safe_open REPO = Path(__file__).resolve().parent sys.path.insert(0, str(REPO)) from quantize_fp8 import ( quantize_weight_per_channel, CONFIG_FILES, _build_args_and_tokenizer, FP8_DTYPE, ) def linear_weight_names(model_path: str) -> set: """Build the model on meta and collect '.weight' for every nn.Linear.""" from src.mimo_audio.modeling_mimo_audio import MiMoAudioForCausalLM from transformers import AutoConfig args, _ = _build_args_and_tokenizer(model_path) cfg = AutoConfig.from_pretrained(model_path) with torch.device("meta"): model = MiMoAudioForCausalLM(cfg, args) names = set() for name, mod in model.named_modules(): if isinstance(mod, nn.Linear): names.add(name + ".weight") return names def main(): ap = argparse.ArgumentParser() ap.add_argument("--model-path", required=True) ap.add_argument("--out-dir", required=True) ap.add_argument("--dry-run", action="store_true") a = ap.parse_args() mp = Path(a.model_path) out = Path(a.out_dir) lin = linear_weight_names(str(mp)) print(f"quantizable Linear weights: {len(lin)}") idx = json.loads((mp / "model.safetensors.index.json").read_text()) weight_map = idx["weight_map"] shards = {} for k, sh in weight_map.items(): shards.setdefault(sh, []).append(k) all_keys = set(weight_map) matched = lin & all_keys missing = lin - all_keys print(f"checkpoint keys: {len(all_keys)} | linear matched: {len(matched)} | linear missing in ckpt: {len(missing)}") if missing: print(" MISSING (first 10):", sorted(missing)[:10]) if a.dry_run: # show a few non-linear keys for sanity nonlin = sorted(all_keys - lin) print("sample NON-linear keys kept as-is:", nonlin[:8]) print("total params:", len(all_keys), "-> fp8:", len(matched), "kept:", len(all_keys) - len(matched)) return out.mkdir(parents=True, exist_ok=True) new_state = {} bytes_before = bytes_after = 0 conv = 0 for shard in sorted(shards): with safe_open(str(mp / shard), framework="pt", device="cpu") as f: for k in shards[shard]: t = f.get_tensor(k) if k in lin: w_fp8, scale = quantize_weight_per_channel(t) # scale [out,1] base = k[:-len(".weight")] new_state[base + ".weight_fp8"] = w_fp8.contiguous() new_state[base + ".weight_scale"] = scale.squeeze(1).contiguous() bytes_before += t.numel() * t.element_size() bytes_after += w_fp8.numel() + scale.numel() * 4 conv += 1 else: bytes_before += t.numel() * t.element_size() # kept tensors (embeddings/norms/biases): store bf16 to match the # bf16 model the FP8Linear loader builds, and to shrink the file. if t.is_floating_point(): t = t.to(torch.bfloat16) new_state[k] = t.contiguous() bytes_after += t.numel() * t.element_size() print(f" done {shard} (running fp8 layers: {conv})") st = out / "model.safetensors" save_file(new_state, str(st), metadata={"format": "pt"}) copied = [] for cfg in CONFIG_FILES: src = mp / cfg if src.exists(): shutil.copy2(src, out / cfg); copied.append(cfg) ratio = round(bytes_before / max(bytes_after, 1), 3) meta = { "dtype": "float8_e4m3fn", "weight_scaling": "per_channel_absmax", "activation_scaling": "dynamic_per_tensor", "matmul_op": "torch._scaled_mm", "output_dtype": "bfloat16", "converted_layers": conv, "weight_gb_before": round(bytes_before / 1e9, 3), "weight_gb_after": round(bytes_after / 1e9, 3), "compression_ratio": ratio, "quantizer": "streaming_cpu", } (out / "fp8_meta.json").write_text(json.dumps(meta, indent=2)) gb = st.stat().st_size / 1e9 print(f"\nOK {st} ({gb:.2f} GB on disk)") print(f" {conv} layers -> fp8 | {round(bytes_before/1e9,2)}GB -> {round(bytes_after/1e9,2)}GB ({ratio}x)") print(f" copied config: {', '.join(copied)}") if __name__ == "__main__": main()