nndly46 commited on
Commit
728da96
·
verified ·
1 Parent(s): 5878807

Create image_captioner.py

Browse files
Files changed (1) hide show
  1. image_captioner.py +67 -0
image_captioner.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ import requests
3
+
4
+ import logging
5
+
6
+ from haystack import Document, component
7
+ from haystack.lazy_imports import LazyImport
8
+ from PIL import Image
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
13
+ import torch
14
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, BlipProcessor, BlipForConditionalGeneration
15
+ from PIL import Image
16
+
17
+ @component
18
+ class ImageCaptioner:
19
+ def __init__(
20
+ self,
21
+ model_name: str = "Salesforce/blip-image-captioning-base",
22
+ ):
23
+ torch_and_transformers_import.check()
24
+ self.model_name = model_name
25
+
26
+ if model_name == "nlpconnect/vit-gpt2-image-captioning":
27
+ self.model = VisionEncoderDecoderModel.from_pretrained(model_name)
28
+ self.feature_extractor = ViTImageProcessor.from_pretrained(model_name)
29
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ max_length = 16
31
+ num_beams = 4
32
+ self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
33
+ else:
34
+ self.processor = BlipProcessor.from_pretrained(model_name)
35
+ self.model = BlipForConditionalGeneration.from_pretrained(model_name)
36
+
37
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ self.model.to(self.device)
39
+
40
+ @component.output_types(caption=str)
41
+ def run(self, image_file_path: str) -> List[Document]:
42
+
43
+ i_image = Image.open(image_file_path)
44
+ if i_image.mode != "RGB":
45
+ i_image = i_image.convert(mode="RGB")
46
+
47
+ preds = []
48
+ if self.model_name == "nlpconnect/vit-gpt2-image-captioning":
49
+ pixel_values = self.feature_extractor(images=[i_image], return_tensors="pt").pixel_values
50
+ pixel_values = pixel_values.to(self.device)
51
+
52
+ output_ids = self.model.generate(pixel_values, **self.gen_kwargs)
53
+
54
+ preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
55
+ preds = [pred.strip() for pred in preds]
56
+ else:
57
+
58
+ inputs = self.processor([i_image], return_tensors="pt")
59
+ output_ids = self.model.generate(**inputs)
60
+ preds = self.processor.batch_decode(output_ids, skip_special_tokens=True)
61
+ preds = [pred.strip() for pred in preds]
62
+
63
+ # captions: List[Document] = []
64
+ # for caption, image_file_path in zip(preds, image_file_paths):
65
+ # document = Document(content=caption, meta={"image_path": image_file_path})
66
+ # captions.append(document)
67
+ return {"caption": preds[0]}