Spaces:
Sleeping
Sleeping
| import torch | |
| from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL | |
| from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel | |
| from diffusers.models.controlnets.controlnet import ControlNetModel | |
| from diffusers.pipelines.controlnet.pipeline_controlnet import StableDiffusionControlNetPipeline | |
| from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| def prep_control_image(cond_values: torch.Tensor, device: torch.device) -> torch.Tensor: | |
| x = cond_values | |
| if x.min() < 0: | |
| x = (x * 0.5 + 0.5).clamp(0, 1) | |
| x = x.to(device=device, dtype=torch.float32) | |
| return x | |
| def build_controlnet_pipe( | |
| base_model_name: str, | |
| controlnet: ControlNetModel, | |
| vae: AutoencoderKL, | |
| unet: UNet2DConditionModel, | |
| text_encoder: CLIPTextModel, | |
| tokenizer: CLIPTokenizer, | |
| device: torch.device, | |
| weight_dtype: torch.dtype, | |
| use_unipc: bool = True, | |
| ) -> StableDiffusionControlNetPipeline: | |
| pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
| base_model_name, | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, | |
| controlnet=controlnet, | |
| safety_checker=None, | |
| torch_dtype=weight_dtype, | |
| ) | |
| if use_unipc: | |
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
| pipe = pipe.to(device) | |
| pipe.set_progress_bar_config(disable=True) | |
| return pipe |