|
|
import time |
|
|
import os |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
import concurrent.futures |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch_neuronx |
|
|
import transformers |
|
|
from transformers import AutoConfig |
|
|
from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN |
|
|
from llava.conversation import conv_templates |
|
|
from llava.model.utils import LayerNorm |
|
|
from llava.mm_utils import tokenizer_image_token, process_images_v2 |
|
|
from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor |
|
|
|
|
|
from typing import Dict |
|
|
from fastapi import FastAPI, Request, HTTPException |
|
|
from backend_model import MistralModel |
|
|
|
|
|
|
|
|
transformers.logging.set_verbosity_error() |
|
|
|
|
|
NUM_SEGMENTS = 10 |
|
|
WEIGHT_ROOT = '/root/inf2_dir/' |
|
|
CONFIG_DIR = os.path.join(WEIGHT_ROOT, "llava-mistral_videollava_ptv12_350k_samep_65k_sopv2_065th_sopv1_fps2_scratch") |
|
|
NEURON_VISION_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", "neuron_eva_vit_batch7.pth") |
|
|
NEURON_BERT_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", "neuron_bert.pth") |
|
|
PROJECTOR_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'projector.pth') |
|
|
EMBED_TOKEN_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'embed_tokens.pth') |
|
|
QUERY_TOKEN_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'query_tokens.pth') |
|
|
LAYERNORM_SAVE_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'ln_state_dict.pth') |
|
|
POSITION_ENCODING_SAVE_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'frame_position_encoding.pth') |
|
|
|
|
|
|
|
|
def generate_input_ids(tokenizer): |
|
|
""" |
|
|
Generate input token IDs and conversation template for the model. |
|
|
|
|
|
Args: |
|
|
tokenizer (AutoTokenizer): Tokenizer instance. |
|
|
|
|
|
Returns: |
|
|
tuple: (input_ids, conv) |
|
|
input_ids (torch.Tensor): Input token IDs. |
|
|
conv (Conversation): Conversation template. |
|
|
""" |
|
|
conv = conv_templates['thoth'].copy() |
|
|
qs = "Please describe this video in detail." |
|
|
qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs |
|
|
conv.append_message(conv.roles[0], qs) |
|
|
conv.append_message(conv.roles[1], None) |
|
|
prompt = conv.get_prompt() |
|
|
input_ids = tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze( |
|
|
0) |
|
|
return input_ids |
|
|
|
|
|
|
|
|
def uniform_sample(frames, num_segments): |
|
|
""" |
|
|
Uniformly sample frames from the provided list. |
|
|
|
|
|
Args: |
|
|
frames (list): List of frame images. |
|
|
num_segments (int): Number of segments to sample. |
|
|
|
|
|
Returns: |
|
|
list: List of sampled frames. |
|
|
""" |
|
|
indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype( |
|
|
int) |
|
|
frames = [frames[ind] for ind in indices] |
|
|
return frames |
|
|
|
|
|
|
|
|
def image_open_byteio(byte_data): |
|
|
output = Image.open(BytesIO(byte_data)).convert('RGB') |
|
|
return output |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
mistral_model = MistralModel(model_name=CONFIG_DIR) |
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained(CONFIG_DIR, trust_remote_code=True) |
|
|
tokenizer = mistral_model.tokenizer |
|
|
|
|
|
input_ids = generate_input_ids(tokenizer) |
|
|
input_ids = input_ids[0].to('cpu') |
|
|
|
|
|
|
|
|
image_processor = Blip2ImageTrainProcessor( |
|
|
image_size=config.img_size, |
|
|
is_training=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch_neuronx.experimental.neuron_cores_context(start_nc=0, nc_count=2): |
|
|
vision_module_neuron = torch.jit.load(NEURON_VISION_PATH) |
|
|
vision_module_neuron = vision_module_neuron.eval() |
|
|
|
|
|
|
|
|
padding_idx = config.pad_token_id |
|
|
embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx) |
|
|
embed_weight = torch.load(EMBED_TOKEN_PATH) |
|
|
embed_tokens.load_state_dict(embed_weight) |
|
|
embed_tokens = embed_tokens.eval() |
|
|
embed_tokens.to(torch.float16).to('cpu') |
|
|
|
|
|
|
|
|
vision_width = 1408 |
|
|
ln_vision = LayerNorm(vision_width) |
|
|
ln_vision_weight = torch.load(LAYERNORM_SAVE_PATH) |
|
|
ln_vision.load_state_dict(ln_vision_weight) |
|
|
ln_vision = ln_vision.eval() |
|
|
ln_vision = ln_vision.to(torch.float32) |
|
|
|
|
|
num_query_token = 32 |
|
|
query_tokens = nn.Parameter( |
|
|
torch.zeros(1, num_query_token, 768) |
|
|
) |
|
|
query_tokens.data.normal_(mean=0.0, std=0.02) |
|
|
query_tokens_weight = torch.load(QUERY_TOKEN_PATH)['query_tokens'] |
|
|
query_tokens.data = query_tokens_weight |
|
|
|
|
|
frame_position_encoding = nn.Embedding(10, 768) |
|
|
frame_position_encoding_weight = torch.load(POSITION_ENCODING_SAVE_PATH) |
|
|
frame_position_encoding.load_state_dict(frame_position_encoding_weight) |
|
|
|
|
|
projector = nn.Linear(config.mm_hidden_size, config.hidden_size) |
|
|
projector_weight = torch.load(PROJECTOR_PATH) |
|
|
projector.load_state_dict(projector_weight) |
|
|
|
|
|
neuron_bert = torch.jit.load(NEURON_BERT_PATH) |
|
|
neuron_bert = neuron_bert.eval() |
|
|
def vision_module_neuron_output(image): |
|
|
""" |
|
|
Get output from the vision module (Neuron format). |
|
|
|
|
|
Args: |
|
|
image (torch.Tensor): Input image. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Vision module output. |
|
|
""" |
|
|
output = vision_module_neuron(image) |
|
|
return output |
|
|
|
|
|
|
|
|
@app.post("/generate") |
|
|
async def generate(request: Request) -> Dict[str, str]: |
|
|
""" |
|
|
Generate text using the Mistral model. |
|
|
|
|
|
Args: |
|
|
request (Request): The incoming request object. |
|
|
|
|
|
Returns: |
|
|
Dict[str, str]: A dictionary containing the generated text or an error message. |
|
|
""" |
|
|
try: |
|
|
s1 = time.time() |
|
|
request_payload = await request.json() |
|
|
request_payload_keys = request_payload.keys() |
|
|
|
|
|
if "images" in request_payload_keys: |
|
|
packed_data = request_payload.get("images") |
|
|
|
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
|
unpacked_data = list(executor.map(base64.b64decode, packed_data)) |
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
|
input_images = list(executor.map(image_open_byteio, unpacked_data)) |
|
|
|
|
|
input_images = uniform_sample(input_images, NUM_SEGMENTS) |
|
|
input_images = process_images_v2(input_images, image_processor, config) |
|
|
print("s1 - input_images time: ", time.time() - s1) |
|
|
|
|
|
si = time.time() |
|
|
with torch.inference_mode(): |
|
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
|
image_features_list = list(executor.map(vision_module_neuron_output, input_images)) |
|
|
image_features = torch.cat(image_features_list, dim=0) |
|
|
print("si - image_features neuron time: ", time.time() - si) |
|
|
|
|
|
s2 = time.time() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_features = ln_vision(image_features) |
|
|
attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to(image_features.device) |
|
|
query_tokens_inputs = query_tokens.expand(image_features.shape[0], -1, -1) |
|
|
|
|
|
image_features = neuron_bert( |
|
|
query_tokens_inputs.to(torch.float32), |
|
|
image_features.to(torch.float32), |
|
|
attn_mask.to(torch.int64) |
|
|
)["last_hidden_state"].to(torch.float32) |
|
|
|
|
|
frame_ids = torch.arange(input_images.shape[0], dtype=torch.long, device=image_features.device).unsqueeze(1) |
|
|
frame_ids = frame_ids.repeat(1, input_images.shape[1]).flatten(0, 1) |
|
|
image_features += frame_position_encoding(frame_ids).unsqueeze(-2) |
|
|
projected_features = projector(image_features) |
|
|
|
|
|
image_features = projected_features.flatten(0, 1) |
|
|
print(image_features.shape) |
|
|
|
|
|
image_features.to(device='cpu', dtype=torch.float16) |
|
|
print("s2 - image_features prepare time: ", time.time() - s2) |
|
|
|
|
|
s3 = time.time() |
|
|
vision_token_indice = torch.where(input_ids == MM_TOKEN_INDEX)[0][0] |
|
|
pre_text_token = embed_tokens(input_ids[:vision_token_indice]) |
|
|
post_text_token = embed_tokens(input_ids[vision_token_indice + 1:]) |
|
|
print("s3 - text_token time: ", time.time() - s3) |
|
|
|
|
|
s4 = time.time() |
|
|
inputs_embeds = torch.cat([pre_text_token, image_features, post_text_token]).unsqueeze( |
|
|
0) |
|
|
inputs = inputs_embeds.detach().cpu().numpy().tolist() |
|
|
print("s4 - inputs time: ", time.time() - s4) |
|
|
|
|
|
elif "inputs" in request_payload_keys: |
|
|
si = time.time() |
|
|
inputs = request_payload.get("inputs") |
|
|
else: |
|
|
raise HTTPException(status_code=400, detail="Please provide correct input") |
|
|
|
|
|
if not inputs: |
|
|
raise HTTPException(status_code=400, detail="No input provided") |
|
|
|
|
|
s5 = time.time() |
|
|
parameters = request_payload.get("parameters", {}) |
|
|
generated_text = mistral_model.generate(inputs, parameters) |
|
|
print("s5 - generated_text time: ", time.time() - s5) |
|
|
print("total inference time: ", time.time() - si) |
|
|
return {"generated_text": generated_text} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}") |
|
|
|
|
|
|