File size: 5,723 Bytes
af758d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import os
from diffusers.utils import export_to_video
import torchvision
from torchvision.io import read_video
import numpy as np
import imageio
from einops import rearrange
import math
import cv2

def save_video(test_video_out, outdir, name='sample_grid', fps=8):
    test_video_out = reshape_video_grid(test_video_out)
    test_video_out = test_video_out.numpy()
    test_video_out = (test_video_out.transpose(0,2,3,1) * 255).astype(np.uint8)
    imageio.mimwrite(os.path.join(outdir, f'{name}.mp4'), test_video_out, fps=fps)

def wave_func(values, wave_pos, wave_length=1.0):
    """Cosine-squared falloff within wave band, zero outside."""
    dist = (values - wave_pos) / wave_length
    mask = np.abs(dist) <= 1.0
    wave = np.zeros_like(values, dtype=np.float32)
    wave[mask] = np.cos(dist[mask] * np.pi / 2.0) ** 2
    return wave

def generate_wave_video(image_tensor: torch.Tensor,
                        depth_tensor: torch.Tensor,
                        batch_idx: int = 0,
                        frame_idx: int = 0,
                        n_frames: int = 24,
                        wave_length: float = 1.0,
                        wave_color=(255, 255, 255),
                        wave_color_front = [255, 230, 200],
                        wave_color_back = [200, 220, 255],
                        use_gradient_color: bool = True,
                        pre_frames: int = 24) -> torch.Tensor:
    """
    Generates a wave propagation video and returns it as a torch.Tensor
    in shape [T, 3, H, W], range [0.0, 1.0].
    """
    assert image_tensor.ndim == 5 and image_tensor.shape[2] == 3
    assert depth_tensor.ndim == 5 and depth_tensor.shape[2] == 1

    image = image_tensor[batch_idx, frame_idx].detach().cpu().numpy()  # (3, H, W)
    depth = depth_tensor[batch_idx, frame_idx, 0].detach().cpu().numpy()  # (H, W)

    image = np.transpose(image, (1, 2, 0)).astype(np.float32) * 255.0  # (H, W, 3)
    depth = depth.astype(np.float32)

    assert image.shape[:2] == depth.shape

    min_depth, max_depth = depth.min(), depth.max()
    if max_depth - min_depth < 1e-5:
        max_depth = min_depth + 1.0

    if use_gradient_color:
        wave_color_front = np.array(wave_color_front, dtype=np.float32)   # Warm white
        wave_color_back  = np.array(wave_color_back, dtype=np.float32)   # Cool metallic blue
        depth_norm = (depth - min_depth) / (max_depth - min_depth)
        wave_color_map = (1. - depth_norm[..., None]) * wave_color_front + depth_norm[..., None] * wave_color_back
    else:
        wave_color_map = np.array(wave_color, dtype=np.float32).reshape(1, 1, 3)

    frames_np = []

    # Pre-video: hold initial frame
    initial_frame = np.clip(image, 0, 255).astype(np.uint8)
    frames_np.extend([initial_frame] * pre_frames)

    # Wave animation
    for i in range(n_frames + 1):
        ratio = i / n_frames
        curr_depth = (max_depth - min_depth) * ratio + min_depth

        wave = wave_func(depth, curr_depth, wave_length)[..., None]
        wave = np.clip(wave, 0.0, 1.0)

        frame = image * (1.0 - wave) + wave * wave_color_map
        frame = np.clip(frame, 0, 255).astype(np.uint8)
        frames_np.append(frame)

    # Convert frames to torch.Tensor in [0,1], shape: (T, 3, H, W)
    frames_np = np.stack(frames_np, axis=0).astype(np.float32) / 255.0  # (T, H, W, 3)
    frames_np = np.transpose(frames_np, (0, 3, 1, 2))                  # (T, 3, H, W)
    frames_tensor = torch.from_numpy(frames_np)
    frames_tensor = frames_tensor[None]
    return frames_tensor

def create_depth_visu(x, cmap='jet', data_range=None, out_float=True, min_max_perc=[0.01, 0.99]): #min_max_perc=[0., 1.]):
    B, T, C, H, W = x.shape
    dtype = x.dtype
    device = x.device
    if data_range is None:
        x_flat = x.view(x.shape[0], -1)
        x_flat = x_flat.cpu().numpy()
        x_min = np.percentile(x_flat, min_max_perc[0]*100) #x_flat.amin(1).view(-1, 1, 1, 1, 1)
        x_max = np.percentile(x_flat, min_max_perc[1]*100) #x_flat.amax(1).view(-1, 1, 1, 1, 1)
        x = x.clip(x_min, x_max)
    else:
        x_min, x_max = data_range
    x = (x - x_min) / (x_max - x_min)
    x = rearrange(x, 'b t c h w -> (b t) h w c')
    x_np = x.cpu().numpy()
    x_np = (x_np * 255.0).astype(np.uint8)
    if cmap == "jet":
        color_map = cv2.COLORMAP_JET
    elif cmap == "inferno":
        color_map = cv2.COLORMAP_INFERNO
    x_np = [cv2.applyColorMap(x_np_i, color_map) for x_np_i in x_np]
    x = torch.from_numpy(np.array(x_np))
    x = rearrange(x, '(b t) h w c -> b t c h w', b=B)
    x = x.to(device=device, dtype=dtype)
    if out_float:
        x = x/255
    return x

def reshape_video_grid(video_tensor):
    b, t, c, h, w = video_tensor.shape
    N1 = N2 = int(math.sqrt(b))
    if N1 * N2 != b:
        N1 = 1
        N2 = b
    assert N1 * N2 == b, "Batch size must be a perfect square"

    # Rearrange using einops
    grid_video = rearrange(video_tensor, "(N1 N2) t c h w -> t c (N1 h) (N2 w)", N1=N1, N2=N2)

    return grid_video