| | import torch |
| | import numpy as np |
| | from PIL import Image |
| | from typing import Union, List |
| |
|
| | |
| | def pil2tensor(image: Union[Image.Image, List[Image.Image]]) -> torch.Tensor: |
| | if isinstance(image, list): |
| | return torch.cat([pil2tensor(img) for img in image], dim=0) |
| |
|
| | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) |
| |
|
| |
|
| | def np2tensor(img_np: Union[np.ndarray, List[np.ndarray]]) -> torch.Tensor: |
| | if isinstance(img_np, list): |
| | return torch.cat([np2tensor(img) for img in img_np], dim=0) |
| |
|
| | return torch.from_numpy(img_np.astype(np.float32) / 255.0).unsqueeze(0) |
| |
|
| |
|
| | def tensor2np(tensor: torch.Tensor): |
| | if len(tensor.shape) == 3: |
| | return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8) |
| | else: |
| | return [np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor] |
| | |
| | def tensor2pil(image: torch.Tensor) -> List[Image.Image]: |
| | batch_count = image.size(0) if len(image.shape) > 3 else 1 |
| | if batch_count > 1: |
| | out = [] |
| | for i in range(batch_count): |
| | out.extend(tensor2pil(image[i])) |
| | return out |
| |
|
| | return [ |
| | Image.fromarray( |
| | np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) |
| | ) |
| | ] |