gemma-4-E2B-it / utils /extract_per_layer_weights.py
yongqiang
first commit
c6de0f4
"""Extract Gemma 4 per-layer weights from model.safetensors into standalone .npy files.
Usage:
python extract_per_layer_weights.py \
--input_path /path/to/gemma-4-E2B-it \
--output_path /path/to/axmodel_dir
Produces:
{output_path}/embed_tokens_per_layer.weight.npy (~4.4 GB, bfloat16)
{output_path}/per_layer_model_projection.weight.npy (~26 MB, float32)
{output_path}/per_layer_projection_norm.weight.npy (<1 KB, float32)
"""
import argparse
from pathlib import Path
import numpy as np
from safetensors import safe_open
EXTRACTED_KEYS = [
"embed_tokens_per_layer.weight",
"per_layer_model_projection.weight",
"per_layer_projection_norm.weight",
]
def _infer_prefix(safe_file) -> str:
key = next((k for k in safe_file.keys() if k.endswith("embed_tokens.weight")), None)
if key is None:
return "model.language_model."
return key[: -len("embed_tokens.weight")]
def extract(input_path: str, output_path: str):
model_safetensors = Path(input_path) / "model.safetensors"
if not model_safetensors.exists():
raise FileNotFoundError(f"model.safetensors not found in {input_path}")
out_dir = Path(output_path)
out_dir.mkdir(parents=True, exist_ok=True)
# Use PyTorch framework to handle bfloat16 tensors, then convert to numpy
with safe_open(str(model_safetensors), framework="pt") as f:
prefix = _infer_prefix(f)
for short_key in EXTRACTED_KEYS:
full_key = prefix + short_key
out_file = out_dir / f"{short_key}.npy"
tensor = f.get_tensor(full_key)
# Large embedding: save as bfloat16 via ml_dtypes to keep file small
# Projection/norm: save as float32 for direct use
if "embed_tokens_per_layer" in short_key:
data = tensor.float().numpy().astype(np.float16)
else:
data = tensor.float().numpy()
np.save(str(out_file), data)
size_mb = out_file.stat().st_size / 1024 / 1024
print(f" {full_key} -> {out_file.name} ({size_mb:.1f} MB, {data.dtype})")
print("Done.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Extract Gemma 4 per-layer weights")
parser.add_argument("--input_path", required=True, help="HF model directory with model.safetensors")
parser.add_argument("--output_path", required=True, help="axmodel output directory")
args = parser.parse_args()
extract(args.input_path, args.output_path)