Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from diffusers import DiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DDIMScheduler | |
| from diffusers.utils import BaseOutput | |
| class MarigoldDepthOutput(BaseOutput): | |
| depth_np: np.ndarray | |
| depth_image: Image.Image | |
| class MarigoldPipeline(DiffusionPipeline): | |
| def __init__(self, unet: UNet2DConditionModel, vae: AutoencoderKL, scheduler: DDIMScheduler): | |
| super().__init__() | |
| self.unet = unet | |
| self.vae = vae | |
| self.scheduler = scheduler | |
| def __call__(self, input_image: Image, denoising_steps: int = 10, save_path: str = None) -> MarigoldDepthOutput: | |
| device = self.device | |
| # Image preprocessing | |
| input_image = input_image.convert("RGB") | |
| image = np.asarray(input_image) | |
| rgb = np.transpose(image, (2, 0, 1)) | |
| rgb_norm = rgb / 255.0 * 2.0 - 1.0 | |
| rgb_norm = torch.from_numpy(rgb_norm).to(device) | |
| # Encode image | |
| rgb_latent = self._encode_rgb(rgb_norm) | |
| # Initial depth map (noise) | |
| depth_latent = torch.randn(rgb_latent.shape, device=device) | |
| # Denoising loop | |
| timesteps = self.scheduler.timesteps | |
| for t in timesteps: | |
| unet_input = torch.cat([rgb_latent, depth_latent], dim=1) | |
| noise_pred = self.unet(unet_input, t).sample | |
| depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample | |
| # Decode depth map | |
| depth = self._decode_depth(depth_latent) | |
| # Scale to [0, 1] and convert to numpy | |
| depth = (depth + 1.0) / 2.0 | |
| depth_np = depth.cpu().numpy().astype(np.float32) | |
| depth_image = (depth_np * 255).astype(np.uint8) | |
| depth_image = Image.fromarray(depth_image[0], 'L') # 'L' mode for grayscale image | |
| # Save the depth map image if a path is provided | |
| if save_path: | |
| depth_image.save(save_path) | |
| return MarigoldDepthOutput(depth_np=depth_np, depth_image=depth_image) | |
| def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: | |
| h = self.vae.encoder(rgb_in) | |
| moments = self.vae.quant_conv(h) | |
| mean, _ = torch.chunk(moments, 2, dim=1) | |
| rgb_latent = mean * 0.18215 | |
| return rgb_latent | |
| def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: | |
| z = self.vae.post_quant_conv(depth_latent) | |
| stacked = self.vae.decoder(z) | |
| depth_mean = stacked.mean(dim=1, keepdim=True) | |
| return depth_mean | |
| # Instantiate the model components and the pipeline | |
| unet_model = UNet2DConditionModel() | |
| vae_model = AutoencoderKL() | |
| scheduler = DDIMScheduler() | |
| pipeline = MarigoldPipeline(unet=unet_model, vae=vae_model, scheduler=scheduler) | |
| # Load an image and predict the depth map | |
| input_image = Image.open('path_to_your_image.jpg') | |
| output_path = 'path_to_save_image.jpg' # Specify the path where you want to save the depth image | |
| output = pipeline(input_image, denoising_steps=10, save_path=output_path) | |