MiMo-V2.5-ASR-FP8 / quantize_fp8_stream.py
Infatoshi's picture
Upload folder using huggingface_hub
04e43d3 verified
Raw
History Blame Contribute Delete
4.92 kB
"""
Streaming CPU FP8 quantizer for MiMo-V2.5-ASR.
Avoids loading the full model into RAM and never uses CUDA (works around the lack of
sm_120 / Blackwell kernels in torch 2.6+cu124). Reads the original safetensors shards,
quantizes every nn.Linear weight to float8_e4m3fn (per-out-channel absmax), leaves
embeddings / norms / biases in bf16, and writes a single model.safetensors whose keys
match the FP8Linear loader (`*.weight_fp8`, `*.weight_scale`).
"""
import sys, json, shutil, argparse
from pathlib import Path
import torch
import torch.nn as nn
from safetensors.torch import save_file, safe_open
REPO = Path(__file__).resolve().parent
sys.path.insert(0, str(REPO))
from quantize_fp8 import (
quantize_weight_per_channel, CONFIG_FILES, _build_args_and_tokenizer, FP8_DTYPE,
)
def linear_weight_names(model_path: str) -> set:
"""Build the model on meta and collect '<path>.weight' for every nn.Linear."""
from src.mimo_audio.modeling_mimo_audio import MiMoAudioForCausalLM
from transformers import AutoConfig
args, _ = _build_args_and_tokenizer(model_path)
cfg = AutoConfig.from_pretrained(model_path)
with torch.device("meta"):
model = MiMoAudioForCausalLM(cfg, args)
names = set()
for name, mod in model.named_modules():
if isinstance(mod, nn.Linear):
names.add(name + ".weight")
return names
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model-path", required=True)
ap.add_argument("--out-dir", required=True)
ap.add_argument("--dry-run", action="store_true")
a = ap.parse_args()
mp = Path(a.model_path)
out = Path(a.out_dir)
lin = linear_weight_names(str(mp))
print(f"quantizable Linear weights: {len(lin)}")
idx = json.loads((mp / "model.safetensors.index.json").read_text())
weight_map = idx["weight_map"]
shards = {}
for k, sh in weight_map.items():
shards.setdefault(sh, []).append(k)
all_keys = set(weight_map)
matched = lin & all_keys
missing = lin - all_keys
print(f"checkpoint keys: {len(all_keys)} | linear matched: {len(matched)} | linear missing in ckpt: {len(missing)}")
if missing:
print(" MISSING (first 10):", sorted(missing)[:10])
if a.dry_run:
# show a few non-linear keys for sanity
nonlin = sorted(all_keys - lin)
print("sample NON-linear keys kept as-is:", nonlin[:8])
print("total params:", len(all_keys), "-> fp8:", len(matched), "kept:", len(all_keys) - len(matched))
return
out.mkdir(parents=True, exist_ok=True)
new_state = {}
bytes_before = bytes_after = 0
conv = 0
for shard in sorted(shards):
with safe_open(str(mp / shard), framework="pt", device="cpu") as f:
for k in shards[shard]:
t = f.get_tensor(k)
if k in lin:
w_fp8, scale = quantize_weight_per_channel(t) # scale [out,1]
base = k[:-len(".weight")]
new_state[base + ".weight_fp8"] = w_fp8.contiguous()
new_state[base + ".weight_scale"] = scale.squeeze(1).contiguous()
bytes_before += t.numel() * t.element_size()
bytes_after += w_fp8.numel() + scale.numel() * 4
conv += 1
else:
bytes_before += t.numel() * t.element_size()
# kept tensors (embeddings/norms/biases): store bf16 to match the
# bf16 model the FP8Linear loader builds, and to shrink the file.
if t.is_floating_point():
t = t.to(torch.bfloat16)
new_state[k] = t.contiguous()
bytes_after += t.numel() * t.element_size()
print(f" done {shard} (running fp8 layers: {conv})")
st = out / "model.safetensors"
save_file(new_state, str(st), metadata={"format": "pt"})
copied = []
for cfg in CONFIG_FILES:
src = mp / cfg
if src.exists():
shutil.copy2(src, out / cfg); copied.append(cfg)
ratio = round(bytes_before / max(bytes_after, 1), 3)
meta = {
"dtype": "float8_e4m3fn", "weight_scaling": "per_channel_absmax",
"activation_scaling": "dynamic_per_tensor", "matmul_op": "torch._scaled_mm",
"output_dtype": "bfloat16", "converted_layers": conv,
"weight_gb_before": round(bytes_before / 1e9, 3),
"weight_gb_after": round(bytes_after / 1e9, 3), "compression_ratio": ratio,
"quantizer": "streaming_cpu",
}
(out / "fp8_meta.json").write_text(json.dumps(meta, indent=2))
gb = st.stat().st_size / 1e9
print(f"\nOK {st} ({gb:.2f} GB on disk)")
print(f" {conv} layers -> fp8 | {round(bytes_before/1e9,2)}GB -> {round(bytes_after/1e9,2)}GB ({ratio}x)")
print(f" copied config: {', '.join(copied)}")
if __name__ == "__main__":
main()