Spaces:
Sleeping
Sleeping
| import torch | |
| import einops | |
| import numpy as np | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from functools import partial | |
| from torchdiffeq import odeint | |
| from unet import UNetModel | |
| from diffusers import AutoencoderKL | |
| def exists(val): | |
| return val is not None | |
| class DepthFM(nn.Module): | |
| def __init__(self, ckpt_path: str): | |
| super().__init__() | |
| vae_id = "runwayml/stable-diffusion-v1-5" | |
| self.vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae") | |
| self.scale_factor = 0.18215 | |
| # set with checkpoint | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| self.noising_step = ckpt['noising_step'] | |
| self.empty_text_embed = ckpt['empty_text_embedding'] | |
| self.model = UNetModel(**ckpt['ldm_hparams']) | |
| self.model.load_state_dict(ckpt['state_dict']) | |
| def ode_fn(self, t: Tensor, x: Tensor, **kwargs): | |
| if t.numel() == 1: | |
| t = t.expand(x.size(0)) | |
| return self.model(x=x, t=t, **kwargs) | |
| def generate(self, z: Tensor, num_steps: int = 4, n_intermediates: int = 0, **kwargs): | |
| """ | |
| ODE solving from z0 (ims) to z1 (depth). | |
| """ | |
| ode_kwargs = dict(method="euler", rtol=1e-5, atol=1e-5, options=dict(step_size=1.0 / num_steps)) | |
| # t specifies which intermediate times should the solver return | |
| # e.g. t = [0, 0.5, 1] means return the solution at t=0, t=0.5 and t=1 | |
| # but it also specifies the number of steps for fixed step size methods | |
| t = torch.linspace(0, 1, n_intermediates + 2, device=z.device, dtype=z.dtype) | |
| # t = torch.tensor([0., 1.], device=z.device, dtype=z.dtype) | |
| # allow conditioning information for model | |
| ode_fn = partial(self.ode_fn, **kwargs) | |
| ode_results = odeint(ode_fn, z, t, **ode_kwargs) | |
| if n_intermediates > 0: | |
| return ode_results | |
| return ode_results[-1] | |
| def forward(self, ims: Tensor, num_steps: int = 4, ensemble_size: int = 1): | |
| """ | |
| Args: | |
| ims: Tensor of shape (b, 3, h, w) in range [-1, 1] | |
| Returns: | |
| depth: Tensor of shape (b, 1, h, w) in range [0, 1] | |
| """ | |
| if ensemble_size > 1: | |
| assert ims.shape[0] == 1, "Ensemble mode only supported with batch size 1" | |
| ims = ims.repeat(ensemble_size, 1, 1, 1) | |
| bs, dev = ims.shape[0], ims.device | |
| ims_z = self.encode(ims, sample_posterior=False) | |
| conditioning = torch.tensor(self.empty_text_embed).to(dev).repeat(bs, 1, 1) | |
| context = ims_z | |
| x_source = ims_z | |
| if self.noising_step > 0: | |
| x_source = q_sample(x_source, self.noising_step) | |
| # solve ODE | |
| depth_z = self.generate(x_source, num_steps=num_steps, context=context, context_ca=conditioning) | |
| depth = self.decode(depth_z) | |
| depth = depth.mean(dim=1, keepdim=True) | |
| if ensemble_size > 1: | |
| depth = depth.mean(dim=0, keepdim=True) | |
| # normalize depth maps to range [-1, 1] | |
| depth = per_sample_min_max_normalization(depth.exp()) | |
| return depth | |
| def predict_depth(self, ims: Tensor, num_steps: int = 4, ensemble_size: int = 1): | |
| """ Inference method for DepthFM. """ | |
| return self.forward(ims, num_steps, ensemble_size) | |
| def encode(self, x: Tensor, sample_posterior: bool = True): | |
| posterior = self.vae.encode(x) | |
| if sample_posterior: | |
| z = posterior.latent_dist.sample() | |
| else: | |
| z = posterior.latent_dist.mode() | |
| # normalize latent code | |
| z = z * self.scale_factor | |
| return z | |
| def decode(self, z: Tensor): | |
| z = 1.0 / self.scale_factor * z | |
| return self.vae.decode(z).sample | |
| def sigmoid(x): | |
| return 1 / (1 + np.exp(-x)) | |
| def cosine_log_snr(t, eps=0.00001): | |
| """ | |
| Returns log Signal-to-Noise ratio for time step t and image size 64 | |
| eps: avoid division by zero | |
| """ | |
| return -2 * np.log(np.tan((np.pi * t) / 2) + eps) | |
| def cosine_alpha_bar(t): | |
| return sigmoid(cosine_log_snr(t)) | |
| def q_sample(x_start: torch.Tensor, t: int, noise: torch.Tensor = None, n_diffusion_timesteps: int = 1000): | |
| """ | |
| Diffuse the data for a given number of diffusion steps. In other | |
| words sample from q(x_t | x_0). | |
| """ | |
| dev = x_start.device | |
| dtype = x_start.dtype | |
| if noise is None: | |
| noise = torch.randn_like(x_start) | |
| alpha_bar_t = cosine_alpha_bar(t / n_diffusion_timesteps) | |
| alpha_bar_t = torch.tensor(alpha_bar_t).to(dev).to(dtype) | |
| return torch.sqrt(alpha_bar_t) * x_start + torch.sqrt(1 - alpha_bar_t) * noise | |
| def per_sample_min_max_normalization(x): | |
| """ Normalize each sample in a batch independently | |
| with min-max normalization to [0, 1] """ | |
| bs, *shape = x.shape | |
| x_ = einops.rearrange(x, "b ... -> b (...)") | |
| min_val = einops.reduce(x_, "b ... -> b", "min")[..., None] | |
| max_val = einops.reduce(x_, "b ... -> b", "max")[..., None] | |
| x_ = (x_ - min_val) / (max_val - min_val) | |
| return x_.reshape(bs, *shape) | |