inf2_dir / app /tmp /main_old.py
root
feat: update
7c5440e
from typing import Dict
from fastapi import FastAPI, Request, HTTPException
from backend_model import MistralModel
from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor
from transformers import AutoTokenizer
from llava.constants import MM_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, \
DEFAULT_VIDEO_PATCH_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN
from llava.conversation import conv_templates
from llava.model.multimodal_encoder.qformer import BertConfig, BertLMHeadModel, BertModel
from llava.model.multimodal_projector.builder import build_vision_projector
from llava.model.utils import LayerNorm
from llava.model.multimodal_encoder.eva_clip_encoder import EvaClipVisionTower
from llava.mm_utils import tokenizer_image_token, process_images_v2
import torch
import numpy as np
from PIL import Image
import os
import msgpack
from io import BytesIO
import base64
import torch.nn as nn
from transformers import AutoConfig
from collections import OrderedDict
import torch_neuronx
NUM_SEGMENTS = 10
def generate_input_ids(tokenizer):
conv = conv_templates['thoth'].copy()
qs = "Describe the following video in detail."
qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
return input_ids, conv
def generate_images(frame_folder, image_processor, model_cfg):
images = load_frames(frame_folder)
if len(images) > NUM_SEGMENTS:
images = uniform_sample(images, NUM_SEGMENTS)
return process_images_v2(images, image_processor, model_cfg)
def uniform_sample(frames, num_segments):
indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int)
frames = [frames[ind] for ind in indices]
return frames
def load_frames(frames_dir):
results = []
image_files = [(int(os.path.splitext(img)[0]), img) for img in os.listdir(frames_dir) if img.endswith('jpg')]
image_files = sorted(image_files, key=lambda img: img[0])
for frame_name in image_files:
image_path = f"{frames_dir}/{frame_name[1]}"
image = Image.open(image_path).convert('RGB')
results.append(image)
return results
class MASPVision(torch.nn.Module):
def __init__(self, config):
super().__init__()
# device = 'cuda:0'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device_map = {"": 0}
vision_tower = EvaClipVisionTower("eva-vit-g", config, delay_load=True)
vision_tower.load_model(device_map=device_map)
vision_tower.to(device=device, dtype=torch.float16)
image_processor = Blip2ImageTrainProcessor(
image_size=config.img_size,
is_training=False)
cross_attention_freq = 2
vision_width = vision_tower.hidden_size
num_query_token = config.num_query_token
ln_vision = LayerNorm(vision_width)
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
# Qformer = BertLMHeadModel(config=encoder_config)
self.bert = BertModel(encoder_config, add_pooling_layer=False)
self.bert.embeddings.word_embeddings = None
self.bert.embeddings.position_embeddings = None
for layer in self.bert.encoder.layer:
layer.output = None
layer.intermediate = None
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
frame_position_encoding = nn.Embedding(
config.max_num_segments,
encoder_config.hidden_size
)
mm_projector = build_vision_projector(config)
self.vision_tower = vision_tower
# self.qformer = Qformer
self.projector = mm_projector
self.query_tokens = query_tokens
self.ln_vision = ln_vision
self.frame_position_encoding = frame_position_encoding
def forward(self, images):
# images: [num_frames, patches, 3, image_size, image_size]
image_features = self.vision_tower(images.flatten(0, 1))
image_features = self.ln_vision(image_features)
attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to(
image_features.device) # [num_frames * num_patches, 256]
query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) # [num_frames * num_patches, 32, 768]
dtype_ = self.vision_tower.dtype
image_features = self.bert(
query_embeds=query_tokens.to(dtype_),
encoder_hidden_states=image_features.to(dtype_),
encoder_attention_mask=attn_mask,
return_dict=True
).last_hidden_state.to(dtype_)
frame_ids = torch.arange(images.shape[0], dtype=torch.long, device=image_features.device).unsqueeze(1)
frame_ids = frame_ids.repeat(1, images.shape[1]).flatten(0, 1) # [num_frames * num_patches]
image_features += self.frame_position_encoding(frame_ids).unsqueeze(-2) # [num_frames, 1, 768]
return self.projector(image_features)
# zheng add
def forward_features(self, a, b):
# images: [num_frames, patches, 3, image_size, image_size]
images = a
image_features = b
image_features = self.ln_vision(image_features)
attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to(
image_features.device) # [num_frames * num_patches, 256]
query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) # [num_frames * num_patches, 32, 768]
dtype_ = self.vision_tower.dtype
image_features = self.bert(
query_embeds=query_tokens.to(dtype_),
encoder_hidden_states=image_features.to(dtype_),
encoder_attention_mask=attn_mask,
return_dict=True
).last_hidden_state.to(dtype_)
frame_ids = torch.arange(images.shape[0], dtype=torch.long, device=image_features.device).unsqueeze(1)
frame_ids = frame_ids.repeat(1, images.shape[1]).flatten(0, 1) # [num_frames * num_patches]
image_features += self.frame_position_encoding(frame_ids).unsqueeze(-2) # [num_frames, 1, 768]
return self.projector(image_features)
WEIGHT_ROOT = '/root/masp_models_inf2'
tokenizer_dir = '../tokenizer_dir'
NEURON_VISION_PATH = os.path.join(WEIGHT_ROOT, "./neuron_eva_vit_base.pt")
VISION_STATE_DICT = os.path.join(WEIGHT_ROOT, 'new_vision_state_dict.pth')
EMBED_TOKEN_PATH = os.path.join(WEIGHT_ROOT, 'embed_tokens.pth')
EVA_VIT_PATH = os.path.join(WEIGHT_ROOT, 'eva_vit_g.pth')
app = FastAPI()
mistral_model = MistralModel()
config = AutoConfig.from_pretrained(tokenizer_dir, trust_remote_code=True)
config.vit_model_path = EVA_VIT_PATH
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN],
special_tokens=True)
input_ids, conv = generate_input_ids(tokenizer)
input_ids = input_ids[0].to('cpu') # [token_len]
image_processor = Blip2ImageTrainProcessor(
image_size=config.img_size,
is_training=False)
vision_module = MASPVision(config=config)
# new_vision_state_dict = torch.load('new_vision_state_dict.pth')
new_vision_state_dict = torch.load(VISION_STATE_DICT, map_location='cpu')
vision_module.load_state_dict(new_vision_state_dict)
vision_module = vision_module.eval()
vision_module = vision_module.to('cpu')
# vision_module.to(torch.float16)
# zheng add
vision_module.to(torch.float32)
vision_module_neuron = torch.jit.load(NEURON_VISION_PATH)
vision_module_neuron = vision_module_neuron.eval()
padding_idx = config.pad_token_id
embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx)
embed_weight = torch.load(EMBED_TOKEN_PATH)
embed_tokens.load_state_dict(embed_weight)
embed_tokens = embed_tokens.eval()
embed_tokens.to(torch.float16).to('cpu')
@app.post("/generate")
async def generate(request: Request) -> Dict[str, str]:
"""
Generate text using the Mistral model.
Args:
request (Request): The incoming request object.
Returns:
Dict[str, str]: A dictionary containing the generated text or an error message.
"""
try:
request_payload = await request.json()
packed_data = request_payload.get("images")
parameters = request_payload.get("parameters", {})
#unpacked_data = msgpack.unpackb(packed_data, raw=False)
unpacked_data = [base64.b64decode(item) for item in packed_data]
input_images = [Image.open(BytesIO(byte_data)).convert('RGB') for byte_data in unpacked_data]
input_images = uniform_sample(input_images, NUM_SEGMENTS)
input_images = process_images_v2(input_images, image_processor, config)
with torch.inference_mode():
# get image feature
# image_features = vision_module(images).flatten(0, 1) # [num_frames * num_patches * num_query_token, 4096]
image_features = torch.Tensor() # init a tensor
for image in input_images:
output = vision_module_neuron(image)
output = output[:, 1:].to(torch.float32)
if len(image_features) == 0:
image_features = output
else:
image_features = torch.cat([image_features, output], dim=0)
# zheng [70, 256, 1408]
image_features = vision_module.forward_features(input_images, image_features)
image_features = image_features.flatten(0, 1)
print(image_features.shape) # zheng [70, 32, 4096]
image_features.to(device='cpu', dtype=torch.float16)
image_features_numpy = image_features.detach().cpu().numpy()
# image_features_saved = np.load('image_features_numpy.npy')
# print(np.sum(image_features_numpy -image_features_saved ))
# image_features_numpy = image_features.detach().cpu().numpy
# np.save('image_features_numpy.npy', image_features_numpy)
# print('images features shape', image_features.shape)
# image_features = loaded(images).flatten(0, 1)
# concat with text features
vision_token_indice = torch.where(input_ids == MM_TOKEN_INDEX)[0][0]
pre_text_token = embed_tokens(input_ids[:vision_token_indice]) # zheng [32, 4096]
post_text_token = embed_tokens(input_ids[vision_token_indice + 1:])
inputs_embeds = torch.cat([pre_text_token, image_features, post_text_token]).unsqueeze(
0) # [1, num_token, 4096]
inputs = inputs_embeds.detach().cpu().numpy().tolist()
if not inputs:
raise HTTPException(status_code=400, detail="No input provided")
generated_text = mistral_model.generate(inputs, parameters)
return {"generated_text": generated_text}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")