|
|
import argparse |
|
|
import os |
|
|
import time |
|
|
|
|
|
from moge.model.v1 import MoGeModel |
|
|
import torch |
|
|
import numpy as np |
|
|
from cosmos_predict1.diffusion.inference.gen3c_pipeline import Gen3cPipeline |
|
|
from cosmos_predict1.diffusion.inference.gen3c_single_image import ( |
|
|
create_parser as create_parser_base, |
|
|
validate_args as validate_args_base, |
|
|
_predict_moge_depth, |
|
|
_predict_moge_depth_from_tensor |
|
|
) |
|
|
from cosmos_predict1.utils import log, misc |
|
|
from cosmos_predict1.utils.distributed import device_with_rank, is_rank0, get_rank |
|
|
from cosmos_predict1.utils.io import save_video |
|
|
from cosmos_predict1.diffusion.inference.cache_3d import Cache3D_Buffer, Cache4D |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def create_parser(): |
|
|
return create_parser_base() |
|
|
|
|
|
|
|
|
def validate_args(args: argparse.Namespace): |
|
|
validate_args_base(args) |
|
|
assert args.batch_input_path is None, "Unsupported in persistent mode" |
|
|
assert args.prompt is not None, "Prompt is required in persistent mode (but it can be the empty string)" |
|
|
assert args.input_image_path is None, "Image should be provided directly by value in persistent mode" |
|
|
assert args.trajectory in (None, 'none'), "Trajectory should be provided directly by value in persistent mode, set --trajectory=none" |
|
|
assert not args.video_save_name, f"Video saving name will be set automatically for each inference request. Found string: \"{args.video_save_name}\"" |
|
|
|
|
|
|
|
|
def resize_intrinsics(intrinsics: np.ndarray | torch.Tensor, |
|
|
old_size: tuple[int, int], new_size: tuple[int, int], |
|
|
crop_size: tuple[int, int] | None = None) -> np.ndarray | torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(intrinsics, np.ndarray): |
|
|
intrinsics_copy = np.copy(intrinsics) |
|
|
elif isinstance(intrinsics, torch.Tensor): |
|
|
intrinsics_copy = intrinsics.clone() |
|
|
else: |
|
|
raise ValueError(f"Invalid intrinsics type: {type(intrinsics)}") |
|
|
intrinsics_copy[:, 0, :] *= new_size[1] / old_size[1] |
|
|
intrinsics_copy[:, 1, :] *= new_size[0] / old_size[0] |
|
|
if crop_size is not None: |
|
|
intrinsics_copy[:, 0, -1] = intrinsics_copy[:, 0, -1] - (new_size[1] - crop_size[1]) / 2 |
|
|
intrinsics_copy[:, 1, -1] = intrinsics_copy[:, 1, -1] - (new_size[0] - crop_size[0]) / 2 |
|
|
return intrinsics_copy |
|
|
|
|
|
|
|
|
class Gen3cPersistentModel(): |
|
|
"""Helper class to run Gen3C image-to-video or video-to-video inference. |
|
|
|
|
|
This class loads the models only once and can be reused for multiple inputs. |
|
|
|
|
|
This function handles the main video-to-world generation pipeline, including: |
|
|
- Setting up the random seed for reproducibility |
|
|
- Initializing the generation pipeline with the provided configuration |
|
|
- Processing single or multiple prompts/images/videos from input |
|
|
- Generating videos from prompts and images/videos |
|
|
- Saving the generated videos and corresponding prompts to disk |
|
|
|
|
|
Args: |
|
|
cfg (argparse.Namespace): Configuration namespace containing: |
|
|
- Model configuration (checkpoint paths, model settings) |
|
|
- Generation parameters (guidance, steps, dimensions) |
|
|
- Input/output settings (prompts/images/videos, save paths) |
|
|
- Performance options (model offloading settings) |
|
|
|
|
|
The function will save: |
|
|
- Generated MP4 video files |
|
|
- Text files containing the processed prompts |
|
|
""" |
|
|
|
|
|
@torch.no_grad() |
|
|
def __init__(self, args: argparse.Namespace): |
|
|
misc.set_random_seed(args.seed) |
|
|
validate_args(args) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
if args.num_gpus > 1: |
|
|
from megatron.core import parallel_state |
|
|
|
|
|
from cosmos_predict1.utils import distributed |
|
|
|
|
|
distributed.init() |
|
|
parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) |
|
|
process_group = parallel_state.get_context_parallel_group() |
|
|
|
|
|
self.frames_per_batch = 121 |
|
|
self.inference_overlap_frames = 1 |
|
|
|
|
|
|
|
|
pipeline = Gen3cPipeline( |
|
|
inference_type="video2world", |
|
|
checkpoint_dir=args.checkpoint_dir, |
|
|
checkpoint_name="Gen3C-Cosmos-7B", |
|
|
prompt_upsampler_dir=args.prompt_upsampler_dir, |
|
|
enable_prompt_upsampler=not args.disable_prompt_upsampler, |
|
|
offload_network=args.offload_diffusion_transformer, |
|
|
offload_tokenizer=args.offload_tokenizer, |
|
|
offload_text_encoder_model=args.offload_text_encoder_model, |
|
|
offload_prompt_upsampler=args.offload_prompt_upsampler, |
|
|
offload_guardrail_models=args.offload_guardrail_models, |
|
|
disable_guardrail=args.disable_guardrail, |
|
|
guidance=args.guidance, |
|
|
num_steps=args.num_steps, |
|
|
height=args.height, |
|
|
width=args.width, |
|
|
fps=args.fps, |
|
|
num_video_frames=self.frames_per_batch, |
|
|
seed=args.seed, |
|
|
) |
|
|
if args.num_gpus > 1: |
|
|
pipeline.model.net.enable_context_parallel(process_group) |
|
|
|
|
|
self.args = args |
|
|
self.frame_buffer_max = pipeline.model.frame_buffer_max |
|
|
self.generator = torch.Generator(device=device).manual_seed(args.seed) |
|
|
self.sample_n_frames = pipeline.model.chunk_size |
|
|
self.moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device) |
|
|
self.pipeline = pipeline |
|
|
self.device = device |
|
|
self.device_with_rank = device_with_rank(self.device) |
|
|
|
|
|
self.cache: Cache3D_Buffer | Cache4D | None = None |
|
|
self.model_was_seeded = False |
|
|
|
|
|
|
|
|
self.seeding_image: torch.Tensor | None = None |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def seed_model_from_values(self, |
|
|
images_np: np.ndarray, |
|
|
depths_np: np.ndarray | None, |
|
|
world_to_cameras_np: np.ndarray, |
|
|
focal_lengths_np: np.ndarray, |
|
|
principal_point_rel_np: np.ndarray, |
|
|
resolutions: np.ndarray, |
|
|
masks_np: np.ndarray | None = None): |
|
|
import torchvision.transforms.functional as transforms_F |
|
|
|
|
|
|
|
|
n = images_np.shape[0] |
|
|
assert images_np.shape[-1] == 3 |
|
|
assert world_to_cameras_np.shape == (n, 4, 4) |
|
|
assert focal_lengths_np.shape == (n, 2) |
|
|
assert principal_point_rel_np.shape == (n, 2) |
|
|
assert resolutions.shape == (n, 2) |
|
|
assert (depths_np is None) or (depths_np.shape == images_np.shape[:-1]) |
|
|
assert (masks_np is None) or (masks_np.shape == images_np.shape[:-1]) |
|
|
|
|
|
|
|
|
if n == 1: |
|
|
|
|
|
assert depths_np is None, "Not supported yet: directly providing pre-estimated depth values along with a single image." |
|
|
|
|
|
|
|
|
input_image_np = images_np[0, ...] * 255.0 |
|
|
del images_np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
( |
|
|
moge_image_b1chw_float, |
|
|
moge_depth_b11hw, |
|
|
moge_mask_b11hw, |
|
|
moge_initial_w2c_b144, |
|
|
moge_intrinsics_b133, |
|
|
) = _predict_moge_depth( |
|
|
input_image_np, self.args.height, self.args.width, self.device_with_rank, self.moge_model |
|
|
) |
|
|
|
|
|
|
|
|
input_image = moge_image_b1chw_float[:, 0].clone() |
|
|
self.cache = Cache3D_Buffer( |
|
|
frame_buffer_max=self.frame_buffer_max, |
|
|
generator=self.generator, |
|
|
noise_aug_strength=self.args.noise_aug_strength, |
|
|
input_image=input_image, |
|
|
input_depth=moge_depth_b11hw[:, 0], |
|
|
|
|
|
input_w2c=moge_initial_w2c_b144[:, 0], |
|
|
input_intrinsics=moge_intrinsics_b133[:, 0], |
|
|
filter_points_threshold=self.args.filter_points_threshold, |
|
|
foreground_masking=self.args.foreground_masking, |
|
|
) |
|
|
|
|
|
seeding_image = input_image_np.transpose(2, 0, 1)[None, ...] / 128.0 - 1.0 |
|
|
seeding_image = torch.from_numpy(seeding_image).to(device_with_rank(self.device_with_rank)) |
|
|
|
|
|
|
|
|
estimated_w2c_b44_np = moge_initial_w2c_b144.cpu().numpy()[:, 0, ...] |
|
|
moge_intrinsics_b133_np = moge_intrinsics_b133.cpu().numpy() |
|
|
estimated_focal_lengths_b2_np = np.stack([moge_intrinsics_b133_np[:, 0, 0, 0], |
|
|
moge_intrinsics_b133_np[:, 0, 1, 1]], axis=1) |
|
|
estimated_principal_point_rel_b2_np = moge_intrinsics_b133_np[:, 0, :2, 2] |
|
|
|
|
|
else: |
|
|
if depths_np is None: |
|
|
raise NotImplementedError("Seeding from multiple frames requires providing depth values.") |
|
|
if masks_np is None: |
|
|
raise NotImplementedError("Seeding from multiple frames requires providing mask values.") |
|
|
|
|
|
|
|
|
image_bchw_float = torch.from_numpy(images_np.transpose(0, 3, 1, 2).astype(np.float32)).to(self.device_with_rank) |
|
|
|
|
|
image_bchw_float = (image_bchw_float * 2.0) - 1.0 |
|
|
del images_np |
|
|
|
|
|
|
|
|
depth_b1hw = torch.from_numpy(depths_np[:, None, ...].astype(np.float32)).to(self.device_with_rank) |
|
|
|
|
|
mask_b1hw = torch.from_numpy(masks_np[:, None, ...].astype(np.float32)).to(self.device_with_rank) |
|
|
|
|
|
initial_w2c_b44 = torch.from_numpy(world_to_cameras_np).to(self.device_with_rank) |
|
|
|
|
|
intrinsics_b33_np = np.zeros((n, 3, 3), dtype=np.float32) |
|
|
intrinsics_b33_np[:, 0, 0] = focal_lengths_np[:, 0] |
|
|
intrinsics_b33_np[:, 1, 1] = focal_lengths_np[:, 1] |
|
|
intrinsics_b33_np[:, 0, 2] = principal_point_rel_np[:, 0] * self.args.width |
|
|
intrinsics_b33_np[:, 1, 2] = principal_point_rel_np[:, 1] * self.args.height |
|
|
intrinsics_b33_np[:, 2, 2] = 1.0 |
|
|
intrinsics_b33 = torch.from_numpy(intrinsics_b33_np).to(self.device_with_rank) |
|
|
|
|
|
self.cache = Cache4D( |
|
|
input_image=image_bchw_float.clone(), |
|
|
input_depth=depth_b1hw, |
|
|
input_mask=mask_b1hw, |
|
|
input_w2c=initial_w2c_b44, |
|
|
input_intrinsics=intrinsics_b33, |
|
|
filter_points_threshold=self.args.filter_points_threshold, |
|
|
foreground_masking=self.args.foreground_masking, |
|
|
input_format=["F", "C", "H", "W"], |
|
|
) |
|
|
|
|
|
|
|
|
seeding_image = image_bchw_float |
|
|
estimated_w2c_b44_np = world_to_cameras_np |
|
|
estimated_focal_lengths_b2_np = focal_lengths_np |
|
|
estimated_principal_point_rel_b2_np = principal_point_rel_np |
|
|
|
|
|
|
|
|
if (seeding_image.shape[2] != self.H) or (seeding_image.shape[3] != self.W): |
|
|
|
|
|
seeding_image = transforms_F.resize( |
|
|
seeding_image, |
|
|
size=(self.H, self.W), |
|
|
interpolation=transforms_F.InterpolationMode.BICUBIC, |
|
|
antialias=True, |
|
|
) |
|
|
|
|
|
self.seeding_image = seeding_image[:, :, None, ...] |
|
|
|
|
|
working_resolutions_b2_np = np.tile([[self.args.width, self.args.height]], (n, 1)) |
|
|
return ( |
|
|
estimated_w2c_b44_np, |
|
|
estimated_focal_lengths_b2_np, |
|
|
estimated_principal_point_rel_b2_np, |
|
|
working_resolutions_b2_np |
|
|
) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def inference_on_cameras(self, view_cameras_w2cs: np.ndarray, view_camera_intrinsics: np.ndarray, |
|
|
fps: int | float, |
|
|
overlap_frames:int = 1, |
|
|
return_estimated_depths: bool = False, |
|
|
video_save_quality: int = 5, |
|
|
save_buffer: bool | None = None) -> dict | None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.pipeline.fps = int(fps) |
|
|
del fps |
|
|
save_buffer = save_buffer if (save_buffer is not None) else self.args.save_buffer |
|
|
|
|
|
video_save_name = self.args.video_save_name |
|
|
if not video_save_name: |
|
|
video_save_name = f"video_{time.strftime('%Y-%m-%d_%H-%M-%S')}" |
|
|
video_save_path = os.path.join(self.args.video_save_folder, f"{video_save_name}.mp4") |
|
|
os.makedirs(self.args.video_save_folder, exist_ok=True) |
|
|
|
|
|
cache_is_multiframe = isinstance(self.cache, Cache4D) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
view_cameras_w2cs, view_camera_intrinsics = self.prepare_camera_for_inference( |
|
|
view_cameras_w2cs, view_camera_intrinsics, |
|
|
old_size=(self.H, self.W), new_size=(self.H, self.W) |
|
|
) |
|
|
|
|
|
n_frames_total = view_cameras_w2cs.shape[1] |
|
|
num_ar_iterations = (n_frames_total - overlap_frames) // (self.sample_n_frames - overlap_frames) |
|
|
log.info(f"Generating {n_frames_total} frames will take {num_ar_iterations} auto-regressive iterations") |
|
|
|
|
|
|
|
|
log.info(f"Generating frames 0 - {self.sample_n_frames} (out of {n_frames_total} total)...") |
|
|
rendered_warp_images, rendered_warp_masks = self.cache.render_cache( |
|
|
view_cameras_w2cs[:, 0:self.sample_n_frames], |
|
|
view_camera_intrinsics[:, 0:self.sample_n_frames], |
|
|
start_frame_idx=0, |
|
|
) |
|
|
|
|
|
all_rendered_warps = [] |
|
|
all_predicted_depth = [] |
|
|
if save_buffer: |
|
|
all_rendered_warps.append(rendered_warp_images.clone().cpu()) |
|
|
|
|
|
current_prompt = self.args.prompt |
|
|
if current_prompt is None and self.args.disable_prompt_upsampler: |
|
|
log.critical("Prompt is missing, skipping world generation.") |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
starting_frame = self.seeding_image |
|
|
if cache_is_multiframe: |
|
|
starting_frame = starting_frame[0].unsqueeze(0) |
|
|
|
|
|
generated_output = self.pipeline.generate( |
|
|
prompt=current_prompt, |
|
|
image_path=starting_frame, |
|
|
negative_prompt=self.args.negative_prompt, |
|
|
rendered_warp_images=rendered_warp_images, |
|
|
rendered_warp_masks=rendered_warp_masks, |
|
|
) |
|
|
if generated_output is None: |
|
|
log.critical("Guardrail blocked video2world generation.") |
|
|
return |
|
|
video, _ = generated_output |
|
|
|
|
|
|
|
|
def depth_for_frame(frame: np.ndarray | torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
last_frame_hwc_0_255 = torch.tensor(frame, device=self.device_with_rank) |
|
|
pred_image_for_depth_chw_0_1 = last_frame_hwc_0_255.permute(2, 0, 1) / 255.0 |
|
|
|
|
|
pred_depth, pred_mask = _predict_moge_depth_from_tensor( |
|
|
pred_image_for_depth_chw_0_1, self.moge_model |
|
|
) |
|
|
return pred_depth, pred_mask, pred_image_for_depth_chw_0_1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
need_depth_of_latest_frame = return_estimated_depths or (num_ar_iterations > 1 and not cache_is_multiframe) |
|
|
if need_depth_of_latest_frame: |
|
|
pred_depth, _, pred_image_for_depth_chw_0_1 = depth_for_frame(video[-1]) |
|
|
|
|
|
if return_estimated_depths: |
|
|
|
|
|
|
|
|
|
|
|
depths_batch_0 = np.full((video.shape[0], 1, self.H, self.W), fill_value=np.nan, |
|
|
dtype=np.float32) |
|
|
depths_batch_0[-1, ...] = pred_depth.cpu().numpy() |
|
|
all_predicted_depth.append(depths_batch_0) |
|
|
del depths_batch_0 |
|
|
|
|
|
|
|
|
|
|
|
for num_iter in range(1, num_ar_iterations): |
|
|
|
|
|
start_frame_idx = num_iter * (self.sample_n_frames - overlap_frames) |
|
|
end_frame_idx = start_frame_idx + self.sample_n_frames |
|
|
log.info(f"Generating frames {start_frame_idx} - {end_frame_idx} (out of {n_frames_total} total)...") |
|
|
|
|
|
if cache_is_multiframe: |
|
|
|
|
|
|
|
|
pred_image_for_depth_chw_0_1 = torch.tensor( |
|
|
video[-1], device=self.device_with_rank |
|
|
).permute(2, 0, 1) / 255.0 |
|
|
|
|
|
else: |
|
|
self.cache.update_cache( |
|
|
new_image=pred_image_for_depth_chw_0_1.unsqueeze(0) * 2 - 1, |
|
|
new_depth=pred_depth, |
|
|
|
|
|
new_w2c=view_cameras_w2cs[:, start_frame_idx], |
|
|
new_intrinsics=view_camera_intrinsics[:, start_frame_idx], |
|
|
) |
|
|
|
|
|
current_segment_w2cs = view_cameras_w2cs[:, start_frame_idx:end_frame_idx] |
|
|
current_segment_intrinsics = view_camera_intrinsics[:, start_frame_idx:end_frame_idx] |
|
|
|
|
|
cache_start_frame_idx = 0 |
|
|
if cache_is_multiframe: |
|
|
|
|
|
|
|
|
cache_start_frame_idx = min( |
|
|
start_frame_idx, |
|
|
self.cache.input_frame_count() - (end_frame_idx - start_frame_idx) |
|
|
) |
|
|
|
|
|
rendered_warp_images, rendered_warp_masks = self.cache.render_cache( |
|
|
current_segment_w2cs, |
|
|
current_segment_intrinsics, |
|
|
start_frame_idx=cache_start_frame_idx, |
|
|
) |
|
|
|
|
|
if save_buffer: |
|
|
all_rendered_warps.append(rendered_warp_images[:, overlap_frames:].clone().cpu()) |
|
|
|
|
|
pred_image_for_depth_bcthw_minus1_1 = pred_image_for_depth_chw_0_1.unsqueeze(0).unsqueeze(2) * 2 - 1 |
|
|
generated_output = self.pipeline.generate( |
|
|
prompt=current_prompt, |
|
|
image_path=pred_image_for_depth_bcthw_minus1_1, |
|
|
negative_prompt=self.args.negative_prompt, |
|
|
rendered_warp_images=rendered_warp_images, |
|
|
rendered_warp_masks=rendered_warp_masks, |
|
|
) |
|
|
video_new, _ = generated_output |
|
|
|
|
|
video = np.concatenate([video, video_new[overlap_frames:]], axis=0) |
|
|
|
|
|
|
|
|
need_depth_of_latest_frame = return_estimated_depths or ((num_iter < num_ar_iterations - 1) and not cache_is_multiframe) |
|
|
if need_depth_of_latest_frame: |
|
|
|
|
|
|
|
|
pred_depth, _, pred_image_for_depth_chw_0_1 = depth_for_frame(video_new[-1]) |
|
|
if return_estimated_depths: |
|
|
depths_batch_i = np.full((video_new.shape[0] - overlap_frames, 1, self.H, self.W), |
|
|
fill_value=np.nan, dtype=np.float32) |
|
|
depths_batch_i[-1, ...] = pred_depth.cpu().numpy() |
|
|
all_predicted_depth.append(depths_batch_i) |
|
|
del depths_batch_i |
|
|
|
|
|
|
|
|
if is_rank0(): |
|
|
|
|
|
final_video_to_save = video |
|
|
final_width = self.args.width |
|
|
|
|
|
if save_buffer and all_rendered_warps: |
|
|
squeezed_warps = [t.squeeze(0) for t in all_rendered_warps] |
|
|
|
|
|
if squeezed_warps: |
|
|
n_max = max(t.shape[1] for t in squeezed_warps) |
|
|
|
|
|
padded_t_list = [] |
|
|
for sq_t in squeezed_warps: |
|
|
|
|
|
current_n_i = sq_t.shape[1] |
|
|
padding_needed_dim1 = n_max - current_n_i |
|
|
|
|
|
pad_spec = (0,0, |
|
|
0,0, |
|
|
0,0, |
|
|
0,padding_needed_dim1, |
|
|
0,0) |
|
|
padded_t = F.pad(sq_t, pad_spec, mode='constant', value=-1.0) |
|
|
padded_t_list.append(padded_t) |
|
|
|
|
|
full_rendered_warp_tensor = torch.cat(padded_t_list, dim=0) |
|
|
|
|
|
T_total, _, C_dim, H_dim, W_dim = full_rendered_warp_tensor.shape |
|
|
buffer_video_TCHnW = full_rendered_warp_tensor.permute(0, 2, 3, 1, 4) |
|
|
buffer_video_TCHWstacked = buffer_video_TCHnW.contiguous().view(T_total, C_dim, H_dim, n_max * W_dim) |
|
|
buffer_video_TCHWstacked = (buffer_video_TCHWstacked * 0.5 + 0.5) * 255.0 |
|
|
buffer_numpy_TCHWstacked = buffer_video_TCHWstacked.cpu().numpy().astype(np.uint8) |
|
|
buffer_numpy_THWC = np.transpose(buffer_numpy_TCHWstacked, (0, 2, 3, 1)) |
|
|
|
|
|
final_video_to_save = np.concatenate([buffer_numpy_THWC, final_video_to_save], axis=2) |
|
|
final_width = self.args.width * (1 + n_max) |
|
|
log.info(f"Concatenating video with {n_max} warp buffers. Final video width will be {final_width}") |
|
|
|
|
|
else: |
|
|
log.info("No warp buffers to save.") |
|
|
|
|
|
|
|
|
save_video( |
|
|
video=final_video_to_save, |
|
|
fps=self.pipeline.fps, |
|
|
H=self.args.height, |
|
|
W=final_width, |
|
|
video_save_quality=video_save_quality, |
|
|
video_save_path=video_save_path, |
|
|
) |
|
|
log.info(f"Saved video to {video_save_path}") |
|
|
|
|
|
|
|
|
if return_estimated_depths: |
|
|
predicted_depth = np.concatenate(all_predicted_depth, axis=0) |
|
|
else: |
|
|
predicted_depth = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video = video.transpose(0, 3, 1, 2)[None, ...] |
|
|
|
|
|
|
|
|
|
|
|
rendered_warp_images_no_overlap = rendered_warp_images |
|
|
video_no_overlap = video |
|
|
return { |
|
|
"rendered_warp_images": rendered_warp_images, |
|
|
"video": video, |
|
|
"rendered_warp_images_no_overlap": rendered_warp_images_no_overlap, |
|
|
"video_no_overlap": video_no_overlap, |
|
|
"predicted_depth": predicted_depth, |
|
|
"video_save_path": video_save_path, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def prepare_camera_for_inference(self, view_cameras: np.ndarray, view_camera_intrinsics: np.ndarray, |
|
|
old_size: tuple[int, int], new_size: tuple[int, int]): |
|
|
"""Old and new sizes should be given as (height, width).""" |
|
|
if isinstance(view_cameras, np.ndarray): |
|
|
view_cameras = torch.from_numpy(view_cameras).float().contiguous() |
|
|
if view_cameras.ndim == 3: |
|
|
view_cameras = view_cameras.unsqueeze(dim=0) |
|
|
|
|
|
if isinstance(view_camera_intrinsics, np.ndarray): |
|
|
view_camera_intrinsics = torch.from_numpy(view_camera_intrinsics).float().contiguous() |
|
|
|
|
|
view_camera_intrinsics = resize_intrinsics(view_camera_intrinsics, old_size, new_size) |
|
|
view_camera_intrinsics = view_camera_intrinsics.unsqueeze(dim=0) |
|
|
assert view_camera_intrinsics.ndim == 4 |
|
|
|
|
|
return view_cameras.to(device_with_rank(self.device_with_rank)), \ |
|
|
view_camera_intrinsics.to(device_with_rank(self.device_with_rank)) |
|
|
|
|
|
|
|
|
def get_cache_input_depths(self) -> torch.Tensor | None: |
|
|
if self.cache is None: |
|
|
return None |
|
|
return self.cache.input_depth |
|
|
|
|
|
@property |
|
|
def W(self) -> int: |
|
|
return self.args.width |
|
|
|
|
|
@property |
|
|
def H(self) -> int: |
|
|
return self.args.height |
|
|
|
|
|
|
|
|
def clear_cache(self) -> None: |
|
|
self.cache = None |
|
|
self.model_was_seeded = False |
|
|
|
|
|
|
|
|
def cleanup(self) -> None: |
|
|
if self.args.num_gpus > 1: |
|
|
rank = get_rank() |
|
|
log.info(f"Model cleanup: destroying model parallel group on rank={rank}.", |
|
|
rank0_only=False) |
|
|
from megatron.core import parallel_state |
|
|
parallel_state.destroy_model_parallel() |
|
|
|
|
|
import torch.distributed as dist |
|
|
dist.destroy_process_group() |
|
|
|
|
|
log.info(f"Destroyed model parallel group on rank={rank}.", rank0_only=False) |
|
|
else: |
|
|
log.info("Model cleanup: nothing to do (no parallelism).", rank0_only=False) |
|
|
|