File size: 4,298 Bytes
68810f0 ae17ecb b1e0003 68810f0 fb5dbe5 68810f0 96d308a 68810f0 0616b25 68810f0 8a20c12 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | import sys
sys.path.append("/repository/BLIP")
sys.path.append("/repository")
sys.path.append("BLIP")
from typing import Dict, List, Any
from PIL import Image
import requests
import torch
import base64
import os
from io import BytesIO
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers.modeling_utils import PreTrainedModel
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import torch
import numpy as np
from BLIP.models.blip import BLIP_Base,load_checkpoint
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, pretrain_path="/repository/blip_base.pth"):
image_size = 224
# load the optimized model
self.blip_encoder_path = pretrain_path
self.blip_encoder = BLIP_Base(
image_size=image_size,
vit='base',
med_config='/repository/configs/med_config.json'
)
load_checkpoint(self.blip_encoder, "/repository/blip_base.pth")
self.blip_encoder.eval()
self.blip_encoder = self.blip_encoder.to(device)
self.transform = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
def __call__(self, data: Any) -> Dict[str, List[float]]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
A :obj:`dict`:. The object returned should be a dict like {"feature_vector": [0.6331314444541931,0.8802216053009033,...,-0.7866355180740356,]} containing :
- "feature_vector": A list of floats corresponding to the image embedding.
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {"mode": "image"})
# Get text
text = [inputs["text"]]
has_image = inputs["has_image"]
# decode base64 image to PIL and transform
image = Image.open(BytesIO(base64.b64decode(inputs['image']))) if has_image else Image.new("RGB", (224,224))
image = self.transform(image).unsqueeze(0).to(device)
# Get image embeddings
image_embeds = self.blip_encoder.visual_encoder(image)
# Add a fake pad text to the batch to match image embedding length if no associated image
if not has_image:
target_len = image_embeds.shape[1]
# subtract 2 from target since tokenizer will add a start and end token
pad = " [PAD]" * (target_len - 2)
text = text + [pad]
text = self.blip_encoder.tokenizer(tuple(text), return_tensors="pt", padding="longest").to(device)
# [:1] removes potential padding
text_input_ids = text.input_ids[:1]
text_attention_mask = text.attention_mask[:1]
# Inject text embedding as image embedding if no associated image
if not has_image:
text_embeds = self.blip_encoder.text_encoder(text_input_ids, attention_mask = text_attention_mask,
return_dict = True, mode = 'text')
image_embeds[0] = text_embeds.last_hidden_state[0]
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
# Switch out start token with encode [ENC] token
text.input_ids[:,0] = self.blip_encoder.tokenizer.enc_token_id
# Get image-grounded text embedding
output = self.blip_encoder.text_encoder(text_input_ids,
attention_mask = text_attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict=True
)
hidden_state = output.last_hidden_state
hidden_state = hidden_state[:,0,:].detach()
return {"feature_vector": hidden_state.tolist()} |