File size: 18,470 Bytes
af758d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Dict, List
from tqdm import tqdm
import os
import re
import torch
import torch.nn.functional as F
import numpy as np
from accelerate import PartialState
import einops
from omegaconf import OmegaConf
from accelerate.logging import get_logger

from src.models.recon.model_latent_recon import LatentRecon
from src.utils.visu import create_depth_visu, generate_wave_video, save_video
from src.models.data import get_multi_dataloader
from src.models.utils.model import encode_latent_time_vae, encode_plucker_vae
from src.models.utils.render import get_plucker_embedding_and_rays, save_ply, save_ply_orig
from src.models.utils.model import load_vae, encode_multi_view_video, encode_video, decode_multi_view_latents
from src.models.utils.data import write_dict_to_json
from src.models.utils.misc import dtype_map, seed_everything, load_and_merge_configs
from src.models.utils.train import get_most_recent_checkpoint

logger = get_logger(__name__, log_level="INFO")

def load_model(ckpt_path, config, weight_dtype):
    # Load model
    distributed_state = PartialState()
    device = distributed_state.device
    vae = load_vae(config.vae_backbone, config.vae_path)
    transformer = LatentRecon(
        config
    )

    # Load ckpt
    data = torch.load(ckpt_path)
    transformer.load_state_dict(data["module"])

    # Cast model
    transformer.to(device=device, dtype=weight_dtype)
    vae.to(device=device, dtype=weight_dtype)
    transformer.eval()
    vae.eval()
    return transformer, vae, distributed_state

def main(
    config,
    **kwargs
):
    # For dynamic scenes, loop over all target times
    target_index_manual = config.target_index_manual
    if target_index_manual is None and config.target_index_manual_start_idx is not None:
        target_index_manual = list(range(config.target_index_manual_start_idx, config.target_index_manual_start_idx + config.target_index_manual_num_idx, config.target_index_manual_stride))
    if target_index_manual is not None and not isinstance(target_index_manual, int):
        for target_index_manual_manual_i in target_index_manual:
            print(f"Bullet time {target_index_manual_manual_i}")
            config.target_index_manual = target_index_manual_manual_i
            transformer, vae, distributed_state, ckpt_path = main_single(config, **kwargs)
            kwargs['transformer'] = transformer
            kwargs['vae'] = vae
            kwargs['distributed_state'] = distributed_state
            kwargs['ckpt_path'] = ckpt_path
    else:
        main_single(config, **kwargs)

