import os import matplotlib import matplotlib.font_manager as font_manager import matplotlib.pyplot as plt import numpy as np import torch from matplotlib.patches import Ellipse from numpy.linalg import norm from PIL import Image from scipy.ndimage import sobel FONT_PATH = "assets/fonts/linux_libertine/LinLibertine_R.ttf" # Make font loading optional for deployment environments try: font_manager.fontManager.addfont(FONT_PATH) FONT_PROP = font_manager.FontProperties(fname=FONT_PATH).get_name() plt.rcParams["font.family"] = FONT_PROP plt.rcParams["text.usetex"] = True except (FileNotFoundError, OSError): # Use default font if custom font is not available FONT_PROP = "DejaVu Sans" plt.rcParams["font.family"] = FONT_PROP plt.rcParams["text.usetex"] = False # Disable LaTeX if custom font unavailable matplotlib.rcParams["font.size"] = 16 matplotlib.rcParams["axes.titlesize"] = 16 matplotlib.rcParams["figure.titlesize"] = 16 matplotlib.rcParams["legend.fontsize"] = 16 matplotlib.rcParams["legend.title_fontsize"] = 16 matplotlib.rcParams["xtick.labelsize"] = 14 matplotlib.rcParams["ytick.labelsize"] = 14 ALLOWED_IMAGE_FILE_FORMATS = [".jpeg", ".jpg", ".png"] ALLOWED_IMAGE_TYPES = {"RGB": 3, "RGBA": 3, "L": 1} PLOT_DPI = 72.0 GAUSSIAN_ZOOM = 5 GAUSSIAN_COLOR = "#80ed99" def get_psnr(image1, image2, max_value=1.0): mse = torch.mean((image1 - image2) ** 2) if mse.item() <= 1e-7: return float("inf") psnr = 20 * torch.log10(max_value / torch.sqrt(mse)) return psnr def get_grid(h, w, x_lim=np.asarray([0, 1]), y_lim=np.asarray([0, 1])): x = torch.linspace(x_lim[0], x_lim[1], steps=w + 1)[:-1] + 0.5 / w y = torch.linspace(y_lim[0], y_lim[1], steps=h + 1)[:-1] + 0.5 / h grid_x, grid_y = torch.meshgrid(x, y, indexing="xy") grid = torch.stack([grid_x, grid_y], dim=-1) return grid def compute_image_gradients(image): gy, gx = [], [] for image_channel in image: gy.append(sobel(image_channel, 0)) gx.append(sobel(image_channel, 1)) gy = norm(np.stack(gy, axis=0), ord=2, axis=0).astype(np.float32) gx = norm(np.stack(gx, axis=0), ord=2, axis=0).astype(np.float32) return gy, gx def load_images(load_path, downsample_ratio=None, gamma=None): """ Load target images or textures from a directory or a single file. """ image_list = [] image_path_list = [] image_fname_list = [] num_channels_list = [] if ( os.path.isfile(load_path) and os.path.splitext(load_path)[1].lower() in ALLOWED_IMAGE_FILE_FORMATS ): image_path_list.append(load_path) elif os.path.isdir(load_path): for file in sorted(os.listdir(load_path), key=str.lower): if os.path.splitext(file)[1].lower() in ALLOWED_IMAGE_FILE_FORMATS: image_path_list.append(os.path.join(load_path, file)) if len(image_path_list) == 0: raise FileNotFoundError(f"No supported image file found at '{load_path}'") for image_path in image_path_list: image_fname_list.append(os.path.splitext(os.path.basename(image_path))[0]) image = Image.open(image_path) # Warning: Only support images of type L, RGB, or RGBA in JPEG or PNG format if image.mode not in ALLOWED_IMAGE_TYPES: raise TypeError( f"Only support images of type {list(ALLOWED_IMAGE_TYPES.keys())} in JPEG or PNG format" ) num_channels = ALLOWED_IMAGE_TYPES[image.mode] num_channels_list.append(num_channels) if downsample_ratio is not None: image = image.resize( ( round(image.width / downsample_ratio), round(image.height / downsample_ratio), ), resample=Image.Resampling.BILINEAR, ) # Warning: Assume 8 bit color depth image = np.asarray(image, dtype=np.float32) / 255.0 if gamma is not None: image = np.power(image, gamma) if len(image.shape) == 2: image = np.expand_dims(image, axis=2) image = image.transpose(2, 0, 1) image = image[:num_channels] image_list.append(image) return np.concatenate(image_list, axis=0), num_channels_list, image_fname_list def to_output_format(image, gamma): if len(image.shape) not in [2, 3]: raise ValueError(f"Wrong image format: shape = {image.shape}") if isinstance(image, torch.Tensor): image = image.detach().cpu().clone().numpy() if len(image.shape) == 3 and image.shape[2] not in [1, 3]: image = image.transpose(1, 2, 0) if image.shape[2] not in [1, 3]: raise ValueError(f"Wrong image format: shape = {image.shape}") if len(image.shape) == 3 and image.shape[2] == 1: image = image.squeeze(axis=2) image = np.clip(image, 0.0, 1.0) if gamma is not None: image = np.power(image, 1.0 / gamma) image = (255.0 * image).astype(np.uint8) return image def save_image(image, save_path, gamma=None, zoom=None): image = to_output_format(image, gamma) image = Image.fromarray(image) if zoom is not None and zoom > 0.0: width, height = image.size image = image.resize( (round(width * zoom), round(height * zoom)), resample=Image.Resampling.BOX ) image.save(save_path) def separate_image_channels(images, input_channels): if len(images) != sum(input_channels): raise ValueError( f"Incompatible number of channels: {len(images):d} vs {sum(input_channels):d}" ) image_list = [] curr_channel = 0 for num_channels in input_channels: image_list.append(images[curr_channel : curr_channel + num_channels]) curr_channel += num_channels return image_list def visualize_gaussians( filepath, xy, scale, rot, feat, img_h, img_w, input_channels, alpha=0.8, gamma=None ): """ Visualize Gaussians as colored elliptical disks. """ if feat.shape[1] != sum(input_channels): raise ValueError( f"Incompatible number of channels: {feat.shape[1]:d} vs {sum(input_channels):d}" ) xy = xy.detach().cpu().clone().numpy() y, x = xy[:, 1] * img_h, xy[:, 0] * img_w scale = GAUSSIAN_ZOOM * scale.detach().cpu().clone().numpy() rot = rot.detach().cpu().clone().numpy() if gamma is not None: feat = torch.pow(feat, 1.0 / gamma) feat = np.clip(feat.detach().cpu().clone().numpy(), 0.0, 1.0) curr_channel = 0 for image_id, num_channels in enumerate(input_channels, 1): curr_feat = feat[:, curr_channel : curr_channel + num_channels] fig = plt.figure() fig.set_dpi(PLOT_DPI) fig.set_size_inches(w=img_w / PLOT_DPI, h=img_h / PLOT_DPI, forward=False) ax = plt.gca() for gid in range(len(xy)): ellipse = Ellipse( xy=(x[gid], y[gid]), width=scale[gid, 0], height=scale[gid, 1], angle=rot[gid, 0] * 180 / np.pi, alpha=alpha, ec=None, fc=curr_feat[gid], lw=None, ) ax.add_patch(ellipse) plt.xlim(0, img_w) plt.ylim(img_h, 0) plt.axis("off") plt.tight_layout() suffix = "" if len(input_channels) == 1 else f"_{image_id:d}" plt.savefig( f"{filepath}{suffix}.png", bbox_inches="tight", pad_inches=0, dpi=PLOT_DPI ) plt.close() curr_channel += num_channels def visualize_added_gaussians( filepath, images, old_xy, new_xy, input_channels, size=500, every_n=5, alpha=0.8, gamma=None, ): """ Visualize the positions of added Gaussians during error-guided progressive optimization. """ if len(images) != sum(input_channels): raise ValueError( f"Incompatible number of channels: {len(images):d} vs {sum(input_channels):d}" ) image_height, image_width = images.shape[1:] old_xy = old_xy.detach().cpu().clone().numpy()[::every_n] new_xy = new_xy.detach().cpu().clone().numpy()[::every_n] old_x, old_y = old_xy[:, 0] * image_width, old_xy[:, 1] * image_height new_x, new_y = new_xy[:, 0] * image_width, new_xy[:, 1] * image_height curr_channel = 0 for image_id, num_channels in enumerate(input_channels, 1): image = images[curr_channel : curr_channel + num_channels] image = to_output_format(image, gamma) fig = plt.figure() fig.set_dpi(PLOT_DPI) fig.set_size_inches( w=image_width / PLOT_DPI, h=image_height / PLOT_DPI, forward=False ) plt.imshow(Image.fromarray(image), cmap="gray", vmin=0, vmax=255) plt.scatter(old_x, old_y, s=size, c="#ef476f", marker="o", alpha=alpha) # red plt.scatter(new_x, new_y, s=size, c="#06d6a0", marker="o", alpha=alpha) # green plt.xlim(0, image_width) plt.ylim(image_height, 0) plt.axis("off") plt.tight_layout() suffix = "" if len(input_channels) == 1 else f"_{image_id:d}" plt.savefig( f"{filepath}{suffix}.png", bbox_inches="tight", pad_inches=0, dpi=PLOT_DPI ) plt.close() curr_channel += num_channels