| |
| import numpy as np |
| import torch |
| from axengine import InferenceSession |
| from ml_dtypes import bfloat16 |
| from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM |
| from tqdm import tqdm |
| from einops import rearrange |
| from janus.models import MultiModalityCausalLM, VLChatProcessor |
| from janus.models.modeling_vlm import MultiModalityConfig |
| from janus.utils.io import load_pil_images |
| import os |
| import PIL.Image |
| from loguru import logger |
| import onnxruntime |
| import argparse |
|
|
|
|
| parser = argparse.ArgumentParser(description="Model configuration parameters") |
| parser.add_argument("--tokenizer_dir", type=str, default="Janus-Pro-1B", |
| help="Path to HuggingFace model") |
| parser.add_argument("--axmodel_path", type=str, default="janus_pro_1B_axmodel", |
| help="Path to save compiled axmodel of llama model") |
| args = parser.parse_args() |
|
|
|
|
| |
| tokenizer_dir = args.tokenizer_dir |
| axmodel_path = args.axmodel_path |
|
|
| """ONNX MODEL""" |
| gen_vision_model_decode = onnxruntime.InferenceSession("./img_gen_onnx/gen_vision_model_decode_sim.onnx", providers=["CPUExecutionProvider"]) |
| gen_aligner = onnxruntime.InferenceSession("./img_gen_onnx/gen_aligner.onnx", providers=["CPUExecutionProvider"]) |
| gen_head = onnxruntime.InferenceSession("./img_gen_onnx/post_head.onnx", providers=["CPUExecutionProvider"]) |
| post_norm = onnxruntime.InferenceSession("./img_gen_onnx/post_norm.onnx", providers=["CPUExecutionProvider"]) |
| """ONNX MODEL""" |
|
|
| """EMBEDINGs""" |
| embeds = np.load(f"{axmodel_path}/model.embed_tokens.weight.npy") |
| gen_embed = np.load("./embeds/gen_embed.npy") |
| codebook_entry_embedding = torch.load('./embeds/codebook_entry_embedding.pt', map_location=torch.device('cpu')) |
| """EMBEDINGs""" |
|
|
|
|
| def prefill( |
| cfg, |
| prefill_decoder_sessins, |
| vl_chat_processor: VLChatProcessor, |
| prompt: str, |
| temperature: float = 1, |
| parallel_size: int = 1, |
| cfg_weight: float = 5, |
| image_token_num_per_image: int = 576, |
| ): |
| input_ids = vl_chat_processor.tokenizer.encode(prompt) |
| input_ids = torch.LongTensor(input_ids) |
|
|
| tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int) |
| for i in range(parallel_size*2): |
| tokens[i, :] = input_ids |
| if i % 2 != 0: |
| tokens[i, 1: -1] = vl_chat_processor.pad_id |
|
|
| inputs_embeds = embeds[tokens.numpy()] |
| batch, token_len, seq_dim = inputs_embeds.shape |
| generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int) |
| prefill_len = 640 |
| token_ids = tokens |
|
|
| |
| lastN = 1023 |
| kv_dim = cfg.hidden_size // cfg.num_attention_heads * cfg.num_key_value_heads |
| batch_k_caches = {} |
| batch_v_caches = {} |
|
|
| for bid in range(batch): |
| batch_k_caches[bid] = [ |
| np.zeros((1, lastN, kv_dim), dtype=bfloat16) |
| for _ in range(cfg.num_hidden_layers) |
| ] |
| batch_v_caches[bid] = [ |
| np.zeros((1, lastN, kv_dim), dtype=bfloat16) |
| for _ in range(cfg.num_hidden_layers) |
| ] |
| |
| mask = np.zeros((1, prefill_len, prefill_len)) - 65536 |
| for j in range(token_len): |
| mask[:, j, :j + 1] = 0 |
| mask = mask.astype(bfloat16) |
|
|
| indices = np.array(list(range(prefill_len)), np.uint32).reshape( |
| (1, prefill_len) |
| ) |
| indices[:, token_len:] = 0 |
| hidden_states = np.zeros((batch, token_len, cfg.hidden_size)).astype(bfloat16) |
|
|
| for bid in range(batch): |
| data = np.zeros((1, prefill_len, cfg.hidden_size)).astype(bfloat16) |
| data[:, 0:token_len] = inputs_embeds[bid].astype(bfloat16) |
| k_caches = batch_k_caches[bid] |
| v_caches = batch_v_caches[bid] |
|
|
| for i in range(cfg.num_hidden_layers): |
| input_feed = { |
| "K_cache": np.zeros((1, 1, cfg.hidden_size), dtype=bfloat16), |
| "V_cache": np.zeros((1, 1, cfg.hidden_size), dtype=bfloat16), |
| "indices": indices, |
| "input": data, |
| "mask": mask, |
| } |
| outputs = prefill_decoder_sessins[i].run(None, input_feed, shape_group=1) |
| k_caches[i][:, :token_len, :] = outputs[0][:, :token_len, :] |
| v_caches[i][:, :token_len, :] = outputs[1][:, :token_len, :] |
| data[:, :token_len] = outputs[2][:, :token_len, :] |
|
|
| |
| hidden_states[bid] = data[:, :token_len] |
| batch_k_caches[bid] = k_caches |
| batch_v_caches[bid] = v_caches |
|
|
| |
| hidden_states = post_norm.run(["output"], {"input": hidden_states[:, -1:, :].astype(np.float32)})[0] |
| logits = gen_head.run(["output"], {"input": hidden_states[:, -1, :]})[0] |
| |
| logits = torch.from_numpy(logits) |
| logit_cond = logits[0::2, :] |
| logit_uncond = logits[1::2, :] |
| logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond) |
| probs = torch.softmax(logits / temperature, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| generated_tokens[:, 0] = next_token.squeeze(dim=-1) |
| next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) |
| |
| gen_embed_res = np.take(gen_embed, next_token.numpy().tolist(), axis=0) |
| img_embeds = gen_aligner.run(["output"], {"input": gen_embed_res})[0] |
| inputs_embeds = np.expand_dims(img_embeds, axis=1) |
| return inputs_embeds, token_ids, generated_tokens, batch_k_caches, batch_v_caches |
|
|
|
|
| @torch.inference_mode() |
| def generate( |
| cfg, |
| prefill_decoder_sessins, |
| vl_chat_processor: VLChatProcessor, |
| prompt: str, |
| temperature: float = 1, |
| parallel_size: int = 1, |
| cfg_weight: float = 5, |
| image_token_num_per_image: int = 576, |
| img_size: int = 384, |
| patch_size: int = 16, |
| ): |
| inputs_embeds, token_ids, generated_tokens, batch_k_caches, batch_v_caches = prefill( |
| cfg, prefill_decoder_sessins, vl_chat_processor, |
| prompt, temperature, parallel_size, cfg_weight, image_token_num_per_image |
| ) |
|
|
| logger.debug("prefill completed!") |
| token_len = token_ids.shape[1] |
|
|
| lastN = 1023 |
|
|
| batch = parallel_size * 2 |
|
|
| mask = np.zeros((1, 1, lastN + 1), dtype=np.float32).astype(bfloat16) |
| mask[:, :, :lastN] -= 65536 |
| mask[:, :, :token_len] = 0 |
|
|
| for image_token_i in tqdm(range(1, image_token_num_per_image), desc="ImageToken"): |
|
|
| |
| start_indice = image_token_i + token_len - 1 |
| indices = np.array([start_indice], np.uint32).reshape((1, 1)) |
| hidden_states = np.zeros((batch, 1, cfg.hidden_size)).astype(bfloat16) |
| assert (inputs_embeds[0] == inputs_embeds[1]).all() |
|
|
| for bid in range(batch): |
| k_caches = batch_k_caches[bid] |
| v_caches = batch_v_caches[bid] |
| data = inputs_embeds[:1, ...].astype(bfloat16) |
|
|
| for i in range(cfg.num_hidden_layers): |
| input_feed = { |
| "K_cache": k_caches[i], |
| "V_cache": v_caches[i], |
| "indices": indices, |
| "input": data, |
| "mask": mask, |
| } |
|
|
| outputs = prefill_decoder_sessins[i].run(None, input_feed, shape_group=0) |
| k_caches[i][:, start_indice, :] = outputs[0][:, :, :] |
| v_caches[i][:, start_indice, :] = outputs[1][:, :, :] |
| data = outputs[2] |
|
|
| hidden_states[bid] = data |
| batch_k_caches[bid] = k_caches |
| batch_v_caches[bid] = v_caches |
|
|
| mask[..., start_indice] = 0 |
|
|
| |
| hidden_states = post_norm.run(["output"], {"input": hidden_states.astype(np.float32)})[0] |
| logits = gen_head.run(["output"], {"input": hidden_states[:, -1, :]})[0] |
| |
| logits = torch.from_numpy(logits) |
| logit_cond = logits[0::2, :] |
| logit_uncond = logits[1::2, :] |
| logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond) |
| probs = torch.softmax(logits / temperature, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| generated_tokens[:, image_token_i] = next_token.squeeze(dim=-1) |
| next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) |
| |
| gen_embed_res = np.take(gen_embed, next_token.numpy().tolist(), axis=0) |
| img_embeds = gen_aligner.run(["output"], {"input": gen_embed_res})[0] |
| inputs_embeds = np.expand_dims(img_embeds, axis=1) |
|
|
| |
| indices = generated_tokens.to(dtype=torch.int) |
| shape = [parallel_size, 8, img_size//patch_size, img_size//patch_size] |
| z_q = codebook_entry_embedding[indices] |
| z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) |
| |
| z_q = z_q.permute(0, 3, 1, 2) |
| dec = gen_vision_model_decode.run(['image'], {'quant': z_q.to(dtype=torch.float32).numpy()})[0] |
| dec = dec.transpose(0, 2, 3, 1) |
| dec = np.clip((dec + 1) / 2 * 255, 0, 255) |
| visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8) |
| visual_img[:, :, :] = dec |
|
|
| os.makedirs('generated_samples', exist_ok=True) |
| for i in range(parallel_size): |
| save_path = os.path.join('generated_samples', "img_{}.jpg".format(i)) |
| PIL.Image.fromarray(visual_img[i]).save(save_path) |
|
|
| |
| config: MultiModalityConfig = AutoConfig.from_pretrained(tokenizer_dir, trust_remote_code=True) |
| vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(tokenizer_dir) |
| tokenizer = vl_chat_processor.tokenizer |
|
|
| description = "A close-up high-contrast photo of Sydney Opera House sitting next to Eiffel tower, under a blue night sky of roiling energy, exploding yellow stars, and radiating swirls of blue." |
|
|
| conversation = [ |
| { |
| "role": "User", |
| "content": description, |
| }, |
| {"role": "Assistant", "content": ""}, |
| ] |
|
|
| sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( |
| conversations=conversation, |
| sft_format=vl_chat_processor.sft_format, |
| system_prompt="", |
| ) |
| prompt = sft_format + vl_chat_processor.image_start_tag |
| |
|
|
| cfg = config.language_config |
|
|
| prefill_decoder_sessins = [] |
| for i in tqdm(range(cfg.num_hidden_layers), desc="Init InferenceSession"): |
| session = InferenceSession( |
| f"{axmodel_path}/llama_p640_l{i}_together.axmodel" |
| ) |
| prefill_decoder_sessins.append(session) |
|
|
| logger.info("model load done!") |
|
|
| generate( |
| cfg, |
| prefill_decoder_sessins, |
| vl_chat_processor, |
| prompt |
| ) |
|
|