File size: 1,669 Bytes
192a99b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import Dict, List, Any
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import requests
import torch

class EndpointHandler():
    def __init__(self, path="./"):
        # Load the processor and model, and move to CUDA if available
        self.processor = BlipProcessor.from_pretrained(path)
        self.model = BlipForConditionalGeneration.from_pretrained(path).to("cuda" if torch.cuda.is_available() else "cpu")

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
            image_url (:obj: `str`): URL of the image to caption
            prompt (:obj: `str`, optional): Text prompt for conditional captioning
        Return:
            A :obj:`list` with caption as `dict`
        """
        # Get inputs from the data
        image_url = data.get("image_url")
        prompt = data.get("prompt", "")  # Optional prompt for conditional captioning

        # Load image from URL and ensure RGB format
        image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
        
        # Conditional or Unconditional Captioning
        if prompt:
            # Conditional captioning
            inputs = self.processor(image, prompt, return_tensors="pt").to(self.model.device)
        else:
            # Unconditional captioning
            inputs = self.processor(image, return_tensors="pt").to(self.model.device)

        # Generate caption
        out = self.model.generate(**inputs)
        caption = self.processor.decode(out[0], skip_special_tokens=True)

        # Return the generated caption
        return [{"caption": caption}]