|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import os |
|
|
from diffusers.utils import export_to_video |
|
|
import torchvision |
|
|
from torchvision.io import read_video |
|
|
import numpy as np |
|
|
import imageio |
|
|
from einops import rearrange |
|
|
import math |
|
|
import cv2 |
|
|
|
|
|
def save_video(test_video_out, outdir, name='sample_grid', fps=8): |
|
|
test_video_out = reshape_video_grid(test_video_out) |
|
|
test_video_out = test_video_out.numpy() |
|
|
test_video_out = (test_video_out.transpose(0,2,3,1) * 255).astype(np.uint8) |
|
|
imageio.mimwrite(os.path.join(outdir, f'{name}.mp4'), test_video_out, fps=fps) |
|
|
|
|
|
def wave_func(values, wave_pos, wave_length=1.0): |
|
|
"""Cosine-squared falloff within wave band, zero outside.""" |
|
|
dist = (values - wave_pos) / wave_length |
|
|
mask = np.abs(dist) <= 1.0 |
|
|
wave = np.zeros_like(values, dtype=np.float32) |
|
|
wave[mask] = np.cos(dist[mask] * np.pi / 2.0) ** 2 |
|
|
return wave |
|
|
|
|
|
def generate_wave_video(image_tensor: torch.Tensor, |
|
|
depth_tensor: torch.Tensor, |
|
|
batch_idx: int = 0, |
|
|
frame_idx: int = 0, |
|
|
n_frames: int = 24, |
|
|
wave_length: float = 1.0, |
|
|
wave_color=(255, 255, 255), |
|
|
wave_color_front = [255, 230, 200], |
|
|
wave_color_back = [200, 220, 255], |
|
|
use_gradient_color: bool = True, |
|
|
pre_frames: int = 24) -> torch.Tensor: |
|
|
""" |
|
|
Generates a wave propagation video and returns it as a torch.Tensor |
|
|
in shape [T, 3, H, W], range [0.0, 1.0]. |
|
|
""" |
|
|
assert image_tensor.ndim == 5 and image_tensor.shape[2] == 3 |
|
|
assert depth_tensor.ndim == 5 and depth_tensor.shape[2] == 1 |
|
|
|
|
|
image = image_tensor[batch_idx, frame_idx].detach().cpu().numpy() |
|
|
depth = depth_tensor[batch_idx, frame_idx, 0].detach().cpu().numpy() |
|
|
|
|
|
image = np.transpose(image, (1, 2, 0)).astype(np.float32) * 255.0 |
|
|
depth = depth.astype(np.float32) |
|
|
|
|
|
assert image.shape[:2] == depth.shape |
|
|
|
|
|
min_depth, max_depth = depth.min(), depth.max() |
|
|
if max_depth - min_depth < 1e-5: |
|
|
max_depth = min_depth + 1.0 |
|
|
|
|
|
if use_gradient_color: |
|
|
wave_color_front = np.array(wave_color_front, dtype=np.float32) |
|
|
wave_color_back = np.array(wave_color_back, dtype=np.float32) |
|
|
depth_norm = (depth - min_depth) / (max_depth - min_depth) |
|
|
wave_color_map = (1. - depth_norm[..., None]) * wave_color_front + depth_norm[..., None] * wave_color_back |
|
|
else: |
|
|
wave_color_map = np.array(wave_color, dtype=np.float32).reshape(1, 1, 3) |
|
|
|
|
|
frames_np = [] |
|
|
|
|
|
|
|
|
initial_frame = np.clip(image, 0, 255).astype(np.uint8) |
|
|
frames_np.extend([initial_frame] * pre_frames) |
|
|
|
|
|
|
|
|
for i in range(n_frames + 1): |
|
|
ratio = i / n_frames |
|
|
curr_depth = (max_depth - min_depth) * ratio + min_depth |
|
|
|
|
|
wave = wave_func(depth, curr_depth, wave_length)[..., None] |
|
|
wave = np.clip(wave, 0.0, 1.0) |
|
|
|
|
|
frame = image * (1.0 - wave) + wave * wave_color_map |
|
|
frame = np.clip(frame, 0, 255).astype(np.uint8) |
|
|
frames_np.append(frame) |
|
|
|
|
|
|
|
|
frames_np = np.stack(frames_np, axis=0).astype(np.float32) / 255.0 |
|
|
frames_np = np.transpose(frames_np, (0, 3, 1, 2)) |
|
|
frames_tensor = torch.from_numpy(frames_np) |
|
|
frames_tensor = frames_tensor[None] |
|
|
return frames_tensor |
|
|
|
|
|
def create_depth_visu(x, cmap='jet', data_range=None, out_float=True, min_max_perc=[0.01, 0.99]): |
|
|
B, T, C, H, W = x.shape |
|
|
dtype = x.dtype |
|
|
device = x.device |
|
|
if data_range is None: |
|
|
x_flat = x.view(x.shape[0], -1) |
|
|
x_flat = x_flat.cpu().numpy() |
|
|
x_min = np.percentile(x_flat, min_max_perc[0]*100) |
|
|
x_max = np.percentile(x_flat, min_max_perc[1]*100) |
|
|
x = x.clip(x_min, x_max) |
|
|
else: |
|
|
x_min, x_max = data_range |
|
|
x = (x - x_min) / (x_max - x_min) |
|
|
x = rearrange(x, 'b t c h w -> (b t) h w c') |
|
|
x_np = x.cpu().numpy() |
|
|
x_np = (x_np * 255.0).astype(np.uint8) |
|
|
if cmap == "jet": |
|
|
color_map = cv2.COLORMAP_JET |
|
|
elif cmap == "inferno": |
|
|
color_map = cv2.COLORMAP_INFERNO |
|
|
x_np = [cv2.applyColorMap(x_np_i, color_map) for x_np_i in x_np] |
|
|
x = torch.from_numpy(np.array(x_np)) |
|
|
x = rearrange(x, '(b t) h w c -> b t c h w', b=B) |
|
|
x = x.to(device=device, dtype=dtype) |
|
|
if out_float: |
|
|
x = x/255 |
|
|
return x |
|
|
|
|
|
def reshape_video_grid(video_tensor): |
|
|
b, t, c, h, w = video_tensor.shape |
|
|
N1 = N2 = int(math.sqrt(b)) |
|
|
if N1 * N2 != b: |
|
|
N1 = 1 |
|
|
N2 = b |
|
|
assert N1 * N2 == b, "Batch size must be a perfect square" |
|
|
|
|
|
|
|
|
grid_video = rearrange(video_tensor, "(N1 N2) t c h w -> t c (N1 h) (N2 w)", N1=N1, N2=N2) |
|
|
|
|
|
return grid_video |