Spaces:
Paused
Paused
| """ | |
| Based upon ImageCaptionLoader in LangChain version: langchain/document_loaders/image_captions.py | |
| But accepts preloaded model to avoid slowness in use and CUDA forking issues | |
| Loader that loads image captions | |
| By default, the loader utilizes the pre-trained BLIP image captioning model. | |
| https://huggingface.co/Salesforce/blip-image-captioning-base | |
| """ | |
| from typing import List, Union, Any, Tuple | |
| import requests | |
| from langchain.docstore.document import Document | |
| from langchain.document_loaders import ImageCaptionLoader | |
| from utils import get_device, NullContext, clear_torch_cache | |
| from importlib.metadata import distribution, PackageNotFoundError | |
| try: | |
| assert distribution('bitsandbytes') is not None | |
| have_bitsandbytes = True | |
| except (PackageNotFoundError, AssertionError): | |
| have_bitsandbytes = False | |
| class H2OImageCaptionLoader(ImageCaptionLoader): | |
| """Loader that loads the captions of an image""" | |
| def __init__(self, path_images: Union[str, List[str]] = None, | |
| blip_processor: str = None, | |
| blip_model: str = None, | |
| caption_gpu=True, | |
| load_in_8bit=True, | |
| # True doesn't seem to work, even though https://huggingface.co/Salesforce/blip2-flan-t5-xxl#in-8-bit-precision-int8 | |
| load_half=False, | |
| load_gptq='', | |
| load_exllama=False, | |
| use_safetensors=False, | |
| revision=None, | |
| min_new_tokens=20, | |
| max_tokens=50): | |
| if blip_model is None or blip_model is None: | |
| blip_processor = "Salesforce/blip-image-captioning-base" | |
| blip_model = "Salesforce/blip-image-captioning-base" | |
| super().__init__(path_images, blip_processor, blip_model) | |
| self.blip_processor = blip_processor | |
| self.blip_model = blip_model | |
| self.processor = None | |
| self.model = None | |
| self.caption_gpu = caption_gpu | |
| self.context_class = NullContext | |
| self.device = 'cpu' | |
| self.load_in_8bit = load_in_8bit and have_bitsandbytes # only for blip2 | |
| self.load_half = load_half | |
| self.load_gptq = load_gptq | |
| self.load_exllama = load_exllama | |
| self.use_safetensors = use_safetensors | |
| self.revision = revision | |
| self.gpu_id = 'auto' | |
| # default prompt | |
| self.prompt = "image of" | |
| self.min_new_tokens = min_new_tokens | |
| self.max_tokens = max_tokens | |
| def set_context(self): | |
| if get_device() == 'cuda' and self.caption_gpu: | |
| import torch | |
| n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 | |
| if n_gpus > 0: | |
| self.context_class = torch.device | |
| self.device = 'cuda' | |
| else: | |
| self.device = 'cpu' | |
| else: | |
| self.device = 'cpu' | |
| def load_model(self): | |
| try: | |
| import transformers | |
| except ImportError: | |
| raise ValueError( | |
| "`transformers` package not found, please install with " | |
| "`pip install transformers`." | |
| ) | |
| self.set_context() | |
| if self.model: | |
| if not self.load_in_8bit and self.model.device != self.device: | |
| self.model.to(self.device) | |
| return self | |
| if self.caption_gpu: | |
| if self.gpu_id == 'auto': | |
| # blip2 has issues with multi-GPU. Error says need to somehow set language model in device map | |
| # device_map = 'auto' | |
| device_map = {"": 0} | |
| else: | |
| if self.device == 'cuda': | |
| device_map = {"": self.gpu_id} | |
| else: | |
| device_map = {"": 'cpu'} | |
| else: | |
| device_map = {"": 'cpu'} | |
| import torch | |
| with torch.no_grad(): | |
| with self.context_class(self.device): | |
| context_class_cast = NullContext if self.device == 'cpu' else torch.autocast | |
| with context_class_cast(self.device): | |
| if 'blip2' in self.blip_processor.lower(): | |
| from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
| if self.load_half and not self.load_in_8bit: | |
| self.processor = Blip2Processor.from_pretrained(self.blip_processor, | |
| device_map=device_map).half() | |
| self.model = Blip2ForConditionalGeneration.from_pretrained(self.blip_model, | |
| device_map=device_map).half() | |
| else: | |
| self.processor = Blip2Processor.from_pretrained(self.blip_processor, | |
| load_in_8bit=self.load_in_8bit, | |
| device_map=device_map, | |
| ) | |
| self.model = Blip2ForConditionalGeneration.from_pretrained(self.blip_model, | |
| load_in_8bit=self.load_in_8bit, | |
| device_map=device_map) | |
| else: | |
| from transformers import BlipForConditionalGeneration, BlipProcessor | |
| self.load_half = False # not supported | |
| if self.caption_gpu: | |
| if device_map == 'auto': | |
| # Blip doesn't support device_map='auto' | |
| if self.device == 'cuda': | |
| if self.gpu_id == 'auto': | |
| device_map = {"": 0} | |
| else: | |
| device_map = {"": self.gpu_id} | |
| else: | |
| device_map = {"": 'cpu'} | |
| else: | |
| device_map = {"": 'cpu'} | |
| self.processor = BlipProcessor.from_pretrained(self.blip_processor, device_map=device_map) | |
| self.model = BlipForConditionalGeneration.from_pretrained(self.blip_model, | |
| device_map=device_map) | |
| return self | |
| def set_image_paths(self, path_images: Union[str, List[str]]): | |
| """ | |
| Load from a list of image files | |
| """ | |
| if isinstance(path_images, str): | |
| self.image_paths = [path_images] | |
| else: | |
| self.image_paths = path_images | |
| def load(self, prompt=None) -> List[Document]: | |
| if self.processor is None or self.model is None: | |
| self.load_model() | |
| results = [] | |
| for path_image in self.image_paths: | |
| caption, metadata = self._get_captions_and_metadata( | |
| model=self.model, processor=self.processor, path_image=path_image, | |
| prompt=prompt, | |
| ) | |
| doc = Document(page_content=caption, metadata=metadata) | |
| results.append(doc) | |
| return results | |
| def unload_model(self): | |
| if hasattr(self, 'model') and hasattr(self.model, 'cpu'): | |
| self.model.cpu() | |
| clear_torch_cache() | |
| def _get_captions_and_metadata( | |
| self, model: Any, processor: Any, path_image: str, | |
| prompt=None) -> Tuple[str, dict]: | |
| """ | |
| Helper function for getting the captions and metadata of an image | |
| """ | |
| if prompt is None: | |
| prompt = self.prompt | |
| try: | |
| from PIL import Image | |
| except ImportError: | |
| raise ValueError( | |
| "`PIL` package not found, please install with `pip install pillow`" | |
| ) | |
| try: | |
| if path_image.startswith("http://") or path_image.startswith("https://"): | |
| image = Image.open(requests.get(path_image, stream=True).raw).convert( | |
| "RGB" | |
| ) | |
| else: | |
| image = Image.open(path_image).convert("RGB") | |
| except Exception: | |
| raise ValueError(f"Could not get image data for {path_image}") | |
| import torch | |
| with torch.no_grad(): | |
| with self.context_class(self.device): | |
| context_class_cast = NullContext if self.device == 'cpu' else torch.autocast | |
| with context_class_cast(self.device): | |
| if self.load_half: | |
| # FIXME: RuntimeError: "slow_conv2d_cpu" not implemented for 'Half' | |
| inputs = processor(image, prompt, return_tensors="pt") # .half() | |
| else: | |
| inputs = processor(image, prompt, return_tensors="pt") | |
| min_length = len(prompt) // 4 + self.min_new_tokens | |
| self.max_tokens = max(self.max_tokens, min_length) | |
| output = model.generate(**inputs, min_length=min_length, max_length=self.max_tokens) | |
| caption: str = processor.decode(output[0], skip_special_tokens=True) | |
| prompti = caption.find(prompt) | |
| if prompti >= 0: | |
| caption = caption[prompti + len(prompt):] | |
| metadata: dict = {"image_path": path_image} | |
| return caption, metadata | |