MiniCPM-V-4.6 / python /infer_axmodel.py
yongqiang
Add AX650 MiniCPM-V 4.6 release package
aa969ca
import argparse
import atexit
import hashlib
import json
import os
import re
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple
import numpy as np
from axengine import InferenceSession
from ml_dtypes import bfloat16
from transformers import AutoProcessor
def resolve_default_hf_model() -> str:
base_dir = os.path.dirname(__file__)
candidates = [
os.path.join(base_dir, "minicpm_v46_tokenizer"),
os.path.join(base_dir, "MiniCPM-V-4.6"),
os.path.join(base_dir, "MiniCPM-V-4.6-GPTQ"),
]
for path in candidates:
if os.path.exists(os.path.join(path, "config.json")):
return path
return candidates[0]
DEFAULT_HF_MODEL = resolve_default_hf_model()
def release_ax_inference_session(session):
inner = getattr(session, "_sess", None)
unload = getattr(inner, "_unload", None)
if not callable(unload):
return
try:
unload()
except Exception as exc:
print(f"[WARN] Failed to unload axengine session cleanly: {exc}")
finally:
try:
inner._unload = lambda: None
except Exception:
pass
def bf16_zeros(shape: Sequence[int]) -> np.ndarray:
return np.zeros(tuple(shape), dtype=bfloat16)
def dtype_from_axengine(dtype) -> np.dtype:
name = str(dtype).lower()
if "bfloat16" in name or "bf16" in name:
return bfloat16
if "float32" in name or "fp32" in name:
return np.float32
if "float16" in name or "fp16" in name:
return np.float16
if "uint32" in name or "u32" in name:
return np.uint32
raise ValueError(f"Unsupported axengine dtype: {dtype}")
def tensor_digest(arr: np.ndarray) -> str:
arr = np.asarray(arr)
if arr.dtype == bfloat16:
raw = arr.view(np.uint16).tobytes()
else:
raw = arr.tobytes()
return hashlib.sha256(raw).hexdigest()[:16]
def tensor_stats(arr: np.ndarray) -> str:
arr32 = np.asarray(arr, dtype=np.float32)
finite = bool(np.isfinite(arr32).all())
if finite:
return (
f"finite=True hash={tensor_digest(np.asarray(arr))} "
f"sum={float(arr32.sum()):.6f} max={float(arr32.max()):.6f} min={float(arr32.min()):.6f}"
)
return (
f"finite=False hash={tensor_digest(np.asarray(arr))} "
f"nans={int(np.isnan(arr32).sum())} infs={int(np.isinf(arr32).sum())}"
)
def ensure_finite(name: str, arr: np.ndarray):
arr32 = np.asarray(arr, dtype=np.float32)
if not np.isfinite(arr32).all():
raise RuntimeError(f"{name} is non-finite: {tensor_stats(np.asarray(arr))}")
def load_text_config(hf_model: str) -> Tuple[dict, int]:
with open(os.path.join(hf_model, "config.json"), "r", encoding="utf-8") as f:
cfg = json.load(f)
text_cfg = cfg["text_config"]
eos_token_id = cfg.get("eos_token_id")
if eos_token_id is None:
eos_token_id = text_cfg.get("eos_token_id")
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
return text_cfg, int(eos_token_id if eos_token_id is not None else 248044)
@dataclass
class LayerFiles:
layer_paths: List[str]
post_path: str
def detect_layer_files(model_dir: str, max_layers: Optional[int] = None) -> LayerFiles:
names = os.listdir(model_dir)
layer_pattern = re.compile(r"^(?P<prefix>.*)_p(?P<prefill>\d+)_l(?P<idx>\d+)_together\.axmodel$")
decode_layer_pattern = re.compile(r"^(?P<prefix>.*)_l(?P<idx>\d+)\.axmodel$")
post_pattern = re.compile(r"^(?P<prefix>.*)_post\.axmodel$")
prefix_map = {}
for name in names:
m = layer_pattern.match(name)
if not m:
m = decode_layer_pattern.match(name)
if not m:
continue
prefix = m.group("prefix")
idx = int(m.group("idx"))
prefix_map.setdefault(prefix, []).append((idx, name))
if not prefix_map:
raise FileNotFoundError(f"No layer axmodel found under {model_dir}")
prefix = max(prefix_map.items(), key=lambda kv: len(kv[1]))[0]
layer_items = sorted(prefix_map[prefix], key=lambda it: it[0])
if max_layers is not None:
layer_items = layer_items[: max_layers]
layer_paths = [os.path.join(model_dir, name) for _, name in layer_items]
post_name = None
for name in names:
m = post_pattern.match(name)
if m and m.group("prefix") == prefix:
post_name = name
break
if post_name is None:
raise FileNotFoundError(f"No post axmodel found for prefix {prefix} under {model_dir}")
return LayerFiles(layer_paths=layer_paths, post_path=os.path.join(model_dir, post_name))
class MiniCPMTextAxModelRunner:
def __init__(
self,
hf_model: str,
axmodel_dir: str,
embed_bin: Optional[str],
max_layers: Optional[int],
kv_cache_len: int = 255,
):
self.hf_model = hf_model
self.axmodel_dir = axmodel_dir
self.text_cfg, self.eos_token_id = load_text_config(hf_model)
self.hidden_size = int(self.text_cfg["hidden_size"])
self.vocab_size = int(self.text_cfg["vocab_size"])
self.kv_cache_len = int(kv_cache_len)
self.layer_types = list(self.text_cfg.get("layer_types") or [])
self.num_attention_heads = int(self.text_cfg["num_attention_heads"])
self.num_key_value_heads = int(self.text_cfg["num_key_value_heads"])
self.head_dim = int(self.text_cfg.get("head_dim") or (self.hidden_size // self.num_attention_heads))
self.full_attn_kv_dim = self.num_key_value_heads * self.head_dim
self.processor = AutoProcessor.from_pretrained(hf_model, trust_remote_code=True)
self.layer_files = detect_layer_files(axmodel_dir, max_layers=max_layers)
if embed_bin is None:
candidate = os.path.join(axmodel_dir, "model.embed_tokens.weight.bfloat16.bin")
if not os.path.exists(candidate):
raise FileNotFoundError(
"Embedding bin not found under axmodel_dir, please pass --embed-bin explicitly"
)
embed_bin = candidate
self.embed_bin = embed_bin
self.embed_matrix = np.memmap(embed_bin, mode="r", dtype=np.uint16).view(bfloat16).reshape(
self.vocab_size, self.hidden_size
)
self.decoder_sessions = [InferenceSession(path) for path in self.layer_files.layer_paths]
self.post_session = None
self._closed = False
atexit.register(self.close)
self.layer_decode_input_shapes = []
self.layer_decode_input_dtypes = []
self.layer_prefill_input_shapes = []
self.layer_prefill_input_dtypes = []
self.layer_decode_output_names = []
self.layer_prefill_output_names = []
for layer_idx, session in enumerate(self.decoder_sessions):
decode_input_shapes = {x.name: tuple(x.shape) for x in session.get_inputs(shape_group=0)}
decode_input_dtypes = {x.name: dtype_from_axengine(x.dtype) for x in session.get_inputs(shape_group=0)}
# Some decode-only AX650 exports hide `indices` / `mask` from get_inputs(),
# but axengine still validates them as required runtime inputs.
decode_input_shapes.setdefault("indices", (1, 1))
decode_input_dtypes.setdefault("indices", np.uint32)
decode_input_shapes.setdefault("mask", (1, 1))
decode_input_dtypes.setdefault("mask", bfloat16)
if layer_idx < len(self.layer_types) and self.layer_types[layer_idx] == "full_attention":
decode_input_shapes["K_cache"] = (1, self.kv_cache_len, self.full_attn_kv_dim)
decode_input_shapes["V_cache"] = (1, self.kv_cache_len, self.full_attn_kv_dim)
self.layer_decode_input_shapes.append(decode_input_shapes)
self.layer_decode_input_dtypes.append(decode_input_dtypes)
self.layer_decode_output_names.append([x.name for x in session.get_outputs(shape_group=0)])
prefill_shape_groups = []
prefill_dtype_groups = []
prefill_output_groups = []
for shape_group in range(1, 64):
try:
prefill_inputs = session.get_inputs(shape_group=shape_group)
prefill_shape_groups.append({x.name: tuple(x.shape) for x in prefill_inputs})
prefill_dtype_groups.append({x.name: dtype_from_axengine(x.dtype) for x in prefill_inputs})
prefill_output_groups.append([x.name for x in session.get_outputs(shape_group=shape_group)])
except Exception:
break
self.layer_prefill_input_shapes.append(prefill_shape_groups)
self.layer_prefill_input_dtypes.append(prefill_dtype_groups)
self.layer_prefill_output_names.append(prefill_output_groups)
self.decode_input_shapes = self.layer_decode_input_shapes[0]
self.decode_input_dtypes = self.layer_decode_input_dtypes[0]
self.prefill_input_shapes = (
self.layer_prefill_input_shapes[0][0]
if self.layer_prefill_input_shapes and self.layer_prefill_input_shapes[0]
else {}
)
self.prefill_input_dtypes = (
self.layer_prefill_input_dtypes[0][0]
if self.layer_prefill_input_dtypes and self.layer_prefill_input_dtypes[0]
else {}
)
self.prefill_len = int(self.prefill_input_shapes["input"][1]) if "input" in self.prefill_input_shapes else 0
self.decode_output_names = self.layer_decode_output_names[0]
self.prefill_output_names = (
self.layer_prefill_output_names[0][0]
if self.layer_prefill_output_names and self.layer_prefill_output_names[0]
else []
)
self.hidden_dtype = self.decode_input_dtypes["input"]
def close(self):
if self._closed:
return
for session in getattr(self, "decoder_sessions", []):
release_ax_inference_session(session)
if getattr(self, "post_session", None) is not None:
release_ax_inference_session(self.post_session)
self.decoder_sessions = []
self.post_session = None
self._closed = True
def tokenize_prompt(self, prompt: str) -> List[int]:
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="np",
)
return inputs["input_ids"][0].astype(np.int64).tolist()
def decode_tokens(self, token_ids: Sequence[int]) -> str:
return self.processor.decode(
list(token_ids),
skip_special_tokens=False,
clean_up_tokenization_spaces=False,
)
def embed_token(self, token_id: int) -> np.ndarray:
return np.asarray(self.embed_matrix[int(token_id)], dtype=self.hidden_dtype).reshape(1, 1, self.hidden_size)
def alloc_layer_states(self) -> Tuple[List[np.ndarray], List[np.ndarray]]:
k_states = [
np.zeros(shapes["K_cache"], dtype=dtypes["K_cache"])
for shapes, dtypes in zip(self.layer_decode_input_shapes, self.layer_decode_input_dtypes)
]
v_states = [
np.zeros(shapes["V_cache"], dtype=dtypes["V_cache"])
for shapes, dtypes in zip(self.layer_decode_input_shapes, self.layer_decode_input_dtypes)
]
return k_states, v_states
@staticmethod
def make_feed(shapes: dict, values: dict) -> dict:
return {name: value for name, value in values.items() if name in shapes}
def is_linear_layer(self, layer_idx: int) -> bool:
return layer_idx >= len(self.layer_types) or self.layer_types[layer_idx] != "full_attention"
def prefill_history_capacity(self, shapes: dict) -> int:
if not shapes:
return -1
input_len = int(shapes.get("input", (1, self.prefill_len))[1])
mask_shape = shapes.get("mask")
if mask_shape is not None and len(mask_shape) == 3:
return max(0, int(mask_shape[-1]) - input_len)
k_shape = shapes.get("K_cache")
if k_shape is not None and len(k_shape) >= 2:
return int(k_shape[1])
return 0
def select_prefill_shape_group(self, layer_idx: int, history_len: int) -> int:
groups = self.layer_prefill_input_shapes[layer_idx]
if not groups:
raise RuntimeError(f"layer {layer_idx} has no prefill shape_group")
if history_len <= 0:
return 1
candidates = []
for offset, shapes in enumerate(groups):
cap = self.prefill_history_capacity(shapes)
if cap >= history_len:
candidates.append((cap, offset + 1))
if candidates:
return min(candidates)[1]
# Linear-attention layers may prune reusable warm groups. Reuse the
# largest available group when its state shape is independent of history.
return max(range(1, len(groups) + 1), key=lambda gid: self.prefill_history_capacity(groups[gid - 1]))
def inspect(self):
print("hf_model:", self.hf_model)
print("axmodel_dir:", self.axmodel_dir)
print("embed_bin:", self.embed_bin)
print("num_layers:", len(self.decoder_sessions))
print("hidden_size:", self.hidden_size)
print("vocab_size:", self.vocab_size)
print("eos_token_id:", self.eos_token_id)
print("kv_cache_len:", self.kv_cache_len)
print("prefill_len:", self.prefill_len)
print("decode_inputs:", sorted((k, v, str(self.decode_input_dtypes[k])) for k, v in self.decode_input_shapes.items()))
print("decode_outputs:", [(x.name, tuple(x.shape), str(x.dtype)) for x in self.decoder_sessions[0].get_outputs(0)])
unique_k_shapes = sorted({tuple(spec["K_cache"]) for spec in self.layer_decode_input_shapes})
unique_v_shapes = sorted({tuple(spec["V_cache"]) for spec in self.layer_decode_input_shapes})
print("decode_k_cache_shapes:", unique_k_shapes)
print("decode_v_cache_shapes:", unique_v_shapes)
if self.prefill_input_shapes:
print("prefill_group_count_layer0:", len(self.layer_prefill_input_shapes[0]))
unique_prefill_k_shapes = sorted(
{
tuple(spec["K_cache"])
for groups in self.layer_prefill_input_shapes
for spec in groups
if "K_cache" in spec
}
)
unique_prefill_v_shapes = sorted(
{
tuple(spec["V_cache"])
for groups in self.layer_prefill_input_shapes
for spec in groups
if "V_cache" in spec
}
)
unique_prefill_mask_shapes = sorted(
{
tuple(spec["mask"])
for groups in self.layer_prefill_input_shapes
for spec in groups
if "mask" in spec
}
)
print("prefill_inputs_layer0:", sorted((k, v, str(self.prefill_input_dtypes[k])) for k, v in self.prefill_input_shapes.items()))
print("prefill_k_cache_shapes:", unique_prefill_k_shapes)
print("prefill_v_cache_shapes:", unique_prefill_v_shapes)
print("prefill_mask_shapes:", unique_prefill_mask_shapes)
print("prefill_outputs:", [(x.name, tuple(x.shape), str(x.dtype)) for x in self.decoder_sessions[0].get_outputs(1)])
else:
print("prefill_inputs: []")
print("prefill_outputs: []")
post_session = self.get_post_session()
print("post_inputs:", [(x.name, tuple(x.shape), str(x.dtype)) for x in post_session.get_inputs()])
print("post_outputs:", [(x.name, tuple(x.shape), str(x.dtype)) for x in post_session.get_outputs()])
def get_post_session(self):
if self.post_session is None:
self.post_session = InferenceSession(self.layer_files.post_path)
return self.post_session
def run_prefill(self, token_ids: Sequence[int], verbose: bool = False, return_states: bool = False):
k_states, v_states = self.alloc_layer_states()
last_hidden = None
for start in range(0, len(token_ids), self.prefill_len):
chunk_ids = token_ids[start : start + self.prefill_len]
chunk_len = len(chunk_ids)
data = np.zeros((1, self.prefill_len, self.hidden_size), dtype=self.hidden_dtype)
data[0, :chunk_len, :] = np.asarray(
self.embed_matrix[np.asarray(chunk_ids, dtype=np.int64)], dtype=self.hidden_dtype
)
for layer_idx, session in enumerate(self.decoder_sessions):
if self.is_linear_layer(layer_idx):
out = np.zeros_like(data)
for j in range(chunk_len):
hidden = data[:, j : j + 1, :]
hidden = self.run_single_layer_decode_step(
layer_idx,
hidden,
start + j,
k_states,
v_states,
verbose=False,
)
out[:, j : j + 1, :] = hidden
data = out
if verbose:
print(
f"prefill chunk={start // self.prefill_len} layer={layer_idx} "
f"linear_decode_replay tokens={chunk_len} {tensor_stats(data)}"
)
continue
shape_group = self.select_prefill_shape_group(layer_idx, start)
layer_shapes = self.layer_prefill_input_shapes[layer_idx][shape_group - 1]
layer_dtypes = self.layer_prefill_input_dtypes[layer_idx][shape_group - 1]
if not layer_shapes:
raise RuntimeError(f"layer {layer_idx} has no prefill shape_group={shape_group}")
history_cap = self.prefill_history_capacity(layer_shapes)
history_len = min(start, history_cap)
indices = None
if "indices" in layer_shapes:
indices = np.zeros(layer_shapes["indices"], dtype=layer_dtypes["indices"])
indices.reshape(-1)[:chunk_len] = np.arange(start, start + chunk_len, dtype=np.uint32)
mask = None
if "mask" in layer_shapes:
if self.layer_types[layer_idx] == "full_attention" and len(layer_shapes["mask"]) == 3:
mask = np.full(layer_shapes["mask"], -65536.0, dtype=np.float32)
for q in range(chunk_len):
mask[:, q, : history_len + q + 1] = 0.0
else:
mask = np.zeros(layer_shapes["mask"], dtype=np.float32)
mask.reshape(-1)[:chunk_len] = 1.0
mask = mask.astype(layer_dtypes["mask"])
k_feed = k_states[layer_idx]
v_feed = v_states[layer_idx]
if self.layer_types[layer_idx] == "full_attention":
k_feed = np.zeros(layer_shapes["K_cache"], dtype=layer_dtypes["K_cache"])
v_feed = np.zeros(layer_shapes["V_cache"], dtype=layer_dtypes["V_cache"])
if history_len > 0:
k_feed[:, :history_len, :] = k_states[layer_idx][:, :history_len, :]
v_feed[:, :history_len, :] = v_states[layer_idx][:, :history_len, :]
outputs = session.run(
None,
self.make_feed(
layer_shapes,
{
"K_cache": k_feed,
"V_cache": v_feed,
**({"indices": indices} if indices is not None else {}),
"input": data.astype(layer_dtypes["input"], copy=False),
**({"mask": mask} if mask is not None else {}),
},
),
shape_group=shape_group,
)
output_map = dict(zip(self.layer_prefill_output_names[layer_idx][shape_group - 1], outputs))
k_out = output_map.get("K_cache_out")
if k_out is not None:
if self.layer_types[layer_idx] == "full_attention":
k_states[layer_idx][:, start : start + chunk_len, :] = k_out[:, :chunk_len, :]
else:
k_states[layer_idx] = k_out
v_out = output_map.get("V_cache_out")
if v_out is not None:
if self.layer_types[layer_idx] == "full_attention":
v_states[layer_idx][:, start : start + chunk_len, :] = v_out[:, :chunk_len, :]
else:
v_states[layer_idx] = v_out
data = output_map["output"]
ensure_finite(f"prefill layer {layer_idx} output", data)
if verbose:
print(
f"prefill chunk={start // self.prefill_len} layer={layer_idx} "
f"shape_group={shape_group} history_len={history_len} {tensor_stats(data)}"
)
last_hidden = data[:, chunk_len - 1 : chunk_len, :]
if return_states:
return k_states, v_states, last_hidden
return last_hidden
def run_single_layer_decode_step(
self,
layer_idx: int,
hidden: np.ndarray,
position: int,
k_states: List[np.ndarray],
v_states: List[np.ndarray],
verbose: bool = False,
) -> np.ndarray:
session = self.decoder_sessions[layer_idx]
layer_shapes = self.layer_decode_input_shapes[layer_idx]
layer_dtypes = self.layer_decode_input_dtypes[layer_idx]
indices = None
if "indices" in layer_shapes:
indices = np.zeros(layer_shapes["indices"], dtype=layer_dtypes["indices"])
indices.reshape(-1)[0] = position
mask = None
if "mask" in layer_shapes:
if self.layer_types[layer_idx] == "full_attention" and len(layer_shapes["mask"]) == 3:
mask = np.full(layer_shapes["mask"], -65536.0, dtype=np.float32)
valid_past = min(position, layer_shapes["mask"][-1] - 1)
if valid_past > 0:
mask[:, :, :valid_past] = 0.0
mask[:, :, -1:] = 0.0
else:
mask = np.ones(layer_shapes["mask"], dtype=np.float32)
mask = mask.astype(layer_dtypes["mask"])
outputs = session.run(
None,
self.make_feed(
layer_shapes,
{
"K_cache": k_states[layer_idx],
"V_cache": v_states[layer_idx],
**({"indices": indices} if indices is not None else {}),
"input": hidden,
**({"mask": mask} if mask is not None else {}),
},
),
shape_group=0,
)
output_map = dict(zip(self.layer_decode_output_names[layer_idx], outputs))
k_out = output_map.get("K_cache_out")
if k_out is not None:
if self.layer_types[layer_idx] == "full_attention" and k_states[layer_idx].shape != k_out.shape:
pos_end = position + k_out.shape[1]
k_states[layer_idx][:, position:pos_end, :] = k_out
else:
k_states[layer_idx] = k_out
v_out = output_map.get("V_cache_out")
if v_out is not None:
if self.layer_types[layer_idx] == "full_attention" and v_states[layer_idx].shape != v_out.shape:
pos_end = position + v_out.shape[1]
v_states[layer_idx][:, position:pos_end, :] = v_out
else:
v_states[layer_idx] = v_out
out = output_map["output"]
ensure_finite(f"decode step={position} layer={layer_idx} output", out)
if verbose:
print(f"decode step={position} layer={layer_idx} {tensor_stats(out)}")
return out
def run_decode_step(
self,
hidden: np.ndarray,
position: int,
k_states: List[np.ndarray],
v_states: List[np.ndarray],
verbose: bool = False,
) -> np.ndarray:
data = hidden
for layer_idx in range(len(self.decoder_sessions)):
data = self.run_single_layer_decode_step(layer_idx, data, position, k_states, v_states, verbose=verbose)
return data
def decode_replay_prompt(
self,
token_ids: Sequence[int],
limit_prompt_tokens: Optional[int] = None,
verbose: bool = False,
) -> Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]:
if limit_prompt_tokens is not None:
token_ids = token_ids[:limit_prompt_tokens]
k_states, v_states = self.alloc_layer_states()
last_hidden = None
for pos, token_id in enumerate(token_ids):
last_hidden = self.run_decode_step(self.embed_token(int(token_id)), pos, k_states, v_states, verbose=verbose)
return k_states, v_states, last_hidden
def run_post(self, hidden: np.ndarray) -> np.ndarray:
logits = self.get_post_session().run(None, {"input": hidden})[0]
ensure_finite("post logits", logits)
return logits
def greedy_next_token(self, hidden: np.ndarray) -> int:
logits = self.run_post(hidden)
logits = np.asarray(logits, dtype=np.float32).reshape(-1)
return int(np.argmax(logits))
def generate(self, token_ids: Sequence[int], max_new_tokens: int, prompt_mode: str, verbose: bool = False):
if prompt_mode == "prefill":
k_states, v_states, last_hidden = self.run_prefill(token_ids, verbose=verbose, return_states=True)
else:
k_states, v_states, last_hidden = self.decode_replay_prompt(token_ids, verbose=verbose)
prompt_len = len(token_ids)
generated = []
for step in range(max_new_tokens):
next_token = self.greedy_next_token(last_hidden)
generated.append(next_token)
print(f"gen step={step} token_id={next_token} piece={self.decode_tokens([next_token])!r}")
if next_token == self.eos_token_id:
break
last_hidden = self.run_decode_step(
self.embed_token(next_token),
prompt_len + step,
k_states,
v_states,
verbose=verbose,
)
print("generated_ids:", generated)
print("generated_text:", self.decode_tokens(generated))
def parse_args():
parser = argparse.ArgumentParser(
description="MiniCPM-V-4.6 AX650 text-only axmodel Python runner/debugger"
)
parser.add_argument("--hf-model", default=DEFAULT_HF_MODEL, help="Tokenizer/config path for AutoProcessor")
parser.add_argument("--axmodel-dir", required=True, help="Compiled axmodel directory")
parser.add_argument("--embed-bin", default=None, help="Embedding bf16 bin path")
parser.add_argument(
"--mode",
default="inspect",
choices=["inspect", "prefill", "decode_replay", "generate"],
help="Execution mode",
)
parser.add_argument("--prompt", default="你好,请做一个简短自我介绍。", help="User prompt")
parser.add_argument("--prompt-file", default=None, help="Read user prompt from a UTF-8 text file")
parser.add_argument("--max-layers", type=int, default=None, help="Only load the first N decoder layers")
parser.add_argument(
"--limit-prompt-tokens",
type=int,
default=None,
help="Only consume the first N prompt tokens in decode_replay mode",
)
parser.add_argument(
"--prompt-mode",
default="decode_replay",
choices=["decode_replay", "prefill"],
help="How to consume the prompt before generation",
)
parser.add_argument("--max-new-tokens", type=int, default=16, help="Generation length for --mode generate")
parser.add_argument("--kv-cache-len", type=int, default=255, help="Decode KV cache length used at compile time")
parser.add_argument("--verbose", action="store_true", help="Print per-layer tensor stats")
return parser.parse_args()
def main():
args = parse_args()
if args.prompt_file:
with open(args.prompt_file, "r", encoding="utf-8") as f:
args.prompt = f.read()
runner = MiniCPMTextAxModelRunner(
hf_model=args.hf_model,
axmodel_dir=args.axmodel_dir,
embed_bin=args.embed_bin,
max_layers=args.max_layers,
kv_cache_len=args.kv_cache_len,
)
try:
token_ids = runner.tokenize_prompt(args.prompt)
print("prompt_token_count:", len(token_ids))
print("prompt_token_ids:", token_ids)
print("prompt_template_repr:", runner.decode_tokens(token_ids).encode("unicode_escape").decode())
if args.mode == "inspect":
runner.inspect()
return
if args.mode == "prefill":
hidden = runner.run_prefill(token_ids, verbose=args.verbose)
print("prefill_last_hidden:", tensor_stats(hidden))
logits = runner.run_post(hidden)
print("post_logits:", tensor_stats(logits))
print("greedy_next_token:", runner.greedy_next_token(hidden))
return
if args.mode == "decode_replay":
_, _, hidden = runner.decode_replay_prompt(
token_ids,
limit_prompt_tokens=args.limit_prompt_tokens,
verbose=args.verbose,
)
print("decode_replay_last_hidden:", tensor_stats(hidden))
logits = runner.run_post(hidden)
print("post_logits:", tensor_stats(logits))
print("greedy_next_token:", runner.greedy_next_token(hidden))
return
runner.generate(
token_ids,
max_new_tokens=args.max_new_tokens,
prompt_mode=args.prompt_mode,
verbose=args.verbose,
)
finally:
runner.close()
if __name__ == "__main__":
main()