import sys sys.path.append("/repository/BLIP") sys.path.append("/repository") sys.path.append("BLIP") from typing import Dict, List, Any from PIL import Image import requests import torch import base64 import os from io import BytesIO from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from transformers.modeling_utils import PreTrainedModel from torch import nn from torch.nn import CrossEntropyLoss import torch.nn.functional as F import torch import numpy as np from BLIP.models.blip import BLIP_Base,load_checkpoint device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class EndpointHandler(): def __init__(self, pretrain_path="/repository/blip_base.pth"): image_size = 224 # load the optimized model self.blip_encoder_path = pretrain_path self.blip_encoder = BLIP_Base( image_size=image_size, vit='base', med_config='/repository/configs/med_config.json' ) load_checkpoint(self.blip_encoder, "/repository/blip_base.pth") self.blip_encoder.eval() self.blip_encoder = self.blip_encoder.to(device) self.transform = transforms.Compose([ transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) ]) def __call__(self, data: Any) -> Dict[str, List[float]]: """ Args: data (:obj:): includes the input data and the parameters for the inference. Return: A :obj:`dict`:. The object returned should be a dict like {"feature_vector": [0.6331314444541931,0.8802216053009033,...,-0.7866355180740356,]} containing : - "feature_vector": A list of floats corresponding to the image embedding. """ inputs = data.pop("inputs", data) parameters = data.pop("parameters", {"mode": "image"}) # Get text text = [inputs["text"]] has_image = inputs["has_image"] # decode base64 image to PIL and transform image = Image.open(BytesIO(base64.b64decode(inputs['image']))) if has_image else Image.new("RGB", (224,224)) image = self.transform(image).unsqueeze(0).to(device) # Get image embeddings image_embeds = self.blip_encoder.visual_encoder(image) # Add a fake pad text to the batch to match image embedding length if no associated image if not has_image: target_len = image_embeds.shape[1] # subtract 2 from target since tokenizer will add a start and end token pad = " [PAD]" * (target_len - 2) text = text + [pad] text = self.blip_encoder.tokenizer(tuple(text), return_tensors="pt", padding="longest").to(device) # [:1] removes potential padding text_input_ids = text.input_ids[:1] text_attention_mask = text.attention_mask[:1] # Inject text embedding as image embedding if no associated image if not has_image: text_embeds = self.blip_encoder.text_encoder(text_input_ids, attention_mask = text_attention_mask, return_dict = True, mode = 'text') image_embeds[0] = text_embeds.last_hidden_state[0] image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) # Switch out start token with encode [ENC] token text.input_ids[:,0] = self.blip_encoder.tokenizer.enc_token_id # Get image-grounded text embedding output = self.blip_encoder.text_encoder(text_input_ids, attention_mask = text_attention_mask, encoder_hidden_states = image_embeds, encoder_attention_mask = image_atts, return_dict=True ) hidden_state = output.last_hidden_state hidden_state = hidden_state[:,0,:].detach() return {"feature_vector": hidden_state.tolist()}