jacstrong commited on
Commit
dabc20d
·
verified ·
1 Parent(s): 01d3c03

Update handler.py to accept a single image. Image will be provided via web url.

Browse files
Files changed (1) hide show
  1. handler.py +32 -24
handler.py CHANGED
@@ -1,7 +1,7 @@
1
- from typing import Dict, List, Any
2
  from PIL import Image
3
  import torch
4
- import os
5
  from io import BytesIO
6
  from transformers import BlipForConditionalGeneration, BlipProcessor
7
 
@@ -9,40 +9,48 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
 
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
- # load the optimized model
13
-
14
- self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
15
  self.model = BlipForConditionalGeneration.from_pretrained(
16
  "Salesforce/blip-image-captioning-base"
17
  ).to(device)
18
  self.model.eval()
19
- self.model = self.model.to(device)
20
-
21
-
22
 
23
  def __call__(self, data: Any) -> Dict[str, Any]:
24
  """
25
  Args:
26
- data (:obj:):
27
- includes the input data and the parameters for the inference.
28
  Return:
29
- A :obj:`dict`:. The object returned should be a dict of one list like {"captions": ["A hugging face at the office"]} containing :
30
  - "caption": A string corresponding to the generated caption.
31
  """
32
- inputs = data.pop("inputs", data)
33
- parameters = data.pop("parameters", {})
34
-
35
- raw_images = [Image.open(BytesIO(_img)) for _img in inputs]
36
-
37
- processed_image = self.processor(images=raw_images, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
38
  processed_image["pixel_values"] = processed_image["pixel_values"].to(device)
 
 
39
  processed_image = {**processed_image, **parameters}
40
-
41
  with torch.no_grad():
42
- out = self.model.generate(
43
- **processed_image
44
- )
45
- captions = self.processor.batch_decode(out, skip_special_tokens=True)
46
- # postprocess the prediction
47
- return {"captions": captions}
48
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
  from PIL import Image
3
  import torch
4
+ import requests
5
  from io import BytesIO
6
  from transformers import BlipForConditionalGeneration, BlipProcessor
7
 
 
9
 
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
+ # Load the processor and model
13
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
 
14
  self.model = BlipForConditionalGeneration.from_pretrained(
15
  "Salesforce/blip-image-captioning-base"
16
  ).to(device)
17
  self.model.eval()
 
 
 
18
 
19
  def __call__(self, data: Any) -> Dict[str, Any]:
20
  """
21
  Args:
22
+ data (:obj:`dict`):
23
+ Includes the input data and the parameters for the inference.
24
  Return:
25
+ A :obj:`dict`. The object returned contains:
26
  - "caption": A string corresponding to the generated caption.
27
  """
28
+ # Extract image URL and parameters
29
+ image_url = data.get("image")
30
+ parameters = data.get("parameters", {})
31
+
32
+ if not image_url:
33
+ return {"error": "Missing 'image' field in request body."}
34
+
35
+ try:
36
+ # Download the image
37
+ response = requests.get(image_url)
38
+ response.raise_for_status()
39
+ raw_image = Image.open(BytesIO(response.content)).convert("RGB")
40
+ except Exception as e:
41
+ return {"error": f"Failed to fetch image from URL: {str(e)}"}
42
+
43
+ # Preprocess the image
44
+ processed_image = self.processor(images=raw_image, return_tensors="pt")
45
  processed_image["pixel_values"] = processed_image["pixel_values"].to(device)
46
+
47
+ # Merge parameters if needed
48
  processed_image = {**processed_image, **parameters}
49
+
50
  with torch.no_grad():
51
+ out = self.model.generate(**processed_image)
 
 
 
 
 
52
 
53
+ # Decode the output
54
+ caption = self.processor.decode(out[0], skip_special_tokens=True)
55
+
56
+ return {"caption": caption}