Spaces:
Runtime error
Runtime error
| from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import ( | |
| retrieve_timesteps, | |
| retrieve_latents, | |
| ) | |
| import torch | |
| from functools import partial | |
| from diffusers import DDPMScheduler | |
| from model.pipeline_sdxl import StableDiffusionXLImg2ImgPipeline | |
| SAMPLING_DEVICE = "cpu" # "cuda" | |
| VAE_SAMPLE = "argmax" # "argmax" or "sample" | |
| RESIZE_TYPE = None # Image.LANCZOS | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def encode_image(image, pipe, generator): | |
| pipe_dtype = pipe.dtype | |
| image = pipe.image_processor.preprocess(image) | |
| image = image.to(device=device, dtype=pipe.dtype) | |
| if pipe.vae.config.force_upcast: | |
| image = image.float() | |
| pipe.vae.to(dtype=torch.float32) | |
| init_latents = retrieve_latents( | |
| pipe.vae.encode(image), generator=generator, sample_mode=VAE_SAMPLE | |
| ) | |
| if pipe.vae.config.force_upcast: | |
| pipe.vae.to(pipe_dtype) | |
| init_latents = init_latents.to(pipe_dtype) | |
| init_latents = pipe.vae.config.scaling_factor * init_latents | |
| return init_latents | |
| def create_xts( | |
| noise_shift_delta, | |
| noise_timesteps, | |
| generator, | |
| scheduler, | |
| timesteps, | |
| x_0, | |
| ): | |
| if noise_timesteps is None: | |
| noising_delta = noise_shift_delta * (timesteps[0] - timesteps[1]) | |
| noise_timesteps = [timestep - int(noising_delta) for timestep in timesteps] | |
| # noise_timesteps = [timestep for timestep in timesteps] | |
| # print(noise_timesteps, timesteps) | |
| first_x_0_idx = len(noise_timesteps) | |
| for i in range(len(noise_timesteps)): | |
| if noise_timesteps[i] <= 0: | |
| first_x_0_idx = i | |
| break | |
| noise_timesteps = noise_timesteps[:first_x_0_idx] | |
| x_0_expanded = x_0.expand(len(noise_timesteps), -1, -1, -1) | |
| noise = torch.randn( | |
| x_0_expanded.size(), generator=generator, device=SAMPLING_DEVICE | |
| ).to(x_0.device) | |
| x_ts = scheduler.add_noise( | |
| x_0_expanded, | |
| noise, | |
| torch.IntTensor(noise_timesteps), | |
| ) | |
| x_ts = [t.unsqueeze(dim=0) for t in list(x_ts)] | |
| x_ts += [x_0] * (len(timesteps) - first_x_0_idx) | |
| x_ts += [x_0] | |
| return x_ts | |
| def load_pipeline(fp16, cache_dir): | |
| kwargs = ( | |
| { | |
| "torch_dtype": torch.float16, | |
| "variant": "fp16", | |
| } | |
| if fp16 | |
| else {} | |
| ) | |
| from model.unet_sdxl import OursUNet2DConditionModel | |
| unet = OursUNet2DConditionModel.from_pretrained( | |
| "stabilityai/sdxl-turbo", | |
| subfolder="unet", | |
| cache_dir=cache_dir, | |
| safety_checker=None, | |
| **kwargs, | |
| ) | |
| pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained( | |
| "stabilityai/sdxl-turbo", | |
| unet=unet, | |
| cache_dir=cache_dir, | |
| safety_checker=None, | |
| **kwargs, | |
| ) | |
| pipeline = pipeline.to(device) | |
| pipeline.scheduler = DDPMScheduler.from_pretrained( # type: ignore | |
| "stabilityai/sdxl-turbo", | |
| subfolder="scheduler", | |
| ) | |
| return pipeline | |
| def set_pipeline(pipeline: StableDiffusionXLImg2ImgPipeline, num_timesteps, generator, config): | |
| if config.timesteps is None: | |
| denoising_start = config.step_start / config.num_steps_inversion | |
| timesteps, num_inference_steps = retrieve_timesteps( | |
| pipeline.scheduler, config.num_steps_inversion, device, None | |
| ) | |
| timesteps, num_inference_steps = pipeline.get_timesteps( | |
| num_inference_steps=num_inference_steps, | |
| device=device, | |
| denoising_start=denoising_start, | |
| strength=0, | |
| ) | |
| timesteps = timesteps.type(torch.int64) | |
| pipeline.__call__ = partial( | |
| pipeline.__call__, | |
| num_inference_steps=config.num_steps_inversion, | |
| guidance_scale=config.guidance_scale, | |
| generator=generator, | |
| denoising_start=denoising_start, | |
| strength=0, | |
| ) | |
| pipeline.scheduler.set_timesteps( | |
| timesteps=timesteps.cpu(), | |
| ) | |
| else: | |
| timesteps = torch.tensor(config.timesteps, dtype=torch.int64) | |
| pipeline.__call__ = partial( | |
| pipeline.__call__, | |
| timesteps=timesteps, | |
| guidance_scale=config.guidance_scale, | |
| denoising_start=0, | |
| strength=1, | |
| ) | |
| pipeline.scheduler.set_timesteps( | |
| timesteps=config.timesteps, # device=pipeline.device | |
| ) | |
| timesteps = [torch.tensor(t) for t in timesteps.tolist()] | |
| return timesteps, config | |