Spaces:
Running
Running
| from typing import List | |
| import torch | |
| from PIL import Image | |
| def make_large(img: Image.Image, size=300): | |
| return img.resize((size, size), Image.NEAREST) | |
| def tensor_to_rgba_image(tensor: torch.Tensor) -> List[Image.Image]: | |
| """ | |
| Converts a tensor to RGBA PIL image(s). | |
| :param tensor: Tensor with values in [0, 1], shape (N, C, H, W) or (C, H, W) | |
| :return: RGBA PIL images. | |
| """ | |
| tensor = tensor.to('cpu') # move to cpu | |
| if tensor.ndim == 3: # (C, H, W) | |
| tensor = tensor.unsqueeze(0) # add batch dim | |
| images: List[Image.Image] = [] | |
| for img in tensor: # iterate over batch | |
| if img.shape[0] == 1: # grayscale → replicate RGB + full alpha | |
| rgb = img.expand(3, -1, -1) | |
| alpha = torch.ones(1, *img.shape[1:]) | |
| img = torch.cat((rgb, alpha), dim=0) | |
| elif img.shape[0] == 3: # RGB → add full alpha | |
| alpha = torch.ones(1, *img.shape[1:]) | |
| img = torch.cat((img, alpha), dim=0) | |
| elif img.shape[0] == 4: # already RGBA | |
| pass | |
| else: | |
| raise ValueError("Expected tensor with 1, 3, or 4 channels") | |
| img = (img * 255).byte().permute(1, 2, 0).cpu().numpy() # (H, W, 4) | |
| images.append(Image.fromarray(img, mode="RGBA")) | |
| return images | |
| def normalize_to_unit(images: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Normalizes images from [-1, 1] to [0, 1] range. | |
| :param images: images to normalize | |
| :return: normalized images | |
| """ | |
| # [-1,1] -> [0,1] | |
| return ((images + 1) / 2).clamp(0, 1) |