k-l-lambda commited on
Commit
52aa95f
·
verified ·
1 Parent(s): 1a31f9e

Upload convert_mtp_fp8_to_int4.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. convert_mtp_fp8_to_int4.py +111 -0
convert_mtp_fp8_to_int4.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convert MTP expert weights from FP8 to INT4 compressed-tensors (Marlin format).
2
+ Key fix: pack_factor=4 (4 INT4 values per INT32), matching K2.5 base model format.
3
+ """
4
+ import torch
5
+ from safetensors import safe_open
6
+ from safetensors.torch import save_file
7
+ from collections import OrderedDict
8
+
9
+ MTP_PATH = "/data/models/Kimi-K2.5-MTP/mtp_fp8_orig.safetensors"
10
+ OUTPUT_PATH = "/data/models/Kimi-K2.5-MTP/mtp.safetensors"
11
+ GROUP_SIZE = 32
12
+ FP8_BLOCK_SIZE = 8
13
+ PACK_FACTOR = 4 # 4 INT4 values per INT32 (matching base model Marlin format)
14
+
15
+ def dequantize_fp8_block(weight_u8, weight_scale_fp8, weight_scale_2):
16
+ out_f, in_f = weight_u8.shape
17
+ block_in = in_f // FP8_BLOCK_SIZE
18
+ w = weight_u8.to(torch.float32).reshape(out_f, block_in, FP8_BLOCK_SIZE)
19
+ w = w - 128.0
20
+ s = weight_scale_fp8.to(torch.float32).unsqueeze(-1)
21
+ s2 = weight_scale_2.item() if weight_scale_2.numel() == 1 else 1.0
22
+ return (w * s * s2).reshape(out_f, in_f).to(torch.bfloat16)
23
+
24
+ def quantize_int4_marlin(weight_bf16, group_size=32):
25
+ out_f, in_f = weight_bf16.shape
26
+ w = weight_bf16.to(torch.float32)
27
+
28
+ pad = (group_size - in_f % group_size) % group_size
29
+ if pad > 0:
30
+ w = torch.nn.functional.pad(w, (0, pad))
31
+ in_padded = w.shape[1]
32
+
33
+ w_grouped = w.reshape(out_f, -1, group_size)
34
+ scales = w_grouped.abs().amax(dim=-1) / 7.0
35
+ scales = scales.clamp(min=1e-10)
36
+
37
+ w_int = torch.round(w_grouped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8)
38
+ w_int = w_int.reshape(out_f, in_padded)
39
+
40
+ # Pack with PACK_FACTOR=4 (4 INT4 values per INT32)
41
+ assert in_padded % PACK_FACTOR == 0, f"in_padded={in_padded} not divisible by {PACK_FACTOR}"
42
+ w_unsigned = (w_int + 8).to(torch.int32) # [0, 15]
43
+ w_r = w_unsigned.reshape(out_f, -1, PACK_FACTOR)
44
+ packed = torch.zeros(out_f, w_r.shape[1], dtype=torch.int32)
45
+ for i in range(PACK_FACTOR):
46
+ packed |= (w_r[:, :, i] & 0xF) << (i * 4)
47
+
48
+ shape = torch.tensor([out_f, in_f], dtype=torch.int32)
49
+ return packed, scales.to(torch.bfloat16), shape
50
+
51
+ print("Loading MTP FP8 weights...")
52
+ new_tensors = OrderedDict()
53
+ converted_expert = 0
54
+ converted_shared = 0
55
+ passed = 0
56
+
57
+ with safe_open(MTP_PATH, framework="pt", device="cpu") as f:
58
+ all_keys = sorted(f.keys())
59
+ fp8_bases = set()
60
+ for k in all_keys:
61
+ if k.endswith(".weight") and f"{k[:-7]}.weight_scale" in all_keys:
62
+ fp8_bases.add(k[:-7])
63
+
64
+ print(f"FP8 projections: {len(fp8_bases)}")
65
+
66
+ processed = set()
67
+ for k in all_keys:
68
+ if k in processed:
69
+ continue
70
+
71
+ base = None
72
+ for fb in fp8_bases:
73
+ if k.startswith(fb + "."):
74
+ base = fb
75
+ break
76
+
77
+ if base is not None:
78
+ if k == f"{base}.weight":
79
+ w_u8 = f.get_tensor(k)
80
+ w_scale = f.get_tensor(f"{base}.weight_scale")
81
+ w_scale2 = f.get_tensor(f"{base}.weight_scale_2")
82
+ w_bf16 = dequantize_fp8_block(w_u8, w_scale, w_scale2)
83
+
84
+ if ".mlp.experts." in base:
85
+ packed, scales, shape = quantize_int4_marlin(w_bf16, GROUP_SIZE)
86
+ new_tensors[f"{base}.weight_packed"] = packed
87
+ new_tensors[f"{base}.weight_scale"] = scales
88
+ new_tensors[f"{base}.weight_shape"] = shape
89
+ converted_expert += 1
90
+ else:
91
+ new_tensors[f"{base}.weight"] = w_bf16
92
+ converted_shared += 1
93
+
94
+ processed.update([k, f"{base}.weight_scale", f"{base}.weight_scale_2", f"{base}.input_scale"])
95
+ continue
96
+
97
+ new_tensors[k] = f.get_tensor(k)
98
+ passed += 1
99
+
100
+ print(f"Expert→INT4: {converted_expert}, Shared→BF16: {converted_shared}, Passthrough: {passed}")
101
+ print(f"Total: {len(new_tensors)}")
102
+
103
+ # Verify pack format matches base
104
+ sample = "model.layers.61.mlp.experts.0.gate_proj.weight_packed"
105
+ if sample in new_tensors:
106
+ print(f"\nVerify: {sample} shape={list(new_tensors[sample].shape)}")
107
+ print(f"Expected: [2048, 896] (3584/4=896)")
108
+
109
+ save_file(new_tensors, OUTPUT_PATH)
110
+ import os
111
+ print(f"Saved: {os.path.getsize(OUTPUT_PATH)/1024/1024:.1f} MB")