from typing import Dict from fastapi import FastAPI, Request, HTTPException from backend_model import MistralModel from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor from transformers import AutoTokenizer from llava.constants import MM_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, \ DEFAULT_VIDEO_PATCH_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN from llava.conversation import conv_templates from llava.model.multimodal_encoder.qformer import BertConfig, BertLMHeadModel, BertModel from llava.model.multimodal_projector.builder import build_vision_projector from llava.model.utils import LayerNorm from llava.model.multimodal_encoder.eva_clip_encoder import EvaClipVisionTower from llava.mm_utils import tokenizer_image_token, process_images_v2 import torch import numpy as np from PIL import Image import os import msgpack from io import BytesIO import base64 import torch.nn as nn from transformers import AutoConfig from collections import OrderedDict import torch_neuronx NUM_SEGMENTS = 10 def generate_input_ids(tokenizer): conv = conv_templates['thoth'].copy() qs = "Describe the following video in detail." qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze(0) return input_ids, conv def generate_images(frame_folder, image_processor, model_cfg): images = load_frames(frame_folder) if len(images) > NUM_SEGMENTS: images = uniform_sample(images, NUM_SEGMENTS) return process_images_v2(images, image_processor, model_cfg) def uniform_sample(frames, num_segments): indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int) frames = [frames[ind] for ind in indices] return frames def load_frames(frames_dir): results = [] image_files = [(int(os.path.splitext(img)[0]), img) for img in os.listdir(frames_dir) if img.endswith('jpg')] image_files = sorted(image_files, key=lambda img: img[0]) for frame_name in image_files: image_path = f"{frames_dir}/{frame_name[1]}" image = Image.open(image_path).convert('RGB') results.append(image) return results class MASPVision(torch.nn.Module): def __init__(self, config): super().__init__() # device = 'cuda:0' device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device_map = {"": 0} vision_tower = EvaClipVisionTower("eva-vit-g", config, delay_load=True) vision_tower.load_model(device_map=device_map) vision_tower.to(device=device, dtype=torch.float16) image_processor = Blip2ImageTrainProcessor( image_size=config.img_size, is_training=False) cross_attention_freq = 2 vision_width = vision_tower.hidden_size num_query_token = config.num_query_token ln_vision = LayerNorm(vision_width) encoder_config = BertConfig.from_pretrained("bert-base-uncased") encoder_config.encoder_width = vision_width # insert cross-attention layer every other block encoder_config.add_cross_attention = True encoder_config.cross_attention_freq = cross_attention_freq encoder_config.query_length = num_query_token # Qformer = BertLMHeadModel(config=encoder_config) self.bert = BertModel(encoder_config, add_pooling_layer=False) self.bert.embeddings.word_embeddings = None self.bert.embeddings.position_embeddings = None for layer in self.bert.encoder.layer: layer.output = None layer.intermediate = None query_tokens = nn.Parameter( torch.zeros(1, num_query_token, encoder_config.hidden_size) ) query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) frame_position_encoding = nn.Embedding( config.max_num_segments, encoder_config.hidden_size ) mm_projector = build_vision_projector(config) self.vision_tower = vision_tower # self.qformer = Qformer self.projector = mm_projector self.query_tokens = query_tokens self.ln_vision = ln_vision self.frame_position_encoding = frame_position_encoding def forward(self, images): # images: [num_frames, patches, 3, image_size, image_size] image_features = self.vision_tower(images.flatten(0, 1)) image_features = self.ln_vision(image_features) attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to( image_features.device) # [num_frames * num_patches, 256] query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) # [num_frames * num_patches, 32, 768] dtype_ = self.vision_tower.dtype image_features = self.bert( query_embeds=query_tokens.to(dtype_), encoder_hidden_states=image_features.to(dtype_), encoder_attention_mask=attn_mask, return_dict=True ).last_hidden_state.to(dtype_) frame_ids = torch.arange(images.shape[0], dtype=torch.long, device=image_features.device).unsqueeze(1) frame_ids = frame_ids.repeat(1, images.shape[1]).flatten(0, 1) # [num_frames * num_patches] image_features += self.frame_position_encoding(frame_ids).unsqueeze(-2) # [num_frames, 1, 768] return self.projector(image_features) # zheng add def forward_features(self, a, b): # images: [num_frames, patches, 3, image_size, image_size] images = a image_features = b image_features = self.ln_vision(image_features) attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to( image_features.device) # [num_frames * num_patches, 256] query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) # [num_frames * num_patches, 32, 768] dtype_ = self.vision_tower.dtype image_features = self.bert( query_embeds=query_tokens.to(dtype_), encoder_hidden_states=image_features.to(dtype_), encoder_attention_mask=attn_mask, return_dict=True ).last_hidden_state.to(dtype_) frame_ids = torch.arange(images.shape[0], dtype=torch.long, device=image_features.device).unsqueeze(1) frame_ids = frame_ids.repeat(1, images.shape[1]).flatten(0, 1) # [num_frames * num_patches] image_features += self.frame_position_encoding(frame_ids).unsqueeze(-2) # [num_frames, 1, 768] return self.projector(image_features) WEIGHT_ROOT = '/root/masp_models_inf2' tokenizer_dir = '../tokenizer_dir' NEURON_VISION_PATH = os.path.join(WEIGHT_ROOT, "./neuron_eva_vit_base.pt") VISION_STATE_DICT = os.path.join(WEIGHT_ROOT, 'new_vision_state_dict.pth') EMBED_TOKEN_PATH = os.path.join(WEIGHT_ROOT, 'embed_tokens.pth') EVA_VIT_PATH = os.path.join(WEIGHT_ROOT, 'eva_vit_g.pth') app = FastAPI() mistral_model = MistralModel() config = AutoConfig.from_pretrained(tokenizer_dir, trust_remote_code=True) config.vit_model_path = EVA_VIT_PATH tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) tokenizer.add_tokens( [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN], special_tokens=True) input_ids, conv = generate_input_ids(tokenizer) input_ids = input_ids[0].to('cpu') # [token_len] image_processor = Blip2ImageTrainProcessor( image_size=config.img_size, is_training=False) vision_module = MASPVision(config=config) # new_vision_state_dict = torch.load('new_vision_state_dict.pth') new_vision_state_dict = torch.load(VISION_STATE_DICT, map_location='cpu') vision_module.load_state_dict(new_vision_state_dict) vision_module = vision_module.eval() vision_module = vision_module.to('cpu') # vision_module.to(torch.float16) # zheng add vision_module.to(torch.float32) vision_module_neuron = torch.jit.load(NEURON_VISION_PATH) vision_module_neuron = vision_module_neuron.eval() padding_idx = config.pad_token_id embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx) embed_weight = torch.load(EMBED_TOKEN_PATH) embed_tokens.load_state_dict(embed_weight) embed_tokens = embed_tokens.eval() embed_tokens.to(torch.float16).to('cpu') @app.post("/generate") async def generate(request: Request) -> Dict[str, str]: """ Generate text using the Mistral model. Args: request (Request): The incoming request object. Returns: Dict[str, str]: A dictionary containing the generated text or an error message. """ try: request_payload = await request.json() packed_data = request_payload.get("images") parameters = request_payload.get("parameters", {}) #unpacked_data = msgpack.unpackb(packed_data, raw=False) unpacked_data = [base64.b64decode(item) for item in packed_data] input_images = [Image.open(BytesIO(byte_data)).convert('RGB') for byte_data in unpacked_data] input_images = uniform_sample(input_images, NUM_SEGMENTS) input_images = process_images_v2(input_images, image_processor, config) with torch.inference_mode(): # get image feature # image_features = vision_module(images).flatten(0, 1) # [num_frames * num_patches * num_query_token, 4096] image_features = torch.Tensor() # init a tensor for image in input_images: output = vision_module_neuron(image) output = output[:, 1:].to(torch.float32) if len(image_features) == 0: image_features = output else: image_features = torch.cat([image_features, output], dim=0) # zheng [70, 256, 1408] image_features = vision_module.forward_features(input_images, image_features) image_features = image_features.flatten(0, 1) print(image_features.shape) # zheng [70, 32, 4096] image_features.to(device='cpu', dtype=torch.float16) image_features_numpy = image_features.detach().cpu().numpy() # image_features_saved = np.load('image_features_numpy.npy') # print(np.sum(image_features_numpy -image_features_saved )) # image_features_numpy = image_features.detach().cpu().numpy # np.save('image_features_numpy.npy', image_features_numpy) # print('images features shape', image_features.shape) # image_features = loaded(images).flatten(0, 1) # concat with text features vision_token_indice = torch.where(input_ids == MM_TOKEN_INDEX)[0][0] pre_text_token = embed_tokens(input_ids[:vision_token_indice]) # zheng [32, 4096] post_text_token = embed_tokens(input_ids[vision_token_indice + 1:]) inputs_embeds = torch.cat([pre_text_token, image_features, post_text_token]).unsqueeze( 0) # [1, num_token, 4096] inputs = inputs_embeds.detach().cpu().numpy().tolist() if not inputs: raise HTTPException(status_code=400, detail="No input provided") generated_text = mistral_model.generate(inputs, parameters) return {"generated_text": generated_text} except Exception as e: raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")