Quillan-Ronin / src /AceWeights.py
CrashOverrideX's picture
Add files using upload-large-folder tool
41a3927 verified
import math
import os
import torch
from safetensors import safe_open
# Bytes per MXFP4 block: 32 FP4 numbers packed in 16 bytes
BYTES_PER_BLOCK = 16
FP4_VALUES = [
+0.0, +0.5, +1.0, +1.5, +2.0, +3.0, +4.0, +6.0,
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
]
# Map the names assumed in this implementation to the checkpoint names.
PARAM_NAME_MAP = {
f"block.{n}.mlp.mlp1_bias": f"block.{n}.mlp.mlp1_bias" for n in range(36)
} | {
f"block.{n}.mlp.mlp1_weight": (f"block.{n}.mlp.mlp1_weight.blocks", f"block.{n}.mlp.mlp1_weight.scales") for n in range(36)
} | {
f"block.{n}.mlp.mlp2_bias": f"block.{n}.mlp.mlp2_bias" for n in range(36)
} | {
f"block.{n}.mlp.mlp2_weight": (f"block.{n}.mlp.mlp2_weight.blocks", f"block.{n}.mlp.mlp2_weight.scales") for n in range(36)
}
class Checkpoint:
def __init__(self, path: str, device: torch.device):
device_str = (
device.type
if device.index is None
else device.type + ":" + str(device.index)
)
self.device_str = device_str
# Read from all files ending with .safetensors in the checkpoint directory
safetensor_files = [
os.path.join(path, fname)
for fname in os.listdir(path)
if fname.endswith(".safetensors")
]
# Build a mapping from tensor name to (file, key)
tensor_name_to_file = {}
for safetensor_file in safetensor_files:
with safe_open(safetensor_file, framework="pt", device=device_str) as f:
for key in f.keys():
tensor_name_to_file[key] = safetensor_file
self.tensor_name_to_file = tensor_name_to_file
def get(self, name: str) -> torch.Tensor:
match PARAM_NAME_MAP.get(name, name):
case (blocks_name, scales_name):
# MoE weights: are in block-based MXFP4 format
return self._get_mxfp4_tensor(blocks_name, scales_name, dtype=torch.bfloat16)
case tensor_name:
# MoE biases and other weights
return self._get_tensor(tensor_name)
def _get_tensor(self, name: str) -> str:
assert name in self.tensor_name_to_file, f"Tensor {name} not found in checkpoint."
with safe_open(
self.tensor_name_to_file[name], framework="pt", device=self.device_str
) as f:
return f.get_tensor(name)
def _get_mxfp4_tensor(
self,
blocks_name: str,
scales_name: str,
*,
dtype: torch.dtype = torch.bfloat16,
rows_per_chunk: int = 16384 * 512,
) -> torch.Tensor:
assert blocks_name in self.tensor_name_to_file, (
f"Blocks tensor {blocks_name} not found in checkpoint."
)
assert scales_name in self.tensor_name_to_file, (
f"Scales tensor {scales_name} not found in checkpoint."
)
blocks = self._get_tensor(blocks_name)
scales = self._get_tensor(scales_name).to(torch.int32) - 127
assert blocks.shape[:-1] == scales.shape, (
f"{blocks.shape=} does not match {scales.shape=}"
)
lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
*prefix_shape, G, B = blocks.shape
rows_total = math.prod(prefix_shape) * G
blocks = blocks.reshape(rows_total, B)
scales = scales.reshape(rows_total, 1)
out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)
for r0 in range(0, rows_total, rows_per_chunk):
r1 = min(r0 + rows_per_chunk, rows_total)
blk = blocks[r0:r1]
exp = scales[r0:r1]
# nibble indices -> int64
idx_lo = (blk & 0x0F).to(torch.long)
idx_hi = (blk >> 4).to(torch.long)
sub = out[r0:r1]
sub[:, 0::2] = lut[idx_lo]
sub[:, 1::2] = lut[idx_hi]
torch.ldexp(sub, exp, out=sub)
del idx_lo, idx_hi, blk, exp
return out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
def _get_mxfp4_tensor_copy(self, blocks_name: str, scales_name: str, dtype: torch.dtype = torch.bfloat16):
"short version that uses a lot of memory"
loaded_blocks = self._get_tensor(blocks_name)
# Split it into low and high nibbles, upcast to bytes, and interleave (for swiglu)
loaded_blocks_lo = loaded_blocks & 0x0F
loaded_blocks_hi = loaded_blocks >> 4
loaded_blocks = torch.stack((loaded_blocks_lo, loaded_blocks_hi), dim=-1)
loaded_blocks = loaded_blocks.view(*loaded_blocks.shape[:-2], loaded_blocks.shape[-2] * 2)
loaded_scales = self._get_tensor(scales_name)
# Upcast to int32 and subtract bias
loaded_scales = loaded_scales.int() - 127
# Convert MXFP4 numbers into target dtype
fp4_values = torch.tensor(FP4_VALUES, dtype=dtype, device=self.device_str)
loaded_tensor = torch.ldexp(fp4_values[loaded_blocks.int()], loaded_scales.unsqueeze(-1))
loaded_tensor = loaded_tensor.view(*loaded_tensor.shape[:-2], -1)
return loaded_tensor