llava_mistral_0531 / app /tmp /predict_vision.py
root
feat
14bd491
from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor
from transformers import AutoTokenizer
from llava.constants import MM_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, \
DEFAULT_VIDEO_PATCH_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN
from llava.conversation import conv_templates
from llava.model.multimodal_encoder.qformer import BertConfig, BertLMHeadModel, BertModel
from llava.model.multimodal_projector.builder import build_vision_projector
from llava.model.utils import LayerNorm
from llava.model.multimodal_encoder.eva_clip_encoder import EvaClipVisionTower
import torch
from llava.mm_utils import tokenizer_image_token, process_images_v2, KeywordsStoppingCriteria
import numpy as np
from PIL import Image
import os
import torch.nn as nn
from transformers import AutoConfig
from collections import OrderedDict
import torch_neuronx
NUM_SEGMENTS = 10
def generate_input_ids(tokenizer):
conv = conv_templates['thoth'].copy()
qs = "Describe the following video in detail."
qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
return input_ids, conv
def generate_images(frame_folder, image_processor, model_cfg):
images = load_frames(frame_folder)
if len(images) > NUM_SEGMENTS:
images = uniform_sample(images, NUM_SEGMENTS)
return process_images_v2(images, image_processor, model_cfg)
def uniform_sample(frames, num_segments):
indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int)
frames = [frames[ind] for ind in indices]
return frames
def load_frames(frames_dir):
results = []
image_files = [(int(os.path.splitext(img)[0]), img) for img in os.listdir(frames_dir) if img.endswith('jpg')]
image_files = sorted(image_files, key=lambda img: img[0])
for frame_name in image_files:
image_path = f"{frames_dir}/{frame_name[1]}"
image = Image.open(image_path).convert('RGB')
results.append(image)
return results
class MASPVision(torch.nn.Module):
def __init__(self, config):
super().__init__()
# device = 'cuda:0'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device_map = {"": 0}
config.vit_model_path = 'eva_vit_g.pth'
vision_tower = EvaClipVisionTower("eva-vit-g", config, delay_load=True)
vision_tower.load_model(device_map=device_map)
vision_tower.to(device=device, dtype=torch.float16)
image_processor = Blip2ImageTrainProcessor(
image_size=config.img_size,
is_training=False)
cross_attention_freq = 2
vision_width = vision_tower.hidden_size
num_query_token = config.num_query_token
ln_vision = LayerNorm(vision_width)
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
# Qformer = BertLMHeadModel(config=encoder_config)
self.bert = BertModel(encoder_config, add_pooling_layer=False)
self.bert.embeddings.word_embeddings = None
self.bert.embeddings.position_embeddings = None
for layer in self.bert.encoder.layer:
layer.output = None
layer.intermediate = None
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
frame_position_encoding = nn.Embedding(
config.max_num_segments,
encoder_config.hidden_size
)
mm_projector = build_vision_projector(config)
self.vision_tower = vision_tower
# self.qformer = Qformer
self.projector = mm_projector
self.query_tokens = query_tokens
self.ln_vision = ln_vision
self.frame_position_encoding = frame_position_encoding
def forward(self, images):
# images: [num_frames, patches, 3, image_size, image_size]
image_features = self.vision_tower(images.flatten(0, 1))
image_features = self.ln_vision(image_features)
attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to(
image_features.device) # [num_frames * num_patches, 256]
query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) # [num_frames * num_patches, 32, 768]
dtype_ = self.vision_tower.dtype
image_features = self.bert(
query_embeds=query_tokens.to(dtype_),
encoder_hidden_states=image_features.to(dtype_),
encoder_attention_mask=attn_mask,
return_dict=True
).last_hidden_state.to(dtype_)
frame_ids = torch.arange(images.shape[0], dtype=torch.long, device=image_features.device).unsqueeze(1)
frame_ids = frame_ids.repeat(1, images.shape[1]).flatten(0, 1) # [num_frames * num_patches]
image_features += self.frame_position_encoding(frame_ids).unsqueeze(-2) # [num_frames, 1, 768]
return self.projector(image_features)
# zheng add
def forward_features(self, a, b):
# images: [num_frames, patches, 3, image_size, image_size]
images = a
image_features = b
image_features = self.ln_vision(image_features)
attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to(
image_features.device) # [num_frames * num_patches, 256]
query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) # [num_frames * num_patches, 32, 768]
dtype_ = self.vision_tower.dtype
image_features = self.bert(
query_embeds=query_tokens.to(dtype_),
encoder_hidden_states=image_features.to(dtype_),
encoder_attention_mask=attn_mask,
return_dict=True
).last_hidden_state.to(dtype_)
frame_ids = torch.arange(images.shape[0], dtype=torch.long, device=image_features.device).unsqueeze(1)
frame_ids = frame_ids.repeat(1, images.shape[1]).flatten(0, 1) # [num_frames * num_patches]
image_features += self.frame_position_encoding(frame_ids).unsqueeze(-2) # [num_frames, 1, 768]
return self.projector(image_features)
if __name__ == '__main__':
frame_folder = './v12044gd0000cl5c6rfog65i2eoqcqig'
tokenizer_dir = '../tokenizer_dir'
# device = 'cuda:0'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
config = AutoConfig.from_pretrained(tokenizer_dir, trust_remote_code=True)
tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN],
special_tokens=True)
image_processor = Blip2ImageTrainProcessor(
image_size=config.img_size,
is_training=False)
input_ids, conv = generate_input_ids(tokenizer)
# images = generate_images(frame_folder, image_processor, config).to(device).half() # [num_frames, patches, 3, image_size, image_size]
# zheng
images = generate_images(frame_folder, image_processor, config).to(device)
vision_module = MASPVision(config=config)
input_ids = input_ids[0].to(device) # [token_len]
# new_vision_state_dict = torch.load('new_vision_state_dict.pth')
new_vision_state_dict = torch.load('new_vision_state_dict.pth', map_location=device)
# vision_state_dict = torch.load('masp_vision_statedict.pth', map_location="cuda:0")
# new_vision_state_dict = OrderedDict()
# for k, v in vision_state_dict.items():
# if 'qformer' in k:
# new_key = k[8:]
# new_vision_state_dict[new_key] = v
# else:
# new_vision_state_dict[k] = v
vision_module.load_state_dict(new_vision_state_dict)
vision_module = vision_module.eval()
vision_module = vision_module.to(device)
# vision_module.to(torch.float16)
# zheng add
vision_module.to(torch.float32)
# zheng add
vision_module_neuron = torch.jit.load("./neuron_eva_vit_base.pt")
vision_module_neuron = vision_module_neuron.eval()
# output=vision_module_neuron(images)
padding_idx = config.pad_token_id
embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx)
embed_weight = torch.load('embed_tokens.pth')
embed_tokens.load_state_dict(embed_weight)
embed_tokens = embed_tokens.eval()
embed_tokens.to(torch.float16).to(device)
# vision_module = vision_module.eval()
# vision_state_dict = vision_module.state_dict()
# torch.save(vision_state_dict, 'masp_vision_statedict.pth')
# infernece
# print(images.shape) # [10, 7, 3, 224, 224]
# dummy_images = torch.rand(10, 7, 3, 224, 224).to(model.device)
# scripted_vision_module = torch.jit.script(vision_module)
# print('begin to trace')
# traced_vision_module = torch.jit.trace(vision_module, (images))
# traced_vision_module.save('traced_vision_module.pt')
# loaded = torch.jit.load('traced_vision_module.pt')
import time
start = time.time()
with torch.inference_mode():
# get image feature
# image_features = vision_module(images).flatten(0, 1) # [num_frames * num_patches * num_query_token, 4096]
image_features = torch.Tensor() # init a tensor
for image in images:
output = vision_module_neuron(image)
output = output[:, 1:].to(torch.float32)
if len(image_features) == 0:
image_features = output
else:
image_features = torch.cat([image_features, output], dim=0)
# zheng [70, 256, 1408]
image_features = vision_module.forward_features(images, image_features)
image_features = image_features.flatten(0, 1)
print(image_features.shape) # zheng [70, 32, 4096]
image_features.to(device=device, dtype=torch.float16)
image_features_numpy = image_features.detach().cpu().numpy()
# image_features_saved = np.load('image_features_numpy.npy')
# print(np.sum(image_features_numpy -image_features_saved ))
# image_features_numpy = image_features.detach().cpu().numpy
# np.save('image_features_numpy.npy', image_features_numpy)
# print('images features shape', image_features.shape)
# image_features = loaded(images).flatten(0, 1)
# concat with text features
vision_token_indice = torch.where(input_ids == MM_TOKEN_INDEX)[0][0]
pre_text_token = embed_tokens(input_ids[:vision_token_indice]) # zheng [32, 4096]
post_text_token = embed_tokens(input_ids[vision_token_indice + 1:])
inputs_embeds = torch.cat([pre_text_token, image_features, post_text_token]).unsqueeze(
0) # [1, num_token, 4096]
print("Inference time:", time.time() - start)
input_embeds_numpy = inputs_embeds.detach().cpu().numpy()
image_embeds_saved = np.load('inputs_embeds.npy')
diff = np.sum(input_embeds_numpy - image_embeds_saved)
print('diff with saved in the disk', diff)
# print('inputs embeds numpy shape', input_embeds_numpy.shape)
# np.save('inputs_embeds.npy', input_embeds_numpy)