gemma-4-E2B-it / utils /runtime_layout.py
yongqiang
Align gemma4 runtime layout and refresh Python deployment docs
03f32f8
from __future__ import annotations
from pathlib import Path
import numpy as np
from ml_dtypes import bfloat16
def _unique_paths(paths):
seen = set()
unique = []
for path in paths:
path = Path(path)
key = str(path)
if key not in seen:
seen.add(key)
unique.append(path)
return unique
def resolve_runtime_file(base_dir: str | Path, candidates: list[str]) -> Path:
base_path = Path(base_dir)
checked = []
for name in candidates:
if not name:
continue
candidate = base_path / name
checked.append(candidate)
if candidate.exists():
return candidate
checked_msg = ", ".join(path.name for path in checked) if checked else "<none>"
raise FileNotFoundError(f"Cannot find any runtime file in {base_path}. Checked: {checked_msg}")
def _has_text_runtime_files(base_dir: Path) -> bool:
if not base_dir.exists() or not base_dir.is_dir():
return False
has_layers = any(base_dir.glob("*_p*_l0_together.axmodel"))
has_post = any(base_dir.glob("*_post.axmodel"))
return has_layers and has_post
def default_axmodel_path(script_dir: str | Path) -> str:
script_dir = Path(script_dir)
candidates = _unique_paths(
[
script_dir,
script_dir / "gemma_4_e2b_it_ax650n_axmodel",
script_dir / "gemma_4_e2b_it_ax650n_w4a16_axmodel",
script_dir / "gemma-4-E2B-it_axmodel",
]
)
for candidate in candidates:
if _has_text_runtime_files(candidate):
return str(candidate)
return str(candidates[0])
def load_text_embeddings(base_dir: str | Path, config):
runtime_path = Path(base_dir)
candidates = _unique_paths(
[
runtime_path / "model.embed_tokens.weight.npy",
runtime_path / (getattr(config, "filename_tokens_embed", "") or ""),
runtime_path / "model.embed_tokens.weight.bfloat16.bin",
]
)
for candidate in candidates:
if not candidate.exists():
continue
if candidate.suffix == ".npy":
return np.load(str(candidate), mmap_mode="r")
if candidate.suffix == ".bin":
hidden_size = int(config.hidden_size)
vocab_size = int(getattr(config, "vocab_size", 0) or 0)
if vocab_size <= 0:
file_size = candidate.stat().st_size
bytes_per_row = np.dtype(bfloat16).itemsize * hidden_size
if file_size % bytes_per_row != 0:
raise ValueError(
f"Cannot infer embedding shape from {candidate}: "
f"file_size={file_size}, hidden_size={hidden_size}, bytes_per_row={bytes_per_row}"
)
vocab_size = file_size // bytes_per_row
shape = (vocab_size, hidden_size)
return np.memmap(str(candidate), dtype=bfloat16, mode="r", shape=shape)
checked = ", ".join(path.name for path in candidates)
raise FileNotFoundError(
f"Cannot find text embedding weights in {runtime_path}. Checked: {checked}"
)