File size: 5,279 Bytes
41a3927 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | import math
import os
import torch
from safetensors import safe_open
# Bytes per MXFP4 block: 32 FP4 numbers packed in 16 bytes
BYTES_PER_BLOCK = 16
FP4_VALUES = [
+0.0, +0.5, +1.0, +1.5, +2.0, +3.0, +4.0, +6.0,
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
]
# Map the names assumed in this implementation to the checkpoint names.
PARAM_NAME_MAP = {
f"block.{n}.mlp.mlp1_bias": f"block.{n}.mlp.mlp1_bias" for n in range(36)
} | {
f"block.{n}.mlp.mlp1_weight": (f"block.{n}.mlp.mlp1_weight.blocks", f"block.{n}.mlp.mlp1_weight.scales") for n in range(36)
} | {
f"block.{n}.mlp.mlp2_bias": f"block.{n}.mlp.mlp2_bias" for n in range(36)
} | {
f"block.{n}.mlp.mlp2_weight": (f"block.{n}.mlp.mlp2_weight.blocks", f"block.{n}.mlp.mlp2_weight.scales") for n in range(36)
}
class Checkpoint:
def __init__(self, path: str, device: torch.device):
device_str = (
device.type
if device.index is None
else device.type + ":" + str(device.index)
)
self.device_str = device_str
# Read from all files ending with .safetensors in the checkpoint directory
safetensor_files = [
os.path.join(path, fname)
for fname in os.listdir(path)
if fname.endswith(".safetensors")
]
# Build a mapping from tensor name to (file, key)
tensor_name_to_file = {}
for safetensor_file in safetensor_files:
with safe_open(safetensor_file, framework="pt", device=device_str) as f:
for key in f.keys():
tensor_name_to_file[key] = safetensor_file
self.tensor_name_to_file = tensor_name_to_file
def get(self, name: str) -> torch.Tensor:
match PARAM_NAME_MAP.get(name, name):
case (blocks_name, scales_name):
# MoE weights: are in block-based MXFP4 format
return self._get_mxfp4_tensor(blocks_name, scales_name, dtype=torch.bfloat16)
case tensor_name:
# MoE biases and other weights
return self._get_tensor(tensor_name)
def _get_tensor(self, name: str) -> str:
assert name in self.tensor_name_to_file, f"Tensor {name} not found in checkpoint."
with safe_open(
self.tensor_name_to_file[name], framework="pt", device=self.device_str
) as f:
return f.get_tensor(name)
def _get_mxfp4_tensor(
self,
blocks_name: str,
scales_name: str,
*,
dtype: torch.dtype = torch.bfloat16,
rows_per_chunk: int = 16384 * 512,
) -> torch.Tensor:
assert blocks_name in self.tensor_name_to_file, (
f"Blocks tensor {blocks_name} not found in checkpoint."
)
assert scales_name in self.tensor_name_to_file, (
f"Scales tensor {scales_name} not found in checkpoint."
)
blocks = self._get_tensor(blocks_name)
scales = self._get_tensor(scales_name).to(torch.int32) - 127
assert blocks.shape[:-1] == scales.shape, (
f"{blocks.shape=} does not match {scales.shape=}"
)
lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
*prefix_shape, G, B = blocks.shape
rows_total = math.prod(prefix_shape) * G
blocks = blocks.reshape(rows_total, B)
scales = scales.reshape(rows_total, 1)
out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)
for r0 in range(0, rows_total, rows_per_chunk):
r1 = min(r0 + rows_per_chunk, rows_total)
blk = blocks[r0:r1]
exp = scales[r0:r1]
# nibble indices -> int64
idx_lo = (blk & 0x0F).to(torch.long)
idx_hi = (blk >> 4).to(torch.long)
sub = out[r0:r1]
sub[:, 0::2] = lut[idx_lo]
sub[:, 1::2] = lut[idx_hi]
torch.ldexp(sub, exp, out=sub)
del idx_lo, idx_hi, blk, exp
return out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
def _get_mxfp4_tensor_copy(self, blocks_name: str, scales_name: str, dtype: torch.dtype = torch.bfloat16):
"short version that uses a lot of memory"
loaded_blocks = self._get_tensor(blocks_name)
# Split it into low and high nibbles, upcast to bytes, and interleave (for swiglu)
loaded_blocks_lo = loaded_blocks & 0x0F
loaded_blocks_hi = loaded_blocks >> 4
loaded_blocks = torch.stack((loaded_blocks_lo, loaded_blocks_hi), dim=-1)
loaded_blocks = loaded_blocks.view(*loaded_blocks.shape[:-2], loaded_blocks.shape[-2] * 2)
loaded_scales = self._get_tensor(scales_name)
# Upcast to int32 and subtract bias
loaded_scales = loaded_scales.int() - 127
# Convert MXFP4 numbers into target dtype
fp4_values = torch.tensor(FP4_VALUES, dtype=dtype, device=self.device_str)
loaded_tensor = torch.ldexp(fp4_values[loaded_blocks.int()], loaded_scales.unsqueeze(-1))
loaded_tensor = loaded_tensor.view(*loaded_tensor.shape[:-2], -1)
return loaded_tensor |