from __future__ import annotations from pathlib import Path import numpy as np from ml_dtypes import bfloat16 def _unique_paths(paths): seen = set() unique = [] for path in paths: path = Path(path) key = str(path) if key not in seen: seen.add(key) unique.append(path) return unique def resolve_runtime_file(base_dir: str | Path, candidates: list[str]) -> Path: base_path = Path(base_dir) checked = [] for name in candidates: if not name: continue candidate = base_path / name checked.append(candidate) if candidate.exists(): return candidate checked_msg = ", ".join(path.name for path in checked) if checked else "" raise FileNotFoundError(f"Cannot find any runtime file in {base_path}. Checked: {checked_msg}") def _has_text_runtime_files(base_dir: Path) -> bool: if not base_dir.exists() or not base_dir.is_dir(): return False has_layers = any(base_dir.glob("*_p*_l0_together.axmodel")) has_post = any(base_dir.glob("*_post.axmodel")) return has_layers and has_post def default_axmodel_path(script_dir: str | Path) -> str: script_dir = Path(script_dir) candidates = _unique_paths( [ script_dir, script_dir / "gemma_4_e2b_it_ax650n_axmodel", script_dir / "gemma_4_e2b_it_ax650n_w4a16_axmodel", script_dir / "gemma-4-E2B-it_axmodel", ] ) for candidate in candidates: if _has_text_runtime_files(candidate): return str(candidate) return str(candidates[0]) def load_text_embeddings(base_dir: str | Path, config): runtime_path = Path(base_dir) candidates = _unique_paths( [ runtime_path / "model.embed_tokens.weight.npy", runtime_path / (getattr(config, "filename_tokens_embed", "") or ""), runtime_path / "model.embed_tokens.weight.bfloat16.bin", ] ) for candidate in candidates: if not candidate.exists(): continue if candidate.suffix == ".npy": return np.load(str(candidate), mmap_mode="r") if candidate.suffix == ".bin": hidden_size = int(config.hidden_size) vocab_size = int(getattr(config, "vocab_size", 0) or 0) if vocab_size <= 0: file_size = candidate.stat().st_size bytes_per_row = np.dtype(bfloat16).itemsize * hidden_size if file_size % bytes_per_row != 0: raise ValueError( f"Cannot infer embedding shape from {candidate}: " f"file_size={file_size}, hidden_size={hidden_size}, bytes_per_row={bytes_per_row}" ) vocab_size = file_size // bytes_per_row shape = (vocab_size, hidden_size) return np.memmap(str(candidate), dtype=bfloat16, mode="r", shape=shape) checked = ", ".join(path.name for path in candidates) raise FileNotFoundError( f"Cannot find text embedding weights in {runtime_path}. Checked: {checked}" )