ckandemir commited on
Commit
57aeb22
·
1 Parent(s): f7500e2

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +27 -28
handler.py CHANGED
@@ -8,7 +8,6 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
 
9
  class EndpointHandler():
10
  def __init__(self, path=""):
11
- # load the optimized model
12
  self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
13
  self.model = BlipForConditionalGeneration.from_pretrained(
14
  "Salesforce/blip-image-captioning-large"
@@ -17,31 +16,31 @@ class EndpointHandler():
17
  self.model = self.model.to(device)
18
 
19
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
20
- """
21
- Args:
22
- data (dict):
23
- Should contain:
24
- - 'images': List[bytes] of images.
25
- - 'texts': List[str] of associated texts. (Optional for unconditional captioning)
26
- Return:
27
- A dict with key "captions" and associated list of generated captions.
28
- """
29
  images = data.get("images")
30
- texts = data.get("texts", ["a photography of"] * len(images)) # Default to "a photography of" if not provided
31
-
32
- raw_images = [Image.open(BytesIO(_img)).convert("RGB") for _img in images]
33
-
34
- # Here, process both image and text
35
- processed_inputs = [self.processor(img, txt, return_tensors="pt") for img, txt in zip(raw_images, texts)]
36
- processed_inputs = {
37
- "pixel_values": torch.cat([inp["pixel_values"] for inp in processed_inputs], dim=0).to(device),
38
- "input_ids": torch.cat([inp["input_ids"] for inp in processed_inputs], dim=0).to(device),
39
- "attention_mask": torch.cat([inp["attention_mask"] for inp in processed_inputs], dim=0).to(device)
40
- }
41
-
42
- with torch.no_grad():
43
- out = self.model.generate(**processed_inputs)
44
-
45
- captions = self.processor.batch_decode(out, skip_special_tokens=True)
46
-
47
- return {"captions": captions}
 
 
 
 
 
 
 
 
 
 
8
 
9
  class EndpointHandler():
10
  def __init__(self, path=""):
 
11
  self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
12
  self.model = BlipForConditionalGeneration.from_pretrained(
13
  "Salesforce/blip-image-captioning-large"
 
16
  self.model = self.model.to(device)
17
 
18
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
19
  images = data.get("images")
20
+ # Check if images is None or empty and handle it appropriately
21
+ if not images:
22
+ return {"captions": []}
23
+
24
+ # Default to "a photography of" if texts not provided
25
+ texts = data.get("texts", ["a photography of"] * len(images))
26
+
27
+ try:
28
+ raw_images = [Image.open(BytesIO(_img)).convert("RGB") for _img in images]
29
+ processed_inputs = [
30
+ self.processor(img, txt, return_tensors="pt") for img, txt in zip(raw_images, texts)
31
+ ]
32
+ processed_inputs = {
33
+ "pixel_values": torch.cat([inp["pixel_values"] for inp in processed_inputs], dim=0).to(device),
34
+ "input_ids": torch.cat([inp["input_ids"] for inp in processed_inputs], dim=0).to(device),
35
+ "attention_mask": torch.cat([inp["attention_mask"] for inp in processed_inputs], dim=0).to(device)
36
+ }
37
+
38
+ with torch.no_grad():
39
+ out = self.model.generate(**processed_inputs)
40
+
41
+ captions = self.processor.batch_decode(out, skip_special_tokens=True)
42
+ return {"captions": captions}
43
+ except Exception as e:
44
+ # Handle or log the exception and optionally return an error message
45
+ print(f"Error during processing: {str(e)}")
46
+ return {"error": str(e)}