inf2_dir / app /main.py
root
feat: update
7c5440e
import time
import os
import base64
from io import BytesIO
import concurrent.futures
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch_neuronx
import transformers
from transformers import AutoConfig
from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN
from llava.conversation import conv_templates
from llava.model.utils import LayerNorm
from llava.mm_utils import tokenizer_image_token, process_images_v2
from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor
from typing import Dict
from fastapi import FastAPI, Request, HTTPException
from backend_model import MistralModel
# Suppress transformers logging
transformers.logging.set_verbosity_error()
NUM_SEGMENTS = 10 # Number of frame segments to use
WEIGHT_ROOT = '/root/inf2_dir/' # Root directory for model weights
CONFIG_DIR = os.path.join(WEIGHT_ROOT, "llava-mistral_videollava_ptv12_350k_samep_65k_sopv2_065th_sopv1_fps2_scratch") # Tokenizer directory
NEURON_VISION_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", "neuron_eva_vit_batch7.pth") # Vision model weights (Neuron format)
NEURON_BERT_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", "neuron_bert.pth") # BERT model weights (Neuron format)
PROJECTOR_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'projector.pth') # Projector weights
EMBED_TOKEN_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'embed_tokens.pth') # Embedding weights
QUERY_TOKEN_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'query_tokens.pth')
LAYERNORM_SAVE_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'ln_state_dict.pth')
POSITION_ENCODING_SAVE_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'frame_position_encoding.pth')
def generate_input_ids(tokenizer):
"""
Generate input token IDs and conversation template for the model.
Args:
tokenizer (AutoTokenizer): Tokenizer instance.
Returns:
tuple: (input_ids, conv)
input_ids (torch.Tensor): Input token IDs.
conv (Conversation): Conversation template.
"""
conv = conv_templates['thoth'].copy() # Copy the conversation template
qs = "Please describe this video in detail."
qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs # Prepend video tokens
conv.append_message(conv.roles[0], qs) # Add the question to the conversation
conv.append_message(conv.roles[1], None) # Add a placeholder for the response
prompt = conv.get_prompt() # Get the conversation prompt
input_ids = tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze(
0) # Tokenize and convert to tensor
return input_ids
def uniform_sample(frames, num_segments):
"""
Uniformly sample frames from the provided list.
Args:
frames (list): List of frame images.
num_segments (int): Number of segments to sample.
Returns:
list: List of sampled frames.
"""
indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(
int) # Calculate indices for uniform sampling
frames = [frames[ind] for ind in indices] # Sample frames based on indices
return frames
def image_open_byteio(byte_data):
output = Image.open(BytesIO(byte_data)).convert('RGB')
return output
# Create FastAPI app
app = FastAPI()
mistral_model = MistralModel(model_name=CONFIG_DIR) # Load Mistral model
# Load model configuration and tokenizer
config = AutoConfig.from_pretrained(CONFIG_DIR, trust_remote_code=True)
tokenizer = mistral_model.tokenizer
input_ids = generate_input_ids(tokenizer) # Generate input IDs and conversation template
input_ids = input_ids[0].to('cpu') # [token_len]
# Initialize image processor and vision module
image_processor = Blip2ImageTrainProcessor(
image_size=config.img_size,
is_training=False)
# vision_module = MASPVision()
# new_vision_state_dict = torch.load(VISION_STATE_DICT, map_location='cpu') # Load vision state dict
# vision_module.load_state_dict(new_vision_state_dict)
with torch_neuronx.experimental.neuron_cores_context(start_nc=0, nc_count=2): # Use Neuron cores for inference
vision_module_neuron = torch.jit.load(NEURON_VISION_PATH)
vision_module_neuron = vision_module_neuron.eval()
# Load embedding weights and set up embedding module
padding_idx = config.pad_token_id
embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx)
embed_weight = torch.load(EMBED_TOKEN_PATH)
embed_tokens.load_state_dict(embed_weight)
embed_tokens = embed_tokens.eval()
embed_tokens.to(torch.float16).to('cpu')
#layer norm
vision_width = 1408
ln_vision = LayerNorm(vision_width)
ln_vision_weight = torch.load(LAYERNORM_SAVE_PATH)
ln_vision.load_state_dict(ln_vision_weight)
ln_vision = ln_vision.eval()
ln_vision = ln_vision.to(torch.float32)
num_query_token = 32
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, 768)
)
query_tokens.data.normal_(mean=0.0, std=0.02)
query_tokens_weight = torch.load(QUERY_TOKEN_PATH)['query_tokens']
query_tokens.data = query_tokens_weight
frame_position_encoding = nn.Embedding(10, 768)
frame_position_encoding_weight = torch.load(POSITION_ENCODING_SAVE_PATH)
frame_position_encoding.load_state_dict(frame_position_encoding_weight)
projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
projector_weight = torch.load(PROJECTOR_PATH)
projector.load_state_dict(projector_weight)
neuron_bert = torch.jit.load(NEURON_BERT_PATH)
neuron_bert = neuron_bert.eval()
def vision_module_neuron_output(image):
"""
Get output from the vision module (Neuron format).
Args:
image (torch.Tensor): Input image.
Returns:
torch.Tensor: Vision module output.
"""
output = vision_module_neuron(image)
return output
@app.post("/generate")
async def generate(request: Request) -> Dict[str, str]:
"""
Generate text using the Mistral model.
Args:
request (Request): The incoming request object.
Returns:
Dict[str, str]: A dictionary containing the generated text or an error message.
"""
try:
s1 = time.time()
request_payload = await request.json()
request_payload_keys = request_payload.keys()
if "images" in request_payload_keys: # If input is a list of images
packed_data = request_payload.get("images")
# unpacked_data = [base64.b64decode(item) for item in packed_data] # Decode base64 images
# input_images = [Image.open(BytesIO(byte_data)).convert('RGB') for byte_data in unpacked_data] # Load images
with concurrent.futures.ThreadPoolExecutor() as executor:
unpacked_data = list(executor.map(base64.b64decode, packed_data))
with concurrent.futures.ThreadPoolExecutor() as executor:
input_images = list(executor.map(image_open_byteio, unpacked_data))
input_images = uniform_sample(input_images, NUM_SEGMENTS) # Sample frames
input_images = process_images_v2(input_images, image_processor, config) # Process images
print("s1 - input_images time: ", time.time() - s1)
si = time.time()
with torch.inference_mode(): # Enable inference mode
with concurrent.futures.ThreadPoolExecutor() as executor: # Use thread pool for parallel processing
image_features_list = list(executor.map(vision_module_neuron_output, input_images))
image_features = torch.cat(image_features_list, dim=0) # Concatenate image features
print("si - image_features neuron time: ", time.time() - si)
s2 = time.time()
# projector = nn.Linear(config.mm_hidden_size, config.hidden_size) # Initialize projector
# projector_weight = torch.load(PROJECTOR_DIR)
# projector.load_state_dict(projector_weight)
# image_features = vision_module.forward_features(input_images, image_features,
# projector) # Process vision features
image_features = ln_vision(image_features)
attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to(image_features.device)
query_tokens_inputs = query_tokens.expand(image_features.shape[0], -1, -1)
image_features = neuron_bert(
query_tokens_inputs.to(torch.float32),
image_features.to(torch.float32),
attn_mask.to(torch.int64)
)["last_hidden_state"].to(torch.float32)
frame_ids = torch.arange(input_images.shape[0], dtype=torch.long, device=image_features.device).unsqueeze(1)
frame_ids = frame_ids.repeat(1, input_images.shape[1]).flatten(0, 1) # [num_frames * num_patches]
image_features += frame_position_encoding(frame_ids).unsqueeze(-2) # [num_frames, 1, 768]
projected_features = projector(image_features)
image_features = projected_features.flatten(0, 1)
print(image_features.shape)
image_features.to(device='cpu', dtype=torch.float16) # Convert to float16 and move to CPU
print("s2 - image_features prepare time: ", time.time() - s2)
s3 = time.time()
vision_token_indice = torch.where(input_ids == MM_TOKEN_INDEX)[0][0] # Get index of vision token
pre_text_token = embed_tokens(input_ids[:vision_token_indice]) # Embed tokens before vision token
post_text_token = embed_tokens(input_ids[vision_token_indice + 1:]) # Embed tokens after vision token
print("s3 - text_token time: ", time.time() - s3)
s4 = time.time()
inputs_embeds = torch.cat([pre_text_token, image_features, post_text_token]).unsqueeze(
0) # Concatenate input embeddings
inputs = inputs_embeds.detach().cpu().numpy().tolist() # Convert to list for input to Mistral model
print("s4 - inputs time: ", time.time() - s4)
elif "inputs" in request_payload_keys: # If input is normal text or embedding
si = time.time()
inputs = request_payload.get("inputs")
else:
raise HTTPException(status_code=400, detail="Please provide correct input")
if not inputs:
raise HTTPException(status_code=400, detail="No input provided")
s5 = time.time()
parameters = request_payload.get("parameters", {}) # Get additional parameters
generated_text = mistral_model.generate(inputs, parameters) # Generate text using Mistral model
print("s5 - generated_text time: ", time.time() - s5)
print("total inference time: ", time.time() - si)
return {"generated_text": generated_text}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")