Spaces:
Running on Zero
Running on Zero
| import torch | |
| from typing import Dict, Any | |
| from diffusers.pipelines.stable_diffusion_3 import pipeline_stable_diffusion_3 | |
| from src.flair.pipelines import utils | |
| import tqdm | |
| class SD3Wrapper(pipeline_stable_diffusion_3.StableDiffusion3Pipeline): | |
| def to(self, device, kwargs): | |
| self.transformer.to(device) | |
| self.vae.to(device) | |
| return self | |
| def get_timesteps(self, n_steps, device, ts_min=0): | |
| # Create a linear schedule for timesteps | |
| timesteps = torch.linspace(1, ts_min, n_steps+2, device=device, dtype=torch.float32) | |
| return timesteps[1:-1] # Exclude the first and last timesteps | |
| def single_step( | |
| self, | |
| img_latent: torch.Tensor, | |
| t: torch.Tensor, | |
| kwargs: Dict[str, Any], | |
| is_noised_latent = False, | |
| ): | |
| if "noise" in kwargs: | |
| noise = kwargs["noise"].detach() | |
| alpha = kwargs["inv_alpha"] | |
| if alpha == "tsqrt": | |
| alpha = t**0.5 # * 0.75 | |
| elif alpha == "t": | |
| alpha = t | |
| elif alpha == "sine": | |
| alpha = torch.sin(t * 3.141592653589793/2) | |
| elif alpha == "1-t": | |
| alpha = 1 - t | |
| elif alpha == "1-t*0.5": | |
| alpha = (1 - t)*0.5 | |
| elif alpha == "1-t*0.9": | |
| alpha = (1 - t)*0.9 | |
| elif alpha == "t**1/3": | |
| alpha = t**(1/3) | |
| elif alpha == "(1-t)**0.5": | |
| alpha = (1-t)**0.5 | |
| elif alpha == "((1-t)*0.8)**0.5": | |
| alpha = (1-t*0.8)**0.5 | |
| elif alpha == "(1-t)**2": | |
| alpha = (1-t)**2 | |
| # alpha = t * kwargs["inv_alpha"] | |
| noise = (alpha) * noise + (1-alpha**2)**0.5 * torch.randn_like(img_latent) | |
| # noise = noise / noise.std() | |
| # noise = noise / (1- 2*alpha*(1-alpha))**0.5 | |
| # noise = noise + alpha * torch.randn_like(img_latent) | |
| else: | |
| noise = torch.randn_like(img_latent) | |
| if is_noised_latent: | |
| noised_latent = img_latent | |
| else: | |
| noised_latent = t * noise + (1 - t) * img_latent | |
| latent_model_input = torch.cat([noised_latent] * 2) if self.do_classifier_free_guidance else noised_latent | |
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
| timestep = t.expand(latent_model_input.shape[0]) | |
| noise_pred = self.transformer( | |
| hidden_states=latent_model_input.to(img_latent.dtype), | |
| timestep=(timestep*1000).to(img_latent.dtype), | |
| encoder_hidden_states=kwargs["prompt_embeds"].repeat(img_latent.shape[0], 1, 1), | |
| pooled_projections=kwargs["pooled_prompt_embeds"].repeat(img_latent.shape[0], 1), | |
| joint_attention_kwargs=None, | |
| return_dict=False, | |
| )[0] | |
| if self.do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| eps = utils.v_to_eps(noise_pred, t, noised_latent) | |
| return eps, noise, (1 - t), t, noise_pred | |
| def encode(self, img): | |
| # Encode the image into latent space | |
| img_latent = self.vae.encode(img, return_dict=False)[0] | |
| if hasattr(img_latent, "sample"): | |
| img_latent = img_latent.sample() | |
| img_latent = (img_latent - self.vae.config.shift_factor) * self.vae.config.scaling_factor | |
| return img_latent | |
| def decode(self, img_latent): | |
| # Decode the latent representation back to image space | |
| img = self.vae.decode(img_latent / self.vae.config.scaling_factor + self.vae.config.shift_factor, return_dict=False)[0] | |
| return img | |
| def denoise(self, pseudo_inv, kwargs, inverse=False): | |
| # get timesteps | |
| timesteps = torch.linspace(1, 0, kwargs["n_steps"], device=pseudo_inv.device, dtype=pseudo_inv.dtype) | |
| sigmas = timesteps | |
| if inverse: | |
| timesteps = timesteps.flip(0) | |
| sigmas = sigmas.flip(0) | |
| # make a single step | |
| for i, t in tqdm.tqdm(enumerate(timesteps[:-1]), desc="Denoising", total=len(timesteps)-1): | |
| eps, noise, _, t, v = self.single_step( | |
| pseudo_inv, | |
| t.to("cuda")*1000, | |
| kwargs, | |
| is_noised_latent=True, | |
| ) | |
| # step | |
| sigma_next = sigmas[i+1] | |
| sigma_t = sigmas[i] | |
| pseudo_inv = pseudo_inv + v * (sigma_next - sigma_t) | |
| return pseudo_inv |