MiMo-V2.5-ASR-FP8 / quantize_fp8.py
Infatoshi's picture
Upload folder using huggingface_hub
04e43d3 verified
Raw
History Blame Contribute Delete
11.9 kB
"""
MiMo-V2.5-ASR -> FP8 e4m3fn (per-channel weight quant, dynamic activation quant)
Quantize entrypoint loads MiMoAudioForCausalLM directly (no audio tokenizer / no
flash-attn needed -- the LLM is pure Qwen2). Verify/load paths still go through the
full MimoAudio stack and DO require flash-attn + the audio tokenizer.
"""
import os
import sys
import json
import shutil
import argparse
from pathlib import Path
import torch
import torch.nn as nn
from safetensors.torch import save_file, safe_open
REPO_ROOT = Path(__file__).resolve().parent
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
# ----- constants -----
FP8_DTYPE = torch.float8_e4m3fn
FP8_MAX = torch.finfo(FP8_DTYPE).max # 448.0
SCALE_DTYPE = torch.float32
SKIP_TYPES = (nn.Embedding, nn.LayerNorm, nn.GroupNorm,
nn.BatchNorm1d, nn.BatchNorm2d, nn.RMSNorm)
CONFIG_FILES = [
"config.json", "tokenizer_config.json", "tokenizer.json",
"special_tokens_map.json", "generation_config.json",
"added_tokens.json", "merges.txt", "vocab.json", "chat_template.jinja",
]
SPECIAL_TOKENS = ["<|sosp|>", "<|eosp|>", "<|empty|>", "<|Human|>",
"<|SpeechLM|>", "<|sostm|>", "<|eostm|>", "<|eot|>"]
# ----- weight quantization -----
def quantize_weight_per_channel(weight: torch.Tensor):
"""Per output-channel absmax scaling. weight: [out, in]"""
w = weight.float()
amax = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-12)
scale = (amax / FP8_MAX).to(SCALE_DTYPE)
w_fp8 = (w / scale).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE)
return w_fp8, scale
class FP8Linear(nn.Module):
"""FP8 e4m3fn weights + per-channel scales; dynamic per-tensor activation quant."""
def __init__(self, linear: nn.Linear):
super().__init__()
with torch.no_grad():
w_fp8, w_scale = quantize_weight_per_channel(linear.weight)
self.register_buffer("weight_fp8", w_fp8.contiguous()) # [out, in]
self.register_buffer("weight_scale", w_scale.squeeze(1)) # [out]
if linear.bias is not None:
self.register_buffer("bias", linear.bias.detach().clone())
else:
self.bias = None
self.in_features = linear.in_features
self.out_features = linear.out_features
def forward(self, x: torch.Tensor) -> torch.Tensor:
leading = x.shape[:-1]
x2d = x.reshape(-1, self.in_features)
x_scale = (x2d.float().abs().max().clamp(min=1e-12) / FP8_MAX).to(SCALE_DTYPE)
x_fp8 = (x2d.float() / x_scale).clamp(-FP8_MAX, FP8_MAX).to(FP8_DTYPE)
w_scale_scalar = self.weight_scale.max().to(SCALE_DTYPE)
out = torch._scaled_mm(
x_fp8, self.weight_fp8.t(),
scale_a=x_scale, scale_b=w_scale_scalar,
out_dtype=torch.bfloat16, use_fast_accum=True,
)
correction = (self.weight_scale / w_scale_scalar).to(torch.bfloat16)
out = out * correction.unsqueeze(0)
if self.bias is not None:
out = out + self.bias.to(out.dtype)
return out.reshape(*leading, self.out_features)
def extra_repr(self):
return f"in={self.in_features}, out={self.out_features}, fp8=e4m3fn"
# ----- model walk -----
def quantize_model(model: nn.Module, verbose: bool = True):
stats = {"converted": 0, "skipped": 0, "bytes_before": 0, "bytes_after": 0}
def _walk(parent, prefix=""):
for name, module in list(parent.named_children()):
full = f"{prefix}.{name}" if prefix else name
if isinstance(module, nn.Linear) and not isinstance(module, SKIP_TYPES):
b_before = module.weight.numel() * module.weight.element_size()
if module.bias is not None:
b_before += module.bias.numel() * module.bias.element_size()
fp8mod = FP8Linear(module)
b_after = fp8mod.weight_fp8.numel() + fp8mod.weight_scale.numel() * 4
if fp8mod.bias is not None:
b_after += fp8mod.bias.numel() * fp8mod.bias.element_size()
setattr(parent, name, fp8mod)
stats["converted"] += 1
stats["bytes_before"] += b_before
stats["bytes_after"] += b_after
if verbose:
print(f" [FP8] {full:<70} {b_before/max(b_after,1):.1f}x")
elif isinstance(module, SKIP_TYPES):
stats["skipped"] += 1
else:
_walk(module, full)
_walk(model)
return model, stats
# ----- save -----
def save_fp8(model, out_dir: Path, stats: dict, model_path: Path):
out_dir.mkdir(parents=True, exist_ok=True)
state = {k: v.contiguous().cpu() for k, v in model.state_dict().items()}
st_path = out_dir / "model.safetensors"
save_file(state, str(st_path), metadata={"format": "pt"})
copied = []
for cfg in CONFIG_FILES:
src = model_path / cfg
if src.exists():
shutil.copy2(src, out_dir / cfg)
copied.append(cfg)
if copied:
print(f" Copied config: {', '.join(copied)}")
gb_before = stats["bytes_before"] / 1e9
gb_after = stats["bytes_after"] / 1e9
ratio = round(stats["bytes_before"] / max(stats["bytes_after"], 1), 3)
meta = {
"dtype": "float8_e4m3fn", "weight_scaling": "per_channel_absmax",
"activation_scaling": "dynamic_per_tensor", "matmul_op": "torch._scaled_mm",
"output_dtype": "bfloat16", "converted_layers": stats["converted"],
"skipped_layers": stats["skipped"], "weight_gb_before": round(gb_before, 3),
"weight_gb_after": round(gb_after, 3), "compression_ratio": ratio,
}
with open(out_dir / "fp8_meta.json", "w") as f:
json.dump(meta, f, indent=2)
actual_gb = st_path.stat().st_size / 1e9
print(f"\nOK {st_path} ({actual_gb:.2f} GB on disk)")
print(f" weight bytes: {gb_before:.2f} GB -> {gb_after:.2f} GB ({ratio}x)")
print(f" {stats['converted']} layers converted, {stats['skipped']} skipped")
def _build_args_and_tokenizer(model_path: str):
from transformers import AutoTokenizer
from src.mimo_audio.modeling_mimo_audio import MiMoAudioArguments
tok = AutoTokenizer.from_pretrained(model_path)
for t in SPECIAL_TOKENS:
if t not in tok.get_vocab():
tok.add_tokens([t], special_tokens=True)
gid = lambda t: tok.convert_tokens_to_ids(t)
args = MiMoAudioArguments(
model_name_or_path=model_path,
sosp_idx=gid("<|sosp|>"), eosp_idx=gid("<|eosp|>"),
empty_idx=gid("<|empty|>"), sostm_idx=gid("<|sostm|>"),
eostm_idx=gid("<|eostm|>"), eot_idx=gid("<|eot|>"),
)
return args, tok
# ----- quantize entrypoint (direct LLM load, no audio tokenizer / flash-attn) -----
def run_quantize(args):
from src.mimo_audio.modeling_mimo_audio import MiMoAudioForCausalLM
print(f"Loading MiMoAudioForCausalLM from {args.model_path} on {args.device} ...")
model_args, _ = _build_args_and_tokenizer(args.model_path)
model = MiMoAudioForCausalLM.from_pretrained(
args.model_path,
args=model_args,
torch_dtype=torch.bfloat16,
device_map={"": args.device},
attn_implementation="sdpa",
)
model.eval()
print("OK loaded\n")
print("Quantizing to FP8 e4m3fn ...")
with torch.no_grad():
model, stats = quantize_model(model, verbose=not args.quiet)
save_fp8(model, Path(args.out_dir), stats, Path(args.model_path))
print("\nDone.")
# ----- load FP8 model (for inference / verify) -----
def load_fp8_model(fp8_dir: str, tokenizer_path: str, repo_root: str, device: str = "cuda"):
"""
Load the FP8 checkpoint for inference, returning a MimoAudio wrapper exposing .asr_sft().
Strategy: instantiate the real architecture via from_pretrained on the ORIGINAL repo
weights is NOT required -- instead we build the bf16 architecture from config (correct
rotary init), replace Linears with FP8Linear shells, then load the FP8 state dict.
repo_root must be the cloned MiMo-V2.5-ASR repo (contains src/).
fp8_dir must contain model.safetensors + config/tokenizer files.
"""
rr = Path(repo_root).resolve()
if str(rr) not in sys.path:
sys.path.insert(0, str(rr))
from src.mimo_audio.mimo_audio import MimoAudio
from src.mimo_audio.modeling_mimo_audio import MiMoAudioForCausalLM
from src.mimo_audio_tokenizer import MiMoAudioTokenizer
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
fp8_dir = Path(fp8_dir)
model_args, tokenizer = _build_args_and_tokenizer(str(fp8_dir))
# Build architecture with real init (correct rotary inv_freq), no pretrained shards.
print("Building architecture (config init) ...")
cfg = AutoConfig.from_pretrained(str(fp8_dir))
model = MiMoAudioForCausalLM(cfg, model_args).to(torch.bfloat16)
print("Installing FP8 modules ...")
with torch.no_grad():
quantize_model(model, verbose=False)
model = model.to(device)
print("Loading FP8 weights ...")
state = {}
with safe_open(str(fp8_dir / "model.safetensors"), framework="pt", device=device) as f:
for key in f.keys():
state[key] = f.get_tensor(key)
model.load_state_dict(state, strict=True)
model.eval()
# Wrap in MimoAudio without re-running its __init__ (which would reload weights).
mimo = object.__new__(MimoAudio)
mimo.device = device
mimo.path = str(fp8_dir)
mimo.mimo_audio_tokenizer_path = tokenizer_path
mimo.tokenizer = tokenizer
mimo.padding_idx = int(tokenizer.pad_token_id)
mimo.sosp_idx = model_args.sosp_idx
mimo.eosp_idx = model_args.eosp_idx
mimo.empty_token = model_args.empty_idx
mimo.sostm_idx = model_args.sostm_idx
mimo.eostm_idx = model_args.eostm_idx
mimo.eot_idx = model_args.eot_idx
mimo.im_start_idx = tokenizer.convert_tokens_to_ids("<|im_start|>")
mimo.im_end_idx = tokenizer.convert_tokens_to_ids("<|im_end|>")
mimo.model = model
mimo.group_size = model.config.group_size
mimo.audio_channels = model.config.audio_channels
mimo.delay_pattern = model.config.delay_pattern
mimo.vocab_size = model.config.vocab_size
mimo.speech_zeroemb_idx = model.speech_empty_ids
from src.mimo_audio.modeling_mimo_audio import MiMoSampler
mimo.default_global_sampler = MiMoSampler(do_sample=True, temperature=0.6, top_k=50, top_p=0.95)
mimo.default_local_sampler = MiMoSampler(do_sample=True, temperature=0.9, top_k=50, top_p=0.95)
mimo.task_sampler_configs = {
"asr": {"global": MiMoSampler(do_sample=False, temperature=1.0, top_p=1.0),
"local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)},
}
mimo.generate_kwargs = {
"max_length": 8192,
"eos_token_id": tokenizer.eos_token_id,
"pad_token_id": tokenizer.pad_token_id,
}
mimo.mimo_audio_tokenizer = MiMoAudioTokenizer.from_pretrained(tokenizer_path)
mimo.mimo_audio_tokenizer.eval().bfloat16().to(device)
from torchaudio.transforms import MelSpectrogram
tcfg = mimo.mimo_audio_tokenizer.config
mimo.mel_transform = MelSpectrogram(
sample_rate=tcfg.sampling_rate, n_fft=tcfg.nfft, hop_length=tcfg.hop_length,
win_length=tcfg.window_size, f_min=tcfg.fmin, f_max=tcfg.fmax,
n_mels=tcfg.n_mels, power=1.0, center=True,
).to(device)
print("FP8 model ready\n")
return mimo
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model-path", required=True)
ap.add_argument("--out-dir", default="./MiMo-V2.5-ASR-FP8")
ap.add_argument("--device", default="cuda")
ap.add_argument("--quiet", action="store_true")
args = ap.parse_args()
run_quantize(args)
if __name__ == "__main__":
main()