| """Convert MTP expert weights from FP8 to INT4 compressed-tensors (Marlin format). |
| Key fix: pack_factor=4 (4 INT4 values per INT32), matching K2.5 base model format. |
| """ |
| import torch |
| from safetensors import safe_open |
| from safetensors.torch import save_file |
| from collections import OrderedDict |
|
|
| MTP_PATH = "/data/models/Kimi-K2.5-MTP/mtp_fp8_orig.safetensors" |
| OUTPUT_PATH = "/data/models/Kimi-K2.5-MTP/mtp.safetensors" |
| GROUP_SIZE = 32 |
| FP8_BLOCK_SIZE = 8 |
| PACK_FACTOR = 4 |
|
|
| def dequantize_fp8_block(weight_u8, weight_scale_fp8, weight_scale_2): |
| out_f, in_f = weight_u8.shape |
| block_in = in_f // FP8_BLOCK_SIZE |
| w = weight_u8.to(torch.float32).reshape(out_f, block_in, FP8_BLOCK_SIZE) |
| w = w - 128.0 |
| s = weight_scale_fp8.to(torch.float32).unsqueeze(-1) |
| s2 = weight_scale_2.item() if weight_scale_2.numel() == 1 else 1.0 |
| return (w * s * s2).reshape(out_f, in_f).to(torch.bfloat16) |
|
|
| def quantize_int4_marlin(weight_bf16, group_size=32): |
| out_f, in_f = weight_bf16.shape |
| w = weight_bf16.to(torch.float32) |
| |
| pad = (group_size - in_f % group_size) % group_size |
| if pad > 0: |
| w = torch.nn.functional.pad(w, (0, pad)) |
| in_padded = w.shape[1] |
| |
| w_grouped = w.reshape(out_f, -1, group_size) |
| scales = w_grouped.abs().amax(dim=-1) / 7.0 |
| scales = scales.clamp(min=1e-10) |
| |
| w_int = torch.round(w_grouped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8) |
| w_int = w_int.reshape(out_f, in_padded) |
| |
| |
| assert in_padded % PACK_FACTOR == 0, f"in_padded={in_padded} not divisible by {PACK_FACTOR}" |
| w_unsigned = (w_int + 8).to(torch.int32) |
| w_r = w_unsigned.reshape(out_f, -1, PACK_FACTOR) |
| packed = torch.zeros(out_f, w_r.shape[1], dtype=torch.int32) |
| for i in range(PACK_FACTOR): |
| packed |= (w_r[:, :, i] & 0xF) << (i * 4) |
| |
| shape = torch.tensor([out_f, in_f], dtype=torch.int32) |
| return packed, scales.to(torch.bfloat16), shape |
|
|
| print("Loading MTP FP8 weights...") |
| new_tensors = OrderedDict() |
| converted_expert = 0 |
| converted_shared = 0 |
| passed = 0 |
|
|
| with safe_open(MTP_PATH, framework="pt", device="cpu") as f: |
| all_keys = sorted(f.keys()) |
| fp8_bases = set() |
| for k in all_keys: |
| if k.endswith(".weight") and f"{k[:-7]}.weight_scale" in all_keys: |
| fp8_bases.add(k[:-7]) |
| |
| print(f"FP8 projections: {len(fp8_bases)}") |
| |
| processed = set() |
| for k in all_keys: |
| if k in processed: |
| continue |
| |
| base = None |
| for fb in fp8_bases: |
| if k.startswith(fb + "."): |
| base = fb |
| break |
| |
| if base is not None: |
| if k == f"{base}.weight": |
| w_u8 = f.get_tensor(k) |
| w_scale = f.get_tensor(f"{base}.weight_scale") |
| w_scale2 = f.get_tensor(f"{base}.weight_scale_2") |
| w_bf16 = dequantize_fp8_block(w_u8, w_scale, w_scale2) |
| |
| if ".mlp.experts." in base: |
| packed, scales, shape = quantize_int4_marlin(w_bf16, GROUP_SIZE) |
| new_tensors[f"{base}.weight_packed"] = packed |
| new_tensors[f"{base}.weight_scale"] = scales |
| new_tensors[f"{base}.weight_shape"] = shape |
| converted_expert += 1 |
| else: |
| new_tensors[f"{base}.weight"] = w_bf16 |
| converted_shared += 1 |
| |
| processed.update([k, f"{base}.weight_scale", f"{base}.weight_scale_2", f"{base}.input_scale"]) |
| continue |
| |
| new_tensors[k] = f.get_tensor(k) |
| passed += 1 |
|
|
| print(f"Expert→INT4: {converted_expert}, Shared→BF16: {converted_shared}, Passthrough: {passed}") |
| print(f"Total: {len(new_tensors)}") |
|
|
| |
| sample = "model.layers.61.mlp.experts.0.gate_proj.weight_packed" |
| if sample in new_tensors: |
| print(f"\nVerify: {sample} shape={list(new_tensors[sample].shape)}") |
| print(f"Expected: [2048, 896] (3584/4=896)") |
|
|
| save_file(new_tensors, OUTPUT_PATH) |
| import os |
| print(f"Saved: {os.path.getsize(OUTPUT_PATH)/1024/1024:.1f} MB") |
|
|