Spaces:
Runtime error
Runtime error
| import torch | |
| import tqdm | |
| import numpy as np | |
| from diffusers import DiffusionPipeline | |
| from diffusers.utils import BaseOutput | |
| import matplotlib | |
| def colorize_depth(depth, cmap="Spectral"): | |
| # colorize | |
| cm = matplotlib.colormaps[cmap] | |
| # (B, N, H, W, 3) | |
| depth_colored = cm(depth, bytes=False)[..., 0:3] # value from 0 to 1 | |
| return depth_colored | |
| class DAVOutput(BaseOutput): | |
| r""" | |
| Output class for zero-shot text-to-video pipeline. | |
| Args: | |
| frames (`[List[PIL.Image.Image]`, `np.ndarray`]): | |
| List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, | |
| num_channels)`. | |
| """ | |
| disparity: np.ndarray | |
| disparity_colored: np.ndarray | |
| image: np.ndarray | |
| class DAVPipeline(DiffusionPipeline): | |
| def __init__(self, vae, unet, unet_interp, scheduler): | |
| super().__init__() | |
| self.register_modules( | |
| vae=vae, unet=unet, unet_interp=unet_interp, scheduler=scheduler | |
| ) | |
| def encode(self, input): | |
| num_frames = input.shape[1] | |
| input = input.flatten(0, 1) | |
| latent = self.vae.encode(input.to(self.vae.dtype)).latent_dist.mode() | |
| latent = latent * self.vae.config.scaling_factor | |
| latent = latent.reshape(-1, num_frames, *latent.shape[1:]) | |
| return latent | |
| def decode(self, latents, decode_chunk_size=16): | |
| # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] | |
| num_frames = latents.shape[1] | |
| latents = latents.flatten(0, 1) | |
| latents = latents / self.vae.config.scaling_factor | |
| # decode decode_chunk_size frames at a time to avoid OOM | |
| frames = [] | |
| for i in range(0, latents.shape[0], decode_chunk_size): | |
| num_frames_in = latents[i : i + decode_chunk_size].shape[0] | |
| frame = self.vae.decode( | |
| latents[i : i + decode_chunk_size].to(self.vae.dtype), | |
| num_frames=num_frames_in, | |
| ).sample | |
| frames.append(frame) | |
| frames = torch.cat(frames, dim=0) | |
| # [batch, frames, channels, height, width] | |
| frames = frames.reshape(-1, num_frames, *frames.shape[1:]) | |
| return frames.to(torch.float32) | |
| def single_infer(self, rgb, position_ids=None, num_inference_steps=None): | |
| rgb_latent = self.encode(rgb) | |
| noise_latent = torch.randn_like(rgb_latent) | |
| self.scheduler.set_timesteps(num_inference_steps, device=rgb.device) | |
| timesteps = self.scheduler.timesteps | |
| image_embeddings = torch.zeros((noise_latent.shape[0], 1, 1024)).to( | |
| noise_latent | |
| ) | |
| for i, t in enumerate(timesteps): | |
| latent_model_input = noise_latent | |
| latent_model_input = torch.cat([latent_model_input, rgb_latent], dim=2) | |
| # [batch_size, num_frame, 4, h, w] | |
| model_output = self.unet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=image_embeddings, | |
| position_ids=position_ids, | |
| ).sample | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| noise_latent = self.scheduler.step( | |
| model_output, t, noise_latent | |
| ).prev_sample | |
| return noise_latent | |
| def single_interp_infer( | |
| self, rgb, masked_depth_latent, mask, num_inference_steps=None | |
| ): | |
| rgb_latent = self.encode(rgb) | |
| noise_latent = torch.randn_like(rgb_latent) | |
| self.scheduler.set_timesteps(num_inference_steps, device=rgb.device) | |
| timesteps = self.scheduler.timesteps | |
| image_embeddings = torch.zeros((noise_latent.shape[0], 1, 1024)).to( | |
| noise_latent | |
| ) | |
| for i, t in enumerate(timesteps): | |
| latent_model_input = noise_latent | |
| latent_model_input = torch.cat( | |
| [latent_model_input, rgb_latent, masked_depth_latent, mask], dim=2 | |
| ) | |
| # [batch_size, num_frame, 4, h, w] | |
| model_output = self.unet_interp( | |
| latent_model_input, t, encoder_hidden_states=image_embeddings | |
| ).sample | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| noise_latent = self.scheduler.step( | |
| model_output, t, noise_latent | |
| ).prev_sample | |
| return noise_latent | |
| def __call__( | |
| self, | |
| image, | |
| num_frames, | |
| num_overlap_frames, | |
| num_interp_frames, | |
| decode_chunk_size, | |
| num_inference_steps, | |
| ): | |
| self.vae.to(dtype=torch.float16) | |
| # (1, N, 3, H, W) | |
| image = image.unsqueeze(0) | |
| B, N = image.shape[:2] | |
| rgb = image * 2 - 1 # [-1, 1] | |
| if N <= num_frames or N <= num_interp_frames + 2 - num_overlap_frames: | |
| depth_latent = self.single_infer( | |
| rgb, num_inference_steps=num_inference_steps | |
| ) | |
| else: | |
| assert 2 <= num_overlap_frames <= (num_interp_frames + 2 + 1) // 2 | |
| assert num_frames % 2 == 0 | |
| key_frame_indices = [] | |
| for i in range(0, N, num_interp_frames + 2 - num_overlap_frames): | |
| if ( | |
| i + num_interp_frames + 1 >= N | |
| or len(key_frame_indices) >= num_frames | |
| ): | |
| break | |
| key_frame_indices.append(i) | |
| key_frame_indices.append(i + num_interp_frames + 1) | |
| key_frame_indices = torch.tensor(key_frame_indices, device=rgb.device) | |
| sorted_key_frame_indices, origin_indices = torch.sort(key_frame_indices) | |
| key_rgb = rgb[:, sorted_key_frame_indices] | |
| key_depth_latent = self.single_infer( | |
| key_rgb, | |
| sorted_key_frame_indices.unsqueeze(0).repeat(B, 1), | |
| num_inference_steps=num_inference_steps, | |
| ) | |
| key_depth_latent = key_depth_latent[:, origin_indices] | |
| torch.cuda.empty_cache() | |
| depth_latent = [] | |
| pre_latent = None | |
| for i in tqdm.tqdm(range(0, len(key_frame_indices), 2)): | |
| frame1 = key_depth_latent[:, i] | |
| frame2 = key_depth_latent[:, i + 1] | |
| masked_depth_latent = torch.zeros( | |
| (B, num_interp_frames + 2, *key_depth_latent.shape[2:]) | |
| ).to(key_depth_latent) | |
| masked_depth_latent[:, 0] = frame1 | |
| masked_depth_latent[:, -1] = frame2 | |
| mask = torch.zeros_like(masked_depth_latent) | |
| mask[:, [0, -1]] = 1.0 | |
| latent = self.single_interp_infer( | |
| rgb[:, key_frame_indices[i] : key_frame_indices[i + 1] + 1], | |
| masked_depth_latent, | |
| mask, | |
| num_inference_steps=num_inference_steps, | |
| ) | |
| latent = latent[:, 1:-1] | |
| if pre_latent is not None: | |
| overlap_a = pre_latent[ | |
| :, pre_latent.shape[1] - (num_overlap_frames - 2) : | |
| ] | |
| overlap_b = latent[:, : (num_overlap_frames - 2)] | |
| ratio = ( | |
| torch.linspace(0, 1, num_overlap_frames - 2) | |
| .to(overlap_a) | |
| .view(1, -1, 1, 1, 1) | |
| ) | |
| overlap = overlap_a * (1 - ratio) + overlap_b * ratio | |
| pre_latent[:, pre_latent.shape[1] - (num_overlap_frames - 2) :] = ( | |
| overlap | |
| ) | |
| depth_latent.append(pre_latent) | |
| pre_latent = latent[:, (num_overlap_frames - 2) if i > 0 else 0 :] | |
| torch.cuda.empty_cache() | |
| depth_latent.append(pre_latent) | |
| depth_latent = torch.cat(depth_latent, dim=1) | |
| # dicard the first and last key frames | |
| image = image[:, key_frame_indices[0] + 1 : key_frame_indices[-1]] | |
| assert depth_latent.shape[1] == image.shape[1] | |
| disparity = self.decode(depth_latent, decode_chunk_size=decode_chunk_size) | |
| disparity = disparity.mean(dim=2, keepdim=False) | |
| disparity = torch.clamp(disparity * 0.5 + 0.5, 0.0, 1.0) | |
| # (N, H, W) | |
| disparity = disparity.squeeze(0) | |
| # (N, H, W, 3) | |
| mid_d, max_d = disparity.min(), disparity.max() | |
| disparity_colored = torch.clamp((max_d - disparity) / (max_d - mid_d), 0.0, 1.0) | |
| disparity_colored = colorize_depth(disparity_colored.cpu().numpy()) | |
| disparity_colored = (disparity_colored * 255).astype(np.uint8) | |
| image = image.squeeze(0).permute(0, 2, 3, 1).cpu().numpy() | |
| image = (image * 255).astype(np.uint8) | |
| disparity = disparity.cpu().numpy() | |
| return DAVOutput( | |
| disparity=disparity, | |
| disparity_colored=disparity_colored, | |
| image=image, | |
| ) | |