|
|
import sys |
|
|
|
|
|
sys.path.append("/repository/BLIP") |
|
|
sys.path.append("/repository") |
|
|
sys.path.append("BLIP") |
|
|
|
|
|
from typing import Dict, List, Any |
|
|
from PIL import Image |
|
|
import requests |
|
|
import torch |
|
|
import base64 |
|
|
import os |
|
|
from io import BytesIO |
|
|
from torchvision import transforms |
|
|
from torchvision.transforms.functional import InterpolationMode |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
|
|
from torch import nn |
|
|
from torch.nn import CrossEntropyLoss |
|
|
import torch.nn.functional as F |
|
|
import torch |
|
|
import numpy as np |
|
|
from BLIP.models.blip import BLIP_Base,load_checkpoint |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, pretrain_path="/repository/blip_base.pth"): |
|
|
image_size = 224 |
|
|
|
|
|
self.blip_encoder_path = pretrain_path |
|
|
self.blip_encoder = BLIP_Base( |
|
|
image_size=image_size, |
|
|
vit='base', |
|
|
med_config='/repository/configs/med_config.json' |
|
|
) |
|
|
load_checkpoint(self.blip_encoder, "/repository/blip_base.pth") |
|
|
self.blip_encoder.eval() |
|
|
self.blip_encoder = self.blip_encoder.to(device) |
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, data: Any) -> Dict[str, List[float]]: |
|
|
""" |
|
|
Args: |
|
|
data (:obj:): |
|
|
includes the input data and the parameters for the inference. |
|
|
Return: |
|
|
A :obj:`dict`:. The object returned should be a dict like {"feature_vector": [0.6331314444541931,0.8802216053009033,...,-0.7866355180740356,]} containing : |
|
|
- "feature_vector": A list of floats corresponding to the image embedding. |
|
|
""" |
|
|
inputs = data.pop("inputs", data) |
|
|
parameters = data.pop("parameters", {"mode": "image"}) |
|
|
|
|
|
|
|
|
text = [inputs["text"]] |
|
|
|
|
|
has_image = inputs["has_image"] |
|
|
|
|
|
|
|
|
image = Image.open(BytesIO(base64.b64decode(inputs['image']))) if has_image else Image.new("RGB", (224,224)) |
|
|
image = self.transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
image_embeds = self.blip_encoder.visual_encoder(image) |
|
|
|
|
|
|
|
|
if not has_image: |
|
|
target_len = image_embeds.shape[1] |
|
|
|
|
|
pad = " [PAD]" * (target_len - 2) |
|
|
|
|
|
text = text + [pad] |
|
|
|
|
|
text = self.blip_encoder.tokenizer(tuple(text), return_tensors="pt", padding="longest").to(device) |
|
|
|
|
|
|
|
|
text_input_ids = text.input_ids[:1] |
|
|
text_attention_mask = text.attention_mask[:1] |
|
|
|
|
|
|
|
|
if not has_image: |
|
|
text_embeds = self.blip_encoder.text_encoder(text_input_ids, attention_mask = text_attention_mask, |
|
|
return_dict = True, mode = 'text') |
|
|
image_embeds[0] = text_embeds.last_hidden_state[0] |
|
|
|
|
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) |
|
|
|
|
|
|
|
|
text.input_ids[:,0] = self.blip_encoder.tokenizer.enc_token_id |
|
|
|
|
|
|
|
|
output = self.blip_encoder.text_encoder(text_input_ids, |
|
|
attention_mask = text_attention_mask, |
|
|
encoder_hidden_states = image_embeds, |
|
|
encoder_attention_mask = image_atts, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
hidden_state = output.last_hidden_state |
|
|
hidden_state = hidden_state[:,0,:].detach() |
|
|
|
|
|
return {"feature_vector": hidden_state.tolist()} |