File size: 4,922 Bytes
04e43d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Streaming CPU FP8 quantizer for MiMo-V2.5-ASR.

Avoids loading the full model into RAM and never uses CUDA (works around the lack of
sm_120 / Blackwell kernels in torch 2.6+cu124). Reads the original safetensors shards,
quantizes every nn.Linear weight to float8_e4m3fn (per-out-channel absmax), leaves
embeddings / norms / biases in bf16, and writes a single model.safetensors whose keys
match the FP8Linear loader (`*.weight_fp8`, `*.weight_scale`).
"""
import sys, json, shutil, argparse
from pathlib import Path
import torch
import torch.nn as nn
from safetensors.torch import save_file, safe_open

REPO = Path(__file__).resolve().parent
sys.path.insert(0, str(REPO))
from quantize_fp8 import (
    quantize_weight_per_channel, CONFIG_FILES, _build_args_and_tokenizer, FP8_DTYPE,
)


def linear_weight_names(model_path: str) -> set:
    """Build the model on meta and collect '<path>.weight' for every nn.Linear."""
    from src.mimo_audio.modeling_mimo_audio import MiMoAudioForCausalLM
    from transformers import AutoConfig
    args, _ = _build_args_and_tokenizer(model_path)
    cfg = AutoConfig.from_pretrained(model_path)
    with torch.device("meta"):
        model = MiMoAudioForCausalLM(cfg, args)
    names = set()
    for name, mod in model.named_modules():
        if isinstance(mod, nn.Linear):
            names.add(name + ".weight")
    return names


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model-path", required=True)
    ap.add_argument("--out-dir", required=True)
    ap.add_argument("--dry-run", action="store_true")
    a = ap.parse_args()
    mp = Path(a.model_path)
    out = Path(a.out_dir)

    lin = linear_weight_names(str(mp))
    print(f"quantizable Linear weights: {len(lin)}")

    idx = json.loads((mp / "model.safetensors.index.json").read_text())
    weight_map = idx["weight_map"]
    shards = {}
    for k, sh in weight_map.items():
        shards.setdefault(sh, []).append(k)
    all_keys = set(weight_map)
    matched = lin & all_keys
    missing = lin - all_keys
    print(f"checkpoint keys: {len(all_keys)} | linear matched: {len(matched)} | linear missing in ckpt: {len(missing)}")
    if missing:
        print("  MISSING (first 10):", sorted(missing)[:10])

    if a.dry_run:
        # show a few non-linear keys for sanity
        nonlin = sorted(all_keys - lin)
        print("sample NON-linear keys kept as-is:", nonlin[:8])
        print("total params:", len(all_keys), "-> fp8:", len(matched), "kept:", len(all_keys) - len(matched))
        return

    out.mkdir(parents=True, exist_ok=True)
    new_state = {}
    bytes_before = bytes_after = 0
    conv = 0
    for shard in sorted(shards):
        with safe_open(str(mp / shard), framework="pt", device="cpu") as f:
            for k in shards[shard]:
                t = f.get_tensor(k)
                if k in lin:
                    w_fp8, scale = quantize_weight_per_channel(t)   # scale [out,1]
                    base = k[:-len(".weight")]
                    new_state[base + ".weight_fp8"] = w_fp8.contiguous()
                    new_state[base + ".weight_scale"] = scale.squeeze(1).contiguous()
                    bytes_before += t.numel() * t.element_size()
                    bytes_after += w_fp8.numel() + scale.numel() * 4
                    conv += 1
                else:
                    bytes_before += t.numel() * t.element_size()
                    # kept tensors (embeddings/norms/biases): store bf16 to match the
                    # bf16 model the FP8Linear loader builds, and to shrink the file.
                    if t.is_floating_point():
                        t = t.to(torch.bfloat16)
                    new_state[k] = t.contiguous()
                    bytes_after += t.numel() * t.element_size()
        print(f"  done {shard}  (running fp8 layers: {conv})")

    st = out / "model.safetensors"
    save_file(new_state, str(st), metadata={"format": "pt"})

    copied = []
    for cfg in CONFIG_FILES:
        src = mp / cfg
        if src.exists():
            shutil.copy2(src, out / cfg); copied.append(cfg)

    ratio = round(bytes_before / max(bytes_after, 1), 3)
    meta = {
        "dtype": "float8_e4m3fn", "weight_scaling": "per_channel_absmax",
        "activation_scaling": "dynamic_per_tensor", "matmul_op": "torch._scaled_mm",
        "output_dtype": "bfloat16", "converted_layers": conv,
        "weight_gb_before": round(bytes_before / 1e9, 3),
        "weight_gb_after": round(bytes_after / 1e9, 3), "compression_ratio": ratio,
        "quantizer": "streaming_cpu",
    }
    (out / "fp8_meta.json").write_text(json.dumps(meta, indent=2))
    gb = st.stat().st_size / 1e9
    print(f"\nOK {st} ({gb:.2f} GB on disk)")
    print(f"   {conv} layers -> fp8 | {round(bytes_before/1e9,2)}GB -> {round(bytes_after/1e9,2)}GB ({ratio}x)")
    print(f"   copied config: {', '.join(copied)}")


if __name__ == "__main__":
    main()