from typing import List import torch from PIL import Image def make_large(img: Image.Image, size=300): return img.resize((size, size), Image.NEAREST) def tensor_to_rgba_image(tensor: torch.Tensor) -> List[Image.Image]: """ Converts a tensor to RGBA PIL image(s). :param tensor: Tensor with values in [0, 1], shape (N, C, H, W) or (C, H, W) :return: RGBA PIL images. """ tensor = tensor.to('cpu') # move to cpu if tensor.ndim == 3: # (C, H, W) tensor = tensor.unsqueeze(0) # add batch dim images: List[Image.Image] = [] for img in tensor: # iterate over batch if img.shape[0] == 1: # grayscale → replicate RGB + full alpha rgb = img.expand(3, -1, -1) alpha = torch.ones(1, *img.shape[1:]) img = torch.cat((rgb, alpha), dim=0) elif img.shape[0] == 3: # RGB → add full alpha alpha = torch.ones(1, *img.shape[1:]) img = torch.cat((img, alpha), dim=0) elif img.shape[0] == 4: # already RGBA pass else: raise ValueError("Expected tensor with 1, 3, or 4 channels") img = (img * 255).byte().permute(1, 2, 0).cpu().numpy() # (H, W, 4) images.append(Image.fromarray(img, mode="RGBA")) return images def normalize_to_unit(images: torch.Tensor) -> torch.Tensor: """ Normalizes images from [-1, 1] to [0, 1] range. :param images: images to normalize :return: normalized images """ # [-1,1] -> [0,1] return ((images + 1) / 2).clamp(0, 1)