| | """Utils functions for visualization.""" |
| |
|
| | import torch |
| | import torchvision.transforms.functional as F |
| | from einops import rearrange |
| | from PIL import Image, ImageDraw, ImageFont |
| |
|
| | def make_viz_from_samples( |
| | original_images, |
| | reconstructed_images |
| | ): |
| | """Generates visualization images from original images and reconstructed images. |
| | |
| | Args: |
| | original_images: A torch.Tensor, original images. |
| | reconstructed_images: A torch.Tensor, reconstructed images. |
| | |
| | Returns: |
| | A tuple containing two lists - images_for_saving and images_for_logging. |
| | """ |
| | reconstructed_images = torch.clamp(reconstructed_images, 0.0, 1.0) |
| | reconstructed_images = reconstructed_images * 255.0 |
| | reconstructed_images = reconstructed_images.cpu() |
| | |
| | original_images = torch.clamp(original_images, 0.0, 1.0) |
| | original_images *= 255.0 |
| | original_images = original_images.cpu() |
| |
|
| | diff_img = torch.abs(original_images - reconstructed_images) |
| | to_stack = [original_images, reconstructed_images, diff_img] |
| |
|
| | images_for_logging = rearrange( |
| | torch.stack(to_stack), |
| | "(l1 l2) b c h w -> b c (l1 h) (l2 w)", |
| | l1=1).byte() |
| | images_for_saving = [F.to_pil_image(image) for image in images_for_logging] |
| |
|
| | return images_for_saving, images_for_logging |
| |
|
| |
|
| | def make_viz_from_samples_generation( |
| | generated_images, |
| | ): |
| | generated = torch.clamp(generated_images, 0.0, 1.0) * 255.0 |
| | images_for_logging = rearrange( |
| | generated, |
| | "(l1 l2) c h w -> c (l1 h) (l2 w)", |
| | l1=2) |
| |
|
| | images_for_logging = images_for_logging.cpu().byte() |
| | images_for_saving = F.to_pil_image(images_for_logging) |
| |
|
| | return images_for_saving, images_for_logging |
| |
|
| |
|
| | def make_viz_from_samples_t2i_generation( |
| | generated_images, |
| | captions, |
| | ): |
| | generated = torch.clamp(generated_images, 0.0, 1.0) * 255.0 |
| | images_for_logging = rearrange( |
| | generated, |
| | "(l1 l2) c h w -> c (l1 h) (l2 w)", |
| | l1=2) |
| |
|
| | images_for_logging = images_for_logging.cpu().byte() |
| | images_for_saving = F.to_pil_image(images_for_logging) |
| |
|
| | |
| | width, height = images_for_saving.size |
| | caption_height = 20 * len(captions) + 10 |
| | new_height = height + caption_height |
| | new_image = Image.new("RGB", (width, new_height), "black") |
| | new_image.paste(images_for_saving, (0, 0)) |
| |
|
| | |
| | draw = ImageDraw.Draw(new_image) |
| | font = ImageFont.load_default() |
| |
|
| | for i, caption in enumerate(captions): |
| | draw.text((10, height + 10 + i * 20), caption, fill="white", font=font) |
| |
|
| | return new_image, images_for_logging |