| import argparse |
| import json |
| from pathlib import Path |
|
|
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| import orbax.checkpoint as ocp |
| from safetensors.flax import save_file |
| from tqdm import tqdm |
|
|
| SIGLIP_PREFIX = "SigLiPFromPatches_0/siglip_encoder" |
|
|
|
|
| def flatten(x: np.ndarray, start: int = 0, end: int = -1): |
| if start < 0: |
| start += x.ndim |
| if end < 0: |
| end += x.ndim |
| new_shape = x.shape[:start] + (-1,) + x.shape[end + 1 :] |
| return x.reshape(new_shape) |
|
|
|
|
| def unflatten(x: np.ndarray, dim: int, sizes: tuple[int, ...]): |
| new_shape = x.shape[:dim] + tuple(sizes) + x.shape[dim + 1 :] |
| return x.reshape(new_shape) |
|
|
|
|
| |
| def check_groups(groups: np.ndarray, scales: np.ndarray, dim: int): |
| |
| |
| inv_scale = 1.0 / scales.clip(1e-12) |
| q_group = np.round(groups * inv_scale) |
| max_diff = np.abs(q_group * scales - groups).max(dim, keepdims=True) |
| return max_diff < 1e-6, max_diff |
|
|
|
|
| def find_scales(w: np.ndarray, dim: int): |
| w = unflatten(w, dim, (-1, 32)) |
| group_range = w.max(dim + 1, keepdims=True) - w.min(dim + 1, keepdims=True) |
|
|
| scales = np.zeros_like(group_range) |
| for q in range(15, 0, -1): |
| try_scale = group_range / q |
| ok, _ = check_groups(w, try_scale, dim + 1) |
| scales[ok] = try_scale[ok] |
|
|
| ok, _ = check_groups(w, scales, dim + 1) |
| assert ok.all() |
|
|
| return scales.squeeze(dim + 1) |
|
|
|
|
| def convert_siglip(params, num_layers: int): |
| state_dict = dict() |
|
|
| def convert_layer(prefix: str, layer: dict[str, np.ndarray]): |
| bias = layer["bias"] |
|
|
| if "kernel" in layer: |
| w = layer["kernel"] |
| if w.ndim == 2: |
| w = w.T |
|
|
| elif w.ndim == 3: |
| |
| if bias.ndim == 2: |
| w = flatten(w, 1, 2).T |
| bias = bias.reshape(-1) |
|
|
| |
| elif bias.ndim == 1: |
| w = flatten(w, 0, 1).T |
|
|
| elif w.ndim == 4: |
| w = w.transpose(3, 2, 0, 1) |
|
|
| else: |
| raise RuntimeError(f"Unsupported {w.shape=}") |
|
|
| elif "scale" in layer: |
| w = layer["scale"] |
|
|
| else: |
| raise RuntimeError |
|
|
| state_dict[f"{prefix}weight"] = w |
| state_dict[f"{prefix}bias"] = bias |
|
|
| convert_layer("embeddings.patch_embedding.", params[f"{SIGLIP_PREFIX}/embedding"]) |
| state_dict["embeddings.position_embedding.weight"] = params[SIGLIP_PREFIX]["pos_embedding"].squeeze(0) |
| convert_layer("post_layernorm.", params[f"{SIGLIP_PREFIX}/Transformer/encoder_norm"]) |
|
|
| for layer_idx in range(num_layers): |
| prefix = f"encoder.layers.{layer_idx}." |
| layer_prefix = f"{SIGLIP_PREFIX}/Transformer/encoderblock_{layer_idx}/" |
|
|
| convert_layer(f"{prefix}layer_norm1.", params[f"{layer_prefix}LayerNorm_0"]) |
| convert_layer(f"{prefix}layer_norm2.", params[f"{layer_prefix}LayerNorm_1"]) |
|
|
| attn_prefix = f"{layer_prefix}MultiHeadDotProductAttention_0/" |
| convert_layer(f"{prefix}self_attn.q_proj.", params[f"{attn_prefix}query"]) |
| convert_layer(f"{prefix}self_attn.k_proj.", params[f"{attn_prefix}key"]) |
| convert_layer(f"{prefix}self_attn.v_proj.", params[f"{attn_prefix}value"]) |
| convert_layer(f"{prefix}self_attn.out_proj.", params[f"{attn_prefix}out"]) |
|
|
| mlp_prefix = f"{layer_prefix}MlpBlock_0/" |
| convert_layer(f"{prefix}mlp.fc1.", params[f"{mlp_prefix}Dense_0"]) |
| convert_layer(f"{prefix}mlp.fc2.", params[f"{mlp_prefix}Dense_1"]) |
|
|
| return state_dict |
|
|
|
|
| |
| def convert_to_hf(path: Path): |
| path = path.absolute() |
| ckpt = ocp.StandardCheckpointer() |
| metadata = dict(ckpt.metadata(path)) |
| metadata = jax.tree.map(ocp.utils.to_shape_dtype_struct, metadata) |
|
|
| num_layers = num_siglip_layers = 0 |
| while f"transformer/layer_{num_layers}/attn/_key_norm" in metadata: |
| num_layers += 1 |
| while f"{SIGLIP_PREFIX}/Transformer/encoderblock_{num_siglip_layers}/LayerNorm_0" in metadata: |
| num_siglip_layers += 1 |
| print(f"{num_layers=}") |
| print(f"{num_siglip_layers=}") |
|
|
| |
| params = ckpt.restore(path) |
| state_dict = dict() |
|
|
| if num_siglip_layers > 0: |
| |
| embed = params["transformer/embedder"]["input_embedding"] |
| params["transformer/embedder"]["input_embedding"] = np.pad(embed, ((0, 64), (0, 0))) |
| gemma_prefix = "language_model." |
|
|
| prefix = "multi_modal_projector.mm_" |
| jax_prefix = "transformer/embedder/" |
| state_dict[f"{prefix}input_projection_weight"] = params[f"{jax_prefix}mm_input_projection"]["w"] |
| state_dict[f"{prefix}soft_emb_norm.weight"] = params[f"{jax_prefix}mm_soft_embedding_norm"]["scale"] |
|
|
| else: |
| gemma_prefix = "" |
|
|
| state_dict[f"{gemma_prefix}model.embed_tokens.weight"] = params["transformer/embedder"]["input_embedding"] |
| state_dict[f"{gemma_prefix}model.norm.weight"] = params["transformer/final_norm"]["scale"] |
|
|
| yield state_dict |
|
|
| for layer_idx in range(num_layers): |
| jax_prefix = f"transformer/layer_{layer_idx}/" |
|
|
| state_dict = dict() |
| prefix = f"{gemma_prefix}model.layers.{layer_idx}." |
| state_dict[f"{prefix}input_layernorm.weight"] = params[f"{jax_prefix}pre_attention_norm"]["scale"] |
| state_dict[f"{prefix}post_attention_layernorm.weight"] = params[f"{jax_prefix}post_attention_norm"]["scale"] |
| state_dict[f"{prefix}pre_feedforward_layernorm.weight"] = params[f"{jax_prefix}pre_ffw_norm"]["scale"] |
| state_dict[f"{prefix}post_feedforward_layernorm.weight"] = params[f"{jax_prefix}post_ffw_norm"]["scale"] |
|
|
| prefix = f"{gemma_prefix}model.layers.{layer_idx}.self_attn." |
| jax_prefix = f"transformer/layer_{layer_idx}/attn/" |
| state_dict[f"{prefix}q_norm.weight"] = params[f"{jax_prefix}_query_norm"]["scale"] |
| state_dict[f"{prefix}k_norm.weight"] = params[f"{jax_prefix}_key_norm"]["scale"] |
|
|
| |
| state_dict[f"{prefix}q_proj.weight"] = flatten(params[f"{jax_prefix}q_einsum"]["w"].transpose(0, 2, 1), end=1) |
| state_dict[f"{prefix}k_proj.weight"] = flatten( |
| params[f"{jax_prefix}kv_einsum"]["w"][0].transpose(0, 2, 1), end=1 |
| ) |
| state_dict[f"{prefix}v_proj.weight"] = flatten( |
| params[f"{jax_prefix}kv_einsum"]["w"][1].transpose(0, 2, 1), end=1 |
| ) |
|
|
| |
| state_dict[f"{prefix}o_proj.weight"] = flatten(params[f"{jax_prefix}attn_vec_einsum"]["w"], end=1).T |
|
|
| prefix = f"{gemma_prefix}model.layers.{layer_idx}.mlp." |
| jax_prefix = f"transformer/layer_{layer_idx}/mlp/" |
| state_dict[f"{prefix}gate_proj.weight"] = params[f"{jax_prefix}gating_einsum"]["w"][0] |
| state_dict[f"{prefix}up_proj.weight"] = params[f"{jax_prefix}gating_einsum"]["w"][1] |
| state_dict[f"{prefix}down_proj.weight"] = params[f"{jax_prefix}linear"]["w"].T |
|
|
| yield state_dict |
|
|
| |
| if num_siglip_layers > 0: |
| siglip_state_dict = convert_siglip(params, num_siglip_layers) |
| for k, v in siglip_state_dict.items(): |
| state_dict[f"vision_tower.vision_model.{k}"] = v |
| yield state_dict |
|
|
|
|
| def convert_awq(state_dict: dict[str, np.ndarray]): |
| awq_state_dict = dict() |
|
|
| for k, v in state_dict.items(): |
| if ( |
| k.endswith("model.embed_tokens.weight") |
| or k.startswith(("vision_tower", "multi_modal_projector")) |
| or v.ndim == 1 |
| ): |
| awq_state_dict[k] = v.astype(jnp.bfloat16) |
| continue |
|
|
| assert v.ndim == 2 |
| v = v.T |
|
|
| K, N = v.shape |
| scales = find_scales(v, dim=0) |
| inv_scale = 1 / scales.clip(1e-12) |
| qweight = np.round(v.reshape(K // 32, 32, N) * inv_scale[:, None]) |
|
|
| |
| |
| |
| qweight = (qweight + 8).astype(np.uint32) |
|
|
| |
| |
| qweight = qweight.reshape(K, N // 8, 8) |
| qweight_packed = ( |
| (qweight[..., 7] << (7 * 4)) |
| | (qweight[..., 5] << (6 * 4)) |
| | (qweight[..., 3] << (5 * 4)) |
| | (qweight[..., 1] << (4 * 4)) |
| | (qweight[..., 6] << (3 * 4)) |
| | (qweight[..., 4] << (2 * 4)) |
| | (qweight[..., 2] << (1 * 4)) |
| | (qweight[..., 0] << (0 * 4)) |
| ) |
| qweight_packed = qweight_packed.view(np.int32).reshape(K, N // 8) |
|
|
| prefix = k.removesuffix(".weight") |
| awq_state_dict[f"{prefix}.qweight"] = qweight_packed |
| awq_state_dict[f"{prefix}.qzeros"] = np.full((K // 32, N // 8), 0x8888_8888, dtype=np.uint32).view(np.int32) |
| awq_state_dict[f"{prefix}.scales"] = scales.astype(jnp.bfloat16) |
|
|
| return awq_state_dict |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--ckpt_dir", required=True, type=Path) |
| parser.add_argument("--save_dir", required=True, type=Path) |
| args = parser.parse_args() |
|
|
| args.save_dir.mkdir(parents=True, exist_ok=True) |
|
|
| total_size = 0 |
| weight_map = dict() |
|
|
| state_dict = dict() |
| size = 0 |
| shard_idx = 0 |
| filename = f"model-{shard_idx + 1:05d}.safetensors" |
| for sub_state_dict in tqdm(convert_to_hf(args.ckpt_dir)): |
| sub_state_dict = convert_awq(sub_state_dict) |
| new_size = sum(v.nbytes for v in sub_state_dict.values()) |
|
|
| if size + new_size > 5e9: |
| save_file(state_dict, args.save_dir / filename) |
| state_dict = dict() |
| size = 0 |
| shard_idx += 1 |
| filename = f"model-{shard_idx + 1:05d}.safetensors" |
|
|
| |
| size += new_size |
| total_size += new_size |
| for k, v in sub_state_dict.items(): |
| state_dict[k] = v |
| weight_map[k] = filename |
|
|
| save_file(state_dict, args.save_dir / filename) |
| json.dump( |
| dict(metadata=dict(total_size=total_size), weight_map=weight_map), |
| open(args.save_dir / "model.safetensors.index.json", "w"), |
| ) |
|
|