wli1995's picture
Upload folder using huggingface_hub
29211a0 verified
raw
history blame
15.2 kB
import torch
import numpy as np
from typing import List, Tuple, Optional, Dict
from pathlib import Path
from tqdm import tqdm
from axengine import InferenceSession
from ml_dtypes import bfloat16
from transformers import AutoTokenizer, AutoConfig
import json
from loguru import logger
class KVCacheTools:
"""
k, v cache 的本地保存和加载
"""
def __init__(self, axmodel_num: int, dtype=np.float32):
self.axmodel_num = axmodel_num
self.dtype = dtype
def save_kvcache(
self,
target_dir: str,
system_prompt: str,
precompute_len: int,
k_caches: List[np.ndarray],
v_caches: List[np.ndarray],
metadata: Optional[Dict] = None
) -> bool:
try:
target_path = Path(target_dir)
target_path.mkdir(parents=True, exist_ok=True)
for i, (k, v) in enumerate(zip(k_caches, v_caches)):
k.astype(self.dtype).tofile(target_path / f"k_cache_{i}.bin")
v.astype(self.dtype).tofile(target_path / f"v_cache_{i}.bin")
config = {
"precompute_len": precompute_len,
"system_prompt": system_prompt,
"axmodel_num": self.axmodel_num,
"dtype": str(self.dtype),
"metadata": metadata or {},
}
with open(target_path / "config.json", "w", encoding="utf8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
return True
except Exception as e:
print(f"Save failed: {str(e)}")
return False
def load_kvcache(
self,
cache_dir: str
) -> Tuple[
List[np.ndarray],
List[np.ndarray],
str,
int,
Dict
]:
try:
cache_path = Path(cache_dir)
k_caches, v_caches = [], []
with open(cache_path / "config.json") as f:
config = json.load(f)
if config["axmodel_num"] != self.axmodel_num:
raise ValueError(
f"Model layer mismatch: "
f"Expected {self.axmodel_num}, got {config['axmodel_num']}"
)
for i in range(self.axmodel_num):
k_data = np.fromfile(cache_path / f"k_cache_{i}.bin", dtype=self.dtype).reshape(1, -1, 256)
v_data = np.fromfile(cache_path / f"v_cache_{i}.bin", dtype=self.dtype).reshape(1, -1, 256)
k_caches.append(k_data)
v_caches.append(v_data)
return (
(k_caches, v_caches),
config["system_prompt"],
config["precompute_len"],
config.get("metadata", {})
)
except Exception as e:
print(f"Load failed: {str(e)}")
exit()
class InferManager:
def __init__(self, hf_model_path: str, axmodel_path: str):
self.device = "cpu"
self.hf_model_path = hf_model_path
self.axmodel_path = axmodel_path
self.hf_config = AutoConfig.from_pretrained(self.hf_model_path, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model_path, trust_remote_code=True, use_fast=False)
self.system_prompt = "你的名字叫小智(allen), 你是一个人畜无害的 AI 助手. 深圳市今天(4月1日)阴天, 愚人节, 气温在 14°C 至 19°C 之间, 微风."
self.embeds = np.load(f"{self.axmodel_path}/model.embed_tokens.weight.npy")
def build_system_prompt(self):
messages = [
{"role": "system", "content": self.system_prompt},
# {"role": "user", "content": prompt}
]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)
self.system_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
self.system_input_ids = self.system_inputs.input_ids[0].cpu().numpy().tolist()
self.system_input_embeds = np.take(self.embeds, self.system_input_ids, axis=0)
self.system_input_ids_len = len(self.system_input_ids)
self.model_inputs = {
"input_ids": self.system_input_ids,
"input_embeds": self.system_input_embeds,
"input_ids_len": self.system_input_ids_len
}
self.precompute_len = self.system_input_ids_len
# logger.info(f"system prompt prompt ids len: {self.system_input_ids_len}")
def encoder_prompt(self, prompt):
text = f'<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n'
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
input_ids = model_inputs.input_ids[0].cpu().numpy().tolist()
input_embeds = np.take(self.embeds, input_ids, axis=0)
input_ids_len = len(input_ids)
# logger.info(f"user prompt token_len: {input_ids_len}")
model_inputs = {
"message": text,
"model_inputs": model_inputs,
"input_ids": input_ids,
"input_embeds": input_embeds,
"input_ids_len": input_ids_len
}
return model_inputs
def build_kvcache(self, kv_cache_len: int = 2559):
kv_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads * self.hf_config.num_key_value_heads
self.k_caches = [
np.zeros((1, kv_cache_len, kv_dim), dtype=bfloat16)
for _ in range(self.hf_config.num_hidden_layers)
]
self.v_caches = [
np.zeros((1, kv_cache_len, kv_dim), dtype=bfloat16)
for _ in range(self.hf_config.num_hidden_layers)
]
def get_kvcache(self):
return [self.k_caches, self.v_caches]
def update_kvcache(self, update_kv_cache):
self.k_caches = update_kv_cache[0]
self.v_caches = update_kv_cache[1]
def get_tokenizer(self):
return self.tokenizer
def get_system_prompt(self):
return self.system_prompt
def set_system_prompt(self, prompt):
self.system_prompt = prompt
def build_infer_model(self, ):
self.prefill_decoder_sessins = []
for i in tqdm(range(self.hf_config.num_hidden_layers), desc="Init InferenceSession"):
session = InferenceSession(
f"{self.axmodel_path}/qwen2_p128_l{i}_together.axmodel"
)
self.prefill_decoder_sessins.append(session)
self.post_process_session = InferenceSession(
f"{self.axmodel_path}/qwen2_post.axmodel"
)
print("The models have been loaded!")
def get_infer_session(self):
return [self.prefill_decoder_sessins, self.post_process_session]
@staticmethod
def _top_p(probs: np.ndarray, p: float) -> np.ndarray:
sorted_indices = np.argsort(probs)
filtered = probs.copy()
cumulative = 0
for idx in sorted_indices[::-1]:
if cumulative >= p:
filtered[idx] = 0
cumulative += filtered[idx]
return filtered / cumulative
@staticmethod
def _softmax(logits: np.ndarray) -> np.ndarray:
logits = logits - logits.max()
exp_logits = np.exp(logits)
return (exp_logits / np.sum(exp_logits)).astype(np.float64)
def post_process(self, logits, top_k=1, top_p=0.9, temperature=0.6):
logits = logits.astype(np.float32).flatten()
candidate_indices = np.argpartition(logits, -top_k)[-top_k:]
candidate_logits = logits[candidate_indices] / temperature
candidate_probs = self._softmax(candidate_logits)
candidate_probs = self._top_p(candidate_probs, top_p)
candidate_probs = candidate_probs.astype(np.float64) / candidate_probs.sum()
chosen_idx = np.random.multinomial(1, candidate_probs).argmax()
next_token = candidate_indices[chosen_idx]
return next_token, candidate_indices, candidate_probs
def gen_slice_indices(self, token_len, prefill=128, expand=128):
remaining = max(0, token_len - prefill)
extra_blocks = (remaining + expand - 1) // expand
return list(range(extra_blocks + 1))
def prefill(
self,
model_inputs,
slice_len=128,
precompute_len=0, # system prompt prefill 的时候, 只能设置为 0
):
"""
Prefill step for chunked inference.
"""
token_ids = model_inputs["input_ids"]
token_embeds = model_inputs["input_embeds"]
token_len = model_inputs["input_ids_len"]
seq_len = len(token_ids)
slice_indices = [i for i in range(seq_len // slice_len + 1)]
print(f"slice_indices: {slice_indices}")
# total_prefill_len = (
# slice_len * slice_indices[-1]
# if slice_indices[-1] != 0
# else slice_len
# )
# slice_indices = self.gen_slice_indices(seq_len)
total_prefill_len = slice_len * (slice_indices[-1] + 1)
kv_mask_expand_len = 128
if total_prefill_len > 0:
for slice_index in slice_indices:
if slice_index == 0:
current_slice_len = slice_len
else:
current_slice_len = kv_mask_expand_len
indices = np.array(
list(
range(
precompute_len + slice_index * slice_len,
precompute_len + (slice_index + 1) * slice_len,
)
),
np.uint32,
).reshape((1, slice_len))
indices[:, min(token_len, slice_len):] = 0
mask = (
np.zeros((1, slice_len, current_slice_len * slice_index + slice_len))
- 65536
)
data = np.zeros((1, slice_len, self.hf_config.hidden_size)).astype(bfloat16)
for i, t in enumerate(
range(
slice_index * slice_len,
(slice_index + 1) * slice_len,
)
):
if t < len(token_ids):
# mask[:, i, 0: slice_index * slice_len + i + 1] = 0
data[:, i : i + 1, :] = (
token_embeds[t]
.reshape((1, 1, self.hf_config.hidden_size))
.astype(bfloat16)
)
if t < len(token_ids) + precompute_len:
mask[:, i, 0: slice_index * slice_len + i + 1] = 0
if slice_index == slice_indices[-1]:
curlen_procd = token_len - slice_index * slice_len # curlen_procd 是当前处理数据的长度
else:
curlen_procd = slice_len
mask = mask.astype(bfloat16)
for i in range(self.hf_config.num_hidden_layers):
input_feed = {
"K_cache": (
self.k_caches[i][:, 0: current_slice_len * slice_index, :]
if slice_index
else np.zeros((1, 1, self.hf_config.hidden_size), dtype=bfloat16)
),
"V_cache": (
self.v_caches[i][:, 0: current_slice_len * slice_index, :]
if slice_index
else np.zeros((1, 1, self.hf_config.hidden_size), dtype=bfloat16)
),
"indices": indices,
"input": data,
"mask": mask,
}
outputs = self.prefill_decoder_sessins[i].run(None, input_feed, shape_group=slice_index + 1)
self.k_caches[i][
:,
slice_index
* slice_len + precompute_len : slice_index
* slice_len + curlen_procd + precompute_len,
:,
] = outputs[0][:, :curlen_procd, :]
self.v_caches[i][
:,
slice_index
* slice_len + precompute_len: slice_index
* slice_len + curlen_procd + precompute_len,
:,
] = outputs[1][:, :curlen_procd, :]
data = outputs[2]
print("slice prefill done", slice_index)
else:
print("No prefill needed.")
# return "Calculated the kv cache of the system prompt."
return (self.k_caches, self.v_caches)
def decode(
self,
token_ids,
prefill_len=128,
slice_len=128
):
token_len = len(token_ids)
# set to decoder
print("answer: >> ", end='', flush=True)
kv_cache_len = 2559
mask = np.zeros((1, 1, kv_cache_len + 1), dtype=np.float32).astype(bfloat16)
mask[:, :, :kv_cache_len] -= 65536
if prefill_len > 0:
mask[:, :, :token_len + self.precompute_len] = 0
for start_indice in range(kv_cache_len):
if self.precompute_len > 0 and start_indice < self.precompute_len:
continue
next_token = token_ids[start_indice - self.precompute_len]
indices = np.array([start_indice], np.uint32).reshape((1, 1))
data = self.embeds[next_token, :].reshape((1, 1, self.hf_config.hidden_size)).astype(bfloat16)
for i in range(self.hf_config.num_hidden_layers):
input_feed = {
"K_cache": self.k_caches[i],
"V_cache": self.v_caches[i],
"indices": indices,
"input": data,
"mask": mask,
}
outputs = self.prefill_decoder_sessins[i].run(None, input_feed, shape_group=0)
self.k_caches[i][:, start_indice, :] = outputs[0][:, :, :]
self.v_caches[i][:, start_indice, :] = outputs[1][:, :, :]
data = outputs[2]
mask[..., start_indice] = 0
if start_indice < token_len + self.precompute_len - 1:
pass
else:
post_out = self.post_process_session.run(None, {"input": data})[0]
next_token, posssible_tokens, possible_soft = self.post_process(post_out)
token_ids.append(next_token)
print(self.tokenizer.decode(next_token, skip_special_tokens=True), end='', flush=True)
if next_token == self.tokenizer.eos_token_id and start_indice > token_len + self.precompute_len:
# print("\n>> HINT: The next_token encountered EOS token, generation completed.")
break
print("\n")
self.precompute_len = len(token_ids) + self.precompute_len - 1
return self.tokenizer.decode(token_ids[self.precompute_len - 1:], skip_special_tokens=True)