import torch from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.controlnets.controlnet import ControlNetModel from diffusers.pipelines.controlnet.pipeline_controlnet import StableDiffusionControlNetPipeline from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler from transformers import CLIPTextModel, CLIPTokenizer def prep_control_image(cond_values: torch.Tensor, device: torch.device) -> torch.Tensor: x = cond_values if x.min() < 0: x = (x * 0.5 + 0.5).clamp(0, 1) x = x.to(device=device, dtype=torch.float32) return x def build_controlnet_pipe( base_model_name: str, controlnet: ControlNetModel, vae: AutoencoderKL, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, device: torch.device, weight_dtype: torch.dtype, use_unipc: bool = True, ) -> StableDiffusionControlNetPipeline: pipe = StableDiffusionControlNetPipeline.from_pretrained( base_model_name, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, safety_checker=None, torch_dtype=weight_dtype, ) if use_unipc: pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=True) return pipe