File size: 3,159 Bytes
03f32f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from __future__ import annotations

from pathlib import Path

import numpy as np
from ml_dtypes import bfloat16


def _unique_paths(paths):
    seen = set()
    unique = []
    for path in paths:
        path = Path(path)
        key = str(path)
        if key not in seen:
            seen.add(key)
            unique.append(path)
    return unique


def resolve_runtime_file(base_dir: str | Path, candidates: list[str]) -> Path:
    base_path = Path(base_dir)
    checked = []
    for name in candidates:
        if not name:
            continue
        candidate = base_path / name
        checked.append(candidate)
        if candidate.exists():
            return candidate

    checked_msg = ", ".join(path.name for path in checked) if checked else "<none>"
    raise FileNotFoundError(f"Cannot find any runtime file in {base_path}. Checked: {checked_msg}")


def _has_text_runtime_files(base_dir: Path) -> bool:
    if not base_dir.exists() or not base_dir.is_dir():
        return False
    has_layers = any(base_dir.glob("*_p*_l0_together.axmodel"))
    has_post = any(base_dir.glob("*_post.axmodel"))
    return has_layers and has_post


def default_axmodel_path(script_dir: str | Path) -> str:
    script_dir = Path(script_dir)
    candidates = _unique_paths(
        [
            script_dir,
            script_dir / "gemma_4_e2b_it_ax650n_axmodel",
            script_dir / "gemma_4_e2b_it_ax650n_w4a16_axmodel",
            script_dir / "gemma-4-E2B-it_axmodel",
        ]
    )
    for candidate in candidates:
        if _has_text_runtime_files(candidate):
            return str(candidate)
    return str(candidates[0])


def load_text_embeddings(base_dir: str | Path, config):
    runtime_path = Path(base_dir)
    candidates = _unique_paths(
        [
            runtime_path / "model.embed_tokens.weight.npy",
            runtime_path / (getattr(config, "filename_tokens_embed", "") or ""),
            runtime_path / "model.embed_tokens.weight.bfloat16.bin",
        ]
    )

    for candidate in candidates:
        if not candidate.exists():
            continue
        if candidate.suffix == ".npy":
            return np.load(str(candidate), mmap_mode="r")
        if candidate.suffix == ".bin":
            hidden_size = int(config.hidden_size)
            vocab_size = int(getattr(config, "vocab_size", 0) or 0)
            if vocab_size <= 0:
                file_size = candidate.stat().st_size
                bytes_per_row = np.dtype(bfloat16).itemsize * hidden_size
                if file_size % bytes_per_row != 0:
                    raise ValueError(
                        f"Cannot infer embedding shape from {candidate}: "
                        f"file_size={file_size}, hidden_size={hidden_size}, bytes_per_row={bytes_per_row}"
                    )
                vocab_size = file_size // bytes_per_row
            shape = (vocab_size, hidden_size)
            return np.memmap(str(candidate), dtype=bfloat16, mode="r", shape=shape)

    checked = ", ".join(path.name for path in candidates)
    raise FileNotFoundError(
        f"Cannot find text embedding weights in {runtime_path}. Checked: {checked}"
    )