from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor from transformers import AutoTokenizer from llava.constants import MM_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, \ DEFAULT_VIDEO_PATCH_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN from llava.conversation import conv_templates from llava.model.multimodal_encoder.qformer import BertConfig, BertLMHeadModel, BertModel from llava.model.multimodal_projector.builder import build_vision_projector from llava.model.utils import LayerNorm from llava.model.multimodal_encoder.eva_clip_encoder import EvaClipVisionTower import torch from llava.mm_utils import tokenizer_image_token, process_images_v2, KeywordsStoppingCriteria import numpy as np from PIL import Image import os import torch.nn as nn from transformers import AutoConfig from collections import OrderedDict import torch_neuronx NUM_SEGMENTS = 10 def generate_input_ids(tokenizer): conv = conv_templates['thoth'].copy() qs = "Describe the following video in detail." qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze(0) return input_ids, conv def generate_images(frame_folder, image_processor, model_cfg): images = load_frames(frame_folder) if len(images) > NUM_SEGMENTS: images = uniform_sample(images, NUM_SEGMENTS) return process_images_v2(images, image_processor, model_cfg) def uniform_sample(frames, num_segments): indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int) frames = [frames[ind] for ind in indices] return frames def load_frames(frames_dir): results = [] image_files = [(int(os.path.splitext(img)[0]), img) for img in os.listdir(frames_dir) if img.endswith('jpg')] image_files = sorted(image_files, key=lambda img: img[0]) for frame_name in image_files: image_path = f"{frames_dir}/{frame_name[1]}" image = Image.open(image_path).convert('RGB') results.append(image) return results class MASPVision(torch.nn.Module): def __init__(self, config): super().__init__() # device = 'cuda:0' device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device_map = {"": 0} config.vit_model_path = 'eva_vit_g.pth' vision_tower = EvaClipVisionTower("eva-vit-g", config, delay_load=True) vision_tower.load_model(device_map=device_map) vision_tower.to(device=device, dtype=torch.float16) image_processor = Blip2ImageTrainProcessor( image_size=config.img_size, is_training=False) cross_attention_freq = 2 vision_width = vision_tower.hidden_size num_query_token = config.num_query_token ln_vision = LayerNorm(vision_width) encoder_config = BertConfig.from_pretrained("bert-base-uncased") encoder_config.encoder_width = vision_width # insert cross-attention layer every other block encoder_config.add_cross_attention = True encoder_config.cross_attention_freq = cross_attention_freq encoder_config.query_length = num_query_token # Qformer = BertLMHeadModel(config=encoder_config) self.bert = BertModel(encoder_config, add_pooling_layer=False) self.bert.embeddings.word_embeddings = None self.bert.embeddings.position_embeddings = None for layer in self.bert.encoder.layer: layer.output = None layer.intermediate = None query_tokens = nn.Parameter( torch.zeros(1, num_query_token, encoder_config.hidden_size) ) query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) frame_position_encoding = nn.Embedding( config.max_num_segments, encoder_config.hidden_size ) mm_projector = build_vision_projector(config) self.vision_tower = vision_tower # self.qformer = Qformer self.projector = mm_projector self.query_tokens = query_tokens self.ln_vision = ln_vision self.frame_position_encoding = frame_position_encoding def forward(self, images): # images: [num_frames, patches, 3, image_size, image_size] image_features = self.vision_tower(images.flatten(0, 1)) image_features = self.ln_vision(image_features) attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to( image_features.device) # [num_frames * num_patches, 256] query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) # [num_frames * num_patches, 32, 768] dtype_ = self.vision_tower.dtype image_features = self.bert( query_embeds=query_tokens.to(dtype_), encoder_hidden_states=image_features.to(dtype_), encoder_attention_mask=attn_mask, return_dict=True ).last_hidden_state.to(dtype_) frame_ids = torch.arange(images.shape[0], dtype=torch.long, device=image_features.device).unsqueeze(1) frame_ids = frame_ids.repeat(1, images.shape[1]).flatten(0, 1) # [num_frames * num_patches] image_features += self.frame_position_encoding(frame_ids).unsqueeze(-2) # [num_frames, 1, 768] return self.projector(image_features) # zheng add def forward_features(self, a, b): # images: [num_frames, patches, 3, image_size, image_size] images = a image_features = b image_features = self.ln_vision(image_features) attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to( image_features.device) # [num_frames * num_patches, 256] query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) # [num_frames * num_patches, 32, 768] dtype_ = self.vision_tower.dtype image_features = self.bert( query_embeds=query_tokens.to(dtype_), encoder_hidden_states=image_features.to(dtype_), encoder_attention_mask=attn_mask, return_dict=True ).last_hidden_state.to(dtype_) frame_ids = torch.arange(images.shape[0], dtype=torch.long, device=image_features.device).unsqueeze(1) frame_ids = frame_ids.repeat(1, images.shape[1]).flatten(0, 1) # [num_frames * num_patches] image_features += self.frame_position_encoding(frame_ids).unsqueeze(-2) # [num_frames, 1, 768] return self.projector(image_features) if __name__ == '__main__': frame_folder = './v12044gd0000cl5c6rfog65i2eoqcqig' tokenizer_dir = '../tokenizer_dir' # device = 'cuda:0' device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) config = AutoConfig.from_pretrained(tokenizer_dir, trust_remote_code=True) tokenizer.add_tokens( [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN], special_tokens=True) image_processor = Blip2ImageTrainProcessor( image_size=config.img_size, is_training=False) input_ids, conv = generate_input_ids(tokenizer) # images = generate_images(frame_folder, image_processor, config).to(device).half() # [num_frames, patches, 3, image_size, image_size] # zheng images = generate_images(frame_folder, image_processor, config).to(device) vision_module = MASPVision(config=config) input_ids = input_ids[0].to(device) # [token_len] # new_vision_state_dict = torch.load('new_vision_state_dict.pth') new_vision_state_dict = torch.load('new_vision_state_dict.pth', map_location=device) # vision_state_dict = torch.load('masp_vision_statedict.pth', map_location="cuda:0") # new_vision_state_dict = OrderedDict() # for k, v in vision_state_dict.items(): # if 'qformer' in k: # new_key = k[8:] # new_vision_state_dict[new_key] = v # else: # new_vision_state_dict[k] = v vision_module.load_state_dict(new_vision_state_dict) vision_module = vision_module.eval() vision_module = vision_module.to(device) # vision_module.to(torch.float16) # zheng add vision_module.to(torch.float32) # zheng add vision_module_neuron = torch.jit.load("./neuron_eva_vit_base.pt") vision_module_neuron = vision_module_neuron.eval() # output=vision_module_neuron(images) padding_idx = config.pad_token_id embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx) embed_weight = torch.load('embed_tokens.pth') embed_tokens.load_state_dict(embed_weight) embed_tokens = embed_tokens.eval() embed_tokens.to(torch.float16).to(device) # vision_module = vision_module.eval() # vision_state_dict = vision_module.state_dict() # torch.save(vision_state_dict, 'masp_vision_statedict.pth') # infernece # print(images.shape) # [10, 7, 3, 224, 224] # dummy_images = torch.rand(10, 7, 3, 224, 224).to(model.device) # scripted_vision_module = torch.jit.script(vision_module) # print('begin to trace') # traced_vision_module = torch.jit.trace(vision_module, (images)) # traced_vision_module.save('traced_vision_module.pt') # loaded = torch.jit.load('traced_vision_module.pt') import time start = time.time() with torch.inference_mode(): # get image feature # image_features = vision_module(images).flatten(0, 1) # [num_frames * num_patches * num_query_token, 4096] image_features = torch.Tensor() # init a tensor for image in images: output = vision_module_neuron(image) output = output[:, 1:].to(torch.float32) if len(image_features) == 0: image_features = output else: image_features = torch.cat([image_features, output], dim=0) # zheng [70, 256, 1408] image_features = vision_module.forward_features(images, image_features) image_features = image_features.flatten(0, 1) print(image_features.shape) # zheng [70, 32, 4096] image_features.to(device=device, dtype=torch.float16) image_features_numpy = image_features.detach().cpu().numpy() # image_features_saved = np.load('image_features_numpy.npy') # print(np.sum(image_features_numpy -image_features_saved )) # image_features_numpy = image_features.detach().cpu().numpy # np.save('image_features_numpy.npy', image_features_numpy) # print('images features shape', image_features.shape) # image_features = loaded(images).flatten(0, 1) # concat with text features vision_token_indice = torch.where(input_ids == MM_TOKEN_INDEX)[0][0] pre_text_token = embed_tokens(input_ids[:vision_token_indice]) # zheng [32, 4096] post_text_token = embed_tokens(input_ids[vision_token_indice + 1:]) inputs_embeds = torch.cat([pre_text_token, image_features, post_text_token]).unsqueeze( 0) # [1, num_token, 4096] print("Inference time:", time.time() - start) input_embeds_numpy = inputs_embeds.detach().cpu().numpy() image_embeds_saved = np.load('inputs_embeds.npy') diff = np.sum(input_embeds_numpy - image_embeds_saved) print('diff with saved in the disk', diff) # print('inputs embeds numpy shape', input_embeds_numpy.shape) # np.save('inputs_embeds.npy', input_embeds_numpy)