|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Dict, List |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
import re |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from accelerate import PartialState |
|
|
import einops |
|
|
from omegaconf import OmegaConf |
|
|
from accelerate.logging import get_logger |
|
|
|
|
|
from src.models.recon.model_latent_recon import LatentRecon |
|
|
from src.utils.visu import create_depth_visu, generate_wave_video, save_video |
|
|
from src.models.data import get_multi_dataloader |
|
|
from src.models.utils.model import encode_latent_time_vae, encode_plucker_vae |
|
|
from src.models.utils.render import get_plucker_embedding_and_rays, save_ply, save_ply_orig |
|
|
from src.models.utils.model import load_vae, encode_multi_view_video, encode_video, decode_multi_view_latents |
|
|
from src.models.utils.data import write_dict_to_json |
|
|
from src.models.utils.misc import dtype_map, seed_everything, load_and_merge_configs |
|
|
from src.models.utils.train import get_most_recent_checkpoint |
|
|
|
|
|
logger = get_logger(__name__, log_level="INFO") |
|
|
|
|
|
def load_model(ckpt_path, config, weight_dtype): |
|
|
|
|
|
distributed_state = PartialState() |
|
|
device = distributed_state.device |
|
|
vae = load_vae(config.vae_backbone, config.vae_path) |
|
|
transformer = LatentRecon( |
|
|
config |
|
|
) |
|
|
|
|
|
|
|
|
data = torch.load(ckpt_path) |
|
|
transformer.load_state_dict(data["module"]) |
|
|
|
|
|
|
|
|
transformer.to(device=device, dtype=weight_dtype) |
|
|
vae.to(device=device, dtype=weight_dtype) |
|
|
transformer.eval() |
|
|
vae.eval() |
|
|
return transformer, vae, distributed_state |
|
|
|
|
|
def main( |
|
|
config, |
|
|
**kwargs |
|
|
): |
|
|
|
|
|
target_index_manual = config.target_index_manual |
|
|
if target_index_manual is None and config.target_index_manual_start_idx is not None: |
|
|
target_index_manual = list(range(config.target_index_manual_start_idx, config.target_index_manual_start_idx + config.target_index_manual_num_idx, config.target_index_manual_stride)) |
|
|
if target_index_manual is not None and not isinstance(target_index_manual, int): |
|
|
for target_index_manual_manual_i in target_index_manual: |
|
|
print(f"Bullet time {target_index_manual_manual_i}") |
|
|
config.target_index_manual = target_index_manual_manual_i |
|
|
transformer, vae, distributed_state, ckpt_path = main_single(config, **kwargs) |
|
|
kwargs['transformer'] = transformer |
|
|
kwargs['vae'] = vae |
|
|
kwargs['distributed_state'] = distributed_state |
|
|
kwargs['ckpt_path'] = ckpt_path |
|
|
else: |
|
|
main_single(config, **kwargs) |
|
|
|
|
|
def main_single( |
|
|
config, |
|
|
seed: int = 0, |
|
|
transformer = None, |
|
|
vae = None, |
|
|
distributed_state = None, |
|
|
ckpt_path = None, |
|
|
): |
|
|
weight_dtype = torch.bfloat16 |
|
|
out_fps = config.out_fps |
|
|
g = torch.Generator() |
|
|
g.manual_seed(seed) |
|
|
seed_everything(seed) |
|
|
outdir = config.out_dir_inference |
|
|
|
|
|
|
|
|
if isinstance(config.config_path, str): |
|
|
main_config = OmegaConf.load(config.config_path) |
|
|
else: |
|
|
main_config = load_and_merge_configs(config.config_path) |
|
|
|
|
|
|
|
|
ckpt_name = None |
|
|
if ckpt_path is None: |
|
|
if config.ckpt_path is None: |
|
|
ckpt_model_sub_path = 'pytorch_model/mp_rank_00_model_states.pt' |
|
|
ckpts_path = main_config.output_dir |
|
|
ckpt_name = config.ckpt_name |
|
|
if ckpt_name is None: |
|
|
ckpt_name = get_most_recent_checkpoint(ckpts_path) |
|
|
ckpt_path = os.path.join(ckpts_path, ckpt_name, ckpt_model_sub_path) |
|
|
|
|
|
else: |
|
|
ckpt_path = config.ckpt_path |
|
|
if ckpt_name is None: |
|
|
has_ckpt_name = re.search(r"(checkpoint-\d+)", ckpt_path) |
|
|
if has_ckpt_name: |
|
|
ckpt_name = has_ckpt_name.group(1) |
|
|
if ckpt_name is not None: |
|
|
outdir = os.path.join(outdir, ckpt_name) |
|
|
if os.path.isfile(ckpt_path): |
|
|
print(f"Found ckpt at path {ckpt_path}") |
|
|
else: |
|
|
raise ValueError(f"Could not find ckpt at path {ckpt_path}") |
|
|
|
|
|
|
|
|
if config.set_manual_time_idx: |
|
|
main_config.set_manual_time_idx = config.set_manual_time_idx |
|
|
|
|
|
|
|
|
if config.static_view_indices_fixed is not None: |
|
|
main_config.static_view_indices_fixed = config.static_view_indices_fixed |
|
|
outdir = os.path.join(outdir, f"static_view_indices_fixed_{'_'.join(config.static_view_indices_fixed)}") |
|
|
main_config.static_view_indices_sampling = 'fixed' |
|
|
main_config.num_input_multi_views = len(config.static_view_indices_fixed) |
|
|
|
|
|
|
|
|
if config.target_index_subsample is not None: |
|
|
main_config.target_index_subsample = config.target_index_subsample |
|
|
gaussians_scale_factor = None |
|
|
|
|
|
|
|
|
wave_color_dict = {'wave_color_front': [255, 230, 200], 'wave_color_back': [200, 220, 255], "use_gradient_color": True} |
|
|
wave_length = 0.4 |
|
|
|
|
|
|
|
|
do_eval = config.do_eval |
|
|
if do_eval: |
|
|
config.save_grid = False |
|
|
config.save_gt_input = False |
|
|
config.save_gt_depth = False |
|
|
config.save_video_input = False |
|
|
config.save_rgb_decoding = False |
|
|
config.save_gaussians = False |
|
|
config.save_gaussians_orig = False |
|
|
|
|
|
|
|
|
main_config.batch_size = 1 |
|
|
main_config.gs_view_chunk_size = 1 |
|
|
|
|
|
|
|
|
main_config.num_train_images = 1 |
|
|
|
|
|
|
|
|
if config.dataset_name is not None: |
|
|
main_config.data_mode = [[config.dataset_name, 1]] |
|
|
outdir = os.path.join(outdir, config.dataset_name) |
|
|
|
|
|
|
|
|
main_config.target_index_manual = config.target_index_manual |
|
|
if config.target_index_manual is not None: |
|
|
outdir = os.path.join(outdir, str(config.target_index_manual)) |
|
|
|
|
|
|
|
|
if config.num_test_images is not None: |
|
|
main_config.num_test_images = config.num_test_images |
|
|
|
|
|
|
|
|
main_config.use_depth = config.use_depth |
|
|
|
|
|
|
|
|
train_dataloader, test_dataloader = get_multi_dataloader(main_config) |
|
|
if transformer is None and vae is None and distributed_state is None: |
|
|
transformer, vae, distributed_state = load_model(ckpt_path, main_config, weight_dtype) |
|
|
|
|
|
|
|
|
step_test_sum = 0 |
|
|
step_test_sum_dataset = 0 |
|
|
test_video_out = [] |
|
|
test_video_out_rgb = [] |
|
|
test_video_in = [] |
|
|
|
|
|
|
|
|
outdir_raw = os.path.join(outdir, "raw") |
|
|
outdir_meta = os.path.join(outdir, "meta") |
|
|
outdir_grid = os.path.join(outdir, "grid") |
|
|
outdir_full = os.path.join(outdir, "full_output") |
|
|
outdir_3dgs = os.path.join(outdir, "main_gaussians_renderings") |
|
|
for d in [outdir, outdir_raw, outdir_meta, outdir_grid, outdir_full, outdir_3dgs]: |
|
|
os.makedirs(d, exist_ok=True) |
|
|
|
|
|
|
|
|
for idx, batch_test in tqdm(enumerate(test_dataloader)): |
|
|
|
|
|
|
|
|
batch_file_name = batch_test['file_name'] |
|
|
|
|
|
|
|
|
meta_data_sample = {'file_name': batch_file_name} |
|
|
meta_data_out_path = os.path.join(outdir_meta, f'sample_{idx}.json') |
|
|
if os.path.isfile(meta_data_out_path): |
|
|
continue |
|
|
|
|
|
|
|
|
if do_eval: |
|
|
eval_file_exists = True |
|
|
for view_idx in range(main_config.num_input_multi_views): |
|
|
outdir_view_idx = os.path.join(outdir, str(view_idx)) |
|
|
out_file_name_view = batch_test['file_name'] |
|
|
assert len(out_file_name_view) == 1, f"More than 1 file_names: {len(out_file_name_view)}" |
|
|
out_file_name_view = out_file_name_view[0] |
|
|
out_file_path_view = os.path.join(outdir_view_idx, out_file_name_view) |
|
|
if not os.path.isfile(f"{out_file_path_view}.mp4"): |
|
|
eval_file_exists = False |
|
|
break |
|
|
if eval_file_exists: |
|
|
print(f"Skipping {out_file_name_view}") |
|
|
continue |
|
|
|
|
|
|
|
|
for batch_k, batch_v in batch_test.items(): |
|
|
if not isinstance(batch_v, torch.Tensor): |
|
|
continue |
|
|
batch_test[batch_k] = batch_v.to(distributed_state.device) |
|
|
|
|
|
if batch_k not in ['intrinsics_input', 'c2ws_input', 'cam_view', 'intrinsics', 'file_name']: |
|
|
batch_test[batch_k] = batch_test[batch_k].to(weight_dtype) |
|
|
|
|
|
|
|
|
if main_config.compute_plucker_cuda: |
|
|
batch_test['plucker_embedding'], batch_test['rays_os'], batch_test['rays_ds'] = get_plucker_embedding_and_rays( |
|
|
batch_test['intrinsics_input'], |
|
|
batch_test['c2ws_input'], |
|
|
main_config.img_size, |
|
|
main_config.patch_size_out_factor, |
|
|
batch_test['flip_flag'], |
|
|
get_batch_index=False, |
|
|
dtype=dtype_map[main_config.compute_plucker_dtype], |
|
|
out_dtype=weight_dtype |
|
|
) |
|
|
|
|
|
|
|
|
if 'num_input_multi_views' in batch_test: |
|
|
assert (batch_test['num_input_multi_views'][0] == batch_test['num_input_multi_views']).all(), f"Not supporting multi batch size for variable multi-view" |
|
|
num_input_multi_views = int(batch_test['num_input_multi_views'][0].item()) |
|
|
batch_test['num_input_multi_views'] = num_input_multi_views |
|
|
|
|
|
|
|
|
if 'rgb_latents' in batch_test: |
|
|
model_input = batch_test['rgb_latents'].to(weight_dtype) |
|
|
batch_test['images_input_embed'] = model_input |
|
|
video = None |
|
|
else: |
|
|
video = batch_test['images_input_vae'] |
|
|
if main_config.use_rgb_decoder: |
|
|
model_input = video |
|
|
else: |
|
|
model_input = encode_multi_view_video(vae, video, num_input_multi_views, main_config.vae_backbone) |
|
|
batch_test['images_input_embed'] = model_input |
|
|
if main_config.time_embedding_vae: |
|
|
batch_test = encode_latent_time_vae(batch_test, lambda x: encode_video(vae, x, main_config.vae_backbone), main_config.img_size) |
|
|
if main_config.plucker_embedding_vae: |
|
|
batch_test = encode_plucker_vae(batch_test, lambda x: encode_multi_view_video(vae, x, num_input_multi_views, main_config.vae_backbone)) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
model_output = transformer(batch_test) |
|
|
|
|
|
|
|
|
pred_images = model_output['images_pred'].cpu() |
|
|
pred_depths = create_depth_visu(model_output['depths_pred']).cpu() |
|
|
if 'depths_output' in batch_test: |
|
|
gt_depths = create_depth_visu(batch_test['depths_output'].to(pred_depths.dtype)).cpu() |
|
|
else: |
|
|
gt_depths = None |
|
|
|
|
|
|
|
|
if config.save_rgb_decoding: |
|
|
with torch.no_grad(): |
|
|
reconstructed_latents = decode_multi_view_latents(vae, model_input, num_input_multi_views, main_config.vae_backbone) |
|
|
if video is None: |
|
|
video = reconstructed_latents |
|
|
else: |
|
|
video = torch.cat((reconstructed_latents, video), -1) |
|
|
|
|
|
|
|
|
if config.save_gaussians: |
|
|
out_dir_gaussians = os.path.join(outdir, 'gaussians') |
|
|
os.makedirs(out_dir_gaussians, exist_ok=True) |
|
|
path_gaussians = os.path.join(out_dir_gaussians, f'gaussians_{idx}.ply') |
|
|
save_ply(model_output['gaussians'], path_gaussians, scale_factor=gaussians_scale_factor) |
|
|
|
|
|
|
|
|
if config.save_gaussians_orig: |
|
|
out_dir_gaussians_orig = os.path.join(outdir, 'gaussians_orig') |
|
|
os.makedirs(out_dir_gaussians_orig, exist_ok=True) |
|
|
path_gaussians_orig = os.path.join(out_dir_gaussians_orig, f'gaussians_{idx}.ply') |
|
|
save_ply_orig(model_output['gaussians'], path_gaussians_orig, scale_factor=gaussians_scale_factor) |
|
|
del model_output['gaussians'] |
|
|
|
|
|
|
|
|
pred_images_views = einops.rearrange(pred_images, 'b (v t) c h w -> v b t c h w', v=num_input_multi_views) |
|
|
if not do_eval: |
|
|
use_gradient_color = wave_color_dict['use_gradient_color'] |
|
|
if 'wave_color' in wave_color_dict: |
|
|
wave_color = wave_color_dict['wave_color'] |
|
|
wave_color_front = None |
|
|
wave_color_back = None |
|
|
else: |
|
|
wave_color = None |
|
|
wave_color_front = wave_color_dict['wave_color_front'] |
|
|
wave_color_back = wave_color_dict['wave_color_back'] |
|
|
pred_images_wave = generate_wave_video(model_output['images_pred'], model_output['depths_pred'], wave_length=wave_length, wave_color=wave_color, use_gradient_color=use_gradient_color, wave_color_front=wave_color_front, wave_color_back=wave_color_back) |
|
|
pred_images_rgb = torch.cat((pred_images_wave, pred_images), 1) |
|
|
save_video(pred_images_rgb, outdir_3dgs, name=f'rgb_{idx}', fps=out_fps) |
|
|
save_video(pred_images_wave, outdir_raw, name=f'rgb_wave_{idx}', fps=out_fps) |
|
|
for view_idx, pred_images_view in enumerate(pred_images_views): |
|
|
save_video(pred_images_view, outdir_raw, name=f'rgb_{idx}_view_idx_{view_idx}', fps=out_fps) |
|
|
|
|
|
|
|
|
if do_eval: |
|
|
for view_idx, pred_images_view in enumerate(pred_images_views): |
|
|
outdir_view_idx = os.path.join(outdir, str(view_idx)) |
|
|
out_file_name_view = batch_test['file_name'] |
|
|
assert len(out_file_name_view) == 1, f"More than 1 file_names: {len(out_file_name_view)}" |
|
|
out_file_name_view = out_file_name_view[0] |
|
|
if not os.path.exists(outdir_view_idx): |
|
|
os.makedirs(outdir_view_idx) |
|
|
save_video(pred_images_view, outdir_view_idx, name=out_file_name_view, fps=out_fps) |
|
|
|
|
|
|
|
|
images_grid_list = [pred_images] |
|
|
|
|
|
|
|
|
if config.save_gt_input: |
|
|
gt_images = batch_test['images_output'].cpu() |
|
|
images_grid_list.append(gt_images) |
|
|
|
|
|
|
|
|
if config.save_video_input and video is not None: |
|
|
video_norm = ((video + 1)/2).cpu() |
|
|
video_norm = video_norm.float() |
|
|
if video_norm.shape == pred_images.shape: |
|
|
images_grid_list.append(video_norm) |
|
|
else: |
|
|
if config.save_grid: |
|
|
test_video_in.append(video_norm) |
|
|
save_video(video_norm, outdir_raw, name=f'input_{idx}', fps=out_fps) |
|
|
|
|
|
|
|
|
images_grid_list.append(pred_depths) |
|
|
if config.save_gt_depth and gt_depths is not None: |
|
|
images_grid_list.append(gt_depths) |
|
|
pred_images_out = torch.cat(images_grid_list, -1) |
|
|
step_test_sum += pred_images_out.shape[0] |
|
|
if config.save_grid: |
|
|
test_video_out.append(pred_images_out) |
|
|
test_video_out_rgb.append(pred_images_rgb) |
|
|
|
|
|
|
|
|
if not do_eval: |
|
|
save_video(pred_images_out, outdir_full, name=f'sample_{idx}', fps=out_fps) |
|
|
write_dict_to_json(meta_data_sample, meta_data_out_path) |
|
|
|
|
|
|
|
|
if step_test_sum >= config.num_grid_samples: |
|
|
if config.save_grid: |
|
|
test_video_out = torch.cat(test_video_out, 0) |
|
|
save_video(test_video_out, outdir_grid, name=f'sample_grid_{step_test_sum_dataset}', fps=out_fps) |
|
|
test_video_out_rgb = torch.cat(test_video_out_rgb, 0) |
|
|
save_video(test_video_out_rgb, outdir_grid, name=f'rgb_grid_{step_test_sum_dataset}', fps=out_fps) |
|
|
if len(test_video_in) != 0: |
|
|
test_video_in = torch.cat(test_video_in, 0) |
|
|
save_video(test_video_in, outdir_grid, name=f'input_grid_{step_test_sum_dataset}', fps=out_fps) |
|
|
step_test_sum = 0 |
|
|
step_test_sum_dataset += 1 |
|
|
test_video_out = [] |
|
|
test_video_out_rgb = [] |
|
|
test_video_in = [] |
|
|
print(f"Saved batch index {idx} to {outdir}") |
|
|
|
|
|
print(f"Saved all results to {outdir}") |
|
|
return transformer, vae, distributed_state, ckpt_path |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--config', type=str, default=None) |
|
|
parser.add_argument('--config_default', type=str, default='configs/inference/default.yaml') |
|
|
args, unknown = parser.parse_known_args() |
|
|
config = load_and_merge_configs([args.config_default, args.config]) |
|
|
cli = OmegaConf.from_dotlist(unknown) |
|
|
config = OmegaConf.merge(config, cli) |
|
|
main(config) |
|
|
|