VibeToken / utils /viz_utils.py
APGASU's picture
scripts
7bef20f verified
"""Utils functions for visualization."""
import torch
import torchvision.transforms.functional as F
from einops import rearrange
from PIL import Image, ImageDraw, ImageFont
def make_viz_from_samples(
original_images,
reconstructed_images
):
"""Generates visualization images from original images and reconstructed images.
Args:
original_images: A torch.Tensor, original images.
reconstructed_images: A torch.Tensor, reconstructed images.
Returns:
A tuple containing two lists - images_for_saving and images_for_logging.
"""
reconstructed_images = torch.clamp(reconstructed_images, 0.0, 1.0)
reconstructed_images = reconstructed_images * 255.0
reconstructed_images = reconstructed_images.cpu()
original_images = torch.clamp(original_images, 0.0, 1.0)
original_images *= 255.0
original_images = original_images.cpu()
diff_img = torch.abs(original_images - reconstructed_images)
to_stack = [original_images, reconstructed_images, diff_img]
images_for_logging = rearrange(
torch.stack(to_stack),
"(l1 l2) b c h w -> b c (l1 h) (l2 w)",
l1=1).byte()
images_for_saving = [F.to_pil_image(image) for image in images_for_logging]
return images_for_saving, images_for_logging
def make_viz_from_samples_generation(
generated_images,
):
generated = torch.clamp(generated_images, 0.0, 1.0) * 255.0
images_for_logging = rearrange(
generated,
"(l1 l2) c h w -> c (l1 h) (l2 w)",
l1=2)
images_for_logging = images_for_logging.cpu().byte()
images_for_saving = F.to_pil_image(images_for_logging)
return images_for_saving, images_for_logging
def make_viz_from_samples_t2i_generation(
generated_images,
captions,
):
generated = torch.clamp(generated_images, 0.0, 1.0) * 255.0
images_for_logging = rearrange(
generated,
"(l1 l2) c h w -> c (l1 h) (l2 w)",
l1=2)
images_for_logging = images_for_logging.cpu().byte()
images_for_saving = F.to_pil_image(images_for_logging)
# Create a new image with space for captions
width, height = images_for_saving.size
caption_height = 20 * len(captions) + 10
new_height = height + caption_height
new_image = Image.new("RGB", (width, new_height), "black")
new_image.paste(images_for_saving, (0, 0))
# Adding captions below the image
draw = ImageDraw.Draw(new_image)
font = ImageFont.load_default()
for i, caption in enumerate(captions):
draw.text((10, height + 10 + i * 20), caption, fill="white", font=font)
return new_image, images_for_logging