File size: 1,776 Bytes
c20c148
 
282632a
 
c20c148
1b16559
c20c148
 
282632a
 
85f869b
1b16559
 
c20c148
 
 
282632a
c20c148
87bac41
da1bd63
 
1b16559
da1bd63
c20c148
 
 
6274f48
282632a
e402f49
282632a
 
 
 
e402f49
282632a
 
 
 
c20c148
b791816
0c630c6
bb3026b
c20c148
282632a
 
 
 
c20c148
 
 
 
 
282632a
 
c20c148
 
282632a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import huggingface_hub
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetInpaintPipeline

from DPT.dpt.models import DPTDepthModel
from ip_adapter import IPAdapter, IPAdapterXL
from ip_adapter.utils import register_cross_attention_hook


def setup(base_model_path="stabilityai/stable-diffusion-xl-base-1.0",
          image_encoder_path="sdxl_models/image_encoder",
          ip_ckpt="sdxl_models/ip-adapter_sdxl.bin",
          controlnet_path="diffusers/controlnet-depth-sdxl-1.0",
          device="cuda",
          model_depth_path="DPT/weights/dpt_hybrid-midas-501f0c75.pt",
          depth_backbone="vitb_rn50_384"):
    """Set up the processing module."""
    huggingface_hub.snapshot_download(
        repo_id='h94/IP-Adapter',
        allow_patterns=[
            'models/**',
            'sdxl_models/**',
        ],
        local_dir='./',
        local_dir_use_symlinks=False,
    )

    torch.cuda.empty_cache()

    # load SDXL pipeline
    controlnet = ControlNetModel.from_pretrained(controlnet_path, use_safetensors=True,
                                                 torch_dtype=torch.float16).to(device)
    pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
        base_model_path,
        controlnet=controlnet,
        use_safetensors=True,
        torch_dtype=torch.float16,
        add_watermarker=False,
    ).to(device)
    pipe.unet = register_cross_attention_hook(pipe.unet)

    ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)

    """
    Get Depth Model Ready
    """

    model = DPTDepthModel(
        path=model_depth_path,
        backbone=depth_backbone,
        non_negative=True,
        enable_attention_hooks=False,
    )

    model.eval()

    return [ip_model, model]