| """ |
| Streaming CPU FP8 quantizer for MiMo-V2.5-ASR. |
| |
| Avoids loading the full model into RAM and never uses CUDA (works around the lack of |
| sm_120 / Blackwell kernels in torch 2.6+cu124). Reads the original safetensors shards, |
| quantizes every nn.Linear weight to float8_e4m3fn (per-out-channel absmax), leaves |
| embeddings / norms / biases in bf16, and writes a single model.safetensors whose keys |
| match the FP8Linear loader (`*.weight_fp8`, `*.weight_scale`). |
| """ |
| import sys, json, shutil, argparse |
| from pathlib import Path |
| import torch |
| import torch.nn as nn |
| from safetensors.torch import save_file, safe_open |
|
|
| REPO = Path(__file__).resolve().parent |
| sys.path.insert(0, str(REPO)) |
| from quantize_fp8 import ( |
| quantize_weight_per_channel, CONFIG_FILES, _build_args_and_tokenizer, FP8_DTYPE, |
| ) |
|
|
|
|
| def linear_weight_names(model_path: str) -> set: |
| """Build the model on meta and collect '<path>.weight' for every nn.Linear.""" |
| from src.mimo_audio.modeling_mimo_audio import MiMoAudioForCausalLM |
| from transformers import AutoConfig |
| args, _ = _build_args_and_tokenizer(model_path) |
| cfg = AutoConfig.from_pretrained(model_path) |
| with torch.device("meta"): |
| model = MiMoAudioForCausalLM(cfg, args) |
| names = set() |
| for name, mod in model.named_modules(): |
| if isinstance(mod, nn.Linear): |
| names.add(name + ".weight") |
| return names |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--model-path", required=True) |
| ap.add_argument("--out-dir", required=True) |
| ap.add_argument("--dry-run", action="store_true") |
| a = ap.parse_args() |
| mp = Path(a.model_path) |
| out = Path(a.out_dir) |
|
|
| lin = linear_weight_names(str(mp)) |
| print(f"quantizable Linear weights: {len(lin)}") |
|
|
| idx = json.loads((mp / "model.safetensors.index.json").read_text()) |
| weight_map = idx["weight_map"] |
| shards = {} |
| for k, sh in weight_map.items(): |
| shards.setdefault(sh, []).append(k) |
| all_keys = set(weight_map) |
| matched = lin & all_keys |
| missing = lin - all_keys |
| print(f"checkpoint keys: {len(all_keys)} | linear matched: {len(matched)} | linear missing in ckpt: {len(missing)}") |
| if missing: |
| print(" MISSING (first 10):", sorted(missing)[:10]) |
|
|
| if a.dry_run: |
| |
| nonlin = sorted(all_keys - lin) |
| print("sample NON-linear keys kept as-is:", nonlin[:8]) |
| print("total params:", len(all_keys), "-> fp8:", len(matched), "kept:", len(all_keys) - len(matched)) |
| return |
|
|
| out.mkdir(parents=True, exist_ok=True) |
| new_state = {} |
| bytes_before = bytes_after = 0 |
| conv = 0 |
| for shard in sorted(shards): |
| with safe_open(str(mp / shard), framework="pt", device="cpu") as f: |
| for k in shards[shard]: |
| t = f.get_tensor(k) |
| if k in lin: |
| w_fp8, scale = quantize_weight_per_channel(t) |
| base = k[:-len(".weight")] |
| new_state[base + ".weight_fp8"] = w_fp8.contiguous() |
| new_state[base + ".weight_scale"] = scale.squeeze(1).contiguous() |
| bytes_before += t.numel() * t.element_size() |
| bytes_after += w_fp8.numel() + scale.numel() * 4 |
| conv += 1 |
| else: |
| bytes_before += t.numel() * t.element_size() |
| |
| |
| if t.is_floating_point(): |
| t = t.to(torch.bfloat16) |
| new_state[k] = t.contiguous() |
| bytes_after += t.numel() * t.element_size() |
| print(f" done {shard} (running fp8 layers: {conv})") |
|
|
| st = out / "model.safetensors" |
| save_file(new_state, str(st), metadata={"format": "pt"}) |
|
|
| copied = [] |
| for cfg in CONFIG_FILES: |
| src = mp / cfg |
| if src.exists(): |
| shutil.copy2(src, out / cfg); copied.append(cfg) |
|
|
| ratio = round(bytes_before / max(bytes_after, 1), 3) |
| meta = { |
| "dtype": "float8_e4m3fn", "weight_scaling": "per_channel_absmax", |
| "activation_scaling": "dynamic_per_tensor", "matmul_op": "torch._scaled_mm", |
| "output_dtype": "bfloat16", "converted_layers": conv, |
| "weight_gb_before": round(bytes_before / 1e9, 3), |
| "weight_gb_after": round(bytes_after / 1e9, 3), "compression_ratio": ratio, |
| "quantizer": "streaming_cpu", |
| } |
| (out / "fp8_meta.json").write_text(json.dumps(meta, indent=2)) |
| gb = st.stat().st_size / 1e9 |
| print(f"\nOK {st} ({gb:.2f} GB on disk)") |
| print(f" {conv} layers -> fp8 | {round(bytes_before/1e9,2)}GB -> {round(bytes_after/1e9,2)}GB ({ratio}x)") |
| print(f" copied config: {', '.join(copied)}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|