| import os | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| try: | |
| from kornia.morphology import opening | |
| except ImportError: | |
| from kornia.morphology import open as opening | |
| from torchvision import transforms | |
| from torchvision.utils import make_grid, save_image | |
| from typing import Any | |
| def exist(val: Any) -> bool: | |
| return val is not None | |
| def morph_open(x: torch.Tensor, k: int) -> torch.Tensor: | |
| if k==0: | |
| return x | |
| else: | |
| with torch.no_grad(): | |
| return opening(x, torch.ones(k,k,device=x.device)) | |
| def make_grid_images(images: list[torch.Tensor], **kwargs) -> torch.Tensor: | |
| concatenated_images = torch.cat(images, dim=3) | |
| grid_concatenated = make_grid(concatenated_images, **kwargs) | |
| return grid_concatenated | |
| def save_images(images: tuple[torch.Tensor, torch.Tensor], path: str, **kwargs) -> None: | |
| gen, real = images | |
| concatenated_images = torch.cat((gen, real), dim=3) | |
| grid_concatenated = make_grid(concatenated_images, **kwargs) | |
| ndarr_concatenated = grid_concatenated.permute(1, 2, 0).to("cpu").numpy() | |
| ndarr_concatenated = (ndarr_concatenated * 255).astype(np.uint8) | |
| save_image(torch.from_numpy(ndarr_concatenated).permute(2, 0, 1) / 255, path) | |
| def save_triplet(images: tuple[torch.Tensor, ...], path: str, **kwargs) -> None: | |
| concatenated_images = torch.cat(images, dim=3) | |
| grid_concatenated = make_grid(concatenated_images, **kwargs) | |
| ndarr_concatenated = grid_concatenated.permute(1, 2, 0).to("cpu").numpy() | |
| ndarr_concatenated = (ndarr_concatenated * 255).astype(np.uint8) | |
| save_image(torch.from_numpy(ndarr_concatenated).permute(2, 0, 1) / 255, path) | |
| def plot_images(images: torch.Tensor) -> None: | |
| plt.figure(figsize=(32, 32)) | |
| plt.imshow(torch.cat([ | |
| torch.cat([i for i in images.cpu()], dim=-1), | |
| ], dim=-2).permute(1, 2, 0).cpu()) | |
| plt.show() | |
| def make_graphic(metric_name: str, metrics: list[torch.Tensor], path: str) -> None: | |
| plt.figure(figsize=(32, 32)) | |
| metrics = [m.cpu().numpy() for m in metrics] | |
| plt.plot(metrics) | |
| plt.title(metric_name) | |
| plt.xlabel("Epoch") | |
| plt.ylabel(metric_name) | |
| path = os.path.join(path, f"{metric_name}.png") | |
| plt.savefig(path) | |
| plt.close() | |
| def norm( | |
| img: torch.Tensor, | |
| mean: list[float] = [0.5, 0.5, 0.5], | |
| std: list[float] = [0.5, 0.5, 0.5] | |
| ) -> torch.Tensor: | |
| normalize = transforms.Normalize(mean, std) | |
| return normalize(img) | |
| def denorm( | |
| img: torch.Tensor, | |
| mean: list[float] = [0.5, 0.5, 0.5], | |
| std: list[float] = [0.5, 0.5, 0.5] | |
| ) -> torch.Tensor: | |
| mean = torch.tensor(mean, device=img.device) | |
| std = torch.tensor(std, device=img.device) | |
| return img*std[None][...,None,None] + mean[None][...,None,None] |