| |
| """Convert chromadb/context-1 BF16 weights to MXFP4 format for vLLM.""" |
|
|
| import json |
| import math |
| import os |
| import shutil |
| import struct |
| import sys |
| import time |
|
|
| import numpy as np |
|
|
| MODEL_DIR = os.path.join(os.path.dirname(__file__), "context-1") |
| OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "context-1-mxfp4") |
|
|
| GROUP_SIZE = 32 |
|
|
| |
| FP4_VALUES = np.array( |
| [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=np.float32 |
| ) |
| FP4_MAX = 6.0 |
|
|
| |
| FP4_BOUNDARIES = np.array( |
| [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, np.inf], dtype=np.float32 |
| ) |
|
|
|
|
| def float_to_e2m1(values: np.ndarray) -> np.ndarray: |
| """Quantize float values to 4-bit E2M1 codes (0..15).""" |
| sign = (values < 0).astype(np.uint8) |
| abs_val = np.abs(values) |
| codes = np.digitize(abs_val, FP4_BOUNDARIES).astype(np.uint8) |
| codes = np.clip(codes, 0, 7) |
| return (sign << 3) | codes |
|
|
|
|
| def compute_e8m0_scale(group: np.ndarray) -> tuple[np.uint8, float]: |
| """Compute E8M0 shared exponent for a group. Returns (e8m0_byte, scale_float).""" |
| amax = np.max(np.abs(group)) |
| if amax == 0: |
| return np.uint8(0), 1.0 |
|
|
| |
| |
| log2_scale = math.ceil(math.log2(max(amax / FP4_MAX, 2**-127))) |
| log2_scale = max(log2_scale, -127) |
| log2_scale = min(log2_scale, 127) |
| e8m0 = np.uint8(log2_scale + 127) |
| scale = 2.0 ** log2_scale |
| return e8m0, scale |
|
|
|
|
| def quantize_mxfp4(weight: np.ndarray) -> tuple[np.ndarray, np.ndarray]: |
| """ |
| Quantize a 2D weight [rows, cols] from float to MXFP4 (vectorized). |
| cols must be divisible by GROUP_SIZE. |
| Returns (packed_uint8 [rows, cols//2], scales_uint8 [rows, cols//GROUP_SIZE]). |
| """ |
| rows, cols = weight.shape |
| assert cols % GROUP_SIZE == 0, f"cols={cols} not divisible by {GROUP_SIZE}" |
|
|
| n_groups = cols // GROUP_SIZE |
| grouped = weight.reshape(rows, n_groups, GROUP_SIZE).astype(np.float32) |
|
|
| |
| amax = np.max(np.abs(grouped), axis=-1) |
| amax = np.maximum(amax, 2**-127) |
| log2_scale = np.ceil(np.log2(amax / FP4_MAX)).astype(np.int32) |
| log2_scale = np.clip(log2_scale, -127, 127) |
| scales = (log2_scale + 127).astype(np.uint8) |
| scale_float = np.power(2.0, log2_scale.astype(np.float64)).astype(np.float32) |
|
|
| |
| scaled = grouped / scale_float[:, :, np.newaxis] |
|
|
| |
| flat_scaled = scaled.reshape(rows, cols) |
| fp4_codes = float_to_e2m1(flat_scaled) |
|
|
| |
| even = fp4_codes[:, 0::2] |
| odd = fp4_codes[:, 1::2] |
| packed = ((odd << 4) | even).astype(np.uint8) |
|
|
| return packed, scales |
|
|
|
|
| def read_safetensors_header(path: str) -> dict: |
| with open(path, "rb") as f: |
| header_size = struct.unpack("<Q", f.read(8))[0] |
| header_json = f.read(header_size).decode("utf-8") |
| return json.loads(header_json), header_size |
|
|
|
|
| def read_tensor(path: str, info: dict) -> np.ndarray: |
| dtype_map = {"BF16": (np.uint16, 2), "F32": (np.float32, 4), "U8": (np.uint8, 1)} |
| dtype_str = info["dtype"] |
| np_dtype, elem_size = dtype_map[dtype_str] |
| shape = info["shape"] |
| offsets = info["data_offsets"] |
| start, end = offsets |
|
|
| with open(path, "rb") as f: |
| header_size_bytes = struct.unpack("<Q", f.read(8))[0] |
| f.seek(8 + header_size_bytes + start) |
| data = f.read(end - start) |
|
|
| arr = np.frombuffer(data, dtype=np_dtype).reshape(shape) |
| return arr |
|
|
|
|
| def bf16_to_f32(arr: np.ndarray) -> np.ndarray: |
| """Convert BF16 (stored as uint16) to float32.""" |
| f32_bytes = np.zeros(arr.shape, dtype=np.uint32) |
| f32_bytes[:] = arr.astype(np.uint32) << 16 |
| return f32_bytes.view(np.float32) |
|
|
|
|
| def build_safetensors(tensors: dict, metadata: dict | None = None) -> bytes: |
| """Build safetensors binary from dict of {name: (np_array, dtype_str)}.""" |
| header = {} |
| if metadata: |
| header["__metadata__"] = metadata |
|
|
| data_parts = [] |
| offset = 0 |
| for name, (arr, dtype_str) in sorted(tensors.items()): |
| raw = arr.tobytes() |
| data_parts.append(raw) |
| header[name] = { |
| "dtype": dtype_str, |
| "shape": list(arr.shape), |
| "data_offsets": [offset, offset + len(raw)], |
| } |
| offset += len(raw) |
|
|
| header_json = json.dumps(header, separators=(",", ":")).encode("utf-8") |
| |
| padding = (8 - len(header_json) % 8) % 8 |
| header_json += b" " * padding |
|
|
| result = struct.pack("<Q", len(header_json)) + header_json |
| for part in data_parts: |
| result += part |
| return result |
|
|
|
|
| def main(): |
| st_path = os.path.join(MODEL_DIR, "model.safetensors") |
|
|
| print("Reading safetensors header...") |
| header, header_size = read_safetensors_header(st_path) |
|
|
| metadata = header.pop("__metadata__", {"format": "pt"}) |
|
|
| expert_weight_suffixes = (".mlp.experts.gate_up_proj", ".mlp.experts.down_proj") |
|
|
| output_tensors = {} |
| total = len([k for k in header if k != "__metadata__"]) |
| done = 0 |
|
|
| for name, info in sorted(header.items()): |
| if name == "__metadata__": |
| continue |
| done += 1 |
| is_expert_weight = any(name.endswith(s) for s in expert_weight_suffixes) |
|
|
| if is_expert_weight: |
| print(f"[{done}/{total}] Quantizing {name} {info['shape']}...") |
| t0 = time.time() |
| raw = read_tensor(st_path, info) |
| weight_f32 = bf16_to_f32(raw) |
|
|
| is_gate_up = name.endswith(".gate_up_proj") |
| num_experts = weight_f32.shape[0] |
|
|
| |
| |
| |
| weight_f32 = np.ascontiguousarray( |
| np.transpose(weight_f32, (0, 2, 1)) |
| ) |
|
|
| blocks_list = [] |
| scales_list = [] |
| for e in range(num_experts): |
| packed, scales = quantize_mxfp4(weight_f32[e]) |
| blocks_list.append(packed) |
| scales_list.append(scales) |
|
|
| blocks = np.stack(blocks_list, axis=0) |
| scales = np.stack(scales_list, axis=0) |
|
|
| blocks_name = name.replace(".gate_up_proj", ".gate_up_proj_blocks").replace( |
| ".down_proj", ".down_proj_blocks" |
| ) |
| scales_name = name.replace(".gate_up_proj", ".gate_up_proj_scales").replace( |
| ".down_proj", ".down_proj_scales" |
| ) |
|
|
| output_tensors[blocks_name] = (blocks, "U8") |
| output_tensors[scales_name] = (scales, "U8") |
|
|
| dt = time.time() - t0 |
| print(f" -> {blocks_name} {list(blocks.shape)}, " |
| f"{scales_name} {list(scales.shape)} ({dt:.1f}s)") |
| else: |
| print(f"[{done}/{total}] Copying {name} {info['shape']}...") |
| raw = read_tensor(st_path, info) |
| output_tensors[name] = (raw, info["dtype"]) |
|
|
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| print("\nWriting output safetensors...") |
| out_path = os.path.join(OUTPUT_DIR, "model.safetensors") |
|
|
| |
| |
| header_dict = {} |
| if metadata: |
| header_dict["__metadata__"] = metadata |
|
|
| offset = 0 |
| tensor_order = sorted(output_tensors.keys()) |
| for tname in tensor_order: |
| arr, dtype_str = output_tensors[tname] |
| raw_size = arr.nbytes |
| header_dict[tname] = { |
| "dtype": dtype_str, |
| "shape": list(arr.shape), |
| "data_offsets": [offset, offset + raw_size], |
| } |
| offset += raw_size |
|
|
| header_json = json.dumps(header_dict, separators=(",", ":")).encode("utf-8") |
| padding = (8 - len(header_json) % 8) % 8 |
| header_json += b" " * padding |
|
|
| with open(out_path, "wb") as f: |
| f.write(struct.pack("<Q", len(header_json))) |
| f.write(header_json) |
| for tname in tensor_order: |
| arr, _ = output_tensors[tname] |
| f.write(arr.tobytes()) |
|
|
| print(f"Saved {out_path} ({os.path.getsize(out_path) / 1e9:.2f} GB)") |
|
|
| |
| for fname in ["generation_config.json", "tokenizer_config.json", |
| "tokenizer.json", "chat_template.jinja"]: |
| src = os.path.join(MODEL_DIR, fname) |
| if os.path.exists(src): |
| shutil.copy2(src, os.path.join(OUTPUT_DIR, fname)) |
| print(f"Copied {fname}") |
|
|
| |
| with open(os.path.join(MODEL_DIR, "config.json")) as f: |
| config = json.load(f) |
|
|
| config["quantization_config"] = { |
| "modules_to_not_convert": [ |
| "model.layers.*.self_attn", |
| "model.layers.*.mlp.router", |
| "model.embed_tokens", |
| "lm_head", |
| ], |
| "quant_method": "mxfp4", |
| } |
|
|
| with open(os.path.join(OUTPUT_DIR, "config.json"), "w") as f: |
| json.dump(config, f, indent=2) |
| f.write("\n") |
| print("Wrote config.json with quantization_config") |
|
|
| print("\nDone!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|