gunnit commited on
Commit
da1bd63
·
verified ·
1 Parent(s): bf8b732

Update processing/setup.py

Browse files
Files changed (1) hide show
  1. processing/setup.py +19 -38
processing/setup.py CHANGED
@@ -1,71 +1,52 @@
1
  import huggingface_hub
2
  import torch
3
- from diffusers import ControlNetModel, StableDiffusionXLControlNetInpaintPipeline, DDIMScheduler, AutoencoderKL
 
4
  from DPT.dpt.models import DPTDepthModel
5
- from ip_adapter import IPAdapter, IPAdapterXL
6
  from ip_adapter.utils import register_cross_attention_hook
7
 
8
- def setup(base_model_path="stabilityai/stable-diffusion-xl-base-1.0",
 
9
  image_encoder_path="sdxl_models/image_encoder",
10
- ip_ckpt="sdxl_models/ip-adapter_sdxl.bin",
11
- controlnet_path="diffusers/controlnet-depth-sdxl-1.0",
12
  device="cuda",
13
  model_depth_path="DPT/weights/dpt_hybrid-midas-501f0c75.pt",
14
  depth_backbone="vitb_rn50_384"):
15
- """Set up the processing module."""
16
  huggingface_hub.snapshot_download(
17
- repo_id='h94/IP-Adapter',
18
- allow_patterns=['models/**', 'sdxl_models/**'],
 
 
 
19
  local_dir='./',
20
  local_dir_use_symlinks=False,
21
  )
22
 
23
  torch.cuda.empty_cache()
24
 
25
- # # Load scheduler
26
- # noise_scheduler = DDIMScheduler(
27
- # num_train_timesteps=1000,
28
- # beta_start=0.00085,
29
- # beta_end=0.012,
30
- # beta_schedule="scaled_linear",
31
- # clip_sample=False,
32
- # set_alpha_to_one=False,
33
- # steps_offset=1,
34
- # )
35
-
36
- # Load VAE
37
- vae_model_path = "stabilityai/sd-vae-ft-mse"
38
- vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
39
-
40
- # Load ControlNet model with depth conditioning
41
  controlnet = ControlNetModel.from_pretrained(controlnet_path, use_safetensors=True,
42
  torch_dtype=torch.float16).to(device)
43
- controlnet.conditioning_scale = 1.0 # Optional: Adjust as needed
44
-
45
- # Load SDXL pipeline with additional components
46
- pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
47
  base_model_path,
48
  controlnet=controlnet,
49
  use_safetensors=True,
50
- torch_dtype=torch.float16,
51
- # scheduler=noise_scheduler,
52
- vae=vae,
53
- add_watermarker=False,
54
  ).to(device)
55
-
56
- # Register cross-attention hook for IP Adapter
57
  pipe.unet = register_cross_attention_hook(pipe.unet)
58
 
59
- # Load IP Adapter
60
  ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)
61
 
62
- # Initialize Depth Model
63
  model = DPTDepthModel(
64
  path=model_depth_path,
65
  backbone=depth_backbone,
66
  non_negative=True,
67
  enable_attention_hooks=False,
68
- ).to(device)
69
  model.eval()
70
 
71
- return [ip_model, model]
 
1
  import huggingface_hub
2
  import torch
3
+ from diffusers import ControlNetModel, StableDiffusion3Pipeline
4
+
5
  from DPT.dpt.models import DPTDepthModel
6
+ from ip_adapter import IPAdapterXL
7
  from ip_adapter.utils import register_cross_attention_hook
8
 
9
+
10
+ def setup(base_model_path="stabilityai/stable-diffusion-3.5-medium",
11
  image_encoder_path="sdxl_models/image_encoder",
12
+ ip_ckpt="sdxl_models/ip-adapter_3.5.bin", # Update for 3.5
13
+ controlnet_path="diffusers/controlnet-depth-sd3.5", # Updated path
14
  device="cuda",
15
  model_depth_path="DPT/weights/dpt_hybrid-midas-501f0c75.pt",
16
  depth_backbone="vitb_rn50_384"):
17
+ """Set up the processing module for Stable Diffusion 3.5."""
18
  huggingface_hub.snapshot_download(
19
+ repo_id='stabilityai/stable-diffusion-3.5',
20
+ allow_patterns=[
21
+ 'models/**',
22
+ 'sd3.5_models/**',
23
+ ],
24
  local_dir='./',
25
  local_dir_use_symlinks=False,
26
  )
27
 
28
  torch.cuda.empty_cache()
29
 
30
+ # Load Stable Diffusion 3.5 pipeline with ControlNet for depth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  controlnet = ControlNetModel.from_pretrained(controlnet_path, use_safetensors=True,
32
  torch_dtype=torch.float16).to(device)
33
+ pipe = StableDiffusion3Pipeline.from_pretrained(
 
 
 
34
  base_model_path,
35
  controlnet=controlnet,
36
  use_safetensors=True,
37
+ torch_dtype=torch.float16
 
 
 
38
  ).to(device)
 
 
39
  pipe.unet = register_cross_attention_hook(pipe.unet)
40
 
 
41
  ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)
42
 
43
+ # Initialize DPT Depth Model
44
  model = DPTDepthModel(
45
  path=model_depth_path,
46
  backbone=depth_backbone,
47
  non_negative=True,
48
  enable_attention_hooks=False,
49
+ )
50
  model.eval()
51
 
52
+ return [ip_model, model]