context-1-mxfp4 / convert_mxfp4.py
evilfreelancer's picture
Upload folder using huggingface_hub
4cbd575 verified
#!/usr/bin/env python3
"""Convert chromadb/context-1 BF16 weights to MXFP4 format for vLLM."""
import json
import math
import os
import shutil
import struct
import sys
import time
import numpy as np
MODEL_DIR = os.path.join(os.path.dirname(__file__), "context-1")
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "context-1-mxfp4")
GROUP_SIZE = 32
# E2M1 FP4 positive lookup: index -> value
FP4_VALUES = np.array(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=np.float32
)
FP4_MAX = 6.0
# Midpoints between consecutive FP4 values for nearest rounding
FP4_BOUNDARIES = np.array(
[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, np.inf], dtype=np.float32
)
def float_to_e2m1(values: np.ndarray) -> np.ndarray:
"""Quantize float values to 4-bit E2M1 codes (0..15)."""
sign = (values < 0).astype(np.uint8)
abs_val = np.abs(values)
codes = np.digitize(abs_val, FP4_BOUNDARIES).astype(np.uint8)
codes = np.clip(codes, 0, 7)
return (sign << 3) | codes
def compute_e8m0_scale(group: np.ndarray) -> tuple[np.uint8, float]:
"""Compute E8M0 shared exponent for a group. Returns (e8m0_byte, scale_float)."""
amax = np.max(np.abs(group))
if amax == 0:
return np.uint8(0), 1.0
# scale = 2^ceil(log2(amax / FP4_MAX))
# ensures amax / scale <= FP4_MAX
log2_scale = math.ceil(math.log2(max(amax / FP4_MAX, 2**-127)))
log2_scale = max(log2_scale, -127)
log2_scale = min(log2_scale, 127)
e8m0 = np.uint8(log2_scale + 127)
scale = 2.0 ** log2_scale
return e8m0, scale
def quantize_mxfp4(weight: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Quantize a 2D weight [rows, cols] from float to MXFP4 (vectorized).
cols must be divisible by GROUP_SIZE.
Returns (packed_uint8 [rows, cols//2], scales_uint8 [rows, cols//GROUP_SIZE]).
"""
rows, cols = weight.shape
assert cols % GROUP_SIZE == 0, f"cols={cols} not divisible by {GROUP_SIZE}"
n_groups = cols // GROUP_SIZE
grouped = weight.reshape(rows, n_groups, GROUP_SIZE).astype(np.float32)
# Vectorized E8M0 scale computation per group
amax = np.max(np.abs(grouped), axis=-1) # [rows, n_groups]
amax = np.maximum(amax, 2**-127)
log2_scale = np.ceil(np.log2(amax / FP4_MAX)).astype(np.int32)
log2_scale = np.clip(log2_scale, -127, 127)
scales = (log2_scale + 127).astype(np.uint8) # [rows, n_groups]
scale_float = np.power(2.0, log2_scale.astype(np.float64)).astype(np.float32)
# Scale each group
scaled = grouped / scale_float[:, :, np.newaxis] # [rows, n_groups, 32]
# Vectorized E2M1 quantization
flat_scaled = scaled.reshape(rows, cols)
fp4_codes = float_to_e2m1(flat_scaled)
# Pack 2 FP4 codes per byte: low nibble = even index, high nibble = odd index
even = fp4_codes[:, 0::2]
odd = fp4_codes[:, 1::2]
packed = ((odd << 4) | even).astype(np.uint8)
return packed, scales
def read_safetensors_header(path: str) -> dict:
with open(path, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
header_json = f.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def read_tensor(path: str, info: dict) -> np.ndarray:
dtype_map = {"BF16": (np.uint16, 2), "F32": (np.float32, 4), "U8": (np.uint8, 1)}
dtype_str = info["dtype"]
np_dtype, elem_size = dtype_map[dtype_str]
shape = info["shape"]
offsets = info["data_offsets"]
start, end = offsets
with open(path, "rb") as f:
header_size_bytes = struct.unpack("<Q", f.read(8))[0]
f.seek(8 + header_size_bytes + start)
data = f.read(end - start)
arr = np.frombuffer(data, dtype=np_dtype).reshape(shape)
return arr
def bf16_to_f32(arr: np.ndarray) -> np.ndarray:
"""Convert BF16 (stored as uint16) to float32."""
f32_bytes = np.zeros(arr.shape, dtype=np.uint32)
f32_bytes[:] = arr.astype(np.uint32) << 16
return f32_bytes.view(np.float32)
def build_safetensors(tensors: dict, metadata: dict | None = None) -> bytes:
"""Build safetensors binary from dict of {name: (np_array, dtype_str)}."""
header = {}
if metadata:
header["__metadata__"] = metadata
data_parts = []
offset = 0
for name, (arr, dtype_str) in sorted(tensors.items()):
raw = arr.tobytes()
data_parts.append(raw)
header[name] = {
"dtype": dtype_str,
"shape": list(arr.shape),
"data_offsets": [offset, offset + len(raw)],
}
offset += len(raw)
header_json = json.dumps(header, separators=(",", ":")).encode("utf-8")
# Pad header to 8-byte alignment
padding = (8 - len(header_json) % 8) % 8
header_json += b" " * padding
result = struct.pack("<Q", len(header_json)) + header_json
for part in data_parts:
result += part
return result
def main():
st_path = os.path.join(MODEL_DIR, "model.safetensors")
print("Reading safetensors header...")
header, header_size = read_safetensors_header(st_path)
metadata = header.pop("__metadata__", {"format": "pt"})
expert_weight_suffixes = (".mlp.experts.gate_up_proj", ".mlp.experts.down_proj")
output_tensors = {}
total = len([k for k in header if k != "__metadata__"])
done = 0
for name, info in sorted(header.items()):
if name == "__metadata__":
continue
done += 1
is_expert_weight = any(name.endswith(s) for s in expert_weight_suffixes)
if is_expert_weight:
print(f"[{done}/{total}] Quantizing {name} {info['shape']}...")
t0 = time.time()
raw = read_tensor(st_path, info)
weight_f32 = bf16_to_f32(raw)
is_gate_up = name.endswith(".gate_up_proj")
num_experts = weight_f32.shape[0]
# BF16 checkpoint stores weights as [E, in_features, out_features].
# vLLM MXFP4 expects [E, out_features, in_features // 2] (packed).
# Both gate_up_proj and down_proj need transposing to [E, out, in].
weight_f32 = np.ascontiguousarray(
np.transpose(weight_f32, (0, 2, 1))
)
blocks_list = []
scales_list = []
for e in range(num_experts):
packed, scales = quantize_mxfp4(weight_f32[e])
blocks_list.append(packed)
scales_list.append(scales)
blocks = np.stack(blocks_list, axis=0)
scales = np.stack(scales_list, axis=0)
blocks_name = name.replace(".gate_up_proj", ".gate_up_proj_blocks").replace(
".down_proj", ".down_proj_blocks"
)
scales_name = name.replace(".gate_up_proj", ".gate_up_proj_scales").replace(
".down_proj", ".down_proj_scales"
)
output_tensors[blocks_name] = (blocks, "U8")
output_tensors[scales_name] = (scales, "U8")
dt = time.time() - t0
print(f" -> {blocks_name} {list(blocks.shape)}, "
f"{scales_name} {list(scales.shape)} ({dt:.1f}s)")
else:
print(f"[{done}/{total}] Copying {name} {info['shape']}...")
raw = read_tensor(st_path, info)
output_tensors[name] = (raw, info["dtype"])
os.makedirs(OUTPUT_DIR, exist_ok=True)
print("\nWriting output safetensors...")
out_path = os.path.join(OUTPUT_DIR, "model.safetensors")
# Write in streaming fashion to avoid huge memory spike
# First build header, then write data
header_dict = {}
if metadata:
header_dict["__metadata__"] = metadata
offset = 0
tensor_order = sorted(output_tensors.keys())
for tname in tensor_order:
arr, dtype_str = output_tensors[tname]
raw_size = arr.nbytes
header_dict[tname] = {
"dtype": dtype_str,
"shape": list(arr.shape),
"data_offsets": [offset, offset + raw_size],
}
offset += raw_size
header_json = json.dumps(header_dict, separators=(",", ":")).encode("utf-8")
padding = (8 - len(header_json) % 8) % 8
header_json += b" " * padding
with open(out_path, "wb") as f:
f.write(struct.pack("<Q", len(header_json)))
f.write(header_json)
for tname in tensor_order:
arr, _ = output_tensors[tname]
f.write(arr.tobytes())
print(f"Saved {out_path} ({os.path.getsize(out_path) / 1e9:.2f} GB)")
# Copy config files and add quantization_config
for fname in ["generation_config.json", "tokenizer_config.json",
"tokenizer.json", "chat_template.jinja"]:
src = os.path.join(MODEL_DIR, fname)
if os.path.exists(src):
shutil.copy2(src, os.path.join(OUTPUT_DIR, fname))
print(f"Copied {fname}")
# Update config.json with quantization_config
with open(os.path.join(MODEL_DIR, "config.json")) as f:
config = json.load(f)
config["quantization_config"] = {
"modules_to_not_convert": [
"model.layers.*.self_attn",
"model.layers.*.mlp.router",
"model.embed_tokens",
"lm_head",
],
"quant_method": "mxfp4",
}
with open(os.path.join(OUTPUT_DIR, "config.json"), "w") as f:
json.dump(config, f, indent=2)
f.write("\n")
print("Wrote config.json with quantization_config")
print("\nDone!")
if __name__ == "__main__":
main()