File size: 2,266 Bytes
6dd6674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c5888e
 
 
 
 
 
 
 
 
 
 
6dd6674
 
9c5888e
b15ede3
9b13dd8
b15ede3
6dd6674
 
b15ede3
6dd6674
930edb0
 
 
6dd6674
9c5888e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from typing import Dict, Any, List
from transformers import AutoProcessor, BlipForConditionalGeneration
from PIL import Image
import torch
import io
import base64

class EndpointHandler:
    def __init__(self, path: str = ""):
        # Load model and processor
        self.processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base", use_fast=False)
        self.model = BlipForConditionalGeneration.from_pretrained(path)
        self.model.eval()

        self.default_args = {
            "max_new_tokens": 30,
            "temperature": 0.4,
            "do_sample": True,
            "top_k": 40,
            "top_p": 0.4,
        }

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """

        Args:

            data (dict): {

                "inputs": base64-encoded image,

                "generation_args": optional generation parameters

            }

    

        Returns:

            List[Dict[str, str]]: generated caption or error

        """
        
        image_data = data.get("inputs")
        if image_data is None:
            return [{"error": "Missing 'inputs' key"}]
        # Wake up function to restart the server
        if image_data == "wake": 
            return [{"status": "woken"}]

        # Generation args
        args = data.get("generation_args", {})
        generation_args = self.default_args.copy()
        for k in self.default_args:
            if k in args and args[k] is not None:
                generation_args[k] = args[k]

        # Decode base64 image
        try:
            image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
        except Exception as e:
            return [{"error": f"Image decoding failed: {str(e)}"}]

        # Model Inference
        try:
            inputs = self.processor(image, return_tensors="pt")
            with torch.no_grad():
                output_tokens = self.model.generate(**inputs, **generation_args)
            caption = self.processor.decode(output_tokens[0], skip_special_tokens=True)
            return [{"generated_caption": caption}]
        except Exception as e:
            return [{"error": f"Inference failed: {str(e)}"}]