Spaces:
Paused
Paused
| """ | |
| Based upon ImageCaptionLoader in LangChain version: langchain/document_loaders/image_captions.py | |
| But accepts preloaded model to avoid slowness in use and CUDA forking issues | |
| Loader that uses Pix2Struct models to image caption | |
| """ | |
| from typing import List, Union, Any, Tuple | |
| from langchain.docstore.document import Document | |
| from langchain.document_loaders import ImageCaptionLoader | |
| from utils import get_device, clear_torch_cache | |
| from PIL import Image | |
| class H2OPix2StructLoader(ImageCaptionLoader): | |
| """Loader that extracts text from images""" | |
| def __init__(self, path_images: Union[str, List[str]] = None, model_type="google/pix2struct-textcaps-base", | |
| max_new_tokens=50): | |
| super().__init__(path_images) | |
| self._pix2struct_model = None | |
| self._model_type = model_type | |
| self._max_new_tokens = max_new_tokens | |
| def set_context(self): | |
| if get_device() == 'cuda': | |
| import torch | |
| n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 | |
| if n_gpus > 0: | |
| self.context_class = torch.device | |
| self.device = 'cuda' | |
| else: | |
| self.device = 'cpu' | |
| else: | |
| self.device = 'cpu' | |
| def load_model(self): | |
| try: | |
| from transformers import AutoProcessor, Pix2StructForConditionalGeneration | |
| except ImportError: | |
| raise ValueError( | |
| "`transformers` package not found, please install with " | |
| "`pip install transformers`." | |
| ) | |
| if self._pix2struct_model: | |
| self._pix2struct_model = self._pix2struct_model.to(self.device) | |
| return self | |
| self.set_context() | |
| self._pix2struct_processor = AutoProcessor.from_pretrained(self._model_type) | |
| self._pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained(self._model_type).to(self.device) | |
| return self | |
| def unload_model(self): | |
| if hasattr(self._pix2struct_model, 'cpu'): | |
| self._pix2struct_model.cpu() | |
| clear_torch_cache() | |
| def set_image_paths(self, path_images: Union[str, List[str]]): | |
| """ | |
| Load from a list of image files | |
| """ | |
| if isinstance(path_images, str): | |
| self.image_paths = [path_images] | |
| else: | |
| self.image_paths = path_images | |
| def load(self, prompt=None) -> List[Document]: | |
| if self._pix2struct_model is None: | |
| self.load_model() | |
| results = [] | |
| for path_image in self.image_paths: | |
| caption, metadata = self._get_captions_and_metadata( | |
| processor=self._pix2struct_processor, model=self._pix2struct_model, path_image=path_image | |
| ) | |
| doc = Document(page_content=caption, metadata=metadata) | |
| results.append(doc) | |
| return results | |
| def _get_captions_and_metadata( | |
| self, processor: Any, model: Any, path_image: str) -> Tuple[str, dict]: | |
| """ | |
| Helper function for getting the captions and metadata of an image | |
| """ | |
| try: | |
| image = Image.open(path_image) | |
| except Exception: | |
| raise ValueError(f"Could not get image data for {path_image}") | |
| inputs = self._pix2struct_processor(images=image, return_tensors="pt") | |
| inputs = inputs.to(self.device) | |
| generated_ids = self._pix2struct_model.generate(**inputs, max_new_tokens=self._max_new_tokens) | |
| generated_text = self._pix2struct_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| metadata: dict = {"image_path": path_image} | |
| return generated_text, metadata | |