gunnit's picture
Update processing/setup.py
282632a verified
import huggingface_hub
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetInpaintPipeline
from DPT.dpt.models import DPTDepthModel
from ip_adapter import IPAdapter, IPAdapterXL
from ip_adapter.utils import register_cross_attention_hook
def setup(base_model_path="stabilityai/stable-diffusion-xl-base-1.0",
image_encoder_path="sdxl_models/image_encoder",
ip_ckpt="sdxl_models/ip-adapter_sdxl.bin",
controlnet_path="diffusers/controlnet-depth-sdxl-1.0",
device="cuda",
model_depth_path="DPT/weights/dpt_hybrid-midas-501f0c75.pt",
depth_backbone="vitb_rn50_384"):
"""Set up the processing module."""
huggingface_hub.snapshot_download(
repo_id='h94/IP-Adapter',
allow_patterns=[
'models/**',
'sdxl_models/**',
],
local_dir='./',
local_dir_use_symlinks=False,
)
torch.cuda.empty_cache()
# load SDXL pipeline
controlnet = ControlNetModel.from_pretrained(controlnet_path, use_safetensors=True,
torch_dtype=torch.float16).to(device)
pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
base_model_path,
controlnet=controlnet,
use_safetensors=True,
torch_dtype=torch.float16,
add_watermarker=False,
).to(device)
pipe.unet = register_cross_attention_hook(pipe.unet)
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)
"""
Get Depth Model Ready
"""
model = DPTDepthModel(
path=model_depth_path,
backbone=depth_backbone,
non_negative=True,
enable_attention_hooks=False,
)
model.eval()
return [ip_model, model]