| | """ |
| | Copyright (c) 2022, salesforce.com, inc. |
| | All rights reserved. |
| | SPDX-License-Identifier: BSD-3-Clause |
| | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
| | """ |
| |
|
| | from omegaconf import OmegaConf |
| | from torchvision import transforms |
| | from torchvision.transforms.functional import InterpolationMode |
| |
|
| | from lavis.common.registry import registry |
| | from lavis.processors.base_processor import BaseProcessor |
| | from lavis.processors.blip_processors import BlipImageBaseProcessor |
| |
|
| |
|
| | @registry.register_processor("blip_diffusion_inp_image_train") |
| | @registry.register_processor("blip_diffusion_inp_image_eval") |
| | class BlipDiffusionInputImageProcessor(BlipImageBaseProcessor): |
| | def __init__( |
| | self, |
| | image_size=224, |
| | mean=None, |
| | std=None, |
| | ): |
| | super().__init__(mean=mean, std=std) |
| |
|
| | self.transform = transforms.Compose( |
| | [ |
| | transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC), |
| | transforms.CenterCrop(image_size), |
| | transforms.ToTensor(), |
| | self.normalize, |
| | ] |
| | ) |
| |
|
| | def __call__(self, item): |
| | return self.transform(item) |
| |
|
| | @classmethod |
| | def from_config(cls, cfg=None): |
| | if cfg is None: |
| | cfg = OmegaConf.create() |
| |
|
| | image_size = cfg.get("image_size", 224) |
| |
|
| | mean = cfg.get("mean", None) |
| | std = cfg.get("std", None) |
| |
|
| | return cls(image_size=image_size, mean=mean, std=std) |
| |
|
| |
|
| | @registry.register_processor("blip_diffusion_tgt_image_train") |
| | class BlipDiffusionTargetImageProcessor(BaseProcessor): |
| | def __init__( |
| | self, |
| | image_size=512, |
| | ): |
| | super().__init__() |
| |
|
| | self.transform = transforms.Compose( |
| | [ |
| | transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC), |
| | transforms.CenterCrop(image_size), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.5], [0.5]), |
| | ] |
| | ) |
| |
|
| | def __call__(self, item): |
| | return self.transform(item) |
| |
|
| | @classmethod |
| | def from_config(cls, cfg=None): |
| | if cfg is None: |
| | cfg = OmegaConf.create() |
| |
|
| | image_size = cfg.get("image_size", 512) |
| |
|
| | return cls(image_size=image_size) |
| |
|