| from typing import Dict |
| from fastapi import FastAPI, Request, HTTPException |
| from backend_model import MistralModel |
| from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor |
| from transformers import AutoTokenizer |
| from llava.constants import MM_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, \ |
| DEFAULT_VIDEO_PATCH_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN |
| from llava.conversation import conv_templates |
| from llava.model.multimodal_encoder.qformer import BertConfig, BertLMHeadModel, BertModel |
| from llava.model.multimodal_projector.builder import build_vision_projector |
| from llava.model.utils import LayerNorm |
| from llava.model.multimodal_encoder.eva_clip_encoder import EvaClipVisionTower |
| from llava.mm_utils import tokenizer_image_token, process_images_v2 |
| import torch |
| import numpy as np |
| from PIL import Image |
| import os |
| import msgpack |
| from io import BytesIO |
| import base64 |
| import torch.nn as nn |
| from transformers import AutoConfig |
| from collections import OrderedDict |
|
|
| import torch_neuronx |
|
|
| NUM_SEGMENTS = 10 |
|
|
| def generate_input_ids(tokenizer): |
| conv = conv_templates['thoth'].copy() |
| qs = "Describe the following video in detail." |
| qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs |
| conv.append_message(conv.roles[0], qs) |
| conv.append_message(conv.roles[1], None) |
| prompt = conv.get_prompt() |
| input_ids = tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze(0) |
| return input_ids, conv |
|
|
|
|
| def generate_images(frame_folder, image_processor, model_cfg): |
| images = load_frames(frame_folder) |
| if len(images) > NUM_SEGMENTS: |
| images = uniform_sample(images, NUM_SEGMENTS) |
| return process_images_v2(images, image_processor, model_cfg) |
|
|
|
|
| def uniform_sample(frames, num_segments): |
| indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int) |
| frames = [frames[ind] for ind in indices] |
| return frames |
|
|
|
|
| def load_frames(frames_dir): |
| results = [] |
| image_files = [(int(os.path.splitext(img)[0]), img) for img in os.listdir(frames_dir) if img.endswith('jpg')] |
| image_files = sorted(image_files, key=lambda img: img[0]) |
|
|
| for frame_name in image_files: |
| image_path = f"{frames_dir}/{frame_name[1]}" |
| image = Image.open(image_path).convert('RGB') |
| results.append(image) |
| return results |
|
|
|
|
| class MASPVision(torch.nn.Module): |
|
|
| def __init__(self, config): |
| super().__init__() |
| |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| device_map = {"": 0} |
|
|
| vision_tower = EvaClipVisionTower("eva-vit-g", config, delay_load=True) |
| vision_tower.load_model(device_map=device_map) |
| vision_tower.to(device=device, dtype=torch.float16) |
|
|
| image_processor = Blip2ImageTrainProcessor( |
| image_size=config.img_size, |
| is_training=False) |
|
|
| cross_attention_freq = 2 |
| vision_width = vision_tower.hidden_size |
| num_query_token = config.num_query_token |
| ln_vision = LayerNorm(vision_width) |
| encoder_config = BertConfig.from_pretrained("bert-base-uncased") |
| encoder_config.encoder_width = vision_width |
| |
| encoder_config.add_cross_attention = True |
| encoder_config.cross_attention_freq = cross_attention_freq |
| encoder_config.query_length = num_query_token |
| |
| self.bert = BertModel(encoder_config, add_pooling_layer=False) |
| self.bert.embeddings.word_embeddings = None |
| self.bert.embeddings.position_embeddings = None |
|
|
| for layer in self.bert.encoder.layer: |
| layer.output = None |
| layer.intermediate = None |
|
|
| query_tokens = nn.Parameter( |
| torch.zeros(1, num_query_token, encoder_config.hidden_size) |
| ) |
| query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) |
|
|
| frame_position_encoding = nn.Embedding( |
| config.max_num_segments, |
| encoder_config.hidden_size |
| ) |
|
|
| mm_projector = build_vision_projector(config) |
|
|
| self.vision_tower = vision_tower |
| |
| self.projector = mm_projector |
| self.query_tokens = query_tokens |
| self.ln_vision = ln_vision |
| self.frame_position_encoding = frame_position_encoding |
|
|
| def forward(self, images): |
| |
| image_features = self.vision_tower(images.flatten(0, 1)) |
| image_features = self.ln_vision(image_features) |
| attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to( |
| image_features.device) |
| query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) |
| dtype_ = self.vision_tower.dtype |
| image_features = self.bert( |
| query_embeds=query_tokens.to(dtype_), |
| encoder_hidden_states=image_features.to(dtype_), |
| encoder_attention_mask=attn_mask, |
| return_dict=True |
| ).last_hidden_state.to(dtype_) |
| frame_ids = torch.arange(images.shape[0], dtype=torch.long, device=image_features.device).unsqueeze(1) |
| frame_ids = frame_ids.repeat(1, images.shape[1]).flatten(0, 1) |
| image_features += self.frame_position_encoding(frame_ids).unsqueeze(-2) |
| return self.projector(image_features) |
|
|
| |
|
|
| def forward_features(self, a, b): |
| |
| images = a |
| image_features = b |
| image_features = self.ln_vision(image_features) |
| attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to( |
| image_features.device) |
| query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) |
| dtype_ = self.vision_tower.dtype |
| image_features = self.bert( |
| query_embeds=query_tokens.to(dtype_), |
| encoder_hidden_states=image_features.to(dtype_), |
| encoder_attention_mask=attn_mask, |
| return_dict=True |
| ).last_hidden_state.to(dtype_) |
| frame_ids = torch.arange(images.shape[0], dtype=torch.long, device=image_features.device).unsqueeze(1) |
| frame_ids = frame_ids.repeat(1, images.shape[1]).flatten(0, 1) |
| image_features += self.frame_position_encoding(frame_ids).unsqueeze(-2) |
| return self.projector(image_features) |
|
|
|
|
| WEIGHT_ROOT = '/root/masp_models_inf2' |
| tokenizer_dir = '../tokenizer_dir' |
| NEURON_VISION_PATH = os.path.join(WEIGHT_ROOT, "./neuron_eva_vit_base.pt") |
| VISION_STATE_DICT = os.path.join(WEIGHT_ROOT, 'new_vision_state_dict.pth') |
| EMBED_TOKEN_PATH = os.path.join(WEIGHT_ROOT, 'embed_tokens.pth') |
| EVA_VIT_PATH = os.path.join(WEIGHT_ROOT, 'eva_vit_g.pth') |
| app = FastAPI() |
| mistral_model = MistralModel() |
|
|
| config = AutoConfig.from_pretrained(tokenizer_dir, trust_remote_code=True) |
|
|
| config.vit_model_path = EVA_VIT_PATH |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) |
|
|
| tokenizer.add_tokens( |
| [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN], |
| special_tokens=True) |
|
|
| input_ids, conv = generate_input_ids(tokenizer) |
| input_ids = input_ids[0].to('cpu') |
|
|
| image_processor = Blip2ImageTrainProcessor( |
| image_size=config.img_size, |
| is_training=False) |
| vision_module = MASPVision(config=config) |
| |
| new_vision_state_dict = torch.load(VISION_STATE_DICT, map_location='cpu') |
|
|
| vision_module.load_state_dict(new_vision_state_dict) |
| vision_module = vision_module.eval() |
| vision_module = vision_module.to('cpu') |
|
|
| |
| |
| vision_module.to(torch.float32) |
|
|
| vision_module_neuron = torch.jit.load(NEURON_VISION_PATH) |
| vision_module_neuron = vision_module_neuron.eval() |
|
|
| padding_idx = config.pad_token_id |
| embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx) |
| embed_weight = torch.load(EMBED_TOKEN_PATH) |
| embed_tokens.load_state_dict(embed_weight) |
| embed_tokens = embed_tokens.eval() |
| embed_tokens.to(torch.float16).to('cpu') |
|
|
| @app.post("/generate") |
| async def generate(request: Request) -> Dict[str, str]: |
| """ |
| Generate text using the Mistral model. |
| |
| Args: |
| request (Request): The incoming request object. |
| |
| Returns: |
| Dict[str, str]: A dictionary containing the generated text or an error message. |
| """ |
| try: |
| request_payload = await request.json() |
| packed_data = request_payload.get("images") |
| parameters = request_payload.get("parameters", {}) |
| |
| unpacked_data = [base64.b64decode(item) for item in packed_data] |
| input_images = [Image.open(BytesIO(byte_data)).convert('RGB') for byte_data in unpacked_data] |
|
|
| input_images = uniform_sample(input_images, NUM_SEGMENTS) |
| input_images = process_images_v2(input_images, image_processor, config) |
|
|
| with torch.inference_mode(): |
| |
|
|
| |
| image_features = torch.Tensor() |
| for image in input_images: |
| output = vision_module_neuron(image) |
| output = output[:, 1:].to(torch.float32) |
| if len(image_features) == 0: |
| image_features = output |
| else: |
| image_features = torch.cat([image_features, output], dim=0) |
|
|
| |
| image_features = vision_module.forward_features(input_images, image_features) |
| image_features = image_features.flatten(0, 1) |
| print(image_features.shape) |
|
|
| image_features.to(device='cpu', dtype=torch.float16) |
| image_features_numpy = image_features.detach().cpu().numpy() |
| |
| |
| |
| |
| |
| |
| |
| vision_token_indice = torch.where(input_ids == MM_TOKEN_INDEX)[0][0] |
| pre_text_token = embed_tokens(input_ids[:vision_token_indice]) |
| post_text_token = embed_tokens(input_ids[vision_token_indice + 1:]) |
|
|
| inputs_embeds = torch.cat([pre_text_token, image_features, post_text_token]).unsqueeze( |
| 0) |
| inputs = inputs_embeds.detach().cpu().numpy().tolist() |
|
|
| if not inputs: |
| raise HTTPException(status_code=400, detail="No input provided") |
|
|
| generated_text = mistral_model.generate(inputs, parameters) |
| return {"generated_text": generated_text} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}") |
|
|