Hitesh Satwani commited on
Commit
140c70d
·
1 Parent(s): 6544ea3

updated handler

Browse files
Files changed (1) hide show
  1. handler.py +13 -8
handler.py CHANGED
@@ -1,29 +1,34 @@
1
  import io
2
  import torch
 
3
  from PIL import Image
4
  from transformers import BlipProcessor, BlipForConditionalGeneration
5
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir):
8
- # Load the processor and model from the specified directory
9
  self.processor = BlipProcessor.from_pretrained(model_dir)
10
  self.model = BlipForConditionalGeneration.from_pretrained(model_dir)
11
  self.model.eval()
12
- # Move model to GPU if available
13
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  self.model.to(self.device)
15
 
16
  def __call__(self, data):
17
- # Read the image from the incoming request
18
- image_bytes = data
19
- image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
 
 
 
 
 
 
 
 
20
 
21
- # Preprocess the image
22
  inputs = self.processor(images=image, return_tensors="pt").to(self.device)
23
 
24
- # Generate caption
25
  with torch.no_grad():
26
  out = self.model.generate(**inputs)
27
- caption = self.processor.decode(out[0], skip_special_tokens=True)
28
 
 
29
  return {"caption": caption}
 
1
  import io
2
  import torch
3
+ import requests
4
  from PIL import Image
5
  from transformers import BlipProcessor, BlipForConditionalGeneration
6
 
7
  class EndpointHandler:
8
  def __init__(self, model_dir):
 
9
  self.processor = BlipProcessor.from_pretrained(model_dir)
10
  self.model = BlipForConditionalGeneration.from_pretrained(model_dir)
11
  self.model.eval()
 
12
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  self.model.to(self.device)
14
 
15
  def __call__(self, data):
16
+ # Expecting: { "image_url": "https://example.com/image.jpg" }
17
+ if isinstance(data, dict) and "image_url" in data:
18
+ image_url = data["image_url"]
19
+ try:
20
+ response = requests.get(image_url)
21
+ response.raise_for_status()
22
+ image = Image.open(io.BytesIO(response.content)).convert("RGB")
23
+ except Exception as e:
24
+ return {"error": f"Failed to load image from URL: {str(e)}"}
25
+ else:
26
+ return {"error": "Please provide an 'image_url' in the JSON payload."}
27
 
 
28
  inputs = self.processor(images=image, return_tensors="pt").to(self.device)
29
 
 
30
  with torch.no_grad():
31
  out = self.model.generate(**inputs)
 
32
 
33
+ caption = self.processor.decode(out[0], skip_special_tokens=True)
34
  return {"caption": caption}