Spaces:
Runtime error
Runtime error
MeshAnythingV2ForRoblox
/
MeshAnything
/miche
/michelangelo
/models
/asl_diffusion
/inference_utils.py
| # -*- coding: utf-8 -*- | |
| import torch | |
| from tqdm import tqdm | |
| from typing import Tuple, List, Union, Optional | |
| from diffusers.schedulers import DDIMScheduler | |
| __all__ = ["ddim_sample"] | |
| def ddim_sample(ddim_scheduler: DDIMScheduler, | |
| diffusion_model: torch.nn.Module, | |
| shape: Union[List[int], Tuple[int]], | |
| cond: torch.FloatTensor, | |
| steps: int, | |
| eta: float = 0.0, | |
| guidance_scale: float = 3.0, | |
| do_classifier_free_guidance: bool = True, | |
| generator: Optional[torch.Generator] = None, | |
| device: torch.device = "cuda:0", | |
| disable_prog: bool = True): | |
| assert steps > 0, f"{steps} must > 0." | |
| # init latents | |
| bsz = cond.shape[0] | |
| if do_classifier_free_guidance: | |
| bsz = bsz // 2 | |
| latents = torch.randn( | |
| (bsz, *shape), | |
| generator=generator, | |
| device=cond.device, | |
| dtype=cond.dtype, | |
| ) | |
| # scale the initial noise by the standard deviation required by the scheduler | |
| latents = latents * ddim_scheduler.init_noise_sigma | |
| # set timesteps | |
| ddim_scheduler.set_timesteps(steps) | |
| timesteps = ddim_scheduler.timesteps.to(device) | |
| # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
| # eta (η) is only used with the DDIMScheduler, and between [0, 1] | |
| extra_step_kwargs = { | |
| "eta": eta, | |
| "generator": generator | |
| } | |
| # reverse | |
| for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)): | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = ( | |
| torch.cat([latents] * 2) | |
| if do_classifier_free_guidance | |
| else latents | |
| ) | |
| # latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
| # predict the noise residual | |
| timestep_tensor = torch.tensor([t], dtype=torch.long, device=device) | |
| timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) | |
| noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond) | |
| # perform guidance | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| # text_embeddings_for_guidance = encoder_hidden_states.chunk( | |
| # 2)[1] if do_classifier_free_guidance else encoder_hidden_states | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = ddim_scheduler.step( | |
| noise_pred, t, latents, **extra_step_kwargs | |
| ).prev_sample | |
| yield latents, t | |
| def karra_sample(): | |
| pass | |