| """ |
| Copyright (c) 2022, salesforce.com, inc. |
| All rights reserved. |
| SPDX-License-Identifier: BSD-3-Clause |
| For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
| """ |
|
|
| import re |
|
|
| from .registry import registry |
| from .base_processor import BaseProcessor |
| from .randaugment import RandomAugment |
| from omegaconf import OmegaConf |
| from torchvision import transforms |
| from torchvision.transforms.functional import InterpolationMode |
|
|
|
|
| class BlipImageBaseProcessor(BaseProcessor): |
| def __init__(self, mean=None, std=None): |
| if mean is None: |
| mean = (0.48145466, 0.4578275, 0.40821073) |
| if std is None: |
| std = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
|
| segment_mean = (0.485, 0.456, 0.406) |
| segment_std = (0.229, 0.224, 0.225) |
|
|
| self.normalize = transforms.Normalize(segment_mean, segment_std) |
|
|
|
|
| @registry.register_processor("blip_caption") |
| class BlipCaptionProcessor(BaseProcessor): |
| def __init__(self, prompt="", max_words=50): |
| self.prompt = prompt |
| self.max_words = max_words |
|
|
| def __call__(self, caption): |
| caption = self.prompt + self.pre_caption(caption) |
|
|
| return caption |
|
|
| @classmethod |
| def from_config(cls, cfg=None): |
| if cfg is None: |
| cfg = OmegaConf.create() |
|
|
| prompt = cfg.get("prompt", "") |
| max_words = cfg.get("max_words", 50) |
|
|
| return cls(prompt=prompt, max_words=max_words) |
|
|
| def pre_caption(self, caption): |
| caption = re.sub( |
| r"([.!\"()*#:;~])", |
| " ", |
| caption.lower(), |
| ) |
| caption = re.sub( |
| r"\s{2,}", |
| " ", |
| caption, |
| ) |
| caption = caption.rstrip("\n") |
| caption = caption.strip(" ") |
|
|
| |
| caption_words = caption.split(" ") |
| if len(caption_words) > self.max_words: |
| caption = " ".join(caption_words[: self.max_words]) |
|
|
| return caption |
|
|
|
|
| @registry.register_processor("blip2_image_train") |
| class Blip2ImageTrainProcessor(BlipImageBaseProcessor): |
| def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): |
| super().__init__(mean=mean, std=std) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.transform = transforms.Compose([ |
| transforms.Resize( |
| (image_size, image_size), interpolation=InterpolationMode.BICUBIC |
| ), |
| transforms.ToTensor(), |
| self.normalize, |
| ] |
| ) |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| def __call__(self, item): |
| return self.transform(item) |
|
|
| @classmethod |
| def from_config(cls, cfg=None): |
| if cfg is None: |
| cfg = OmegaConf.create() |
|
|
| image_size = cfg.get("image_size", 224) |
|
|
| mean = cfg.get("mean", None) |
| std = cfg.get("std", None) |
|
|
| min_scale = cfg.get("min_scale", 0.5) |
| max_scale = cfg.get("max_scale", 1.0) |
|
|
| return cls( |
| image_size=image_size, |
| mean=mean, |
| std=std, |
| min_scale=min_scale, |
| max_scale=max_scale, |
| ) |
|
|
|
|
| @registry.register_processor("blip2_image_eval") |
| class Blip2ImageEvalProcessor(BlipImageBaseProcessor): |
| def __init__(self, image_size=224, mean=None, std=None): |
| super().__init__(mean=mean, std=std) |
|
|
| self.transform = transforms.Compose( |
| [ |
| transforms.Resize( |
| (image_size, image_size), interpolation=InterpolationMode.BICUBIC |
| ), |
| transforms.ToTensor(), |
| self.normalize, |
| ] |
| ) |
|
|
| def __call__(self, item): |
| return self.transform(item) |
|
|
| @classmethod |
| def from_config(cls, cfg=None): |
| if cfg is None: |
| cfg = OmegaConf.create() |
|
|
| image_size = cfg.get("image_size", 224) |
|
|
| mean = cfg.get("mean", None) |
| std = cfg.get("std", None) |
|
|
| return cls(image_size=image_size, mean=mean, std=std) |