File size: 1,271 Bytes
236a466
 
 
9ace47e
 
 
 
 
 
 
 
 
 
236a466
 
 
1775fa1
236a466
9ace47e
 
dba2b08
9ace47e
 
236a466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ace47e
 
 
236a466
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
from typing import Dict, List, Any
from transformers import pipeline
import base64
from PIL import Image
import io

def base64_to_pil(base64_image):
    image_data = base64.b64decode(base64_image)
    image_data = io.BytesIO(image_data)
    pil_image = Image.open(image_data)
    
    return pil_image


# check for GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def is_base64(s):
    try:
        return base64.b64encode(base64.b64decode(s)).decode('utf-8') == s
    except Exception:
        return False    

class EndpointHandler():
    def __init__(self, path=""):
        # Preload all the elements you are going to need at inference.
        # pseudo:
        self.pipeline= pipeline("image-to-text", model="Salesforce/blip-image-captioning-large", device=device)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        
        inputs = data.pop("inputs", data)
        if(is_base64(inputs)):
            inputs = base64_to_pil(inputs)
            
        return self.pipeline(inputs)