| | """ |
| | Copyright (c) 2023, salesforce.com, inc. |
| | All rights reserved. |
| | SPDX-License-Identifier: BSD-3-Clause |
| | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
| | """ |
| | import numpy as np |
| | import PIL |
| | import torch |
| | from diffusers.utils.pil_utils import PIL_INTERPOLATION |
| | from PIL import Image |
| |
|
| | from lavis.common.annotator.canny import CannyDetector |
| | from lavis.common.annotator.util import HWC3, resize_image |
| |
|
| | apply_canny = CannyDetector() |
| |
|
| |
|
| | def numpy_to_pil(images): |
| | """ |
| | Convert a numpy image or a batch of images to a PIL image. |
| | """ |
| | if images.ndim == 3: |
| | images = images[None, ...] |
| | images = (images * 255).round().astype("uint8") |
| | pil_images = [Image.fromarray(image) for image in images] |
| |
|
| | return pil_images |
| |
|
| |
|
| | def preprocess_canny( |
| | input_image: np.ndarray, |
| | image_resolution: int, |
| | low_threshold: int, |
| | high_threshold: int, |
| | ): |
| | image = resize_image(HWC3(input_image), image_resolution) |
| | control_image = apply_canny(image, low_threshold, high_threshold) |
| | control_image = HWC3(control_image) |
| | |
| | |
| | |
| | return PIL.Image.fromarray(control_image) |
| |
|
| |
|
| | def generate_canny(cond_image_input, low_threshold, high_threshold): |
| | |
| | cond_image_input = np.array(cond_image_input).astype(np.uint8) |
| |
|
| | |
| | vis_control_image = preprocess_canny(cond_image_input, 512, low_threshold=low_threshold, high_threshold=high_threshold) |
| |
|
| | return vis_control_image |
| |
|
| |
|
| | def prepare_cond_image( |
| | image, width, height, batch_size, device, do_classifier_free_guidance=True |
| | ): |
| | if not isinstance(image, torch.Tensor): |
| | if isinstance(image, Image.Image): |
| | image = [image] |
| |
|
| | if isinstance(image[0], Image.Image): |
| | images = [] |
| |
|
| | for image_ in image: |
| | image_ = image_.convert("RGB") |
| | image_ = image_.resize( |
| | (width, height), resample=PIL_INTERPOLATION["lanczos"] |
| | ) |
| | image_ = np.array(image_) |
| | image_ = image_[None, :] |
| | images.append(image_) |
| |
|
| | image = images |
| |
|
| | image = np.concatenate(image, axis=0) |
| | image = np.array(image).astype(np.float32) / 255.0 |
| | image = image.transpose(0, 3, 1, 2) |
| | image = torch.from_numpy(image) |
| | elif isinstance(image[0], torch.Tensor): |
| | image = torch.cat(image, dim=0) |
| |
|
| | image_batch_size = image.shape[0] |
| |
|
| | if image_batch_size == 1: |
| | repeat_by = batch_size |
| | else: |
| | |
| | |
| | raise NotImplementedError |
| |
|
| | image = image.repeat_interleave(repeat_by, dim=0) |
| |
|
| | |
| | image = image.to(device=device) |
| |
|
| | if do_classifier_free_guidance: |
| | image = torch.cat([image] * 2) |
| |
|
| | return image |
| |
|