Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from typing import Union | |
| # import skvideo.io | |
| import imageio | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as tf | |
| from einops import rearrange, repeat | |
| from jaxtyping import Float, UInt8 | |
| from matplotlib.figure import Figure | |
| from PIL import Image | |
| from torch import Tensor | |
| from optgs.misc.io import CustomPath | |
| FloatImage = Union[ | |
| Float[Tensor, "height width"], | |
| Float[Tensor, "channel height width"], | |
| Float[Tensor, "batch channel height width"], | |
| ] | |
| def fig_to_image( | |
| fig: Figure, | |
| dpi: int = 100, | |
| device: torch.device = torch.device("cpu"), | |
| ) -> Float[Tensor, "3 height width"]: | |
| buffer = io.BytesIO() | |
| fig.savefig(buffer, format="raw", dpi=dpi) | |
| buffer.seek(0) | |
| data = np.frombuffer(buffer.getvalue(), dtype=np.uint8) | |
| h = int(fig.bbox.bounds[3]) | |
| w = int(fig.bbox.bounds[2]) | |
| data = rearrange(data, "(h w c) -> c h w", h=h, w=w, c=4) | |
| buffer.close() | |
| return (torch.tensor(data, device=device, dtype=torch.float32) / 255)[:3] | |
| def prep_image(image: FloatImage) -> UInt8[np.ndarray, "height width channel"]: | |
| # Handle batched images. | |
| if image.ndim == 4: | |
| image = rearrange(image, "b c h w -> c h (b w)") | |
| # Handle single-channel images. | |
| if image.ndim == 2: | |
| image = rearrange(image, "h w -> () h w") | |
| # Ensure that there are 3 or 4 channels. | |
| channel, _, _ = image.shape | |
| if channel == 1: | |
| image = repeat(image, "() h w -> c h w", c=3) | |
| assert image.shape[0] in (3, 4) | |
| # Round-half-up to match torchvision.utils.save_image (3DGS-LM's path). | |
| image = (image.detach().clip(min=0, max=1) * 255 + 0.5).clip(0, 255).type(torch.uint8) | |
| return rearrange(image, "c h w -> h w c").cpu().numpy() | |
| def save_image( | |
| image: FloatImage, | |
| path: Union[Path, str], | |
| ) -> None: | |
| """Save an image. Assumed to be in range 0-1.""" | |
| # Create the parent directory if it doesn't already exist. | |
| path = Path(path) | |
| path.parent.mkdir(exist_ok=True, parents=True) | |
| # Save the image. | |
| Image.fromarray(prep_image(image)).save(path) | |
| def load_image( | |
| path: Union[Path, str], | |
| ) -> Float[Tensor, "3 height width"]: | |
| return tf.ToTensor()(Image.open(path))[:3] | |
| # def save_video( | |
| # images: list[FloatImage], | |
| # path: Union[Path, str], | |
| # fps: None | int = None | |
| # ) -> None: | |
| # """Save an image. Assumed to be in range 0-1.""" | |
| # # Create the parent directory if it doesn't already exist. | |
| # path = Path(path) | |
| # path.parent.mkdir(exist_ok=True, parents=True) | |
| # # prepare frames as uint8 HxWx3 numpy arrays in range 0-255 | |
| # frames = [prep_image(img) for img in images] | |
| # outputdict = {'-pix_fmt': 'yuv420p', '-crf': '23', '-vf': 'setpts=1.*PTS'} | |
| # if fps is not None: | |
| # outputdict['-r'] = str(fps) | |
| # # pass a string filename | |
| # writer = skvideo.io.FFmpegWriter(str(path), outputdict=outputdict) | |
| # for frame in frames: | |
| # writer.writeFrame(frame) | |
| # writer.close() | |
| def save_video(images, path, fps=None, iterations=None): | |
| if len(images) < 3: | |
| return | |
| if iterations is not None: | |
| assert len(images) == len(iterations) | |
| path = CustomPath(path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| if fps is None: | |
| fps = 30 | |
| # ensure frames are uint8 | |
| frames = [ | |
| np.ascontiguousarray(prep_image(img).clip(0, 255).astype("uint8")) | |
| for img in images | |
| ] | |
| # write iteration number on each frame if given | |
| if iterations is not None: | |
| for i in range(len(frames)): | |
| frame = frames[i] | |
| iter_num = iterations[i] | |
| cv2.putText(frame, f"Iter {iter_num}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA) | |
| frames[i] = frame | |
| # TODO Naama: videos cannot be saved with odd dimensions | |
| with imageio.get_writer(str(path), fps=fps) as writer: | |
| for frame in frames: | |
| writer.append_data(frame) |