Lyra / src /utils /visu.py
Muhammad Taqi Raza
adding lyra files
af758d1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import os
from diffusers.utils import export_to_video
import torchvision
from torchvision.io import read_video
import numpy as np
import imageio
from einops import rearrange
import math
import cv2
def save_video(test_video_out, outdir, name='sample_grid', fps=8):
test_video_out = reshape_video_grid(test_video_out)
test_video_out = test_video_out.numpy()
test_video_out = (test_video_out.transpose(0,2,3,1) * 255).astype(np.uint8)
imageio.mimwrite(os.path.join(outdir, f'{name}.mp4'), test_video_out, fps=fps)
def wave_func(values, wave_pos, wave_length=1.0):
"""Cosine-squared falloff within wave band, zero outside."""
dist = (values - wave_pos) / wave_length
mask = np.abs(dist) <= 1.0
wave = np.zeros_like(values, dtype=np.float32)
wave[mask] = np.cos(dist[mask] * np.pi / 2.0) ** 2
return wave
def generate_wave_video(image_tensor: torch.Tensor,
depth_tensor: torch.Tensor,
batch_idx: int = 0,
frame_idx: int = 0,
n_frames: int = 24,
wave_length: float = 1.0,
wave_color=(255, 255, 255),
wave_color_front = [255, 230, 200],
wave_color_back = [200, 220, 255],
use_gradient_color: bool = True,
pre_frames: int = 24) -> torch.Tensor:
"""
Generates a wave propagation video and returns it as a torch.Tensor
in shape [T, 3, H, W], range [0.0, 1.0].
"""
assert image_tensor.ndim == 5 and image_tensor.shape[2] == 3
assert depth_tensor.ndim == 5 and depth_tensor.shape[2] == 1
image = image_tensor[batch_idx, frame_idx].detach().cpu().numpy() # (3, H, W)
depth = depth_tensor[batch_idx, frame_idx, 0].detach().cpu().numpy() # (H, W)
image = np.transpose(image, (1, 2, 0)).astype(np.float32) * 255.0 # (H, W, 3)
depth = depth.astype(np.float32)
assert image.shape[:2] == depth.shape
min_depth, max_depth = depth.min(), depth.max()
if max_depth - min_depth < 1e-5:
max_depth = min_depth + 1.0
if use_gradient_color:
wave_color_front = np.array(wave_color_front, dtype=np.float32) # Warm white
wave_color_back = np.array(wave_color_back, dtype=np.float32) # Cool metallic blue
depth_norm = (depth - min_depth) / (max_depth - min_depth)
wave_color_map = (1. - depth_norm[..., None]) * wave_color_front + depth_norm[..., None] * wave_color_back
else:
wave_color_map = np.array(wave_color, dtype=np.float32).reshape(1, 1, 3)
frames_np = []
# Pre-video: hold initial frame
initial_frame = np.clip(image, 0, 255).astype(np.uint8)
frames_np.extend([initial_frame] * pre_frames)
# Wave animation
for i in range(n_frames + 1):
ratio = i / n_frames
curr_depth = (max_depth - min_depth) * ratio + min_depth
wave = wave_func(depth, curr_depth, wave_length)[..., None]
wave = np.clip(wave, 0.0, 1.0)
frame = image * (1.0 - wave) + wave * wave_color_map
frame = np.clip(frame, 0, 255).astype(np.uint8)
frames_np.append(frame)
# Convert frames to torch.Tensor in [0,1], shape: (T, 3, H, W)
frames_np = np.stack(frames_np, axis=0).astype(np.float32) / 255.0 # (T, H, W, 3)
frames_np = np.transpose(frames_np, (0, 3, 1, 2)) # (T, 3, H, W)
frames_tensor = torch.from_numpy(frames_np)
frames_tensor = frames_tensor[None]
return frames_tensor
def create_depth_visu(x, cmap='jet', data_range=None, out_float=True, min_max_perc=[0.01, 0.99]): #min_max_perc=[0., 1.]):
B, T, C, H, W = x.shape
dtype = x.dtype
device = x.device
if data_range is None:
x_flat = x.view(x.shape[0], -1)
x_flat = x_flat.cpu().numpy()
x_min = np.percentile(x_flat, min_max_perc[0]*100) #x_flat.amin(1).view(-1, 1, 1, 1, 1)
x_max = np.percentile(x_flat, min_max_perc[1]*100) #x_flat.amax(1).view(-1, 1, 1, 1, 1)
x = x.clip(x_min, x_max)
else:
x_min, x_max = data_range
x = (x - x_min) / (x_max - x_min)
x = rearrange(x, 'b t c h w -> (b t) h w c')
x_np = x.cpu().numpy()
x_np = (x_np * 255.0).astype(np.uint8)
if cmap == "jet":
color_map = cv2.COLORMAP_JET
elif cmap == "inferno":
color_map = cv2.COLORMAP_INFERNO
x_np = [cv2.applyColorMap(x_np_i, color_map) for x_np_i in x_np]
x = torch.from_numpy(np.array(x_np))
x = rearrange(x, '(b t) h w c -> b t c h w', b=B)
x = x.to(device=device, dtype=dtype)
if out_float:
x = x/255
return x
def reshape_video_grid(video_tensor):
b, t, c, h, w = video_tensor.shape
N1 = N2 = int(math.sqrt(b))
if N1 * N2 != b:
N1 = 1
N2 = b
assert N1 * N2 == b, "Batch size must be a perfect square"
# Rearrange using einops
grid_video = rearrange(video_tensor, "(N1 N2) t c h w -> t c (N1 h) (N2 w)", N1=N1, N2=N2)
return grid_video