| | import math
|
| | import os
|
| |
|
| | import torch
|
| | from safetensors import safe_open
|
| |
|
| |
|
| |
|
| | BYTES_PER_BLOCK = 16
|
| |
|
| | FP4_VALUES = [
|
| | +0.0, +0.5, +1.0, +1.5, +2.0, +3.0, +4.0, +6.0,
|
| | -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
|
| | ]
|
| |
|
| |
|
| | PARAM_NAME_MAP = {
|
| | f"block.{n}.mlp.mlp1_bias": f"block.{n}.mlp.mlp1_bias" for n in range(36)
|
| | } | {
|
| | f"block.{n}.mlp.mlp1_weight": (f"block.{n}.mlp.mlp1_weight.blocks", f"block.{n}.mlp.mlp1_weight.scales") for n in range(36)
|
| | } | {
|
| | f"block.{n}.mlp.mlp2_bias": f"block.{n}.mlp.mlp2_bias" for n in range(36)
|
| | } | {
|
| | f"block.{n}.mlp.mlp2_weight": (f"block.{n}.mlp.mlp2_weight.blocks", f"block.{n}.mlp.mlp2_weight.scales") for n in range(36)
|
| | }
|
| |
|
| |
|
| | class Checkpoint:
|
| | def __init__(self, path: str, device: torch.device):
|
| | device_str = (
|
| | device.type
|
| | if device.index is None
|
| | else device.type + ":" + str(device.index)
|
| | )
|
| | self.device_str = device_str
|
| |
|
| |
|
| | safetensor_files = [
|
| | os.path.join(path, fname)
|
| | for fname in os.listdir(path)
|
| | if fname.endswith(".safetensors")
|
| | ]
|
| |
|
| | tensor_name_to_file = {}
|
| | for safetensor_file in safetensor_files:
|
| | with safe_open(safetensor_file, framework="pt", device=device_str) as f:
|
| | for key in f.keys():
|
| | tensor_name_to_file[key] = safetensor_file
|
| |
|
| | self.tensor_name_to_file = tensor_name_to_file
|
| |
|
| | def get(self, name: str) -> torch.Tensor:
|
| | match PARAM_NAME_MAP.get(name, name):
|
| | case (blocks_name, scales_name):
|
| |
|
| | return self._get_mxfp4_tensor(blocks_name, scales_name, dtype=torch.bfloat16)
|
| | case tensor_name:
|
| |
|
| | return self._get_tensor(tensor_name)
|
| |
|
| | def _get_tensor(self, name: str) -> str:
|
| | assert name in self.tensor_name_to_file, f"Tensor {name} not found in checkpoint."
|
| | with safe_open(
|
| | self.tensor_name_to_file[name], framework="pt", device=self.device_str
|
| | ) as f:
|
| | return f.get_tensor(name)
|
| |
|
| | def _get_mxfp4_tensor(
|
| | self,
|
| | blocks_name: str,
|
| | scales_name: str,
|
| | *,
|
| | dtype: torch.dtype = torch.bfloat16,
|
| | rows_per_chunk: int = 16384 * 512,
|
| | ) -> torch.Tensor:
|
| | assert blocks_name in self.tensor_name_to_file, (
|
| | f"Blocks tensor {blocks_name} not found in checkpoint."
|
| | )
|
| | assert scales_name in self.tensor_name_to_file, (
|
| | f"Scales tensor {scales_name} not found in checkpoint."
|
| | )
|
| |
|
| | blocks = self._get_tensor(blocks_name)
|
| | scales = self._get_tensor(scales_name).to(torch.int32) - 127
|
| |
|
| | assert blocks.shape[:-1] == scales.shape, (
|
| | f"{blocks.shape=} does not match {scales.shape=}"
|
| | )
|
| |
|
| | lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
|
| |
|
| | *prefix_shape, G, B = blocks.shape
|
| | rows_total = math.prod(prefix_shape) * G
|
| |
|
| | blocks = blocks.reshape(rows_total, B)
|
| | scales = scales.reshape(rows_total, 1)
|
| |
|
| | out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)
|
| |
|
| | for r0 in range(0, rows_total, rows_per_chunk):
|
| | r1 = min(r0 + rows_per_chunk, rows_total)
|
| |
|
| | blk = blocks[r0:r1]
|
| | exp = scales[r0:r1]
|
| |
|
| |
|
| | idx_lo = (blk & 0x0F).to(torch.long)
|
| | idx_hi = (blk >> 4).to(torch.long)
|
| |
|
| | sub = out[r0:r1]
|
| | sub[:, 0::2] = lut[idx_lo]
|
| | sub[:, 1::2] = lut[idx_hi]
|
| |
|
| | torch.ldexp(sub, exp, out=sub)
|
| | del idx_lo, idx_hi, blk, exp
|
| |
|
| | return out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
|
| |
|
| | def _get_mxfp4_tensor_copy(self, blocks_name: str, scales_name: str, dtype: torch.dtype = torch.bfloat16):
|
| | "short version that uses a lot of memory"
|
| |
|
| | loaded_blocks = self._get_tensor(blocks_name)
|
| |
|
| | loaded_blocks_lo = loaded_blocks & 0x0F
|
| | loaded_blocks_hi = loaded_blocks >> 4
|
| | loaded_blocks = torch.stack((loaded_blocks_lo, loaded_blocks_hi), dim=-1)
|
| | loaded_blocks = loaded_blocks.view(*loaded_blocks.shape[:-2], loaded_blocks.shape[-2] * 2)
|
| |
|
| | loaded_scales = self._get_tensor(scales_name)
|
| |
|
| | loaded_scales = loaded_scales.int() - 127
|
| |
|
| |
|
| | fp4_values = torch.tensor(FP4_VALUES, dtype=dtype, device=self.device_str)
|
| | loaded_tensor = torch.ldexp(fp4_values[loaded_blocks.int()], loaded_scales.unsqueeze(-1))
|
| | loaded_tensor = loaded_tensor.view(*loaded_tensor.shape[:-2], -1)
|
| | return loaded_tensor |