File size: 4,298 Bytes
68810f0
 
ae17ecb
b1e0003
68810f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb5dbe5
68810f0
 
 
 
 
 
96d308a
68810f0
0616b25
68810f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a20c12
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import sys

sys.path.append("/repository/BLIP")
sys.path.append("/repository")
sys.path.append("BLIP")

from typing import  Dict, List, Any
from PIL import Image
import requests
import torch
import base64
import os
from io import BytesIO
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers.modeling_utils import PreTrainedModel

from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import torch
import numpy as np
from BLIP.models.blip import BLIP_Base,load_checkpoint

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class EndpointHandler():
    def __init__(self, pretrain_path="/repository/blip_base.pth"):
        image_size = 224
        # load the optimized model
        self.blip_encoder_path = pretrain_path
        self.blip_encoder = BLIP_Base(
            image_size=image_size, 
            vit='base',
            med_config='/repository/configs/med_config.json'
        )
        load_checkpoint(self.blip_encoder, "/repository/blip_base.pth")
        self.blip_encoder.eval()
        self.blip_encoder = self.blip_encoder.to(device)
        
        self.transform = transforms.Compose([
            transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ]) 
     


    def __call__(self, data: Any) -> Dict[str, List[float]]:
        """
        Args:
            data (:obj:):
                includes the input data and the parameters for the inference.
        Return:
            A :obj:`dict`:. The object returned should be a dict like {"feature_vector": [0.6331314444541931,0.8802216053009033,...,-0.7866355180740356,]} containing :
                - "feature_vector": A list of floats corresponding to the image embedding.
        """
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", {"mode": "image"})

        # Get text
        text = [inputs["text"]]

        has_image = inputs["has_image"]
        
        # decode base64 image to PIL and transform
        image = Image.open(BytesIO(base64.b64decode(inputs['image']))) if has_image else Image.new("RGB", (224,224))
        image = self.transform(image).unsqueeze(0).to(device)
        
        # Get image embeddings
        image_embeds = self.blip_encoder.visual_encoder(image)
        
        # Add a fake pad text to the batch to match image embedding length if no associated image
        if not has_image:
            target_len = image_embeds.shape[1]
            # subtract 2 from target since tokenizer will add a start and end token
            pad = " [PAD]" * (target_len - 2)
            
            text = text + [pad]

        text = self.blip_encoder.tokenizer(tuple(text), return_tensors="pt", padding="longest").to(device)

        # [:1] removes potential padding
        text_input_ids = text.input_ids[:1]
        text_attention_mask = text.attention_mask[:1]

        # Inject text embedding as image embedding if no associated image
        if not has_image:
            text_embeds = self.blip_encoder.text_encoder(text_input_ids, attention_mask = text_attention_mask,                      
                                            return_dict = True, mode = 'text')
            image_embeds[0] = text_embeds.last_hidden_state[0]

        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)

        # Switch out start token with encode [ENC] token
        text.input_ids[:,0] = self.blip_encoder.tokenizer.enc_token_id

        # Get image-grounded text embedding
        output = self.blip_encoder.text_encoder(text_input_ids,
                                    attention_mask = text_attention_mask,
                                    encoder_hidden_states = image_embeds,
                                    encoder_attention_mask = image_atts,
                                    return_dict=True
                                    )
        
        hidden_state = output.last_hidden_state
        hidden_state = hidden_state[:,0,:].detach()
        
        return {"feature_vector": hidden_state.tolist()}