Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from tqdm import tqdm | |
| # from seva.geometry import get_camera_dist | |
| from typing import Union | |
| def get_camera_dist( | |
| source_c2ws: torch.Tensor, # N x 3 x 4 | |
| target_c2ws: torch.Tensor, # M x 3 x 4 | |
| mode: str = "translation", | |
| ): | |
| if mode == "rotation": | |
| dists = torch.acos( | |
| ( | |
| ( | |
| torch.matmul( | |
| source_c2ws[:, None, :3, :3], | |
| target_c2ws[None, :, :3, :3].transpose(-1, -2), | |
| ) | |
| .diagonal(offset=0, dim1=-2, dim2=-1) | |
| .sum(-1) | |
| - 1 | |
| ) | |
| / 2 | |
| ).clamp(-1, 1) | |
| ) * (180 / torch.pi) | |
| elif mode == "translation": | |
| dists = torch.norm( | |
| source_c2ws[:, None, :3, 3] - target_c2ws[None, :, :3, 3], dim=-1 | |
| ) | |
| else: | |
| raise NotImplementedError( | |
| f"Mode {mode} is not implemented for finding nearest source indices." | |
| ) | |
| return dists | |
| def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: | |
| """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
| dims_to_append = target_dims - x.ndim | |
| if dims_to_append < 0: | |
| raise ValueError( | |
| f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" | |
| ) | |
| return x[(...,) + (None,) * dims_to_append] | |
| def append_zero(x: torch.Tensor) -> torch.Tensor: | |
| return torch.cat([x, x.new_zeros([1])]) | |
| def to_d(x: torch.Tensor, sigma: torch.Tensor, denoised: torch.Tensor) -> torch.Tensor: | |
| return (x - denoised) / append_dims(sigma, x.ndim) | |
| def make_betas( | |
| num_timesteps: int, linear_start: float = 1e-4, linear_end: float = 2e-2 | |
| ) -> np.ndarray: | |
| betas = ( | |
| torch.linspace( | |
| linear_start**0.5, linear_end**0.5, num_timesteps, dtype=torch.float64 | |
| ) | |
| ** 2 | |
| ) | |
| return betas.numpy() | |
| def generate_roughly_equally_spaced_steps( | |
| num_substeps: int, max_step: int | |
| ) -> np.ndarray: | |
| return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] | |
| class EpsScaling(object): | |
| def __call__( | |
| self, sigma: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| c_skip = torch.ones_like(sigma, device=sigma.device) | |
| c_out = -sigma | |
| c_in = 1 / (sigma**2 + 1.0) ** 0.5 | |
| c_noise = sigma.clone() | |
| return c_skip, c_out, c_in, c_noise | |
| class DDPMDiscretization(object): | |
| def __init__( | |
| self, | |
| linear_start: float = 5e-06, | |
| linear_end: float = 0.012, | |
| num_timesteps: int = 1000, | |
| log_snr_shift: Union[float, None] = 2.4, | |
| ): | |
| self.num_timesteps = num_timesteps | |
| betas = make_betas( | |
| num_timesteps, | |
| linear_start=linear_start, | |
| linear_end=linear_end, | |
| ) | |
| self.log_snr_shift = log_snr_shift | |
| alphas = 1.0 - betas # first alpha here is on data side | |
| self.alphas_cumprod = np.cumprod(alphas, axis=0) | |
| def get_sigmas(self, n: int, device: Union[str, torch.device] = "cpu") -> torch.Tensor: | |
| if n < self.num_timesteps: | |
| timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) | |
| alphas_cumprod = self.alphas_cumprod[timesteps] | |
| elif n == self.num_timesteps: | |
| alphas_cumprod = self.alphas_cumprod | |
| else: | |
| raise ValueError(f"Expected n <= {self.num_timesteps}, but got n = {n}.") | |
| sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 | |
| if self.log_snr_shift is not None: | |
| sigmas = sigmas * np.exp(self.log_snr_shift) | |
| return torch.flip( | |
| torch.tensor(sigmas, dtype=torch.float32, device=device), (0,) | |
| ) | |
| def __call__( | |
| self, | |
| n: int, | |
| do_append_zero: bool = True, | |
| flip: bool = False, | |
| device: Union[str, torch.device] = "cpu", | |
| ) -> torch.Tensor: | |
| sigmas = self.get_sigmas(n, device=device) | |
| sigmas = append_zero(sigmas) if do_append_zero else sigmas | |
| return sigmas if not flip else torch.flip(sigmas, (0,)) | |
| class DiscreteDenoiser(object): | |
| sigmas: torch.Tensor | |
| def __init__( | |
| self, | |
| discretization: DDPMDiscretization, | |
| num_idx: int = 1000, | |
| device: Union[str, torch.device] = "cpu", | |
| ): | |
| self.scaling = EpsScaling() | |
| self.discretization = discretization | |
| self.num_idx = num_idx | |
| self.device = device | |
| self.register_sigmas() | |
| def register_sigmas(self): | |
| self.sigmas = self.discretization( | |
| self.num_idx, do_append_zero=False, flip=True, device=self.device | |
| ) | |
| def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: | |
| dists = sigma - self.sigmas[:, None] | |
| return dists.abs().argmin(dim=0).view(sigma.shape) | |
| def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor: | |
| return self.sigmas[idx] | |
| def __call__( | |
| self, | |
| network: nn.Module, | |
| input: torch.Tensor, | |
| sigma: torch.Tensor, | |
| cond: dict, | |
| **additional_model_inputs, | |
| ) -> torch.Tensor: | |
| sigma = self.idx_to_sigma(self.sigma_to_idx(sigma)) | |
| sigma_shape = sigma.shape | |
| sigma = append_dims(sigma, input.ndim) | |
| c_skip, c_out, c_in, c_noise = self.scaling(sigma) | |
| c_noise = self.sigma_to_idx(c_noise.reshape(sigma_shape)) | |
| if "replace" in cond: | |
| x, mask = cond.get("replace").split((input.shape[1], 1), dim=1) | |
| input = input * (1 - mask) + x * mask | |
| return ( | |
| network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out | |
| + input * c_skip | |
| ) | |
| class ConstantScaleRule(object): | |
| def __call__(self, scale: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: | |
| return scale | |
| class MultiviewScaleRule(object): | |
| def __init__(self, min_scale: float = 1.0): | |
| self.min_scale = min_scale | |
| def __call__( | |
| self, | |
| scale: Union[float, torch.Tensor], | |
| c2w: torch.Tensor, | |
| K: torch.Tensor, | |
| input_frame_mask: torch.Tensor, | |
| ) -> torch.Tensor: | |
| c2w_input = c2w[input_frame_mask] | |
| rotation_diff = get_camera_dist(c2w, c2w_input, mode="rotation").min(-1).values | |
| translation_diff = ( | |
| get_camera_dist(c2w, c2w_input, mode="translation").min(-1).values | |
| ) | |
| K_diff = ( | |
| ((K[:, None] - K[input_frame_mask][None]).flatten(-2) == 0).all(-1).any(-1) | |
| ) | |
| close_frame = (rotation_diff < 10.0) & (translation_diff < 1e-5) & K_diff | |
| if isinstance(scale, torch.Tensor): | |
| scale = scale.clone() | |
| scale[close_frame] = self.min_scale | |
| elif isinstance(scale, float): | |
| scale = torch.where(close_frame, self.min_scale, scale) | |
| else: | |
| raise ValueError(f"Invalid scale type {type(scale)}.") | |
| return scale | |
| class ConstantScaleSchedule(object): | |
| def __call__( | |
| self, sigma: Union[float, torch.Tensor], scale: Union[float, torch.Tensor] | |
| ) -> Union[float, torch.Tensor]: | |
| if isinstance(sigma, float): | |
| return scale | |
| elif isinstance(sigma, torch.Tensor): | |
| if len(sigma.shape) == 1 and isinstance(scale, torch.Tensor): | |
| sigma = append_dims(sigma, scale.ndim) | |
| return scale * torch.ones_like(sigma) | |
| else: | |
| raise ValueError(f"Invalid sigma type {type(sigma)}.") | |
| class ConstantGuidance(object): | |
| def __call__( | |
| self, | |
| uncond: torch.Tensor, | |
| cond: torch.Tensor, | |
| scale: Union[float, torch.Tensor], | |
| ) -> torch.Tensor: | |
| if isinstance(scale, torch.Tensor) and len(scale.shape) == 1: | |
| scale = append_dims(scale, cond.ndim) | |
| return uncond + scale * (cond - uncond) | |
| class VanillaCFG(object): | |
| def __init__(self): | |
| self.scale_rule = ConstantScaleRule() | |
| self.scale_schedule = ConstantScaleSchedule() | |
| self.guidance = ConstantGuidance() | |
| def __call__( | |
| self, x: torch.Tensor, sigma: Union[float, torch.Tensor], scale: Union[float, torch.Tensor] | |
| ) -> torch.Tensor: | |
| x_u, x_c = x.chunk(2) | |
| scale = self.scale_rule(scale) | |
| scale_value = self.scale_schedule(sigma, scale) | |
| x_pred = self.guidance(x_u, x_c, scale_value) | |
| return x_pred | |
| def prepare_inputs( | |
| self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict | |
| ) -> tuple[torch.Tensor, torch.Tensor, dict]: | |
| c_out = dict() | |
| for k in c: | |
| if k in ["vector", "crossattn", "concat", "replace", "dense_vector"]: | |
| c_out[k] = torch.cat((uc[k], c[k]), 0) | |
| else: | |
| assert c[k] == uc[k] | |
| c_out[k] = c[k] | |
| return torch.cat([x] * 2), torch.cat([s] * 2), c_out | |
| class MultiviewCFG(VanillaCFG): | |
| def __init__(self, cfg_min: float = 1.0): | |
| self.scale_min = cfg_min | |
| self.scale_rule = MultiviewScaleRule(min_scale=cfg_min) | |
| self.scale_schedule = ConstantScaleSchedule() | |
| self.guidance = ConstantGuidance() | |
| def __call__( # type: ignore | |
| self, | |
| x: torch.Tensor, | |
| sigma: Union[float, torch.Tensor], | |
| scale: Union[float, torch.Tensor], | |
| c2w: torch.Tensor, | |
| K: torch.Tensor, | |
| input_frame_mask: torch.Tensor, | |
| ) -> torch.Tensor: | |
| x_u, x_c = x.chunk(2) | |
| scale = self.scale_rule(scale, c2w, K, input_frame_mask) | |
| scale_value = self.scale_schedule(sigma, scale) | |
| x_pred = self.guidance(x_u, x_c, scale_value) | |
| return x_pred | |
| class MultiviewTemporalCFG(MultiviewCFG): | |
| def __init__(self, num_frames: int, cfg_min: float = 1.0): | |
| super().__init__(cfg_min=cfg_min) | |
| self.num_frames = num_frames | |
| distance_matrix = ( | |
| torch.arange(num_frames)[None] - torch.arange(num_frames)[:, None] | |
| ).abs() | |
| self.distance_matrix = distance_matrix | |
| def __call__( | |
| self, | |
| x: torch.Tensor, | |
| sigma: Union[float, torch.Tensor], | |
| scale: Union[float, torch.Tensor], | |
| c2w: torch.Tensor, | |
| K: torch.Tensor, | |
| input_frame_mask: torch.Tensor, | |
| ) -> torch.Tensor: | |
| input_frame_mask = rearrange( | |
| input_frame_mask, "(b t) ... -> b t ...", t=self.num_frames | |
| ) | |
| min_distance = ( | |
| self.distance_matrix[None].to(x.device) | |
| + (~input_frame_mask[:, None]) * self.num_frames | |
| ).min(-1)[0] | |
| min_distance = min_distance / min_distance.max(-1, keepdim=True)[0].clamp(min=1) | |
| scale = min_distance * (scale - self.scale_min) + self.scale_min | |
| scale = rearrange(scale, "b t ... -> (b t) ...") | |
| scale = append_dims(scale, x.ndim) | |
| return super().__call__(x, sigma, scale, c2w, K, input_frame_mask.flatten(0, 1)) | |
| class EulerEDMSampler(object): | |
| def __init__( | |
| self, | |
| discretization: DDPMDiscretization, | |
| guider: Union[VanillaCFG, MultiviewCFG, MultiviewTemporalCFG], | |
| num_steps: Union[int, None] = None, | |
| verbose: bool = False, | |
| device: Union[str, torch.device] = "cuda", | |
| s_churn=0.0, | |
| s_tmin=0.0, | |
| s_tmax=float("inf"), | |
| s_noise=1.0, | |
| ): | |
| self.num_steps = num_steps | |
| self.discretization = discretization | |
| self.guider = guider | |
| self.verbose = verbose | |
| self.device = device | |
| self.s_churn = s_churn | |
| self.s_tmin = s_tmin | |
| self.s_tmax = s_tmax | |
| self.s_noise = s_noise | |
| def prepare_sampling_loop( | |
| self, x: torch.Tensor, cond: dict, uc: dict, num_steps: Union[int, None] = None | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, dict, dict]: | |
| num_steps = num_steps or self.num_steps | |
| assert num_steps is not None, "num_steps must be specified" | |
| sigmas = self.discretization(num_steps, device=self.device) | |
| x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) | |
| num_sigmas = len(sigmas) | |
| s_in = x.new_ones([x.shape[0]]) | |
| return x, s_in, sigmas, num_sigmas, cond, uc | |
| def get_sigma_gen(self, num_sigmas: int, verbose: bool = True) -> Union[range, tqdm]: | |
| sigma_generator = range(num_sigmas - 1) | |
| if self.verbose and verbose: | |
| sigma_generator = tqdm( | |
| sigma_generator, | |
| total=num_sigmas - 1, | |
| desc="Sampling", | |
| leave=False, | |
| ) | |
| return sigma_generator | |
| def sampler_step( | |
| self, | |
| sigma: torch.Tensor, | |
| next_sigma: torch.Tensor, | |
| denoiser, | |
| x: torch.Tensor, | |
| scale: Union[float, torch.Tensor], | |
| cond: dict, | |
| uc: dict, | |
| gamma: float = 0.0, | |
| **guider_kwargs, | |
| ) -> torch.Tensor: | |
| sigma_hat = sigma * (gamma + 1.0) + 1e-6 | |
| eps = torch.randn_like(x) * self.s_noise | |
| x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 | |
| denoised = denoiser(*self.guider.prepare_inputs(x, sigma_hat, cond, uc)) | |
| denoised = self.guider(denoised, sigma_hat, scale, **guider_kwargs) | |
| d = to_d(x, sigma_hat, denoised) | |
| dt = append_dims(next_sigma - sigma_hat, x.ndim) | |
| return x + dt * d | |
| def __call__( | |
| self, | |
| denoiser, | |
| x: torch.Tensor, | |
| scale: Union[float, torch.Tensor], | |
| cond: dict, | |
| uc: Union[dict, None] = None, | |
| num_steps: Union[int, None] = None, | |
| verbose: bool = True, | |
| **guider_kwargs, | |
| ) -> torch.Tensor: | |
| uc = cond if uc is None else uc | |
| x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( | |
| x, | |
| cond, | |
| uc, | |
| num_steps, | |
| ) | |
| for i in self.get_sigma_gen(num_sigmas, verbose=verbose): | |
| gamma = ( | |
| min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) | |
| if self.s_tmin <= sigmas[i] <= self.s_tmax | |
| else 0.0 | |
| ) | |
| x = self.sampler_step( | |
| s_in * sigmas[i], | |
| s_in * sigmas[i + 1], | |
| denoiser, | |
| x, | |
| scale, | |
| cond, | |
| uc, | |
| gamma, | |
| **guider_kwargs, | |
| ) | |
| return x | |
| def create_samplers( | |
| guider_types: Union[int, list[int]], | |
| discretization, | |
| num_frames: Union[list[int], None], | |
| num_steps: int=50, | |
| cfg_min: float = 1.2, | |
| device: Union[str, torch.device] = "cuda" | |
| ): | |
| guider_mapping = { | |
| 0: VanillaCFG, | |
| 1: MultiviewCFG, | |
| 2: MultiviewTemporalCFG, | |
| } | |
| samplers = [] | |
| if not isinstance(guider_types, (list, tuple)): | |
| guider_types = [guider_types] | |
| for i, guider_type in enumerate(guider_types): | |
| if guider_type not in guider_mapping: | |
| raise ValueError( | |
| f"Invalid guider type {guider_type}. Must be one of {list(guider_mapping.keys())}" | |
| ) | |
| guider_cls = guider_mapping[guider_type] | |
| guider_args = () | |
| if guider_type > 0: | |
| guider_args += (cfg_min,) | |
| if guider_type == 2: | |
| assert num_frames is not None | |
| guider_args = (num_frames[i], cfg_min) | |
| guider = guider_cls(*guider_args) | |
| sampler = EulerEDMSampler( | |
| discretization=discretization, | |
| guider=guider, | |
| num_steps=num_steps, | |
| s_churn=0.0, | |
| s_tmin=0.0, | |
| s_tmax=999.0, | |
| s_noise=1.0, | |
| verbose=True, | |
| device=device, | |
| ) | |
| samplers.append(sampler) | |
| return samplers | |