def main_single(
    config,
    seed: int = 0,
    transformer = None,
    vae = None,
    distributed_state = None,
    ckpt_path = None,
):
    weight_dtype = torch.bfloat16
    out_fps = config.out_fps
    g = torch.Generator()
    g.manual_seed(seed)
    seed_everything(seed)
    outdir = config.out_dir_inference

    # Either one config path is given or a list of them to merge them
    if isinstance(config.config_path, str):
        main_config = OmegaConf.load(config.config_path)
    else:
        main_config = load_and_merge_configs(config.config_path)

    # Get latest checkpoint if no checkpoint given (e.g., ckpt_name = 'checkpoint-15000')
    ckpt_name = None
    if ckpt_path is None:
        if config.ckpt_path is None:
            ckpt_model_sub_path = 'pytorch_model/mp_rank_00_model_states.pt'
            ckpts_path = main_config.output_dir
            ckpt_name = config.ckpt_name
            if ckpt_name is None:
                ckpt_name = get_most_recent_checkpoint(ckpts_path)
            ckpt_path = os.path.join(ckpts_path, ckpt_name, ckpt_model_sub_path)
            
        else:
            ckpt_path = config.ckpt_path
    if ckpt_name is None:
        has_ckpt_name = re.search(r"(checkpoint-\d+)", ckpt_path)
        if has_ckpt_name:
            ckpt_name = has_ckpt_name.group(1)
    if ckpt_name is not None:
        outdir = os.path.join(outdir, ckpt_name)
    if os.path.isfile(ckpt_path):
        print(f"Found ckpt at path {ckpt_path}")
    else:
        raise ValueError(f"Could not find ckpt at path {ckpt_path}")
    
    # For dynamic scenes, render all camera viewpoints not only the one from the bullet time
    if config.set_manual_time_idx:
        main_config.set_manual_time_idx = config.set_manual_time_idx
    
    # Set view indices
    if config.static_view_indices_fixed is not None:
        main_config.static_view_indices_fixed = config.static_view_indices_fixed
        outdir = os.path.join(outdir, f"static_view_indices_fixed_{'_'.join(config.static_view_indices_fixed)}")
        main_config.static_view_indices_sampling = 'fixed'
        main_config.num_input_multi_views = len(config.static_view_indices_fixed)
    
    # Subsample the output views
    if config.target_index_subsample is not None:
        main_config.target_index_subsample = config.target_index_subsample
    gaussians_scale_factor = None

    # Define wave visualization parameters
    wave_color_dict = {'wave_color_front': [255, 230, 200], 'wave_color_back': [200, 220, 255], "use_gradient_color": True}
    wave_length = 0.4

    # Export only rgb results for evaluation
    do_eval = config.do_eval
    if do_eval:
        config.save_grid = False
        config.save_gt_input = False
        config.save_gt_depth = False
        config.save_video_input = False
        config.save_rgb_decoding = False
        config.save_gaussians = False
        config.save_gaussians_orig = False
    
    # Generate each sample independently
    main_config.batch_size = 1
    main_config.gs_view_chunk_size = 1

    # We are not using the train data loader
    main_config.num_train_images = 1

    # Set test dataset, otherwise take validation set
    if config.dataset_name is not None:
        main_config.data_mode = [[config.dataset_name, 1]]
        outdir = os.path.join(outdir, config.dataset_name)
    
    # Set bullet time manually
    main_config.target_index_manual = config.target_index_manual
    if config.target_index_manual is not None:
        outdir = os.path.join(outdir, str(config.target_index_manual))
    
    # Set number of test scenes, else take as defined in training config
    if config.num_test_images is not None:
        main_config.num_test_images = config.num_test_images
    
    # Set depth (was only used for supervision)
    main_config.use_depth = config.use_depth

    # Get data loader and model
    train_dataloader, test_dataloader = get_multi_dataloader(main_config)
    if transformer is None and vae is None and distributed_state is None:
        transformer, vae, distributed_state = load_model(ckpt_path, main_config, weight_dtype)
    
    # Set up for grid visualization
    step_test_sum = 0
    step_test_sum_dataset = 0
    test_video_out = []
    test_video_out_rgb = []
    test_video_in = []

    # Output dirs
    outdir_raw = os.path.join(outdir, "raw")
    outdir_meta = os.path.join(outdir, "meta")
    outdir_grid = os.path.join(outdir, "grid")
    outdir_full = os.path.join(outdir, "full_output")
    outdir_3dgs = os.path.join(outdir, "main_gaussians_renderings")
    for d in [outdir, outdir_raw, outdir_meta, outdir_grid, outdir_full, outdir_3dgs]:
        os.makedirs(d, exist_ok=True)
    
    # Loop over test set
    for idx, batch_test in tqdm(enumerate(test_dataloader)):
        
        # Skip based on filter list
        batch_file_name = batch_test['file_name']
        
        # Skip if already generated
        meta_data_sample = {'file_name': batch_file_name}
        meta_data_out_path = os.path.join(outdir_meta, f'sample_{idx}.json')
        if os.path.isfile(meta_data_out_path):
            continue
        
        # Check if file exists for eval
        if do_eval:
            eval_file_exists = True
            for view_idx in range(main_config.num_input_multi_views):
                outdir_view_idx = os.path.join(outdir, str(view_idx))
                out_file_name_view = batch_test['file_name']
                assert len(out_file_name_view) == 1, f"More than 1 file_names: {len(out_file_name_view)}"
                out_file_name_view = out_file_name_view[0]
                out_file_path_view = os.path.join(outdir_view_idx, out_file_name_view)
                if not os.path.isfile(f"{out_file_path_view}.mp4"):
                    eval_file_exists = False
                    break
            if eval_file_exists:
                print(f"Skipping {out_file_name_view}")
                continue
        
        # Move to device and cast tensors
        for batch_k, batch_v in batch_test.items():
            if not isinstance(batch_v, torch.Tensor):
                continue
            batch_test[batch_k] = batch_v.to(distributed_state.device)
            # Do rendering with full precision
            if batch_k not in ['intrinsics_input', 'c2ws_input', 'cam_view', 'intrinsics', 'file_name']:
                batch_test[batch_k] = batch_test[batch_k].to(weight_dtype)
        
        # Compute plucker with float64 to match old cpu results
        if main_config.compute_plucker_cuda:
            batch_test['plucker_embedding'], batch_test['rays_os'], batch_test['rays_ds'] = get_plucker_embedding_and_rays(
                batch_test['intrinsics_input'],
                batch_test['c2ws_input'],
                main_config.img_size,
                main_config.patch_size_out_factor,
                batch_test['flip_flag'],
                get_batch_index=False,
                dtype=dtype_map[main_config.compute_plucker_dtype],
                out_dtype=weight_dtype
                )
        
        # Make sure all use the same multi views within one batch
        if 'num_input_multi_views' in batch_test:
            assert (batch_test['num_input_multi_views'][0] == batch_test['num_input_multi_views']).all(), f"Not supporting multi batch size for variable multi-view"
            num_input_multi_views = int(batch_test['num_input_multi_views'][0].item())
            batch_test['num_input_multi_views'] = num_input_multi_views
        
        # Encode video
        if 'rgb_latents' in batch_test:
            model_input = batch_test['rgb_latents'].to(weight_dtype) 
            batch_test['images_input_embed'] = model_input
            video = None
        else:
            video = batch_test['images_input_vae']
            if main_config.use_rgb_decoder:
                model_input = video
            else:
                model_input = encode_multi_view_video(vae, video, num_input_multi_views, main_config.vae_backbone)
            batch_test['images_input_embed'] = model_input
        if main_config.time_embedding_vae:
            batch_test = encode_latent_time_vae(batch_test, lambda x: encode_video(vae, x, main_config.vae_backbone), main_config.img_size)
        if main_config.plucker_embedding_vae:
            batch_test = encode_plucker_vae(batch_test, lambda x: encode_multi_view_video(vae, x, num_input_multi_views, main_config.vae_backbone))
        
        # Reconstruct latents and render from 3DGS
        with torch.no_grad():
            model_output = transformer(batch_test)
        
        # Get RGB and depth from 3DGS
        pred_images = model_output['images_pred'].cpu()
        pred_depths = create_depth_visu(model_output['depths_pred']).cpu()
        if 'depths_output' in batch_test:
            gt_depths = create_depth_visu(batch_test['depths_output'].to(pred_depths.dtype)).cpu()
        else:
            gt_depths = None

        # RGB VAE decoding as reference
        if config.save_rgb_decoding:
            with torch.no_grad():
                reconstructed_latents = decode_multi_view_latents(vae, model_input, num_input_multi_views, main_config.vae_backbone)
            if video is None:
                video = reconstructed_latents
            else:
                video = torch.cat((reconstructed_latents, video), -1)
        
        # Gaussians export just exporting the tensor
        if config.save_gaussians:
            out_dir_gaussians = os.path.join(outdir, 'gaussians')
            os.makedirs(out_dir_gaussians, exist_ok=True)
            path_gaussians = os.path.join(out_dir_gaussians, f'gaussians_{idx}.ply')
            save_ply(model_output['gaussians'], path_gaussians, scale_factor=gaussians_scale_factor)
        
        # Gaussians export following original ply format (used for USDZ with Isaac)
        if config.save_gaussians_orig:
            out_dir_gaussians_orig = os.path.join(outdir, 'gaussians_orig')
            os.makedirs(out_dir_gaussians_orig, exist_ok=True)
            path_gaussians_orig = os.path.join(out_dir_gaussians_orig, f'gaussians_{idx}.ply')
            save_ply_orig(model_output['gaussians'], path_gaussians_orig, scale_factor=gaussians_scale_factor)
        del model_output['gaussians']

        # Wave propagation visualization
        pred_images_views = einops.rearrange(pred_images, 'b (v t) c h w -> v b t c h w', v=num_input_multi_views)
        if not do_eval:
            use_gradient_color = wave_color_dict['use_gradient_color']
            if 'wave_color' in wave_color_dict:
                wave_color = wave_color_dict['wave_color']
                wave_color_front = None
                wave_color_back = None
            else:
                wave_color = None
                wave_color_front = wave_color_dict['wave_color_front']
                wave_color_back = wave_color_dict['wave_color_back']
            pred_images_wave = generate_wave_video(model_output['images_pred'], model_output['depths_pred'], wave_length=wave_length, wave_color=wave_color, use_gradient_color=use_gradient_color, wave_color_front=wave_color_front, wave_color_back=wave_color_back)
            pred_images_rgb = torch.cat((pred_images_wave, pred_images), 1)
            save_video(pred_images_rgb, outdir_3dgs, name=f'rgb_{idx}', fps=out_fps)
            save_video(pred_images_wave, outdir_raw, name=f'rgb_wave_{idx}', fps=out_fps)
            for view_idx, pred_images_view in enumerate(pred_images_views):
                save_video(pred_images_view, outdir_raw, name=f'rgb_{idx}_view_idx_{view_idx}', fps=out_fps)
        
        # Export evaluation rendering with the corresponding filename
        if do_eval:
            for view_idx, pred_images_view in enumerate(pred_images_views):
                outdir_view_idx = os.path.join(outdir, str(view_idx))
                out_file_name_view = batch_test['file_name']
                assert len(out_file_name_view) == 1, f"More than 1 file_names: {len(out_file_name_view)}"
                out_file_name_view = out_file_name_view[0]
                if not os.path.exists(outdir_view_idx):
                    os.makedirs(outdir_view_idx)
                save_video(pred_images_view, outdir_view_idx, name=out_file_name_view, fps=out_fps)
        
        # Add maing 3DGS renderings to grid
        images_grid_list = [pred_images]

        # Add video model RGB reference
        if config.save_gt_input:
            gt_images = batch_test['images_output'].cpu()
            images_grid_list.append(gt_images)
        
        # Input video
        if config.save_video_input and video is not None:
            video_norm = ((video + 1)/2).cpu()
            video_norm = video_norm.float()
            if video_norm.shape == pred_images.shape:
                images_grid_list.append(video_norm)
            else:
                if config.save_grid:
                    test_video_in.append(video_norm)
                save_video(video_norm, outdir_raw, name=f'input_{idx}', fps=out_fps)
        
        # Add images for concatenated visualizations
        images_grid_list.append(pred_depths)
        if config.save_gt_depth and gt_depths is not None:
            images_grid_list.append(gt_depths)
        pred_images_out = torch.cat(images_grid_list, -1)
        step_test_sum += pred_images_out.shape[0]
        if config.save_grid:
            test_video_out.append(pred_images_out)
            test_video_out_rgb.append(pred_images_rgb)
        
        # Write main sample and metadata
        if not do_eval:
            save_video(pred_images_out, outdir_full, name=f'sample_{idx}', fps=out_fps)
            write_dict_to_json(meta_data_sample, meta_data_out_path)

        # Export grid and reset counters
        if step_test_sum >= config.num_grid_samples: 
            if config.save_grid:
                test_video_out = torch.cat(test_video_out, 0)
                save_video(test_video_out, outdir_grid, name=f'sample_grid_{step_test_sum_dataset}', fps=out_fps)
                test_video_out_rgb = torch.cat(test_video_out_rgb, 0)
                save_video(test_video_out_rgb, outdir_grid, name=f'rgb_grid_{step_test_sum_dataset}', fps=out_fps)
                if len(test_video_in) != 0:
                    test_video_in = torch.cat(test_video_in, 0)
                    save_video(test_video_in, outdir_grid, name=f'input_grid_{step_test_sum_dataset}', fps=out_fps)
            step_test_sum = 0
            step_test_sum_dataset += 1
            test_video_out = []
            test_video_out_rgb = []
            test_video_in = []
        print(f"Saved batch index {idx} to {outdir}")
        
    print(f"Saved all results to {outdir}")
    return transformer, vae, distributed_state, ckpt_path

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default=None)
    parser.add_argument('--config_default', type=str, default='configs/inference/default.yaml')
    args, unknown = parser.parse_known_args()
    config = load_and_merge_configs([args.config_default, args.config])
    cli = OmegaConf.from_dotlist(unknown)
    config = OmegaConf.merge(config, cli)
    main(config)