SaiSriTejaKuppa commited on
Commit
ae1956b
·
verified ·
1 Parent(s): 28a75c8

v2 verison of handler.py

Browse files

added the info about the inputs as a additional parameters

Files changed (2) hide show
  1. handler.py +18 -15
  2. test_handler.py +7 -3
handler.py CHANGED
@@ -10,32 +10,35 @@ class EndpointHandler():
10
  self.processor = BlipProcessor.from_pretrained(path)
11
  self.model = BlipForConditionalGeneration.from_pretrained(path).to("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
14
  """
15
- data args:
16
- image_url (:obj: `str`): URL of the image to caption
17
- prompt (:obj: `str`, optional): Text prompt for conditional captioning
18
  Return:
19
- A :obj:`list` with caption as `dict`
 
20
  """
21
- # Get inputs from the data
22
- image_url = data.get("image_url")
23
- prompt = data.get("prompt", "") # Optional prompt for conditional captioning
 
 
 
 
24
 
25
  # Load image from URL and ensure RGB format
26
  image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
27
 
28
- # Conditional or Unconditional Captioning
29
  if prompt:
30
- # Conditional captioning
31
- inputs = self.processor(image, prompt, return_tensors="pt").to(self.model.device)
32
  else:
33
- # Unconditional captioning
34
- inputs = self.processor(image, return_tensors="pt").to(self.model.device)
35
 
36
  # Generate caption
37
- out = self.model.generate(**inputs)
38
  caption = self.processor.decode(out[0], skip_special_tokens=True)
39
 
40
  # Return the generated caption
41
- return [{"caption": caption}]
 
10
  self.processor = BlipProcessor.from_pretrained(path)
11
  self.model = BlipForConditionalGeneration.from_pretrained(path).to("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
+ def __call__(self, data: Any) -> Dict[str, str]:
14
  """
15
+ Args:
16
+ data (:obj:):
17
+ includes the input data and the parameters for the inference.
18
  Return:
19
+ A :obj:`dict`:. The object returned should be a dict like {"caption": "Generated caption for the image"} containing:
20
+ - "caption": The generated caption as a string.
21
  """
22
+ # Extract inputs and parameters
23
+ inputs = data.pop("inputs", data)
24
+ parameters = data.pop("parameters", {"mode": "image"})
25
+
26
+ # Get image URL and prompt from the inputs
27
+ image_url = inputs.get("image_url")
28
+ prompt = inputs.get("prompt", "") # Optional prompt for conditional captioning
29
 
30
  # Load image from URL and ensure RGB format
31
  image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
32
 
33
+ # Process inputs with or without a prompt
34
  if prompt:
35
+ processed_inputs = self.processor(image, prompt, return_tensors="pt").to(self.model.device)
 
36
  else:
37
+ processed_inputs = self.processor(image, return_tensors="pt").to(self.model.device)
 
38
 
39
  # Generate caption
40
+ out = self.model.generate(**processed_inputs)
41
  caption = self.processor.decode(out[0], skip_special_tokens=True)
42
 
43
  # Return the generated caption
44
+ return {"caption": caption}
test_handler.py CHANGED
@@ -6,13 +6,17 @@ my_handler = EndpointHandler()
6
 
7
  # Sample payload for conditional captioning
8
  conditional_payload = {
9
- "image_url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg",
10
- "prompt": "a photography of"
 
 
11
  }
12
 
13
  # Sample payload for unconditional captioning
14
  unconditional_payload = {
15
- "image_url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
 
 
16
  }
17
 
18
  # Run the handler for both cases and print the outputs
 
6
 
7
  # Sample payload for conditional captioning
8
  conditional_payload = {
9
+ "inputs": {
10
+ "image_url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg",
11
+ "prompt": "a photography of"
12
+ }
13
  }
14
 
15
  # Sample payload for unconditional captioning
16
  unconditional_payload = {
17
+ "inputs": {
18
+ "image_url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
19
+ }
20
  }
21
 
22
  # Run the handler for both cases and print the outputs