gemma-3-1b-it-int4-awq / convert_flax.py
gaunernst's picture
add model
36e3234
raw
history blame
6.04 kB
import argparse
from pathlib import Path
import jax.numpy as jnp
import numpy as np
from safetensors.flax import save_file
from tqdm import tqdm
from gemma import gm
def flatten(x: jnp.ndarray, start: int = 0, end: int = -1):
if start < 0:
start += x.ndim
if end < 0:
end += x.ndim
new_shape = x.shape[:start] + (-1,) + x.shape[end + 1 :]
return x.reshape(new_shape)
def unflatten(x: jnp.ndarray, dim: int, sizes: tuple[int, ...]):
new_shape = x.shape[:dim] + tuple(sizes) + x.shape[dim + 1 :]
return x.reshape(new_shape)
# correct quantization parameters mean quantization error = 0 (or close to 0)
def check_groups(groups: jnp.ndarray, scales: jnp.ndarray, dim: int):
# groups: (a, b, c, 32, d, e, f)
# scales: (a, b, c, 1, d, e, f)
inv_scale = 1.0 / scales.clip(1e-12)
q_group = jnp.round(groups * inv_scale)
max_diff = jnp.abs(q_group * scales - groups).max(dim, keepdims=True)
return max_diff < 1e-6, max_diff
def find_scales(w: jnp.ndarray, dim: int, pbar: bool = True):
w = unflatten(w, dim, (-1, 32))
group_range = w.max(dim + 1, keepdims=True) - w.min(dim + 1, keepdims=True)
scales = np.zeros_like(group_range)
for q in tqdm(range(15, 0, -1), disable=not pbar):
try_scale = group_range / q
ok, _ = check_groups(w, try_scale, dim + 1)
scales[ok] = try_scale[ok]
ok, _ = check_groups(w, scales, dim + 1)
assert ok.all()
return scales.squeeze(dim + 1)
# convert to HF format first, then apply quantization
def convert_to_hf(params):
state_dict = dict()
# TODO: output projection
state_dict["model.embed_tokens.weight"] = params["embedder"]["input_embedding"]
state_dict["model.norm.weight"] = params["final_norm"]["scale"]
layer_idx = 0
while f"layer_{layer_idx}" in params:
prefix = f"model.layers.{layer_idx}."
layer_params = params[f"layer_{layer_idx}"]
state_dict[f"{prefix}input_layernorm.weight"] = layer_params["pre_attention_norm"]["scale"]
state_dict[f"{prefix}post_attention_layernorm.weight"] = layer_params["post_attention_norm"]["scale"]
state_dict[f"{prefix}pre_feedforward_layernorm.weight"] = layer_params["pre_ffw_norm"]["scale"]
state_dict[f"{prefix}post_feedforward_layernorm.weight"] = layer_params["post_ffw_norm"]["scale"]
prefix = f"model.layers.{layer_idx}.self_attn."
attn_params = layer_params["attn"]
state_dict[f"{prefix}q_norm.weight"] = attn_params["_query_norm"]["scale"]
state_dict[f"{prefix}k_norm.weight"] = attn_params["_key_norm"]["scale"]
# (num_heads, hidden_size, head_dim) -> (num_heads * head_dim, hidden_size)
state_dict[f"{prefix}q_proj.weight"] = flatten(attn_params["q_einsum"]["w"].transpose(0, 2, 1), end=1)
state_dict[f"{prefix}k_proj.weight"] = flatten(attn_params["kv_einsum"]["w"][0].transpose(0, 2, 1), end=1)
state_dict[f"{prefix}v_proj.weight"] = flatten(attn_params["kv_einsum"]["w"][1].transpose(0, 2, 1), end=1)
# (num_heads, head_dim, hidden_size) -> (hidden_size, num_heads * head_dim)
state_dict[f"{prefix}o_proj.weight"] = flatten(attn_params["attn_vec_einsum"]["w"], end=1).T
prefix = f"model.layers.{layer_idx}.mlp."
mlp_params = layer_params["mlp"]
state_dict[f"{prefix}gate_proj.weight"] = mlp_params["gating_einsum"][0] # NOTE: may need to transpose?
state_dict[f"{prefix}up_proj.weight"] = mlp_params["gating_einsum"][1]
state_dict[f"{prefix}down_proj.weight"] = mlp_params["linear"].T
layer_idx += 1
return state_dict
def convert_awq(state_dict: dict[str, jnp.ndarray]):
awq_state_dict = dict()
for k, v in tqdm(state_dict.items(), total=len(state_dict)):
# AWQ doesn't support INT4 embeddings
if k == "model.embed_tokens.weight" or v.ndim == 1:
awq_state_dict[k] = v.astype(jnp.bfloat16)
continue
assert v.ndim == 2
v = v.T # AWQ transpose the weight
# use numpy since jnp is very slow, likely due to bad memory management on CUDA
v = np.asarray(v)
K, N = v.shape
scales = find_scales(v, dim=0, pbar=False) # (K/32, N)
inv_scale = 1 / scales.clip(1e-12)
qweight = np.round(v.reshape(K // 32, 32, N) * inv_scale[:, None])
# AWQ is actually UINT4 (instead of INT4)
# hence, we will shift qweight up by 8 (even though Google AQT only uses [-7,7])
# and set zero_point = 8
qweight = (qweight + 8).astype(np.uint32)
# AWQ pack 8 int4 into UINT32 in the following layout (from high bits to low bits)
# [7 5 3 1 6 4 2 0] along the 2nd dim
qweight = qweight.reshape(K, N // 8, 8)
qweight_packed = (
(qweight[..., 7] << (7 * 4))
| (qweight[..., 5] << (6 * 4))
| (qweight[..., 3] << (5 * 4))
| (qweight[..., 1] << (4 * 4))
| (qweight[..., 6] << (3 * 4))
| (qweight[..., 4] << (2 * 4))
| (qweight[..., 2] << (1 * 4))
| (qweight[..., 0] << (0 * 4))
)
qweight_packed = qweight_packed.view(np.int32).reshape(K, N // 8)
prefix = k.removesuffix(".weight")
awq_state_dict[f"{prefix}.qweight"] = qweight_packed
awq_state_dict[f"{prefix}.qzeros"] = np.full((K // 32, N // 8), 0x8888_8888, dtype=np.uint32).view(np.int32)
awq_state_dict[f"{prefix}.scales"] = jnp.asarray(scales).astype(jnp.bfloat16)
return awq_state_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_dir", required=True, type=Path)
parser.add_argument("--save_path", required=True, type=Path)
args = parser.parse_args()
params = gm.ckpts.load_params(args.ckpt_dir.absolute())
state_dict = convert_to_hf(params)
awq_state_dict = convert_awq(state_dict)
args.save_path.parent.mkdir(parents=True, exist_ok=True)
save_file(awq_state_dict, args.save_path)