Canny_ControlNet / pipeline.py
ICGenAIShare07's picture
Upload pipeline.py
67ea22c verified
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.models.controlnets.controlnet import ControlNetModel
from diffusers.pipelines.controlnet.pipeline_controlnet import StableDiffusionControlNetPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from transformers import CLIPTextModel, CLIPTokenizer
def prep_control_image(cond_values: torch.Tensor, device: torch.device) -> torch.Tensor:
x = cond_values
if x.min() < 0:
x = (x * 0.5 + 0.5).clamp(0, 1)
x = x.to(device=device, dtype=torch.float32)
return x
def build_controlnet_pipe(
base_model_name: str,
controlnet: ControlNetModel,
vae: AutoencoderKL,
unet: UNet2DConditionModel,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
device: torch.device,
weight_dtype: torch.dtype,
use_unipc: bool = True,
) -> StableDiffusionControlNetPipeline:
pipe = StableDiffusionControlNetPipeline.from_pretrained(
base_model_name,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
safety_checker=None,
torch_dtype=weight_dtype,
)
if use_unipc:
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=True)
return pipe