import torch import numpy as np from typing import List, Tuple, Optional, Dict from pathlib import Path from tqdm import tqdm from axengine import InferenceSession from ml_dtypes import bfloat16 from transformers import AutoTokenizer, AutoConfig import json from loguru import logger class KVCacheTools: """ k, v cache 的本地保存和加载 """ def __init__(self, axmodel_num: int, dtype=np.float32): self.axmodel_num = axmodel_num self.dtype = dtype def save_kvcache( self, target_dir: str, system_prompt: str, precompute_len: int, k_caches: List[np.ndarray], v_caches: List[np.ndarray], metadata: Optional[Dict] = None ) -> bool: try: target_path = Path(target_dir) target_path.mkdir(parents=True, exist_ok=True) for i, (k, v) in enumerate(zip(k_caches, v_caches)): k.astype(self.dtype).tofile(target_path / f"k_cache_{i}.bin") v.astype(self.dtype).tofile(target_path / f"v_cache_{i}.bin") config = { "precompute_len": precompute_len, "system_prompt": system_prompt, "axmodel_num": self.axmodel_num, "dtype": str(self.dtype), "metadata": metadata or {}, } with open(target_path / "config.json", "w", encoding="utf8") as f: json.dump(config, f, indent=2, ensure_ascii=False) return True except Exception as e: print(f"Save failed: {str(e)}") return False def load_kvcache( self, cache_dir: str ) -> Tuple[ List[np.ndarray], List[np.ndarray], str, int, Dict ]: try: cache_path = Path(cache_dir) k_caches, v_caches = [], [] with open(cache_path / "config.json") as f: config = json.load(f) if config["axmodel_num"] != self.axmodel_num: raise ValueError( f"Model layer mismatch: " f"Expected {self.axmodel_num}, got {config['axmodel_num']}" ) for i in range(self.axmodel_num): k_data = np.fromfile(cache_path / f"k_cache_{i}.bin", dtype=self.dtype).reshape(1, -1, 256) v_data = np.fromfile(cache_path / f"v_cache_{i}.bin", dtype=self.dtype).reshape(1, -1, 256) k_caches.append(k_data) v_caches.append(v_data) return ( (k_caches, v_caches), config["system_prompt"], config["precompute_len"], config.get("metadata", {}) ) except Exception as e: print(f"Load failed: {str(e)}") exit() class InferManager: def __init__(self, hf_model_path: str, axmodel_path: str): self.device = "cpu" self.hf_model_path = hf_model_path self.axmodel_path = axmodel_path self.hf_config = AutoConfig.from_pretrained(self.hf_model_path, trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model_path, trust_remote_code=True, use_fast=False) self.system_prompt = "你的名字叫小智(allen), 你是一个人畜无害的 AI 助手. 深圳市今天(4月1日)阴天, 愚人节, 气温在 14°C 至 19°C 之间, 微风." self.embeds = np.load(f"{self.axmodel_path}/model.embed_tokens.weight.npy") def build_system_prompt(self): messages = [ {"role": "system", "content": self.system_prompt}, # {"role": "user", "content": prompt} ] text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False ) self.system_inputs = self.tokenizer([text], return_tensors="pt").to(self.device) self.system_input_ids = self.system_inputs.input_ids[0].cpu().numpy().tolist() self.system_input_embeds = np.take(self.embeds, self.system_input_ids, axis=0) self.system_input_ids_len = len(self.system_input_ids) self.model_inputs = { "input_ids": self.system_input_ids, "input_embeds": self.system_input_embeds, "input_ids_len": self.system_input_ids_len } self.precompute_len = self.system_input_ids_len # logger.info(f"system prompt prompt ids len: {self.system_input_ids_len}") def encoder_prompt(self, prompt): text = f'<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n' model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device) input_ids = model_inputs.input_ids[0].cpu().numpy().tolist() input_embeds = np.take(self.embeds, input_ids, axis=0) input_ids_len = len(input_ids) # logger.info(f"user prompt token_len: {input_ids_len}") model_inputs = { "message": text, "model_inputs": model_inputs, "input_ids": input_ids, "input_embeds": input_embeds, "input_ids_len": input_ids_len } return model_inputs def build_kvcache(self, kv_cache_len: int = 2559): kv_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads * self.hf_config.num_key_value_heads self.k_caches = [ np.zeros((1, kv_cache_len, kv_dim), dtype=bfloat16) for _ in range(self.hf_config.num_hidden_layers) ] self.v_caches = [ np.zeros((1, kv_cache_len, kv_dim), dtype=bfloat16) for _ in range(self.hf_config.num_hidden_layers) ] def get_kvcache(self): return [self.k_caches, self.v_caches] def update_kvcache(self, update_kv_cache): self.k_caches = update_kv_cache[0] self.v_caches = update_kv_cache[1] def get_tokenizer(self): return self.tokenizer def get_system_prompt(self): return self.system_prompt def set_system_prompt(self, prompt): self.system_prompt = prompt def build_infer_model(self, ): self.prefill_decoder_sessins = [] for i in tqdm(range(self.hf_config.num_hidden_layers), desc="Init InferenceSession"): session = InferenceSession( f"{self.axmodel_path}/qwen2_p128_l{i}_together.axmodel" ) self.prefill_decoder_sessins.append(session) self.post_process_session = InferenceSession( f"{self.axmodel_path}/qwen2_post.axmodel" ) print("The models have been loaded!") def get_infer_session(self): return [self.prefill_decoder_sessins, self.post_process_session] @staticmethod def _top_p(probs: np.ndarray, p: float) -> np.ndarray: sorted_indices = np.argsort(probs) filtered = probs.copy() cumulative = 0 for idx in sorted_indices[::-1]: if cumulative >= p: filtered[idx] = 0 cumulative += filtered[idx] return filtered / cumulative @staticmethod def _softmax(logits: np.ndarray) -> np.ndarray: logits = logits - logits.max() exp_logits = np.exp(logits) return (exp_logits / np.sum(exp_logits)).astype(np.float64) def post_process(self, logits, top_k=1, top_p=0.9, temperature=0.6): logits = logits.astype(np.float32).flatten() candidate_indices = np.argpartition(logits, -top_k)[-top_k:] candidate_logits = logits[candidate_indices] / temperature candidate_probs = self._softmax(candidate_logits) candidate_probs = self._top_p(candidate_probs, top_p) candidate_probs = candidate_probs.astype(np.float64) / candidate_probs.sum() chosen_idx = np.random.multinomial(1, candidate_probs).argmax() next_token = candidate_indices[chosen_idx] return next_token, candidate_indices, candidate_probs def gen_slice_indices(self, token_len, prefill=128, expand=128): remaining = max(0, token_len - prefill) extra_blocks = (remaining + expand - 1) // expand return list(range(extra_blocks + 1)) def prefill( self, model_inputs, slice_len=128, precompute_len=0, # system prompt prefill 的时候, 只能设置为 0 ): """ Prefill step for chunked inference. """ token_ids = model_inputs["input_ids"] token_embeds = model_inputs["input_embeds"] token_len = model_inputs["input_ids_len"] seq_len = len(token_ids) slice_indices = [i for i in range(seq_len // slice_len + 1)] print(f"slice_indices: {slice_indices}") # total_prefill_len = ( # slice_len * slice_indices[-1] # if slice_indices[-1] != 0 # else slice_len # ) # slice_indices = self.gen_slice_indices(seq_len) total_prefill_len = slice_len * (slice_indices[-1] + 1) kv_mask_expand_len = 128 if total_prefill_len > 0: for slice_index in slice_indices: if slice_index == 0: current_slice_len = slice_len else: current_slice_len = kv_mask_expand_len indices = np.array( list( range( precompute_len + slice_index * slice_len, precompute_len + (slice_index + 1) * slice_len, ) ), np.uint32, ).reshape((1, slice_len)) indices[:, min(token_len, slice_len):] = 0 mask = ( np.zeros((1, slice_len, current_slice_len * slice_index + slice_len)) - 65536 ) data = np.zeros((1, slice_len, self.hf_config.hidden_size)).astype(bfloat16) for i, t in enumerate( range( slice_index * slice_len, (slice_index + 1) * slice_len, ) ): if t < len(token_ids): # mask[:, i, 0: slice_index * slice_len + i + 1] = 0 data[:, i : i + 1, :] = ( token_embeds[t] .reshape((1, 1, self.hf_config.hidden_size)) .astype(bfloat16) ) if t < len(token_ids) + precompute_len: mask[:, i, 0: slice_index * slice_len + i + 1] = 0 if slice_index == slice_indices[-1]: curlen_procd = token_len - slice_index * slice_len # curlen_procd 是当前处理数据的长度 else: curlen_procd = slice_len mask = mask.astype(bfloat16) for i in range(self.hf_config.num_hidden_layers): input_feed = { "K_cache": ( self.k_caches[i][:, 0: current_slice_len * slice_index, :] if slice_index else np.zeros((1, 1, self.hf_config.hidden_size), dtype=bfloat16) ), "V_cache": ( self.v_caches[i][:, 0: current_slice_len * slice_index, :] if slice_index else np.zeros((1, 1, self.hf_config.hidden_size), dtype=bfloat16) ), "indices": indices, "input": data, "mask": mask, } outputs = self.prefill_decoder_sessins[i].run(None, input_feed, shape_group=slice_index + 1) self.k_caches[i][ :, slice_index * slice_len + precompute_len : slice_index * slice_len + curlen_procd + precompute_len, :, ] = outputs[0][:, :curlen_procd, :] self.v_caches[i][ :, slice_index * slice_len + precompute_len: slice_index * slice_len + curlen_procd + precompute_len, :, ] = outputs[1][:, :curlen_procd, :] data = outputs[2] print("slice prefill done", slice_index) else: print("No prefill needed.") # return "Calculated the kv cache of the system prompt." return (self.k_caches, self.v_caches) def decode( self, token_ids, prefill_len=128, slice_len=128 ): token_len = len(token_ids) # set to decoder print("answer: >> ", end='', flush=True) kv_cache_len = 2559 mask = np.zeros((1, 1, kv_cache_len + 1), dtype=np.float32).astype(bfloat16) mask[:, :, :kv_cache_len] -= 65536 if prefill_len > 0: mask[:, :, :token_len + self.precompute_len] = 0 for start_indice in range(kv_cache_len): if self.precompute_len > 0 and start_indice < self.precompute_len: continue next_token = token_ids[start_indice - self.precompute_len] indices = np.array([start_indice], np.uint32).reshape((1, 1)) data = self.embeds[next_token, :].reshape((1, 1, self.hf_config.hidden_size)).astype(bfloat16) for i in range(self.hf_config.num_hidden_layers): input_feed = { "K_cache": self.k_caches[i], "V_cache": self.v_caches[i], "indices": indices, "input": data, "mask": mask, } outputs = self.prefill_decoder_sessins[i].run(None, input_feed, shape_group=0) self.k_caches[i][:, start_indice, :] = outputs[0][:, :, :] self.v_caches[i][:, start_indice, :] = outputs[1][:, :, :] data = outputs[2] mask[..., start_indice] = 0 if start_indice < token_len + self.precompute_len - 1: pass else: post_out = self.post_process_session.run(None, {"input": data})[0] next_token, posssible_tokens, possible_soft = self.post_process(post_out) token_ids.append(next_token) print(self.tokenizer.decode(next_token, skip_special_tokens=True), end='', flush=True) if next_token == self.tokenizer.eos_token_id and start_indice > token_len + self.precompute_len: # print("\n>> HINT: The next_token encountered EOS token, generation completed.") break print("\n") self.precompute_len = len(token_ids) + self.precompute_len - 1 return self.tokenizer.decode(token_ids[self.precompute_len - 1:], skip_special_tokens=True)