|
|
import argparse |
|
|
from pathlib import Path |
|
|
|
|
|
import jax.numpy as jnp |
|
|
import numpy as np |
|
|
from safetensors.flax import save_file |
|
|
from tqdm import tqdm |
|
|
|
|
|
from gemma import gm |
|
|
|
|
|
|
|
|
def flatten(x: jnp.ndarray, start: int = 0, end: int = -1): |
|
|
if start < 0: |
|
|
start += x.ndim |
|
|
if end < 0: |
|
|
end += x.ndim |
|
|
new_shape = x.shape[:start] + (-1,) + x.shape[end + 1 :] |
|
|
return x.reshape(new_shape) |
|
|
|
|
|
|
|
|
def unflatten(x: jnp.ndarray, dim: int, sizes: tuple[int, ...]): |
|
|
new_shape = x.shape[:dim] + tuple(sizes) + x.shape[dim + 1 :] |
|
|
return x.reshape(new_shape) |
|
|
|
|
|
|
|
|
|
|
|
def check_groups(groups: jnp.ndarray, scales: jnp.ndarray, dim: int): |
|
|
|
|
|
|
|
|
inv_scale = 1.0 / scales.clip(1e-12) |
|
|
q_group = jnp.round(groups * inv_scale) |
|
|
max_diff = jnp.abs(q_group * scales - groups).max(dim, keepdims=True) |
|
|
return max_diff < 1e-6, max_diff |
|
|
|
|
|
|
|
|
def find_scales(w: jnp.ndarray, dim: int, pbar: bool = True): |
|
|
w = unflatten(w, dim, (-1, 32)) |
|
|
group_range = w.max(dim + 1, keepdims=True) - w.min(dim + 1, keepdims=True) |
|
|
|
|
|
scales = np.zeros_like(group_range) |
|
|
for q in tqdm(range(15, 0, -1), disable=not pbar): |
|
|
try_scale = group_range / q |
|
|
ok, _ = check_groups(w, try_scale, dim + 1) |
|
|
scales[ok] = try_scale[ok] |
|
|
|
|
|
ok, _ = check_groups(w, scales, dim + 1) |
|
|
assert ok.all() |
|
|
|
|
|
return scales.squeeze(dim + 1) |
|
|
|
|
|
|
|
|
|
|
|
def convert_to_hf(params): |
|
|
state_dict = dict() |
|
|
|
|
|
state_dict["model.embed_tokens.weight"] = params["embedder"]["input_embedding"] |
|
|
state_dict["model.norm.weight"] = params["final_norm"]["scale"] |
|
|
|
|
|
layer_idx = 0 |
|
|
while f"layer_{layer_idx}" in params: |
|
|
prefix = f"model.layers.{layer_idx}." |
|
|
layer_params = params[f"layer_{layer_idx}"] |
|
|
state_dict[f"{prefix}input_layernorm.weight"] = layer_params["pre_attention_norm"]["scale"] |
|
|
state_dict[f"{prefix}post_attention_layernorm.weight"] = layer_params["post_attention_norm"]["scale"] |
|
|
state_dict[f"{prefix}pre_feedforward_layernorm.weight"] = layer_params["pre_ffw_norm"]["scale"] |
|
|
state_dict[f"{prefix}post_feedforward_layernorm.weight"] = layer_params["post_ffw_norm"]["scale"] |
|
|
|
|
|
prefix = f"model.layers.{layer_idx}.self_attn." |
|
|
attn_params = layer_params["attn"] |
|
|
state_dict[f"{prefix}q_norm.weight"] = attn_params["_query_norm"]["scale"] |
|
|
state_dict[f"{prefix}k_norm.weight"] = attn_params["_key_norm"]["scale"] |
|
|
|
|
|
|
|
|
state_dict[f"{prefix}q_proj.weight"] = flatten(attn_params["q_einsum"]["w"].transpose(0, 2, 1), end=1) |
|
|
state_dict[f"{prefix}k_proj.weight"] = flatten(attn_params["kv_einsum"]["w"][0].transpose(0, 2, 1), end=1) |
|
|
state_dict[f"{prefix}v_proj.weight"] = flatten(attn_params["kv_einsum"]["w"][1].transpose(0, 2, 1), end=1) |
|
|
|
|
|
|
|
|
state_dict[f"{prefix}o_proj.weight"] = flatten(attn_params["attn_vec_einsum"]["w"], end=1).T |
|
|
|
|
|
prefix = f"model.layers.{layer_idx}.mlp." |
|
|
mlp_params = layer_params["mlp"] |
|
|
state_dict[f"{prefix}gate_proj.weight"] = mlp_params["gating_einsum"][0] |
|
|
state_dict[f"{prefix}up_proj.weight"] = mlp_params["gating_einsum"][1] |
|
|
state_dict[f"{prefix}down_proj.weight"] = mlp_params["linear"].T |
|
|
|
|
|
layer_idx += 1 |
|
|
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
def convert_awq(state_dict: dict[str, jnp.ndarray]): |
|
|
awq_state_dict = dict() |
|
|
|
|
|
for k, v in tqdm(state_dict.items(), total=len(state_dict)): |
|
|
|
|
|
if k == "model.embed_tokens.weight" or v.ndim == 1: |
|
|
awq_state_dict[k] = v.astype(jnp.bfloat16) |
|
|
continue |
|
|
|
|
|
assert v.ndim == 2 |
|
|
v = v.T |
|
|
|
|
|
|
|
|
v = np.asarray(v) |
|
|
K, N = v.shape |
|
|
scales = find_scales(v, dim=0, pbar=False) |
|
|
inv_scale = 1 / scales.clip(1e-12) |
|
|
qweight = np.round(v.reshape(K // 32, 32, N) * inv_scale[:, None]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qweight = (qweight + 8).astype(np.uint32) |
|
|
|
|
|
|
|
|
|
|
|
qweight = qweight.reshape(K, N // 8, 8) |
|
|
qweight_packed = ( |
|
|
(qweight[..., 7] << (7 * 4)) |
|
|
| (qweight[..., 5] << (6 * 4)) |
|
|
| (qweight[..., 3] << (5 * 4)) |
|
|
| (qweight[..., 1] << (4 * 4)) |
|
|
| (qweight[..., 6] << (3 * 4)) |
|
|
| (qweight[..., 4] << (2 * 4)) |
|
|
| (qweight[..., 2] << (1 * 4)) |
|
|
| (qweight[..., 0] << (0 * 4)) |
|
|
) |
|
|
qweight_packed = qweight_packed.view(np.int32).reshape(K, N // 8) |
|
|
|
|
|
prefix = k.removesuffix(".weight") |
|
|
awq_state_dict[f"{prefix}.qweight"] = qweight_packed |
|
|
awq_state_dict[f"{prefix}.qzeros"] = np.full((K // 32, N // 8), 0x8888_8888, dtype=np.uint32).view(np.int32) |
|
|
awq_state_dict[f"{prefix}.scales"] = jnp.asarray(scales).astype(jnp.bfloat16) |
|
|
|
|
|
return awq_state_dict |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--ckpt_dir", required=True, type=Path) |
|
|
parser.add_argument("--save_path", required=True, type=Path) |
|
|
args = parser.parse_args() |
|
|
|
|
|
params = gm.ckpts.load_params(args.ckpt_dir.absolute()) |
|
|
state_dict = convert_to_hf(params) |
|
|
awq_state_dict = convert_awq(state_dict) |
|
|
args.save_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
save_file(awq_state_dict, args.save_path) |
|
|
|