ssoxye commited on
Commit
689a987
·
1 Parent(s): 1d216b1

Clean Space repo (code only) + gradio app

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +3 -0
  2. .ipynb_checkpoints/README-checkpoint.md +14 -0
  3. .ipynb_checkpoints/app-checkpoint.py +422 -0
  4. .ipynb_checkpoints/requirements-checkpoint.txt +20 -0
  5. diffusers3/Untitled.ipynb +58 -0
  6. diffusers3/__init__.py +934 -0
  7. diffusers3/__pycache__/__init__.cpython-310.pyc +0 -0
  8. diffusers3/__pycache__/__init__.cpython-38.pyc +0 -0
  9. diffusers3/__pycache__/callbacks.cpython-310.pyc +0 -0
  10. diffusers3/__pycache__/callbacks.cpython-38.pyc +0 -0
  11. diffusers3/__pycache__/configuration_utils.cpython-310.pyc +0 -0
  12. diffusers3/__pycache__/configuration_utils.cpython-38.pyc +0 -0
  13. diffusers3/__pycache__/dependency_versions_check.cpython-310.pyc +0 -0
  14. diffusers3/__pycache__/dependency_versions_check.cpython-38.pyc +0 -0
  15. diffusers3/__pycache__/dependency_versions_table.cpython-310.pyc +0 -0
  16. diffusers3/__pycache__/dependency_versions_table.cpython-38.pyc +0 -0
  17. diffusers3/__pycache__/image_processor.cpython-310.pyc +0 -0
  18. diffusers3/__pycache__/image_processor.cpython-38.pyc +0 -0
  19. diffusers3/callbacks.py +156 -0
  20. diffusers3/commands/__init__.py +27 -0
  21. diffusers3/commands/diffusers_cli.py +43 -0
  22. diffusers3/commands/env.py +180 -0
  23. diffusers3/commands/fp16_safetensors.py +132 -0
  24. diffusers3/configuration_utils.py +720 -0
  25. diffusers3/dependency_versions_check.py +34 -0
  26. diffusers3/dependency_versions_table.py +46 -0
  27. diffusers3/experimental/README.md +5 -0
  28. diffusers3/experimental/__init__.py +1 -0
  29. diffusers3/experimental/rl/__init__.py +1 -0
  30. diffusers3/experimental/rl/value_guided_sampling.py +153 -0
  31. diffusers3/image_processor.py +1103 -0
  32. diffusers3/loaders/__init__.py +100 -0
  33. diffusers3/loaders/__pycache__/__init__.cpython-310.pyc +0 -0
  34. diffusers3/loaders/__pycache__/__init__.cpython-38.pyc +0 -0
  35. diffusers3/loaders/__pycache__/ip_adapter.cpython-310.pyc +0 -0
  36. diffusers3/loaders/__pycache__/ip_adapter.cpython-38.pyc +0 -0
  37. diffusers3/loaders/__pycache__/lora_base.cpython-310.pyc +0 -0
  38. diffusers3/loaders/__pycache__/lora_base.cpython-38.pyc +0 -0
  39. diffusers3/loaders/__pycache__/lora_conversion_utils.cpython-310.pyc +0 -0
  40. diffusers3/loaders/__pycache__/lora_conversion_utils.cpython-38.pyc +0 -0
  41. diffusers3/loaders/__pycache__/lora_pipeline.cpython-310.pyc +0 -0
  42. diffusers3/loaders/__pycache__/lora_pipeline.cpython-38.pyc +0 -0
  43. diffusers3/loaders/__pycache__/peft.cpython-310.pyc +0 -0
  44. diffusers3/loaders/__pycache__/peft.cpython-38.pyc +0 -0
  45. diffusers3/loaders/__pycache__/single_file.cpython-310.pyc +0 -0
  46. diffusers3/loaders/__pycache__/single_file.cpython-38.pyc +0 -0
  47. diffusers3/loaders/__pycache__/single_file_model.cpython-310.pyc +0 -0
  48. diffusers3/loaders/__pycache__/single_file_model.cpython-38.pyc +0 -0
  49. diffusers3/loaders/__pycache__/single_file_utils.cpython-310.pyc +0 -0
  50. diffusers3/loaders/__pycache__/single_file_utils.cpython-38.pyc +0 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ models/
2
+ sdxl_models/
3
+ preprocess/ckpts/
.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: VISTA
3
+ emoji: 🖼
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-4.0
11
+ short_description: VISTA Demo Page
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
.ipynb_checkpoints/app-checkpoint.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from dataclasses import dataclass
4
+ from functools import lru_cache
5
+ from typing import Optional, Tuple
6
+
7
+ import gradio as gr
8
+ import torch
9
+
10
+ # === 네 코드에서 쓰던 import 그대로 ===
11
+ from diffusers import UniPCMultistepScheduler
12
+ from diffusers3.models.controlnet import ControlNetModel
13
+ from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import (
14
+ StableDiffusionXLControlNetImg2ImgPipeline,
15
+ )
16
+ from ip_adapter import IPAdapterXL
17
+
18
+ import cv2
19
+ import numpy as np
20
+ import imageio
21
+ from PIL import Image, ImageOps
22
+
23
+ from preprocess.simple_extractor import run_simple_extractor
24
+
25
+
26
+ # =========================
27
+ # 사용자 환경/경로 설정 (A안: repo에 포함)
28
+ # =========================
29
+ # base/controlnet은 HF Hub에서 내려받음 (그대로 유지)
30
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
31
+ controlnet_path = "diffusers/controlnet-depth-sdxl-1.0"
32
+
33
+ # 아래 2개는 Space repo에 "그대로 포함"되어 있어야 함
34
+ image_encoder_path = "models/image_encoder"
35
+ ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin"
36
+
37
+ DEFAULT_STEPS = 30
38
+ DEBUG_SAVE = False
39
+
40
+ # 네 코드 구조 유지: person을 height=1024로 맞춘 뒤 H,W를 전역으로 씀
41
+ H: Optional[int] = None
42
+ W: Optional[int] = None
43
+
44
+
45
+ # =========================
46
+ # 유틸: 로컬 에셋 체크
47
+ # =========================
48
+ def _ensure_exists(path: str, name: str):
49
+ if not os.path.exists(path):
50
+ raise FileNotFoundError(f"{name} not found: {path}")
51
+
52
+
53
+ def check_local_assets():
54
+ _ensure_exists(image_encoder_path, "image_encoder_path")
55
+ _ensure_exists(ip_ckpt, "ip_ckpt")
56
+
57
+
58
+ # =========================
59
+ # Lazy Loading: pipe/controlnet만 1회 로딩
60
+ # =========================
61
+ @lru_cache(maxsize=1)
62
+ def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]:
63
+ check_local_assets()
64
+
65
+ device = "cuda" if torch.cuda.is_available() else "cpu"
66
+ dtype = torch.float16 if device == "cuda" else torch.float32
67
+
68
+ cn_kwargs = dict(
69
+ torch_dtype=dtype,
70
+ use_safetensors=True,
71
+ )
72
+ if dtype == torch.float16:
73
+ cn_kwargs["variant"] = "fp16"
74
+
75
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, **cn_kwargs)
76
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
77
+ base_model_path,
78
+ controlnet=controlnet,
79
+ use_safetensors=True,
80
+ torch_dtype=dtype,
81
+ add_watermarker=False,
82
+ ).to(device)
83
+
84
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
85
+ pipe.enable_attention_slicing()
86
+
87
+ # xformers는 있으면 쓰고, 없으면 조용히 패스
88
+ try:
89
+ pipe.enable_xformers_memory_efficient_attention()
90
+ except Exception:
91
+ pass
92
+
93
+ return pipe, device, dtype
94
+
95
+
96
+ # =========================
97
+ # 네 코드(함수들) - 필요 최소만 그대로 포함
98
+ # =========================
99
+ @dataclass
100
+ class Paths:
101
+ person_path: str
102
+ depth_path: str # 사실상 sketch/guide 이미지 (네 코드 유지)
103
+ style_path: str
104
+ output_path: str
105
+
106
+
107
+ def _imread_or_raise(path: str, flag=cv2.IMREAD_COLOR):
108
+ img = cv2.imread(path, flag)
109
+ if img is None:
110
+ raise FileNotFoundError(f"cv2.imread failed: {path} (exists={os.path.exists(path)})")
111
+ return img
112
+
113
+
114
+ def compute_hw_from_person(person_path: str):
115
+ """
116
+ person 원본 이미지 기준:
117
+ - height가 정확히 1024가 되도록 스케일
118
+ - aspect ratio 유지
119
+ => H=1024, W=round(orig_w * (1024/orig_h))
120
+ + 안전장치: W가 1024를 넘으면 1024로 cap (padding 음수 방지)
121
+ """
122
+ img = cv2.imread(person_path)
123
+ if img is None:
124
+ raise FileNotFoundError(f"cv2.imread failed: {person_path} (exists={os.path.exists(person_path)})")
125
+
126
+ orig_h, orig_w = img.shape[:2]
127
+ target_h = 1024
128
+ scale = target_h / float(orig_h)
129
+ target_w = int(round(orig_w * scale))
130
+
131
+ if target_w > 1024:
132
+ target_w = 1024 # 데모 안정성 우선 (padding 음수 방지)
133
+
134
+ return target_h, target_w
135
+
136
+
137
+ def invert_sketch_area(sketch_pil: Image.Image) -> Image.Image:
138
+ return ImageOps.invert(sketch_pil.convert("L")).convert("RGB")
139
+
140
+
141
+ def fill_sketch_from_image_path_to_pil(image_path: str) -> Image.Image:
142
+ global H, W
143
+ if H is None or W is None:
144
+ raise RuntimeError("Global H/W not set. Call run_one() first.")
145
+
146
+ img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
147
+ if img is None:
148
+ raise ValueError(f"이미지를 불러올 수 없습니다: {image_path}")
149
+
150
+ img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST)
151
+
152
+ # 흰색=배경/검정=선 가정 → 필요하면 threshold 조절
153
+ threshold = 127
154
+ _, binary = cv2.threshold(img, threshold, 255, cv2.THRESH_BINARY_INV)
155
+
156
+ contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
157
+ filled = np.zeros_like(binary)
158
+ cv2.drawContours(filled, contours, -1, 255, thickness=cv2.FILLED)
159
+
160
+ filled_rgb = cv2.cvtColor(filled, cv2.COLOR_GRAY2RGB)
161
+ return Image.fromarray(filled_rgb)
162
+
163
+
164
+ def merge_white_regions_or(img1: Image.Image, img2: Image.Image) -> Image.Image:
165
+ """
166
+ 두 이미지에서 '흰색(255)' 영역을 OR로 합치기 (네 코드 의도 유지)
167
+ """
168
+ a = np.array(img1.convert("RGB"), dtype=np.uint8)
169
+ b = np.array(img2.convert("RGB"), dtype=np.uint8)
170
+
171
+ white_a = np.all(a == 255, axis=-1)
172
+ white_b = np.all(b == 255, axis=-1)
173
+ out = a.copy()
174
+ out[white_b] = 255
175
+ out[white_a] = 255
176
+ return Image.fromarray(out)
177
+
178
+
179
+ def preprocess_mask(mask_img: Image.Image) -> Image.Image:
180
+ """
181
+ 마스크 전처리: L로 만들고, threshold 등 적용 (필요 최소)
182
+ """
183
+ m = np.array(mask_img.convert("L"), dtype=np.uint8)
184
+ # 흰색/검정 양쪽 케이스 대응: 단순 threshold
185
+ _, m = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY)
186
+ return Image.fromarray(m).convert("RGB")
187
+
188
+
189
+ def make_depth(depth_path: str) -> Image.Image:
190
+ """
191
+ 네 코드의 'depth'는 실제로 sketch 기반 guide 이미지 생성 로직.
192
+ """
193
+ global H, W
194
+ if H is None or W is None:
195
+ raise RuntimeError("Global H/W not set. Call run_one() first.")
196
+
197
+ depth_img = _imread_or_raise(depth_path, 0)
198
+ inverted_depth = cv2.bitwise_not(depth_img)
199
+ contours, _ = cv2.findContours(inverted_depth, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
200
+
201
+ filled_depth = inverted_depth.copy()
202
+ cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
203
+
204
+ filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
205
+ filled_depth_rgb = cv2.cvtColor(filled_depth, cv2.COLOR_GRAY2RGB)
206
+
207
+ # width=1024 기준 padding(또는 crop)로 맞추기
208
+ target_width = 1024
209
+ cur_w = filled_depth_rgb.shape[1]
210
+ if cur_w < target_width:
211
+ padding = (target_width - cur_w) // 2
212
+ filled_depth_rgb = cv2.copyMakeBorder(
213
+ filled_depth_rgb, 0, 0, padding, padding,
214
+ borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0]
215
+ )
216
+ elif cur_w > target_width:
217
+ left = (cur_w - target_width) // 2
218
+ filled_depth_rgb = filled_depth_rgb[:, left:left + target_width]
219
+
220
+ return Image.fromarray(filled_depth_rgb)
221
+
222
+
223
+ def center_crop_lr_to_768x1024(arr: np.ndarray) -> np.ndarray:
224
+ """
225
+ 좌우 중앙 크롭해서 width=768, height=1024로 맞추는 로직(네 코드 의도 유지)
226
+ """
227
+ target_h, target_w = 1024, 768
228
+ h, w = arr.shape[:2]
229
+ if h != target_h:
230
+ # 안전: height는 1024를 기대하지만, 혹시 다르면 리사이즈
231
+ arr = cv2.resize(arr, (w, target_h), interpolation=cv2.INTER_AREA)
232
+ h, w = arr.shape[:2]
233
+ if w < target_w:
234
+ # 너무 좁으면 패딩
235
+ pad = (target_w - w) // 2
236
+ arr = cv2.copyMakeBorder(arr, 0, 0, pad, pad, cv2.BORDER_CONSTANT, value=[255, 255, 255])
237
+ w = arr.shape[1]
238
+ left = (w - target_w) // 2
239
+ return arr[:, left:left + target_w]
240
+
241
+
242
+ def save_cropped(imgs, out_path: str):
243
+ np_imgs = [np.asarray(im) for im in imgs]
244
+ cropped = [center_crop_lr_to_768x1024(x) for x in np_imgs]
245
+ out = np.concatenate(cropped, axis=1)
246
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
247
+ imageio.imsave(out_path, out)
248
+
249
+
250
+ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
251
+ """
252
+ 네 inference 함수 구조를 그대로 유지하되,
253
+ pipe는 lazy loader에서 가져오도록만 바꿔서 Space에서 안정적으로 동작하게 함.
254
+ """
255
+ global H, W
256
+
257
+ pipe, device, _dtype = get_pipe_and_device()
258
+
259
+ # 전역 H/W 세팅
260
+ H, W = compute_hw_from_person(paths.person_path)
261
+
262
+ # ===== parsing/segmentation(네 코드 흐름 유지) =====
263
+ res = run_simple_extractor(
264
+ person_path=paths.person_path,
265
+ category="Upper-clothes",
266
+ )
267
+ parsing_img = res["images"][0] if res.get("images") else None
268
+ if parsing_img is None:
269
+ raise RuntimeError("run_simple_extractor returned no parsing images.")
270
+
271
+ # sketch(=depth_path) 반영해서 mask 만들기
272
+ sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path)
273
+ sketch_area_inv = invert_sketch_area(sketch_area)
274
+ merged_img = merge_white_regions_or(parsing_img, sketch_area_inv)
275
+ mask_pil = preprocess_mask(merged_img)
276
+
277
+ # control image(=depth_map) 생성
278
+ depth_map = make_depth(paths.depth_path)
279
+
280
+ # ===== person/garment 이미지 전처리(네 코드 흐름 유지: width=1024 기준 padding/crop) =====
281
+ person_bgr = _imread_or_raise(paths.person_path)
282
+ person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA)
283
+
284
+ target_width = 1024
285
+ cur_w = person_bgr.shape[1]
286
+ if cur_w < target_width:
287
+ padding = (target_width - cur_w) // 2
288
+ padded_person = cv2.copyMakeBorder(
289
+ person_bgr, 0, 0, padding, padding,
290
+ borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255]
291
+ )
292
+ elif cur_w > target_width:
293
+ left = (cur_w - target_width) // 2
294
+ padded_person = person_bgr[:, left:left + target_width]
295
+ else:
296
+ padded_person = person_bgr
297
+
298
+ person_pil = Image.fromarray(cv2.cvtColor(padded_person, cv2.COLOR_BGR2RGB))
299
+
300
+ # garment 이미지/마스크는 parsing_img 기반(네 코드 흐름 단순화: 동일 사이즈로 맞춤)
301
+ garment_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8)
302
+ garment_pil = Image.fromarray(garment_rgb)
303
+
304
+ garment_mask_bgr = np.array(parsing_img.convert("L"), dtype=np.uint8)
305
+ garment_mask_bgr = cv2.resize(garment_mask_bgr, (W, H), interpolation=cv2.INTER_AREA)
306
+ garment_mask_rgb = cv2.cvtColor(garment_mask_bgr, cv2.COLOR_GRAY2RGB)
307
+
308
+ # padding/crop 동일 적용
309
+ cur_w2 = garment_mask_rgb.shape[1]
310
+ if cur_w2 < target_width:
311
+ padding2 = (target_width - cur_w2) // 2
312
+ garment_mask_rgb = cv2.copyMakeBorder(
313
+ garment_mask_rgb, 0, 0, padding2, padding2,
314
+ borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0]
315
+ )
316
+ elif cur_w2 > target_width:
317
+ left2 = (cur_w2 - target_width) // 2
318
+ garment_mask_rgb = garment_mask_rgb[:, left2:left2 + target_width]
319
+
320
+ garment_mask_pil = Image.fromarray(garment_mask_rgb)
321
+
322
+ # ===== IPAdapterXL 호출(네 코드 구조 유지) =====
323
+ ip_model = IPAdapterXL(
324
+ pipe,
325
+ image_encoder_path,
326
+ ip_ckpt,
327
+ device,
328
+ mask_pil,
329
+ person_pil,
330
+ content_scale=0.3,
331
+ style_scale=0.5,
332
+ garment_images=garment_pil,
333
+ garment_mask=garment_mask_pil,
334
+ )
335
+
336
+ style_img = Image.open(paths.style_path).convert("RGB")
337
+
338
+ with torch.inference_mode():
339
+ images = ip_model.generate(
340
+ pil_image=style_img,
341
+ image=person_pil,
342
+ control_image=depth_map,
343
+ strength=1.0,
344
+ num_samples=1,
345
+ num_inference_steps=int(steps),
346
+ shape_prompt="",
347
+ prompt=prompt or "",
348
+ num=0,
349
+ scale=None,
350
+ controlnet_conditioning_scale=0.7,
351
+ guidance_scale=7.5,
352
+ )
353
+
354
+ save_cropped(images, paths.output_path)
355
+
356
+
357
+ # =========================
358
+ # Gradio UI
359
+ # =========================
360
+ def set_seed(seed: int):
361
+ if seed is None or seed < 0:
362
+ return
363
+ np.random.seed(seed)
364
+ torch.manual_seed(seed)
365
+ if torch.cuda.is_available():
366
+ torch.cuda.manual_seed_all(seed)
367
+
368
+
369
+ def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed, debug_save):
370
+ global DEBUG_SAVE
371
+ DEBUG_SAVE = bool(debug_save)
372
+
373
+ if person_fp is None or sketch_fp is None or style_fp is None:
374
+ raise gr.Error("person / sketch(guide) / style 이미지를 모두 업로드해야 합니다.")
375
+
376
+ set_seed(int(seed) if seed is not None else -1)
377
+
378
+ tmp_dir = tempfile.mkdtemp(prefix="feat_demo_")
379
+ out_path = os.path.join(tmp_dir, "result.png")
380
+
381
+ paths = Paths(
382
+ person_path=person_fp,
383
+ depth_path=sketch_fp,
384
+ style_path=style_fp,
385
+ output_path=out_path,
386
+ )
387
+
388
+ # lazy load는 여기서 트리거됨
389
+ run_one(paths, prompt=prompt, steps=int(steps))
390
+
391
+ out_img = Image.open(out_path).convert("RGB")
392
+ return out_img, out_path
393
+
394
+
395
+ with gr.Blocks(title="FEAT Demo (HF Spaces)") as demo:
396
+ gr.Markdown("## FEAT Demo\nperson / sketch(guide) / style 입력으로 결과를 생성합니다.")
397
+
398
+ with gr.Row():
399
+ person_in = gr.Image(label="Person Image", type="filepath")
400
+ sketch_in = gr.Image(label="Sketch / Guide Image (depth_path)", type="filepath")
401
+ style_in = gr.Image(label="Style Image", type="filepath")
402
+
403
+ with gr.Row():
404
+ prompt_in = gr.Textbox(label="Prompt", value="upper garment", lines=2)
405
+ steps_in = gr.Slider(1, 80, value=DEFAULT_STEPS, step=1, label="Steps")
406
+
407
+ with gr.Row():
408
+ seed_in = gr.Number(label="Seed (-1 = random)", value=-1, precision=0)
409
+ debug_in = gr.Checkbox(label="Debug Save (느림)", value=False)
410
+
411
+ run_btn = gr.Button("Run")
412
+ out_img = gr.Image(label="Output (stitched/cropped)", type="pil")
413
+ out_file = gr.File(label="Download result.png")
414
+
415
+ run_btn.click(
416
+ fn=infer_web,
417
+ inputs=[person_in, sketch_in, style_in, prompt_in, steps_in, seed_in, debug_in],
418
+ outputs=[out_img, out_file],
419
+ )
420
+
421
+ # Spaces 권장: queue()로 안정성/동시성 확보
422
+ demo.queue()
.ipynb_checkpoints/requirements-checkpoint.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0,<5.0
2
+
3
+ # core
4
+ numpy
5
+ Pillow
6
+ imageio
7
+
8
+ # your pinned libs (safe)
9
+ diffusers==0.32.2
10
+ transformers==4.46.3
11
+
12
+ # opencv (spaces friendly)
13
+ opencv-python-headless==4.10.0.84
14
+
15
+ # torch: spaces 환경에서 +cu121 핀하면 깨지는 경우가 많아서 범위 권장
16
+ torch>=2.3,<2.4
17
+
18
+ # 흔히 필요 (diffusers/transformers 쪽)
19
+ accelerate
20
+ safetensors
diffusers3/Untitled.ipynb ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "id": "19479786-a4e0-4ec4-a1a6-2d9b1259c6d1",
7
+ "metadata": {
8
+ "tags": []
9
+ },
10
+ "outputs": [
11
+ {
12
+ "ename": "ImportError",
13
+ "evalue": "attempted relative import beyond top-level package",
14
+ "output_type": "error",
15
+ "traceback": [
16
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
17
+ "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
18
+ "Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpipelines\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcontrolnet\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpipeline_controlnet_sd_xl\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m StableDiffusionXLControlNetPipeline\n",
19
+ "File \u001b[0;32m~/data/diffusers/src/diffusers/pipelines/__init__.py:3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TYPE_CHECKING\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 4\u001b[0m DIFFUSERS_SLOW_IMPORT,\n\u001b[1;32m 5\u001b[0m OptionalDependencyNotAvailable,\n\u001b[1;32m 6\u001b[0m _LazyModule,\n\u001b[1;32m 7\u001b[0m get_objects_from_module,\n\u001b[1;32m 8\u001b[0m is_flax_available,\n\u001b[1;32m 9\u001b[0m is_k_diffusion_available,\n\u001b[1;32m 10\u001b[0m is_librosa_available,\n\u001b[1;32m 11\u001b[0m is_note_seq_available,\n\u001b[1;32m 12\u001b[0m is_onnx_available,\n\u001b[1;32m 13\u001b[0m is_sentencepiece_available,\n\u001b[1;32m 14\u001b[0m is_torch_available,\n\u001b[1;32m 15\u001b[0m is_torch_npu_available,\n\u001b[1;32m 16\u001b[0m is_transformers_available,\n\u001b[1;32m 17\u001b[0m )\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# These modules contain pipelines from multiple libraries/frameworks\u001b[39;00m\n\u001b[1;32m 21\u001b[0m _dummy_objects \u001b[38;5;241m=\u001b[39m {}\n",
20
+ "\u001b[0;31mImportError\u001b[0m: attempted relative import beyond top-level package"
21
+ ]
22
+ }
23
+ ],
24
+ "source": [
25
+ "from pipelines.controlnet.pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "id": "0c6ff172-beff-470f-aabd-12440d1333b0",
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": []
35
+ }
36
+ ],
37
+ "metadata": {
38
+ "kernelspec": {
39
+ "display_name": "Python 3 (ipykernel)",
40
+ "language": "python",
41
+ "name": "python3"
42
+ },
43
+ "language_info": {
44
+ "codemirror_mode": {
45
+ "name": "ipython",
46
+ "version": 3
47
+ },
48
+ "file_extension": ".py",
49
+ "mimetype": "text/x-python",
50
+ "name": "python",
51
+ "nbconvert_exporter": "python",
52
+ "pygments_lexer": "ipython3",
53
+ "version": "3.10.12"
54
+ }
55
+ },
56
+ "nbformat": 4,
57
+ "nbformat_minor": 5
58
+ }
diffusers3/__init__.py ADDED
@@ -0,0 +1,934 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.31.0.dev0"
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from .utils import (
6
+ DIFFUSERS_SLOW_IMPORT,
7
+ OptionalDependencyNotAvailable,
8
+ _LazyModule,
9
+ is_flax_available,
10
+ is_k_diffusion_available,
11
+ is_librosa_available,
12
+ is_note_seq_available,
13
+ is_onnx_available,
14
+ is_scipy_available,
15
+ is_sentencepiece_available,
16
+ is_torch_available,
17
+ is_torchsde_available,
18
+ is_transformers_available,
19
+ )
20
+
21
+
22
+ # Lazy Import based on
23
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py
24
+
25
+ # When adding a new object to this init, please add it to `_import_structure`. The `_import_structure` is a dictionary submodule to list of object names,
26
+ # and is used to defer the actual importing for when the objects are requested.
27
+ # This way `import diffusers` provides the names in the namespace without actually importing anything (and especially none of the backends).
28
+
29
+ _import_structure = {
30
+ "configuration_utils": ["ConfigMixin"],
31
+ "loaders": ["FromOriginalModelMixin"],
32
+ "models": [],
33
+ "pipelines": [],
34
+ "schedulers": [],
35
+ "utils": [
36
+ "OptionalDependencyNotAvailable",
37
+ "is_flax_available",
38
+ "is_inflect_available",
39
+ "is_invisible_watermark_available",
40
+ "is_k_diffusion_available",
41
+ "is_k_diffusion_version",
42
+ "is_librosa_available",
43
+ "is_note_seq_available",
44
+ "is_onnx_available",
45
+ "is_scipy_available",
46
+ "is_torch_available",
47
+ "is_torchsde_available",
48
+ "is_transformers_available",
49
+ "is_transformers_version",
50
+ "is_unidecode_available",
51
+ "logging",
52
+ ],
53
+ }
54
+
55
+ try:
56
+ if not is_onnx_available():
57
+ raise OptionalDependencyNotAvailable()
58
+ except OptionalDependencyNotAvailable:
59
+ from .utils import dummy_onnx_objects # noqa F403
60
+
61
+ _import_structure["utils.dummy_onnx_objects"] = [
62
+ name for name in dir(dummy_onnx_objects) if not name.startswith("_")
63
+ ]
64
+
65
+ else:
66
+ _import_structure["pipelines"].extend(["OnnxRuntimeModel"])
67
+
68
+ try:
69
+ if not is_torch_available():
70
+ raise OptionalDependencyNotAvailable()
71
+ except OptionalDependencyNotAvailable:
72
+ from .utils import dummy_pt_objects # noqa F403
73
+
74
+ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
75
+
76
+ else:
77
+ _import_structure["models"].extend(
78
+ [
79
+ "AsymmetricAutoencoderKL",
80
+ "AuraFlowTransformer2DModel",
81
+ "AutoencoderKL",
82
+ "AutoencoderKLCogVideoX",
83
+ "AutoencoderKLTemporalDecoder",
84
+ "AutoencoderOobleck",
85
+ "AutoencoderTiny",
86
+ "CogVideoXTransformer3DModel",
87
+ "ConsistencyDecoderVAE",
88
+ "ControlNetModel",
89
+ "ControlNetXSAdapter",
90
+ "DiTTransformer2DModel",
91
+ "FluxControlNetModel",
92
+ "FluxMultiControlNetModel",
93
+ "FluxTransformer2DModel",
94
+ "HunyuanDiT2DControlNetModel",
95
+ "HunyuanDiT2DModel",
96
+ "HunyuanDiT2DMultiControlNetModel",
97
+ "I2VGenXLUNet",
98
+ "Kandinsky3UNet",
99
+ "LatteTransformer3DModel",
100
+ "LuminaNextDiT2DModel",
101
+ "ModelMixin",
102
+ "MotionAdapter",
103
+ "MultiAdapter",
104
+ "PixArtTransformer2DModel",
105
+ "PriorTransformer",
106
+ "SD3ControlNetModel",
107
+ "SD3MultiControlNetModel",
108
+ "SD3Transformer2DModel",
109
+ "SparseControlNetModel",
110
+ "StableAudioDiTModel",
111
+ "StableCascadeUNet",
112
+ "T2IAdapter",
113
+ "T5FilmDecoder",
114
+ "Transformer2DModel",
115
+ "UNet1DModel",
116
+ "UNet2DConditionModel",
117
+ "UNet2DModel",
118
+ "UNet3DConditionModel",
119
+ "UNetControlNetXSModel",
120
+ "UNetMotionModel",
121
+ "UNetSpatioTemporalConditionModel",
122
+ "UVit2DModel",
123
+ "VQModel",
124
+ ]
125
+ )
126
+
127
+ _import_structure["optimization"] = [
128
+ "get_constant_schedule",
129
+ "get_constant_schedule_with_warmup",
130
+ "get_cosine_schedule_with_warmup",
131
+ "get_cosine_with_hard_restarts_schedule_with_warmup",
132
+ "get_linear_schedule_with_warmup",
133
+ "get_polynomial_decay_schedule_with_warmup",
134
+ "get_scheduler",
135
+ ]
136
+ _import_structure["pipelines"].extend(
137
+ [
138
+ "AudioPipelineOutput",
139
+ "AutoPipelineForImage2Image",
140
+ "AutoPipelineForInpainting",
141
+ "AutoPipelineForText2Image",
142
+ "ConsistencyModelPipeline",
143
+ "DanceDiffusionPipeline",
144
+ "DDIMPipeline",
145
+ "DDPMPipeline",
146
+ "DiffusionPipeline",
147
+ "DiTPipeline",
148
+ "ImagePipelineOutput",
149
+ "KarrasVePipeline",
150
+ "LDMPipeline",
151
+ "LDMSuperResolutionPipeline",
152
+ "PNDMPipeline",
153
+ "RePaintPipeline",
154
+ "ScoreSdeVePipeline",
155
+ "StableDiffusionMixin",
156
+ ]
157
+ )
158
+ _import_structure["schedulers"].extend(
159
+ [
160
+ "AmusedScheduler",
161
+ "CMStochasticIterativeScheduler",
162
+ "CogVideoXDDIMScheduler",
163
+ "CogVideoXDPMScheduler",
164
+ "DDIMInverseScheduler",
165
+ "DDIMParallelScheduler",
166
+ "DDIMScheduler",
167
+ "DDPMParallelScheduler",
168
+ "DDPMScheduler",
169
+ "DDPMWuerstchenScheduler",
170
+ "DEISMultistepScheduler",
171
+ "DPMSolverMultistepInverseScheduler",
172
+ "DPMSolverMultistepScheduler",
173
+ "DPMSolverSinglestepScheduler",
174
+ "EDMDPMSolverMultistepScheduler",
175
+ "EDMEulerScheduler",
176
+ "EulerAncestralDiscreteScheduler",
177
+ "EulerDiscreteScheduler",
178
+ "FlowMatchEulerDiscreteScheduler",
179
+ "FlowMatchHeunDiscreteScheduler",
180
+ "HeunDiscreteScheduler",
181
+ "IPNDMScheduler",
182
+ "KarrasVeScheduler",
183
+ "KDPM2AncestralDiscreteScheduler",
184
+ "KDPM2DiscreteScheduler",
185
+ "LCMScheduler",
186
+ "PNDMScheduler",
187
+ "RePaintScheduler",
188
+ "SASolverScheduler",
189
+ "SchedulerMixin",
190
+ "ScoreSdeVeScheduler",
191
+ "TCDScheduler",
192
+ "UnCLIPScheduler",
193
+ "UniPCMultistepScheduler",
194
+ "VQDiffusionScheduler",
195
+ ]
196
+ )
197
+ _import_structure["training_utils"] = ["EMAModel"]
198
+
199
+ try:
200
+ if not (is_torch_available() and is_scipy_available()):
201
+ raise OptionalDependencyNotAvailable()
202
+ except OptionalDependencyNotAvailable:
203
+ from .utils import dummy_torch_and_scipy_objects # noqa F403
204
+
205
+ _import_structure["utils.dummy_torch_and_scipy_objects"] = [
206
+ name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_")
207
+ ]
208
+
209
+ else:
210
+ _import_structure["schedulers"].extend(["LMSDiscreteScheduler"])
211
+
212
+ try:
213
+ if not (is_torch_available() and is_torchsde_available()):
214
+ raise OptionalDependencyNotAvailable()
215
+ except OptionalDependencyNotAvailable:
216
+ from .utils import dummy_torch_and_torchsde_objects # noqa F403
217
+
218
+ _import_structure["utils.dummy_torch_and_torchsde_objects"] = [
219
+ name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_")
220
+ ]
221
+
222
+ else:
223
+ _import_structure["schedulers"].extend(["CosineDPMSolverMultistepScheduler", "DPMSolverSDEScheduler"])
224
+
225
+ try:
226
+ if not (is_torch_available() and is_transformers_available()):
227
+ raise OptionalDependencyNotAvailable()
228
+ except OptionalDependencyNotAvailable:
229
+ from .utils import dummy_torch_and_transformers_objects # noqa F403
230
+
231
+ _import_structure["utils.dummy_torch_and_transformers_objects"] = [
232
+ name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
233
+ ]
234
+
235
+ else:
236
+ _import_structure["pipelines"].extend(
237
+ [
238
+ "AltDiffusionImg2ImgPipeline",
239
+ "AltDiffusionPipeline",
240
+ "AmusedImg2ImgPipeline",
241
+ "AmusedInpaintPipeline",
242
+ "AmusedPipeline",
243
+ "AnimateDiffControlNetPipeline",
244
+ "AnimateDiffPAGPipeline",
245
+ "AnimateDiffPipeline",
246
+ "AnimateDiffSDXLPipeline",
247
+ "AnimateDiffSparseControlNetPipeline",
248
+ "AnimateDiffVideoToVideoControlNetPipeline",
249
+ "AnimateDiffVideoToVideoPipeline",
250
+ "AudioLDM2Pipeline",
251
+ "AudioLDM2ProjectionModel",
252
+ "AudioLDM2UNet2DConditionModel",
253
+ "AudioLDMPipeline",
254
+ "AuraFlowPipeline",
255
+ "BlipDiffusionControlNetPipeline",
256
+ "BlipDiffusionPipeline",
257
+ "CLIPImageProjection",
258
+ "CogVideoXPipeline",
259
+ "CogVideoXVideoToVideoPipeline",
260
+ "CycleDiffusionPipeline",
261
+ "FluxControlNetPipeline",
262
+ "FluxImg2ImgPipeline",
263
+ "FluxInpaintPipeline",
264
+ "FluxPipeline",
265
+ "HunyuanDiTControlNetPipeline",
266
+ "HunyuanDiTPAGPipeline",
267
+ "HunyuanDiTPipeline",
268
+ "I2VGenXLPipeline",
269
+ "IFImg2ImgPipeline",
270
+ "IFImg2ImgSuperResolutionPipeline",
271
+ "IFInpaintingPipeline",
272
+ "IFInpaintingSuperResolutionPipeline",
273
+ "IFPipeline",
274
+ "IFSuperResolutionPipeline",
275
+ "ImageTextPipelineOutput",
276
+ "Kandinsky3Img2ImgPipeline",
277
+ "Kandinsky3Pipeline",
278
+ "KandinskyCombinedPipeline",
279
+ "KandinskyImg2ImgCombinedPipeline",
280
+ "KandinskyImg2ImgPipeline",
281
+ "KandinskyInpaintCombinedPipeline",
282
+ "KandinskyInpaintPipeline",
283
+ "KandinskyPipeline",
284
+ "KandinskyPriorPipeline",
285
+ "KandinskyV22CombinedPipeline",
286
+ "KandinskyV22ControlnetImg2ImgPipeline",
287
+ "KandinskyV22ControlnetPipeline",
288
+ "KandinskyV22Img2ImgCombinedPipeline",
289
+ "KandinskyV22Img2ImgPipeline",
290
+ "KandinskyV22InpaintCombinedPipeline",
291
+ "KandinskyV22InpaintPipeline",
292
+ "KandinskyV22Pipeline",
293
+ "KandinskyV22PriorEmb2EmbPipeline",
294
+ "KandinskyV22PriorPipeline",
295
+ "LatentConsistencyModelImg2ImgPipeline",
296
+ "LatentConsistencyModelPipeline",
297
+ "LattePipeline",
298
+ "LDMTextToImagePipeline",
299
+ "LEditsPPPipelineStableDiffusion",
300
+ "LEditsPPPipelineStableDiffusionXL",
301
+ "LuminaText2ImgPipeline",
302
+ "MarigoldDepthPipeline",
303
+ "MarigoldNormalsPipeline",
304
+ "MusicLDMPipeline",
305
+ "PaintByExamplePipeline",
306
+ "PIAPipeline",
307
+ "PixArtAlphaPipeline",
308
+ "PixArtSigmaPAGPipeline",
309
+ "PixArtSigmaPipeline",
310
+ "SemanticStableDiffusionPipeline",
311
+ "ShapEImg2ImgPipeline",
312
+ "ShapEPipeline",
313
+ "StableAudioPipeline",
314
+ "StableAudioProjectionModel",
315
+ "StableCascadeCombinedPipeline",
316
+ "StableCascadeDecoderPipeline",
317
+ "StableCascadePriorPipeline",
318
+ "StableDiffusion3ControlNetInpaintingPipeline",
319
+ "StableDiffusion3ControlNetPipeline",
320
+ "StableDiffusion3Img2ImgPipeline",
321
+ "StableDiffusion3InpaintPipeline",
322
+ "StableDiffusion3PAGPipeline",
323
+ "StableDiffusion3Pipeline",
324
+ "StableDiffusionAdapterPipeline",
325
+ "StableDiffusionAttendAndExcitePipeline",
326
+ "StableDiffusionControlNetImg2ImgPipeline",
327
+ "StableDiffusionControlNetInpaintPipeline",
328
+ "StableDiffusionControlNetPAGPipeline",
329
+ "StableDiffusionControlNetPipeline",
330
+ "StableDiffusionControlNetXSPipeline",
331
+ "StableDiffusionDepth2ImgPipeline",
332
+ "StableDiffusionDiffEditPipeline",
333
+ "StableDiffusionGLIGENPipeline",
334
+ "StableDiffusionGLIGENTextImagePipeline",
335
+ "StableDiffusionImageVariationPipeline",
336
+ "StableDiffusionImg2ImgPipeline",
337
+ "StableDiffusionInpaintPipeline",
338
+ "StableDiffusionInpaintPipelineLegacy",
339
+ "StableDiffusionInstructPix2PixPipeline",
340
+ "StableDiffusionLatentUpscalePipeline",
341
+ "StableDiffusionLDM3DPipeline",
342
+ "StableDiffusionModelEditingPipeline",
343
+ "StableDiffusionPAGPipeline",
344
+ "StableDiffusionPanoramaPipeline",
345
+ "StableDiffusionParadigmsPipeline",
346
+ "StableDiffusionPipeline",
347
+ "StableDiffusionPipelineSafe",
348
+ "StableDiffusionPix2PixZeroPipeline",
349
+ "StableDiffusionSAGPipeline",
350
+ "StableDiffusionUpscalePipeline",
351
+ "StableDiffusionXLAdapterPipeline",
352
+ "StableDiffusionXLControlNetImg2ImgPipeline",
353
+ "StableDiffusionXLControlNetInpaintPipeline",
354
+ "StableDiffusionXLControlNetPAGImg2ImgPipeline",
355
+ "StableDiffusionXLControlNetPAGPipeline",
356
+ "StableDiffusionXLControlNetPipeline",
357
+ "StableDiffusionXLControlNetXSPipeline",
358
+ "StableDiffusionXLImg2ImgPipeline",
359
+ "StableDiffusionXLInpaintPipeline",
360
+ "StableDiffusionXLInstructPix2PixPipeline",
361
+ "StableDiffusionXLPAGImg2ImgPipeline",
362
+ "StableDiffusionXLPAGInpaintPipeline",
363
+ "StableDiffusionXLPAGPipeline",
364
+ "StableDiffusionXLPipeline",
365
+ "StableUnCLIPImg2ImgPipeline",
366
+ "StableUnCLIPPipeline",
367
+ "StableVideoDiffusionPipeline",
368
+ "TextToVideoSDPipeline",
369
+ "TextToVideoZeroPipeline",
370
+ "TextToVideoZeroSDXLPipeline",
371
+ "UnCLIPImageVariationPipeline",
372
+ "UnCLIPPipeline",
373
+ "UniDiffuserModel",
374
+ "UniDiffuserPipeline",
375
+ "UniDiffuserTextDecoder",
376
+ "VersatileDiffusionDualGuidedPipeline",
377
+ "VersatileDiffusionImageVariationPipeline",
378
+ "VersatileDiffusionPipeline",
379
+ "VersatileDiffusionTextToImagePipeline",
380
+ "VideoToVideoSDPipeline",
381
+ "VQDiffusionPipeline",
382
+ "WuerstchenCombinedPipeline",
383
+ "WuerstchenDecoderPipeline",
384
+ "WuerstchenPriorPipeline",
385
+ ]
386
+ )
387
+
388
+ try:
389
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
390
+ raise OptionalDependencyNotAvailable()
391
+ except OptionalDependencyNotAvailable:
392
+ from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
393
+
394
+ _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [
395
+ name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_")
396
+ ]
397
+
398
+ else:
399
+ _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"])
400
+
401
+ try:
402
+ if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
403
+ raise OptionalDependencyNotAvailable()
404
+ except OptionalDependencyNotAvailable:
405
+ from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403
406
+
407
+ _import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [
408
+ name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_")
409
+ ]
410
+
411
+ else:
412
+ _import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"])
413
+
414
+ try:
415
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
416
+ raise OptionalDependencyNotAvailable()
417
+ except OptionalDependencyNotAvailable:
418
+ from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
419
+
420
+ _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [
421
+ name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_")
422
+ ]
423
+
424
+ else:
425
+ _import_structure["pipelines"].extend(
426
+ [
427
+ "OnnxStableDiffusionImg2ImgPipeline",
428
+ "OnnxStableDiffusionInpaintPipeline",
429
+ "OnnxStableDiffusionInpaintPipelineLegacy",
430
+ "OnnxStableDiffusionPipeline",
431
+ "OnnxStableDiffusionUpscalePipeline",
432
+ "StableDiffusionOnnxPipeline",
433
+ ]
434
+ )
435
+
436
+ try:
437
+ if not (is_torch_available() and is_librosa_available()):
438
+ raise OptionalDependencyNotAvailable()
439
+ except OptionalDependencyNotAvailable:
440
+ from .utils import dummy_torch_and_librosa_objects # noqa F403
441
+
442
+ _import_structure["utils.dummy_torch_and_librosa_objects"] = [
443
+ name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_")
444
+ ]
445
+
446
+ else:
447
+ _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"])
448
+
449
+ try:
450
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
451
+ raise OptionalDependencyNotAvailable()
452
+ except OptionalDependencyNotAvailable:
453
+ from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
454
+
455
+ _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [
456
+ name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_")
457
+ ]
458
+
459
+
460
+ else:
461
+ _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"])
462
+
463
+ try:
464
+ if not is_flax_available():
465
+ raise OptionalDependencyNotAvailable()
466
+ except OptionalDependencyNotAvailable:
467
+ from .utils import dummy_flax_objects # noqa F403
468
+
469
+ _import_structure["utils.dummy_flax_objects"] = [
470
+ name for name in dir(dummy_flax_objects) if not name.startswith("_")
471
+ ]
472
+
473
+
474
+ else:
475
+ _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
476
+ _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
477
+ _import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
478
+ _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
479
+ _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
480
+ _import_structure["schedulers"].extend(
481
+ [
482
+ "FlaxDDIMScheduler",
483
+ "FlaxDDPMScheduler",
484
+ "FlaxDPMSolverMultistepScheduler",
485
+ "FlaxEulerDiscreteScheduler",
486
+ "FlaxKarrasVeScheduler",
487
+ "FlaxLMSDiscreteScheduler",
488
+ "FlaxPNDMScheduler",
489
+ "FlaxSchedulerMixin",
490
+ "FlaxScoreSdeVeScheduler",
491
+ ]
492
+ )
493
+
494
+
495
+ try:
496
+ if not (is_flax_available() and is_transformers_available()):
497
+ raise OptionalDependencyNotAvailable()
498
+ except OptionalDependencyNotAvailable:
499
+ from .utils import dummy_flax_and_transformers_objects # noqa F403
500
+
501
+ _import_structure["utils.dummy_flax_and_transformers_objects"] = [
502
+ name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_")
503
+ ]
504
+
505
+
506
+ else:
507
+ _import_structure["pipelines"].extend(
508
+ [
509
+ "FlaxStableDiffusionControlNetPipeline",
510
+ "FlaxStableDiffusionImg2ImgPipeline",
511
+ "FlaxStableDiffusionInpaintPipeline",
512
+ "FlaxStableDiffusionPipeline",
513
+ "FlaxStableDiffusionXLPipeline",
514
+ ]
515
+ )
516
+
517
+ try:
518
+ if not (is_note_seq_available()):
519
+ raise OptionalDependencyNotAvailable()
520
+ except OptionalDependencyNotAvailable:
521
+ from .utils import dummy_note_seq_objects # noqa F403
522
+
523
+ _import_structure["utils.dummy_note_seq_objects"] = [
524
+ name for name in dir(dummy_note_seq_objects) if not name.startswith("_")
525
+ ]
526
+
527
+
528
+ else:
529
+ _import_structure["pipelines"].extend(["MidiProcessor"])
530
+
531
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
532
+ from .configuration_utils import ConfigMixin
533
+
534
+ try:
535
+ if not is_onnx_available():
536
+ raise OptionalDependencyNotAvailable()
537
+ except OptionalDependencyNotAvailable:
538
+ from .utils.dummy_onnx_objects import * # noqa F403
539
+ else:
540
+ from .pipelines import OnnxRuntimeModel
541
+
542
+ try:
543
+ if not is_torch_available():
544
+ raise OptionalDependencyNotAvailable()
545
+ except OptionalDependencyNotAvailable:
546
+ from .utils.dummy_pt_objects import * # noqa F403
547
+ else:
548
+ from .models import (
549
+ AsymmetricAutoencoderKL,
550
+ AuraFlowTransformer2DModel,
551
+ AutoencoderKL,
552
+ AutoencoderKLCogVideoX,
553
+ AutoencoderKLTemporalDecoder,
554
+ AutoencoderOobleck,
555
+ AutoencoderTiny,
556
+ CogVideoXTransformer3DModel,
557
+ ConsistencyDecoderVAE,
558
+ ControlNetModel,
559
+ ControlNetXSAdapter,
560
+ DiTTransformer2DModel,
561
+ FluxControlNetModel,
562
+ FluxMultiControlNetModel,
563
+ FluxTransformer2DModel,
564
+ HunyuanDiT2DControlNetModel,
565
+ HunyuanDiT2DModel,
566
+ HunyuanDiT2DMultiControlNetModel,
567
+ I2VGenXLUNet,
568
+ Kandinsky3UNet,
569
+ LatteTransformer3DModel,
570
+ LuminaNextDiT2DModel,
571
+ ModelMixin,
572
+ MotionAdapter,
573
+ MultiAdapter,
574
+ PixArtTransformer2DModel,
575
+ PriorTransformer,
576
+ SD3ControlNetModel,
577
+ SD3MultiControlNetModel,
578
+ SD3Transformer2DModel,
579
+ SparseControlNetModel,
580
+ StableAudioDiTModel,
581
+ T2IAdapter,
582
+ T5FilmDecoder,
583
+ Transformer2DModel,
584
+ UNet1DModel,
585
+ UNet2DConditionModel,
586
+ UNet2DModel,
587
+ UNet3DConditionModel,
588
+ UNetControlNetXSModel,
589
+ UNetMotionModel,
590
+ UNetSpatioTemporalConditionModel,
591
+ UVit2DModel,
592
+ VQModel,
593
+ )
594
+ from .optimization import (
595
+ get_constant_schedule,
596
+ get_constant_schedule_with_warmup,
597
+ get_cosine_schedule_with_warmup,
598
+ get_cosine_with_hard_restarts_schedule_with_warmup,
599
+ get_linear_schedule_with_warmup,
600
+ get_polynomial_decay_schedule_with_warmup,
601
+ get_scheduler,
602
+ )
603
+ from .pipelines import (
604
+ AudioPipelineOutput,
605
+ AutoPipelineForImage2Image,
606
+ AutoPipelineForInpainting,
607
+ AutoPipelineForText2Image,
608
+ BlipDiffusionControlNetPipeline,
609
+ BlipDiffusionPipeline,
610
+ CLIPImageProjection,
611
+ ConsistencyModelPipeline,
612
+ DanceDiffusionPipeline,
613
+ DDIMPipeline,
614
+ DDPMPipeline,
615
+ DiffusionPipeline,
616
+ DiTPipeline,
617
+ ImagePipelineOutput,
618
+ KarrasVePipeline,
619
+ LDMPipeline,
620
+ LDMSuperResolutionPipeline,
621
+ PNDMPipeline,
622
+ RePaintPipeline,
623
+ ScoreSdeVePipeline,
624
+ StableDiffusionMixin,
625
+ )
626
+ from .schedulers import (
627
+ AmusedScheduler,
628
+ CMStochasticIterativeScheduler,
629
+ CogVideoXDDIMScheduler,
630
+ CogVideoXDPMScheduler,
631
+ DDIMInverseScheduler,
632
+ DDIMParallelScheduler,
633
+ DDIMScheduler,
634
+ DDPMParallelScheduler,
635
+ DDPMScheduler,
636
+ DDPMWuerstchenScheduler,
637
+ DEISMultistepScheduler,
638
+ DPMSolverMultistepInverseScheduler,
639
+ DPMSolverMultistepScheduler,
640
+ DPMSolverSinglestepScheduler,
641
+ EDMDPMSolverMultistepScheduler,
642
+ EDMEulerScheduler,
643
+ EulerAncestralDiscreteScheduler,
644
+ EulerDiscreteScheduler,
645
+ FlowMatchEulerDiscreteScheduler,
646
+ FlowMatchHeunDiscreteScheduler,
647
+ HeunDiscreteScheduler,
648
+ IPNDMScheduler,
649
+ KarrasVeScheduler,
650
+ KDPM2AncestralDiscreteScheduler,
651
+ KDPM2DiscreteScheduler,
652
+ LCMScheduler,
653
+ PNDMScheduler,
654
+ RePaintScheduler,
655
+ SASolverScheduler,
656
+ SchedulerMixin,
657
+ ScoreSdeVeScheduler,
658
+ TCDScheduler,
659
+ UnCLIPScheduler,
660
+ UniPCMultistepScheduler,
661
+ VQDiffusionScheduler,
662
+ )
663
+ from .training_utils import EMAModel
664
+
665
+ try:
666
+ if not (is_torch_available() and is_scipy_available()):
667
+ raise OptionalDependencyNotAvailable()
668
+ except OptionalDependencyNotAvailable:
669
+ from .utils.dummy_torch_and_scipy_objects import * # noqa F403
670
+ else:
671
+ from .schedulers import LMSDiscreteScheduler
672
+
673
+ try:
674
+ if not (is_torch_available() and is_torchsde_available()):
675
+ raise OptionalDependencyNotAvailable()
676
+ except OptionalDependencyNotAvailable:
677
+ from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
678
+ else:
679
+ from .schedulers import CosineDPMSolverMultistepScheduler, DPMSolverSDEScheduler
680
+
681
+ try:
682
+ if not (is_torch_available() and is_transformers_available()):
683
+ raise OptionalDependencyNotAvailable()
684
+ except OptionalDependencyNotAvailable:
685
+ from .utils.dummy_torch_and_transformers_objects import * # noqa F403
686
+ else:
687
+ from .pipelines import (
688
+ AltDiffusionImg2ImgPipeline,
689
+ AltDiffusionPipeline,
690
+ AmusedImg2ImgPipeline,
691
+ AmusedInpaintPipeline,
692
+ AmusedPipeline,
693
+ AnimateDiffControlNetPipeline,
694
+ AnimateDiffPAGPipeline,
695
+ AnimateDiffPipeline,
696
+ AnimateDiffSDXLPipeline,
697
+ AnimateDiffSparseControlNetPipeline,
698
+ AnimateDiffVideoToVideoControlNetPipeline,
699
+ AnimateDiffVideoToVideoPipeline,
700
+ AudioLDM2Pipeline,
701
+ AudioLDM2ProjectionModel,
702
+ AudioLDM2UNet2DConditionModel,
703
+ AudioLDMPipeline,
704
+ AuraFlowPipeline,
705
+ CLIPImageProjection,
706
+ CogVideoXPipeline,
707
+ CogVideoXVideoToVideoPipeline,
708
+ CycleDiffusionPipeline,
709
+ FluxControlNetPipeline,
710
+ FluxImg2ImgPipeline,
711
+ FluxInpaintPipeline,
712
+ FluxPipeline,
713
+ HunyuanDiTControlNetPipeline,
714
+ HunyuanDiTPAGPipeline,
715
+ HunyuanDiTPipeline,
716
+ I2VGenXLPipeline,
717
+ IFImg2ImgPipeline,
718
+ IFImg2ImgSuperResolutionPipeline,
719
+ IFInpaintingPipeline,
720
+ IFInpaintingSuperResolutionPipeline,
721
+ IFPipeline,
722
+ IFSuperResolutionPipeline,
723
+ ImageTextPipelineOutput,
724
+ Kandinsky3Img2ImgPipeline,
725
+ Kandinsky3Pipeline,
726
+ KandinskyCombinedPipeline,
727
+ KandinskyImg2ImgCombinedPipeline,
728
+ KandinskyImg2ImgPipeline,
729
+ KandinskyInpaintCombinedPipeline,
730
+ KandinskyInpaintPipeline,
731
+ KandinskyPipeline,
732
+ KandinskyPriorPipeline,
733
+ KandinskyV22CombinedPipeline,
734
+ KandinskyV22ControlnetImg2ImgPipeline,
735
+ KandinskyV22ControlnetPipeline,
736
+ KandinskyV22Img2ImgCombinedPipeline,
737
+ KandinskyV22Img2ImgPipeline,
738
+ KandinskyV22InpaintCombinedPipeline,
739
+ KandinskyV22InpaintPipeline,
740
+ KandinskyV22Pipeline,
741
+ KandinskyV22PriorEmb2EmbPipeline,
742
+ KandinskyV22PriorPipeline,
743
+ LatentConsistencyModelImg2ImgPipeline,
744
+ LatentConsistencyModelPipeline,
745
+ LattePipeline,
746
+ LDMTextToImagePipeline,
747
+ LEditsPPPipelineStableDiffusion,
748
+ LEditsPPPipelineStableDiffusionXL,
749
+ LuminaText2ImgPipeline,
750
+ MarigoldDepthPipeline,
751
+ MarigoldNormalsPipeline,
752
+ MusicLDMPipeline,
753
+ PaintByExamplePipeline,
754
+ PIAPipeline,
755
+ PixArtAlphaPipeline,
756
+ PixArtSigmaPAGPipeline,
757
+ PixArtSigmaPipeline,
758
+ SemanticStableDiffusionPipeline,
759
+ ShapEImg2ImgPipeline,
760
+ ShapEPipeline,
761
+ StableAudioPipeline,
762
+ StableAudioProjectionModel,
763
+ StableCascadeCombinedPipeline,
764
+ StableCascadeDecoderPipeline,
765
+ StableCascadePriorPipeline,
766
+ StableDiffusion3ControlNetPipeline,
767
+ StableDiffusion3Img2ImgPipeline,
768
+ StableDiffusion3InpaintPipeline,
769
+ StableDiffusion3PAGPipeline,
770
+ StableDiffusion3Pipeline,
771
+ StableDiffusionAdapterPipeline,
772
+ StableDiffusionAttendAndExcitePipeline,
773
+ StableDiffusionControlNetImg2ImgPipeline,
774
+ StableDiffusionControlNetInpaintPipeline,
775
+ StableDiffusionControlNetPAGPipeline,
776
+ StableDiffusionControlNetPipeline,
777
+ StableDiffusionControlNetXSPipeline,
778
+ StableDiffusionDepth2ImgPipeline,
779
+ StableDiffusionDiffEditPipeline,
780
+ StableDiffusionGLIGENPipeline,
781
+ StableDiffusionGLIGENTextImagePipeline,
782
+ StableDiffusionImageVariationPipeline,
783
+ StableDiffusionImg2ImgPipeline,
784
+ StableDiffusionInpaintPipeline,
785
+ StableDiffusionInpaintPipelineLegacy,
786
+ StableDiffusionInstructPix2PixPipeline,
787
+ StableDiffusionLatentUpscalePipeline,
788
+ StableDiffusionLDM3DPipeline,
789
+ StableDiffusionModelEditingPipeline,
790
+ StableDiffusionPAGPipeline,
791
+ StableDiffusionPanoramaPipeline,
792
+ StableDiffusionParadigmsPipeline,
793
+ StableDiffusionPipeline,
794
+ StableDiffusionPipelineSafe,
795
+ StableDiffusionPix2PixZeroPipeline,
796
+ StableDiffusionSAGPipeline,
797
+ StableDiffusionUpscalePipeline,
798
+ StableDiffusionXLAdapterPipeline,
799
+ StableDiffusionXLControlNetImg2ImgPipeline,
800
+ StableDiffusionXLControlNetInpaintPipeline,
801
+ StableDiffusionXLControlNetPAGImg2ImgPipeline,
802
+ StableDiffusionXLControlNetPAGPipeline,
803
+ StableDiffusionXLControlNetPipeline,
804
+ StableDiffusionXLControlNetXSPipeline,
805
+ StableDiffusionXLImg2ImgPipeline,
806
+ StableDiffusionXLInpaintPipeline,
807
+ StableDiffusionXLInstructPix2PixPipeline,
808
+ StableDiffusionXLPAGImg2ImgPipeline,
809
+ StableDiffusionXLPAGInpaintPipeline,
810
+ StableDiffusionXLPAGPipeline,
811
+ StableDiffusionXLPipeline,
812
+ StableUnCLIPImg2ImgPipeline,
813
+ StableUnCLIPPipeline,
814
+ StableVideoDiffusionPipeline,
815
+ TextToVideoSDPipeline,
816
+ TextToVideoZeroPipeline,
817
+ TextToVideoZeroSDXLPipeline,
818
+ UnCLIPImageVariationPipeline,
819
+ UnCLIPPipeline,
820
+ UniDiffuserModel,
821
+ UniDiffuserPipeline,
822
+ UniDiffuserTextDecoder,
823
+ VersatileDiffusionDualGuidedPipeline,
824
+ VersatileDiffusionImageVariationPipeline,
825
+ VersatileDiffusionPipeline,
826
+ VersatileDiffusionTextToImagePipeline,
827
+ VideoToVideoSDPipeline,
828
+ VQDiffusionPipeline,
829
+ WuerstchenCombinedPipeline,
830
+ WuerstchenDecoderPipeline,
831
+ WuerstchenPriorPipeline,
832
+ )
833
+
834
+ try:
835
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
836
+ raise OptionalDependencyNotAvailable()
837
+ except OptionalDependencyNotAvailable:
838
+ from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
839
+ else:
840
+ from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline
841
+
842
+ try:
843
+ if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
844
+ raise OptionalDependencyNotAvailable()
845
+ except OptionalDependencyNotAvailable:
846
+ from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
847
+ else:
848
+ from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
849
+ try:
850
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
851
+ raise OptionalDependencyNotAvailable()
852
+ except OptionalDependencyNotAvailable:
853
+ from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
854
+ else:
855
+ from .pipelines import (
856
+ OnnxStableDiffusionImg2ImgPipeline,
857
+ OnnxStableDiffusionInpaintPipeline,
858
+ OnnxStableDiffusionInpaintPipelineLegacy,
859
+ OnnxStableDiffusionPipeline,
860
+ OnnxStableDiffusionUpscalePipeline,
861
+ StableDiffusionOnnxPipeline,
862
+ )
863
+
864
+ try:
865
+ if not (is_torch_available() and is_librosa_available()):
866
+ raise OptionalDependencyNotAvailable()
867
+ except OptionalDependencyNotAvailable:
868
+ from .utils.dummy_torch_and_librosa_objects import * # noqa F403
869
+ else:
870
+ from .pipelines import AudioDiffusionPipeline, Mel
871
+
872
+ try:
873
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
874
+ raise OptionalDependencyNotAvailable()
875
+ except OptionalDependencyNotAvailable:
876
+ from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
877
+ else:
878
+ from .pipelines import SpectrogramDiffusionPipeline
879
+
880
+ try:
881
+ if not is_flax_available():
882
+ raise OptionalDependencyNotAvailable()
883
+ except OptionalDependencyNotAvailable:
884
+ from .utils.dummy_flax_objects import * # noqa F403
885
+ else:
886
+ from .models.controlnet_flax import FlaxControlNetModel
887
+ from .models.modeling_flax_utils import FlaxModelMixin
888
+ from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
889
+ from .models.vae_flax import FlaxAutoencoderKL
890
+ from .pipelines import FlaxDiffusionPipeline
891
+ from .schedulers import (
892
+ FlaxDDIMScheduler,
893
+ FlaxDDPMScheduler,
894
+ FlaxDPMSolverMultistepScheduler,
895
+ FlaxEulerDiscreteScheduler,
896
+ FlaxKarrasVeScheduler,
897
+ FlaxLMSDiscreteScheduler,
898
+ FlaxPNDMScheduler,
899
+ FlaxSchedulerMixin,
900
+ FlaxScoreSdeVeScheduler,
901
+ )
902
+
903
+ try:
904
+ if not (is_flax_available() and is_transformers_available()):
905
+ raise OptionalDependencyNotAvailable()
906
+ except OptionalDependencyNotAvailable:
907
+ from .utils.dummy_flax_and_transformers_objects import * # noqa F403
908
+ else:
909
+ from .pipelines import (
910
+ FlaxStableDiffusionControlNetPipeline,
911
+ FlaxStableDiffusionImg2ImgPipeline,
912
+ FlaxStableDiffusionInpaintPipeline,
913
+ FlaxStableDiffusionPipeline,
914
+ FlaxStableDiffusionXLPipeline,
915
+ )
916
+
917
+ try:
918
+ if not (is_note_seq_available()):
919
+ raise OptionalDependencyNotAvailable()
920
+ except OptionalDependencyNotAvailable:
921
+ from .utils.dummy_note_seq_objects import * # noqa F403
922
+ else:
923
+ from .pipelines import MidiProcessor
924
+
925
+ else:
926
+ import sys
927
+
928
+ sys.modules[__name__] = _LazyModule(
929
+ __name__,
930
+ globals()["__file__"],
931
+ _import_structure,
932
+ module_spec=__spec__,
933
+ extra_objects={"__version__": __version__},
934
+ )
diffusers3/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (18.8 kB). View file
 
diffusers3/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (20.9 kB). View file
 
diffusers3/__pycache__/callbacks.cpython-310.pyc ADDED
Binary file (6.18 kB). View file
 
diffusers3/__pycache__/callbacks.cpython-38.pyc ADDED
Binary file (6.2 kB). View file
 
diffusers3/__pycache__/configuration_utils.cpython-310.pyc ADDED
Binary file (24.5 kB). View file
 
diffusers3/__pycache__/configuration_utils.cpython-38.pyc ADDED
Binary file (24.7 kB). View file
 
diffusers3/__pycache__/dependency_versions_check.cpython-310.pyc ADDED
Binary file (660 Bytes). View file
 
diffusers3/__pycache__/dependency_versions_check.cpython-38.pyc ADDED
Binary file (665 Bytes). View file
 
diffusers3/__pycache__/dependency_versions_table.cpython-310.pyc ADDED
Binary file (1.41 kB). View file
 
diffusers3/__pycache__/dependency_versions_table.cpython-38.pyc ADDED
Binary file (1.27 kB). View file
 
diffusers3/__pycache__/image_processor.cpython-310.pyc ADDED
Binary file (34.1 kB). View file
 
diffusers3/__pycache__/image_processor.cpython-38.pyc ADDED
Binary file (34.3 kB). View file
 
diffusers3/callbacks.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ from .configuration_utils import ConfigMixin, register_to_config
4
+ from .utils import CONFIG_NAME
5
+
6
+
7
+ class PipelineCallback(ConfigMixin):
8
+ """
9
+ Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing
10
+ custom callbacks and ensures that all callbacks have a consistent interface.
11
+
12
+ Please implement the following:
13
+ `tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to
14
+ include
15
+ variables listed in the `._callback_tensor_inputs` attribute of your pipeline class.
16
+ `callback_fn`: This method defines the core functionality of your callback.
17
+ """
18
+
19
+ config_name = CONFIG_NAME
20
+
21
+ @register_to_config
22
+ def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None):
23
+ super().__init__()
24
+
25
+ if (cutoff_step_ratio is None and cutoff_step_index is None) or (
26
+ cutoff_step_ratio is not None and cutoff_step_index is not None
27
+ ):
28
+ raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.")
29
+
30
+ if cutoff_step_ratio is not None and (
31
+ not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0)
32
+ ):
33
+ raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")
34
+
35
+ @property
36
+ def tensor_inputs(self) -> List[str]:
37
+ raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")
38
+
39
+ def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
40
+ raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")
41
+
42
+ def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
43
+ return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)
44
+
45
+
46
+ class MultiPipelineCallbacks:
47
+ """
48
+ This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and
49
+ provides a unified interface for calling all of them.
50
+ """
51
+
52
+ def __init__(self, callbacks: List[PipelineCallback]):
53
+ self.callbacks = callbacks
54
+
55
+ @property
56
+ def tensor_inputs(self) -> List[str]:
57
+ return [input for callback in self.callbacks for input in callback.tensor_inputs]
58
+
59
+ def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
60
+ """
61
+ Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
62
+ """
63
+ for callback in self.callbacks:
64
+ callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs)
65
+
66
+ return callback_kwargs
67
+
68
+
69
+ class SDCFGCutoffCallback(PipelineCallback):
70
+ """
71
+ Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
72
+ `cutoff_step_index`), this callback will disable the CFG.
73
+
74
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
75
+ """
76
+
77
+ tensor_inputs = ["prompt_embeds"]
78
+
79
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
80
+ cutoff_step_ratio = self.config.cutoff_step_ratio
81
+ cutoff_step_index = self.config.cutoff_step_index
82
+
83
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
84
+ cutoff_step = (
85
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
86
+ )
87
+
88
+ if step_index == cutoff_step:
89
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
90
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
91
+
92
+ pipeline._guidance_scale = 0.0
93
+
94
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
95
+ return callback_kwargs
96
+
97
+
98
+ class SDXLCFGCutoffCallback(PipelineCallback):
99
+ """
100
+ Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
101
+ `cutoff_step_index`), this callback will disable the CFG.
102
+
103
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
104
+ """
105
+
106
+ tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
107
+
108
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
109
+ cutoff_step_ratio = self.config.cutoff_step_ratio
110
+ cutoff_step_index = self.config.cutoff_step_index
111
+
112
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
113
+ cutoff_step = (
114
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
115
+ )
116
+
117
+ if step_index == cutoff_step:
118
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
119
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
120
+
121
+ add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
122
+ add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
123
+
124
+ add_time_ids = callback_kwargs[self.tensor_inputs[2]]
125
+ add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
126
+
127
+ pipeline._guidance_scale = 0.0
128
+
129
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
130
+ callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
131
+ callback_kwargs[self.tensor_inputs[2]] = add_time_ids
132
+ return callback_kwargs
133
+
134
+
135
+ class IPAdapterScaleCutoffCallback(PipelineCallback):
136
+ """
137
+ Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by
138
+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`.
139
+
140
+ Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step.
141
+ """
142
+
143
+ tensor_inputs = []
144
+
145
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
146
+ cutoff_step_ratio = self.config.cutoff_step_ratio
147
+ cutoff_step_index = self.config.cutoff_step_index
148
+
149
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
150
+ cutoff_step = (
151
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
152
+ )
153
+
154
+ if step_index == cutoff_step:
155
+ pipeline.set_ip_adapter_scale(0.0)
156
+ return callback_kwargs
diffusers3/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseDiffusersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
diffusers3/commands/diffusers_cli.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .env import EnvironmentCommand
19
+ from .fp16_safetensors import FP16SafetensorsCommand
20
+
21
+
22
+ def main():
23
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
24
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
25
+
26
+ # Register commands
27
+ EnvironmentCommand.register_subcommand(commands_parser)
28
+ FP16SafetensorsCommand.register_subcommand(commands_parser)
29
+
30
+ # Let's go
31
+ args = parser.parse_args()
32
+
33
+ if not hasattr(args, "func"):
34
+ parser.print_help()
35
+ exit(1)
36
+
37
+ # Run
38
+ service = args.func(args)
39
+ service.run()
40
+
41
+
42
+ if __name__ == "__main__":
43
+ main()
diffusers3/commands/env.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ import subprocess
17
+ from argparse import ArgumentParser
18
+
19
+ import huggingface_hub
20
+
21
+ from .. import __version__ as version
22
+ from ..utils import (
23
+ is_accelerate_available,
24
+ is_bitsandbytes_available,
25
+ is_flax_available,
26
+ is_google_colab,
27
+ is_peft_available,
28
+ is_safetensors_available,
29
+ is_torch_available,
30
+ is_transformers_available,
31
+ is_xformers_available,
32
+ )
33
+ from . import BaseDiffusersCLICommand
34
+
35
+
36
+ def info_command_factory(_):
37
+ return EnvironmentCommand()
38
+
39
+
40
+ class EnvironmentCommand(BaseDiffusersCLICommand):
41
+ @staticmethod
42
+ def register_subcommand(parser: ArgumentParser) -> None:
43
+ download_parser = parser.add_parser("env")
44
+ download_parser.set_defaults(func=info_command_factory)
45
+
46
+ def run(self) -> dict:
47
+ hub_version = huggingface_hub.__version__
48
+
49
+ safetensors_version = "not installed"
50
+ if is_safetensors_available():
51
+ import safetensors
52
+
53
+ safetensors_version = safetensors.__version__
54
+
55
+ pt_version = "not installed"
56
+ pt_cuda_available = "NA"
57
+ if is_torch_available():
58
+ import torch
59
+
60
+ pt_version = torch.__version__
61
+ pt_cuda_available = torch.cuda.is_available()
62
+
63
+ flax_version = "not installed"
64
+ jax_version = "not installed"
65
+ jaxlib_version = "not installed"
66
+ jax_backend = "NA"
67
+ if is_flax_available():
68
+ import flax
69
+ import jax
70
+ import jaxlib
71
+
72
+ flax_version = flax.__version__
73
+ jax_version = jax.__version__
74
+ jaxlib_version = jaxlib.__version__
75
+ jax_backend = jax.lib.xla_bridge.get_backend().platform
76
+
77
+ transformers_version = "not installed"
78
+ if is_transformers_available():
79
+ import transformers
80
+
81
+ transformers_version = transformers.__version__
82
+
83
+ accelerate_version = "not installed"
84
+ if is_accelerate_available():
85
+ import accelerate
86
+
87
+ accelerate_version = accelerate.__version__
88
+
89
+ peft_version = "not installed"
90
+ if is_peft_available():
91
+ import peft
92
+
93
+ peft_version = peft.__version__
94
+
95
+ bitsandbytes_version = "not installed"
96
+ if is_bitsandbytes_available():
97
+ import bitsandbytes
98
+
99
+ bitsandbytes_version = bitsandbytes.__version__
100
+
101
+ xformers_version = "not installed"
102
+ if is_xformers_available():
103
+ import xformers
104
+
105
+ xformers_version = xformers.__version__
106
+
107
+ platform_info = platform.platform()
108
+
109
+ is_google_colab_str = "Yes" if is_google_colab() else "No"
110
+
111
+ accelerator = "NA"
112
+ if platform.system() in {"Linux", "Windows"}:
113
+ try:
114
+ sp = subprocess.Popen(
115
+ ["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader"],
116
+ stdout=subprocess.PIPE,
117
+ stderr=subprocess.PIPE,
118
+ )
119
+ out_str, _ = sp.communicate()
120
+ out_str = out_str.decode("utf-8")
121
+
122
+ if len(out_str) > 0:
123
+ accelerator = out_str.strip()
124
+ except FileNotFoundError:
125
+ pass
126
+ elif platform.system() == "Darwin": # Mac OS
127
+ try:
128
+ sp = subprocess.Popen(
129
+ ["system_profiler", "SPDisplaysDataType"],
130
+ stdout=subprocess.PIPE,
131
+ stderr=subprocess.PIPE,
132
+ )
133
+ out_str, _ = sp.communicate()
134
+ out_str = out_str.decode("utf-8")
135
+
136
+ start = out_str.find("Chipset Model:")
137
+ if start != -1:
138
+ start += len("Chipset Model:")
139
+ end = out_str.find("\n", start)
140
+ accelerator = out_str[start:end].strip()
141
+
142
+ start = out_str.find("VRAM (Total):")
143
+ if start != -1:
144
+ start += len("VRAM (Total):")
145
+ end = out_str.find("\n", start)
146
+ accelerator += " VRAM: " + out_str[start:end].strip()
147
+ except FileNotFoundError:
148
+ pass
149
+ else:
150
+ print("It seems you are running an unusual OS. Could you fill in the accelerator manually?")
151
+
152
+ info = {
153
+ "🤗 Diffusers version": version,
154
+ "Platform": platform_info,
155
+ "Running on Google Colab?": is_google_colab_str,
156
+ "Python version": platform.python_version(),
157
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
158
+ "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
159
+ "Jax version": jax_version,
160
+ "JaxLib version": jaxlib_version,
161
+ "Huggingface_hub version": hub_version,
162
+ "Transformers version": transformers_version,
163
+ "Accelerate version": accelerate_version,
164
+ "PEFT version": peft_version,
165
+ "Bitsandbytes version": bitsandbytes_version,
166
+ "Safetensors version": safetensors_version,
167
+ "xFormers version": xformers_version,
168
+ "Accelerator": accelerator,
169
+ "Using GPU in script?": "<fill in>",
170
+ "Using distributed or parallel set-up in script?": "<fill in>",
171
+ }
172
+
173
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
174
+ print(self.format_dict(info))
175
+
176
+ return info
177
+
178
+ @staticmethod
179
+ def format_dict(d: dict) -> str:
180
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
diffusers3/commands/fp16_safetensors.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Usage example:
17
+ diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors
18
+ """
19
+
20
+ import glob
21
+ import json
22
+ import warnings
23
+ from argparse import ArgumentParser, Namespace
24
+ from importlib import import_module
25
+
26
+ import huggingface_hub
27
+ import torch
28
+ from huggingface_hub import hf_hub_download
29
+ from packaging import version
30
+
31
+ from ..utils import logging
32
+ from . import BaseDiffusersCLICommand
33
+
34
+
35
+ def conversion_command_factory(args: Namespace):
36
+ if args.use_auth_token:
37
+ warnings.warn(
38
+ "The `--use_auth_token` flag is deprecated and will be removed in a future version. Authentication is now"
39
+ " handled automatically if user is logged in."
40
+ )
41
+ return FP16SafetensorsCommand(args.ckpt_id, args.fp16, args.use_safetensors)
42
+
43
+
44
+ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
45
+ @staticmethod
46
+ def register_subcommand(parser: ArgumentParser):
47
+ conversion_parser = parser.add_parser("fp16_safetensors")
48
+ conversion_parser.add_argument(
49
+ "--ckpt_id",
50
+ type=str,
51
+ help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.",
52
+ )
53
+ conversion_parser.add_argument(
54
+ "--fp16", action="store_true", help="If serializing the variables in FP16 precision."
55
+ )
56
+ conversion_parser.add_argument(
57
+ "--use_safetensors", action="store_true", help="If serializing in the safetensors format."
58
+ )
59
+ conversion_parser.add_argument(
60
+ "--use_auth_token",
61
+ action="store_true",
62
+ help="When working with checkpoints having private visibility. When used `huggingface-cli login` needs to be run beforehand.",
63
+ )
64
+ conversion_parser.set_defaults(func=conversion_command_factory)
65
+
66
+ def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool):
67
+ self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
68
+ self.ckpt_id = ckpt_id
69
+ self.local_ckpt_dir = f"/tmp/{ckpt_id}"
70
+ self.fp16 = fp16
71
+
72
+ self.use_safetensors = use_safetensors
73
+
74
+ if not self.use_safetensors and not self.fp16:
75
+ raise NotImplementedError(
76
+ "When `use_safetensors` and `fp16` both are False, then this command is of no use."
77
+ )
78
+
79
+ def run(self):
80
+ if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
81
+ raise ImportError(
82
+ "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
83
+ " installation."
84
+ )
85
+ else:
86
+ from huggingface_hub import create_commit
87
+ from huggingface_hub._commit_api import CommitOperationAdd
88
+
89
+ model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json")
90
+ with open(model_index, "r") as f:
91
+ pipeline_class_name = json.load(f)["_class_name"]
92
+ pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
93
+ self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
94
+
95
+ # Load the appropriate pipeline. We could have use `DiffusionPipeline`
96
+ # here, but just to avoid any rough edge cases.
97
+ pipeline = pipeline_class.from_pretrained(
98
+ self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32
99
+ )
100
+ pipeline.save_pretrained(
101
+ self.local_ckpt_dir,
102
+ safe_serialization=True if self.use_safetensors else False,
103
+ variant="fp16" if self.fp16 else None,
104
+ )
105
+ self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.")
106
+
107
+ # Fetch all the paths.
108
+ if self.fp16:
109
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*")
110
+ elif self.use_safetensors:
111
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors")
112
+
113
+ # Prepare for the PR.
114
+ commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}."
115
+ operations = []
116
+ for path in modified_paths:
117
+ operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path))
118
+
119
+ # Open the PR.
120
+ commit_description = (
121
+ "Variables converted by the [`diffusers`' `fp16_safetensors`"
122
+ " CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)."
123
+ )
124
+ hub_pr_url = create_commit(
125
+ repo_id=self.ckpt_id,
126
+ operations=operations,
127
+ commit_message=commit_message,
128
+ commit_description=commit_description,
129
+ repo_type="model",
130
+ create_pr=True,
131
+ ).pr_url
132
+ self.logger.info(f"PR created here: {hub_pr_url}.")
diffusers3/configuration_utils.py ADDED
@@ -0,0 +1,720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ConfigMixin base class and utilities."""
17
+
18
+ import dataclasses
19
+ import functools
20
+ import importlib
21
+ import inspect
22
+ import json
23
+ import os
24
+ import re
25
+ from collections import OrderedDict
26
+ from pathlib import Path
27
+ from typing import Any, Dict, Tuple, Union
28
+
29
+ import numpy as np
30
+ from huggingface_hub import create_repo, hf_hub_download
31
+ from huggingface_hub.utils import (
32
+ EntryNotFoundError,
33
+ RepositoryNotFoundError,
34
+ RevisionNotFoundError,
35
+ validate_hf_hub_args,
36
+ )
37
+ from requests import HTTPError
38
+
39
+ from . import __version__
40
+ from .utils import (
41
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
42
+ DummyObject,
43
+ deprecate,
44
+ extract_commit_hash,
45
+ http_user_agent,
46
+ logging,
47
+ )
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
53
+
54
+
55
+ class FrozenDict(OrderedDict):
56
+ def __init__(self, *args, **kwargs):
57
+ super().__init__(*args, **kwargs)
58
+
59
+ for key, value in self.items():
60
+ setattr(self, key, value)
61
+
62
+ self.__frozen = True
63
+
64
+ def __delitem__(self, *args, **kwargs):
65
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
66
+
67
+ def setdefault(self, *args, **kwargs):
68
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
69
+
70
+ def pop(self, *args, **kwargs):
71
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
72
+
73
+ def update(self, *args, **kwargs):
74
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
75
+
76
+ def __setattr__(self, name, value):
77
+ if hasattr(self, "__frozen") and self.__frozen:
78
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
79
+ super().__setattr__(name, value)
80
+
81
+ def __setitem__(self, name, value):
82
+ if hasattr(self, "__frozen") and self.__frozen:
83
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
84
+ super().__setitem__(name, value)
85
+
86
+
87
+ class ConfigMixin:
88
+ r"""
89
+ Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
90
+ provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
91
+ saving classes that inherit from [`ConfigMixin`].
92
+
93
+ Class attributes:
94
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
95
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
96
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
97
+ overridden by subclass).
98
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
99
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
100
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
101
+ subclass).
102
+ """
103
+
104
+ config_name = None
105
+ ignore_for_config = []
106
+ has_compatibles = False
107
+
108
+ _deprecated_kwargs = []
109
+
110
+ def register_to_config(self, **kwargs):
111
+ if self.config_name is None:
112
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
113
+ # Special case for `kwargs` used in deprecation warning added to schedulers
114
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
115
+ # or solve in a more general way.
116
+ kwargs.pop("kwargs", None)
117
+
118
+ if not hasattr(self, "_internal_dict"):
119
+ internal_dict = kwargs
120
+ else:
121
+ previous_dict = dict(self._internal_dict)
122
+ internal_dict = {**self._internal_dict, **kwargs}
123
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
124
+
125
+ self._internal_dict = FrozenDict(internal_dict)
126
+
127
+ def __getattr__(self, name: str) -> Any:
128
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
129
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
130
+
131
+ This function is mostly copied from PyTorch's __getattr__ overwrite:
132
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
133
+ """
134
+
135
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
136
+ is_attribute = name in self.__dict__
137
+
138
+ if is_in_config and not is_attribute:
139
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
140
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
141
+ return self._internal_dict[name]
142
+
143
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
144
+
145
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
146
+ """
147
+ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
148
+ [`~ConfigMixin.from_config`] class method.
149
+
150
+ Args:
151
+ save_directory (`str` or `os.PathLike`):
152
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
153
+ push_to_hub (`bool`, *optional*, defaults to `False`):
154
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
155
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
156
+ namespace).
157
+ kwargs (`Dict[str, Any]`, *optional*):
158
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
159
+ """
160
+ if os.path.isfile(save_directory):
161
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
162
+
163
+ os.makedirs(save_directory, exist_ok=True)
164
+
165
+ # If we save using the predefined names, we can load using `from_config`
166
+ output_config_file = os.path.join(save_directory, self.config_name)
167
+
168
+ self.to_json_file(output_config_file)
169
+ logger.info(f"Configuration saved in {output_config_file}")
170
+
171
+ if push_to_hub:
172
+ commit_message = kwargs.pop("commit_message", None)
173
+ private = kwargs.pop("private", False)
174
+ create_pr = kwargs.pop("create_pr", False)
175
+ token = kwargs.pop("token", None)
176
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
177
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
178
+
179
+ self._upload_folder(
180
+ save_directory,
181
+ repo_id,
182
+ token=token,
183
+ commit_message=commit_message,
184
+ create_pr=create_pr,
185
+ )
186
+
187
+ @classmethod
188
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
189
+ r"""
190
+ Instantiate a Python class from a config dictionary.
191
+
192
+ Parameters:
193
+ config (`Dict[str, Any]`):
194
+ A config dictionary from which the Python class is instantiated. Make sure to only load configuration
195
+ files of compatible classes.
196
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
197
+ Whether kwargs that are not consumed by the Python class should be returned or not.
198
+ kwargs (remaining dictionary of keyword arguments, *optional*):
199
+ Can be used to update the configuration object (after it is loaded) and initiate the Python class.
200
+ `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
201
+ overwrite the same named arguments in `config`.
202
+
203
+ Returns:
204
+ [`ModelMixin`] or [`SchedulerMixin`]:
205
+ A model or scheduler object instantiated from a config dictionary.
206
+
207
+ Examples:
208
+
209
+ ```python
210
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
211
+
212
+ >>> # Download scheduler from huggingface.co and cache.
213
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
214
+
215
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
216
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
217
+
218
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
219
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
220
+ ```
221
+ """
222
+ # <===== TO BE REMOVED WITH DEPRECATION
223
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
224
+ if "pretrained_model_name_or_path" in kwargs:
225
+ config = kwargs.pop("pretrained_model_name_or_path")
226
+
227
+ if config is None:
228
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
229
+ # ======>
230
+
231
+ if not isinstance(config, dict):
232
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
233
+ if "Scheduler" in cls.__name__:
234
+ deprecation_message += (
235
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
236
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
237
+ " be removed in v1.0.0."
238
+ )
239
+ elif "Model" in cls.__name__:
240
+ deprecation_message += (
241
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
242
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
243
+ " instead. This functionality will be removed in v1.0.0."
244
+ )
245
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
246
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
247
+
248
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
249
+
250
+ # Allow dtype to be specified on initialization
251
+ if "dtype" in unused_kwargs:
252
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
253
+
254
+ # add possible deprecated kwargs
255
+ for deprecated_kwarg in cls._deprecated_kwargs:
256
+ if deprecated_kwarg in unused_kwargs:
257
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
258
+
259
+ # Return model and optionally state and/or unused_kwargs
260
+ model = cls(**init_dict)
261
+
262
+ # make sure to also save config parameters that might be used for compatible classes
263
+ # update _class_name
264
+ if "_class_name" in hidden_dict:
265
+ hidden_dict["_class_name"] = cls.__name__
266
+
267
+ model.register_to_config(**hidden_dict)
268
+
269
+ # add hidden kwargs of compatible classes to unused_kwargs
270
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
271
+
272
+ if return_unused_kwargs:
273
+ return (model, unused_kwargs)
274
+ else:
275
+ return model
276
+
277
+ @classmethod
278
+ def get_config_dict(cls, *args, **kwargs):
279
+ deprecation_message = (
280
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
281
+ " removed in version v1.0.0"
282
+ )
283
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
284
+ return cls.load_config(*args, **kwargs)
285
+
286
+ @classmethod
287
+ @validate_hf_hub_args
288
+ def load_config(
289
+ cls,
290
+ pretrained_model_name_or_path: Union[str, os.PathLike],
291
+ return_unused_kwargs=False,
292
+ return_commit_hash=False,
293
+ **kwargs,
294
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
295
+ r"""
296
+ Load a model or scheduler configuration.
297
+
298
+ Parameters:
299
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
300
+ Can be either:
301
+
302
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
303
+ the Hub.
304
+ - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
305
+ [`~ConfigMixin.save_config`].
306
+
307
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
308
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
309
+ is not used.
310
+ force_download (`bool`, *optional*, defaults to `False`):
311
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
312
+ cached versions if they exist.
313
+ proxies (`Dict[str, str]`, *optional*):
314
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
315
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
316
+ output_loading_info(`bool`, *optional*, defaults to `False`):
317
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
318
+ local_files_only (`bool`, *optional*, defaults to `False`):
319
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
320
+ won't be downloaded from the Hub.
321
+ token (`str` or *bool*, *optional*):
322
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
323
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
324
+ revision (`str`, *optional*, defaults to `"main"`):
325
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
326
+ allowed by Git.
327
+ subfolder (`str`, *optional*, defaults to `""`):
328
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
329
+ return_unused_kwargs (`bool`, *optional*, defaults to `False):
330
+ Whether unused keyword arguments of the config are returned.
331
+ return_commit_hash (`bool`, *optional*, defaults to `False):
332
+ Whether the `commit_hash` of the loaded configuration are returned.
333
+
334
+ Returns:
335
+ `dict`:
336
+ A dictionary of all the parameters stored in a JSON configuration file.
337
+
338
+ """
339
+ cache_dir = kwargs.pop("cache_dir", None)
340
+ local_dir = kwargs.pop("local_dir", None)
341
+ local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto")
342
+ force_download = kwargs.pop("force_download", False)
343
+ proxies = kwargs.pop("proxies", None)
344
+ token = kwargs.pop("token", None)
345
+ local_files_only = kwargs.pop("local_files_only", False)
346
+ revision = kwargs.pop("revision", None)
347
+ _ = kwargs.pop("mirror", None)
348
+ subfolder = kwargs.pop("subfolder", None)
349
+ user_agent = kwargs.pop("user_agent", {})
350
+
351
+ user_agent = {**user_agent, "file_type": "config"}
352
+ user_agent = http_user_agent(user_agent)
353
+
354
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
355
+
356
+ if cls.config_name is None:
357
+ raise ValueError(
358
+ "`self.config_name` is not defined. Note that one should not load a config from "
359
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
360
+ )
361
+
362
+ if os.path.isfile(pretrained_model_name_or_path):
363
+ config_file = pretrained_model_name_or_path
364
+ elif os.path.isdir(pretrained_model_name_or_path):
365
+ if subfolder is not None and os.path.isfile(
366
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
367
+ ):
368
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
369
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
370
+ # Load from a PyTorch checkpoint
371
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
372
+ else:
373
+ raise EnvironmentError(
374
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
375
+ )
376
+ else:
377
+ try:
378
+ # Load from URL or cache if already cached
379
+ config_file = hf_hub_download(
380
+ pretrained_model_name_or_path,
381
+ filename=cls.config_name,
382
+ cache_dir=cache_dir,
383
+ force_download=force_download,
384
+ proxies=proxies,
385
+ local_files_only=local_files_only,
386
+ token=token,
387
+ user_agent=user_agent,
388
+ subfolder=subfolder,
389
+ revision=revision,
390
+ local_dir=local_dir,
391
+ local_dir_use_symlinks=local_dir_use_symlinks,
392
+ )
393
+ except RepositoryNotFoundError:
394
+ raise EnvironmentError(
395
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
396
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
397
+ " token having permission to this repo with `token` or log in with `huggingface-cli login`."
398
+ )
399
+ except RevisionNotFoundError:
400
+ raise EnvironmentError(
401
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
402
+ " this model name. Check the model page at"
403
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
404
+ )
405
+ except EntryNotFoundError:
406
+ raise EnvironmentError(
407
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
408
+ )
409
+ except HTTPError as err:
410
+ raise EnvironmentError(
411
+ "There was a specific connection error when trying to load"
412
+ f" {pretrained_model_name_or_path}:\n{err}"
413
+ )
414
+ except ValueError:
415
+ raise EnvironmentError(
416
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
417
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
418
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
419
+ " run the library in offline mode at"
420
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
421
+ )
422
+ except EnvironmentError:
423
+ raise EnvironmentError(
424
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
425
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
426
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
427
+ f"containing a {cls.config_name} file"
428
+ )
429
+
430
+ try:
431
+ # Load config dict
432
+ config_dict = cls._dict_from_json_file(config_file)
433
+
434
+ commit_hash = extract_commit_hash(config_file)
435
+ except (json.JSONDecodeError, UnicodeDecodeError):
436
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
437
+
438
+ if not (return_unused_kwargs or return_commit_hash):
439
+ return config_dict
440
+
441
+ outputs = (config_dict,)
442
+
443
+ if return_unused_kwargs:
444
+ outputs += (kwargs,)
445
+
446
+ if return_commit_hash:
447
+ outputs += (commit_hash,)
448
+
449
+ return outputs
450
+
451
+ @staticmethod
452
+ def _get_init_keys(input_class):
453
+ return set(dict(inspect.signature(input_class.__init__).parameters).keys())
454
+
455
+ @classmethod
456
+ def extract_init_dict(cls, config_dict, **kwargs):
457
+ # Skip keys that were not present in the original config, so default __init__ values were used
458
+ used_defaults = config_dict.get("_use_default_values", [])
459
+ config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
460
+
461
+ # 0. Copy origin config dict
462
+ original_dict = dict(config_dict.items())
463
+
464
+ # 1. Retrieve expected config attributes from __init__ signature
465
+ expected_keys = cls._get_init_keys(cls)
466
+ expected_keys.remove("self")
467
+ # remove general kwargs if present in dict
468
+ if "kwargs" in expected_keys:
469
+ expected_keys.remove("kwargs")
470
+ # remove flax internal keys
471
+ if hasattr(cls, "_flax_internal_args"):
472
+ for arg in cls._flax_internal_args:
473
+ expected_keys.remove(arg)
474
+
475
+ # 2. Remove attributes that cannot be expected from expected config attributes
476
+ # remove keys to be ignored
477
+ if len(cls.ignore_for_config) > 0:
478
+ expected_keys = expected_keys - set(cls.ignore_for_config)
479
+
480
+ # load diffusers library to import compatible and original scheduler
481
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
482
+
483
+ if cls.has_compatibles:
484
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
485
+ else:
486
+ compatible_classes = []
487
+
488
+ expected_keys_comp_cls = set()
489
+ for c in compatible_classes:
490
+ expected_keys_c = cls._get_init_keys(c)
491
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
492
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
493
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
494
+
495
+ # remove attributes from orig class that cannot be expected
496
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
497
+ if (
498
+ isinstance(orig_cls_name, str)
499
+ and orig_cls_name != cls.__name__
500
+ and hasattr(diffusers_library, orig_cls_name)
501
+ ):
502
+ orig_cls = getattr(diffusers_library, orig_cls_name)
503
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
504
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
505
+ elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
506
+ raise ValueError(
507
+ "Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
508
+ )
509
+
510
+ # remove private attributes
511
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
512
+
513
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
514
+ init_dict = {}
515
+ for key in expected_keys:
516
+ # if config param is passed to kwarg and is present in config dict
517
+ # it should overwrite existing config dict key
518
+ if key in kwargs and key in config_dict:
519
+ config_dict[key] = kwargs.pop(key)
520
+
521
+ if key in kwargs:
522
+ # overwrite key
523
+ init_dict[key] = kwargs.pop(key)
524
+ elif key in config_dict:
525
+ # use value from config dict
526
+ init_dict[key] = config_dict.pop(key)
527
+
528
+ # 4. Give nice warning if unexpected values have been passed
529
+ if len(config_dict) > 0:
530
+ logger.warning(
531
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
532
+ "but are not expected and will be ignored. Please verify your "
533
+ f"{cls.config_name} configuration file."
534
+ )
535
+
536
+ # 5. Give nice info if config attributes are initialized to default because they have not been passed
537
+ passed_keys = set(init_dict.keys())
538
+ if len(expected_keys - passed_keys) > 0:
539
+ logger.info(
540
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
541
+ )
542
+
543
+ # 6. Define unused keyword arguments
544
+ unused_kwargs = {**config_dict, **kwargs}
545
+
546
+ # 7. Define "hidden" config parameters that were saved for compatible classes
547
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
548
+
549
+ return init_dict, unused_kwargs, hidden_config_dict
550
+
551
+ @classmethod
552
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
553
+ with open(json_file, "r", encoding="utf-8") as reader:
554
+ text = reader.read()
555
+ return json.loads(text)
556
+
557
+ def __repr__(self):
558
+ return f"{self.__class__.__name__} {self.to_json_string()}"
559
+
560
+ @property
561
+ def config(self) -> Dict[str, Any]:
562
+ """
563
+ Returns the config of the class as a frozen dictionary
564
+
565
+ Returns:
566
+ `Dict[str, Any]`: Config of the class.
567
+ """
568
+ return self._internal_dict
569
+
570
+ def to_json_string(self) -> str:
571
+ """
572
+ Serializes the configuration instance to a JSON string.
573
+
574
+ Returns:
575
+ `str`:
576
+ String containing all the attributes that make up the configuration instance in JSON format.
577
+ """
578
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
579
+ config_dict["_class_name"] = self.__class__.__name__
580
+ config_dict["_diffusers_version"] = __version__
581
+
582
+ def to_json_saveable(value):
583
+ if isinstance(value, np.ndarray):
584
+ value = value.tolist()
585
+ elif isinstance(value, Path):
586
+ value = value.as_posix()
587
+ return value
588
+
589
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
590
+ # Don't save "_ignore_files" or "_use_default_values"
591
+ config_dict.pop("_ignore_files", None)
592
+ config_dict.pop("_use_default_values", None)
593
+
594
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
595
+
596
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
597
+ """
598
+ Save the configuration instance's parameters to a JSON file.
599
+
600
+ Args:
601
+ json_file_path (`str` or `os.PathLike`):
602
+ Path to the JSON file to save a configuration instance's parameters.
603
+ """
604
+ with open(json_file_path, "w", encoding="utf-8") as writer:
605
+ writer.write(self.to_json_string())
606
+
607
+
608
+ def register_to_config(init):
609
+ r"""
610
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
611
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
612
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
613
+
614
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
615
+ """
616
+
617
+ @functools.wraps(init)
618
+ def inner_init(self, *args, **kwargs):
619
+ # Ignore private kwargs in the init.
620
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
621
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
622
+ if not isinstance(self, ConfigMixin):
623
+ raise RuntimeError(
624
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
625
+ "not inherit from `ConfigMixin`."
626
+ )
627
+
628
+ ignore = getattr(self, "ignore_for_config", [])
629
+ # Get positional arguments aligned with kwargs
630
+ new_kwargs = {}
631
+ signature = inspect.signature(init)
632
+ parameters = {
633
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
634
+ }
635
+ for arg, name in zip(args, parameters.keys()):
636
+ new_kwargs[name] = arg
637
+
638
+ # Then add all kwargs
639
+ new_kwargs.update(
640
+ {
641
+ k: init_kwargs.get(k, default)
642
+ for k, default in parameters.items()
643
+ if k not in ignore and k not in new_kwargs
644
+ }
645
+ )
646
+
647
+ # Take note of the parameters that were not present in the loaded config
648
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
649
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
650
+
651
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
652
+ getattr(self, "register_to_config")(**new_kwargs)
653
+ init(self, *args, **init_kwargs)
654
+
655
+ return inner_init
656
+
657
+
658
+ def flax_register_to_config(cls):
659
+ original_init = cls.__init__
660
+
661
+ @functools.wraps(original_init)
662
+ def init(self, *args, **kwargs):
663
+ if not isinstance(self, ConfigMixin):
664
+ raise RuntimeError(
665
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
666
+ "not inherit from `ConfigMixin`."
667
+ )
668
+
669
+ # Ignore private kwargs in the init. Retrieve all passed attributes
670
+ init_kwargs = dict(kwargs.items())
671
+
672
+ # Retrieve default values
673
+ fields = dataclasses.fields(self)
674
+ default_kwargs = {}
675
+ for field in fields:
676
+ # ignore flax specific attributes
677
+ if field.name in self._flax_internal_args:
678
+ continue
679
+ if type(field.default) == dataclasses._MISSING_TYPE:
680
+ default_kwargs[field.name] = None
681
+ else:
682
+ default_kwargs[field.name] = getattr(self, field.name)
683
+
684
+ # Make sure init_kwargs override default kwargs
685
+ new_kwargs = {**default_kwargs, **init_kwargs}
686
+ # dtype should be part of `init_kwargs`, but not `new_kwargs`
687
+ if "dtype" in new_kwargs:
688
+ new_kwargs.pop("dtype")
689
+
690
+ # Get positional arguments aligned with kwargs
691
+ for i, arg in enumerate(args):
692
+ name = fields[i].name
693
+ new_kwargs[name] = arg
694
+
695
+ # Take note of the parameters that were not present in the loaded config
696
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
697
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
698
+
699
+ getattr(self, "register_to_config")(**new_kwargs)
700
+ original_init(self, *args, **kwargs)
701
+
702
+ cls.__init__ = init
703
+ return cls
704
+
705
+
706
+ class LegacyConfigMixin(ConfigMixin):
707
+ r"""
708
+ A subclass of `ConfigMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
709
+ pipeline-specific classes (like `DiTTransformer2DModel`).
710
+ """
711
+
712
+ @classmethod
713
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
714
+ # To prevent dependency import problem.
715
+ from .models.model_loading_utils import _fetch_remapped_cls_from_config
716
+
717
+ # resolve remapping
718
+ remapped_class = _fetch_remapped_cls_from_config(config, cls)
719
+
720
+ return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
diffusers3/dependency_versions_check.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .dependency_versions_table import deps
16
+ from .utils.versions import require_version, require_version_core
17
+
18
+
19
+ # define which module versions we always want to check at run time
20
+ # (usually the ones defined in `install_requires` in setup.py)
21
+ #
22
+ # order specific notes:
23
+ # - tqdm must be checked before tokenizers
24
+
25
+ pkgs_to_check_at_runtime = "python requests filelock numpy".split()
26
+ for pkg in pkgs_to_check_at_runtime:
27
+ if pkg in deps:
28
+ require_version_core(deps[pkg])
29
+ else:
30
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
31
+
32
+
33
+ def dep_version_check(pkg, hint=None):
34
+ require_version(deps[pkg], hint)
diffusers3/dependency_versions_table.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update`
4
+ deps = {
5
+ "Pillow": "Pillow",
6
+ "accelerate": "accelerate>=0.31.0",
7
+ "compel": "compel==0.1.8",
8
+ "datasets": "datasets",
9
+ "filelock": "filelock",
10
+ "flax": "flax>=0.4.1",
11
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
12
+ "huggingface-hub": "huggingface-hub>=0.23.2",
13
+ "requests-mock": "requests-mock==1.10.0",
14
+ "importlib_metadata": "importlib_metadata",
15
+ "invisible-watermark": "invisible-watermark>=0.2.0",
16
+ "isort": "isort>=5.5.4",
17
+ "jax": "jax>=0.4.1",
18
+ "jaxlib": "jaxlib>=0.4.1",
19
+ "Jinja2": "Jinja2",
20
+ "k-diffusion": "k-diffusion>=0.0.12",
21
+ "torchsde": "torchsde",
22
+ "note_seq": "note_seq",
23
+ "librosa": "librosa",
24
+ "numpy": "numpy",
25
+ "parameterized": "parameterized",
26
+ "peft": "peft>=0.6.0",
27
+ "protobuf": "protobuf>=3.20.3,<4",
28
+ "pytest": "pytest",
29
+ "pytest-timeout": "pytest-timeout",
30
+ "pytest-xdist": "pytest-xdist",
31
+ "python": "python>=3.8.0",
32
+ "ruff": "ruff==0.1.5",
33
+ "safetensors": "safetensors>=0.3.1",
34
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
35
+ "GitPython": "GitPython<3.1.19",
36
+ "scipy": "scipy",
37
+ "onnx": "onnx",
38
+ "regex": "regex!=2019.12.17",
39
+ "requests": "requests",
40
+ "tensorboard": "tensorboard",
41
+ "torch": "torch>=1.4",
42
+ "torchvision": "torchvision",
43
+ "transformers": "transformers>=4.41.2",
44
+ "urllib3": "urllib3<=2.0.0",
45
+ "black": "black",
46
+ }
diffusers3/experimental/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # 🧨 Diffusers Experimental
2
+
3
+ We are adding experimental code to support novel applications and usages of the Diffusers library.
4
+ Currently, the following experiments are supported:
5
+ * Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.
diffusers3/experimental/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .rl import ValueGuidedRLPipeline
diffusers3/experimental/rl/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .value_guided_sampling import ValueGuidedRLPipeline
diffusers3/experimental/rl/value_guided_sampling.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ import tqdm
18
+
19
+ from ...models.unets.unet_1d import UNet1DModel
20
+ from ...pipelines import DiffusionPipeline
21
+ from ...utils.dummy_pt_objects import DDPMScheduler
22
+ from ...utils.torch_utils import randn_tensor
23
+
24
+
25
+ class ValueGuidedRLPipeline(DiffusionPipeline):
26
+ r"""
27
+ Pipeline for value-guided sampling from a diffusion model trained to predict sequences of states.
28
+
29
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
30
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
31
+
32
+ Parameters:
33
+ value_function ([`UNet1DModel`]):
34
+ A specialized UNet for fine-tuning trajectories base on reward.
35
+ unet ([`UNet1DModel`]):
36
+ UNet architecture to denoise the encoded trajectories.
37
+ scheduler ([`SchedulerMixin`]):
38
+ A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
39
+ application is [`DDPMScheduler`].
40
+ env ():
41
+ An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ value_function: UNet1DModel,
47
+ unet: UNet1DModel,
48
+ scheduler: DDPMScheduler,
49
+ env,
50
+ ):
51
+ super().__init__()
52
+
53
+ self.register_modules(value_function=value_function, unet=unet, scheduler=scheduler, env=env)
54
+
55
+ self.data = env.get_dataset()
56
+ self.means = {}
57
+ for key in self.data.keys():
58
+ try:
59
+ self.means[key] = self.data[key].mean()
60
+ except: # noqa: E722
61
+ pass
62
+ self.stds = {}
63
+ for key in self.data.keys():
64
+ try:
65
+ self.stds[key] = self.data[key].std()
66
+ except: # noqa: E722
67
+ pass
68
+ self.state_dim = env.observation_space.shape[0]
69
+ self.action_dim = env.action_space.shape[0]
70
+
71
+ def normalize(self, x_in, key):
72
+ return (x_in - self.means[key]) / self.stds[key]
73
+
74
+ def de_normalize(self, x_in, key):
75
+ return x_in * self.stds[key] + self.means[key]
76
+
77
+ def to_torch(self, x_in):
78
+ if isinstance(x_in, dict):
79
+ return {k: self.to_torch(v) for k, v in x_in.items()}
80
+ elif torch.is_tensor(x_in):
81
+ return x_in.to(self.unet.device)
82
+ return torch.tensor(x_in, device=self.unet.device)
83
+
84
+ def reset_x0(self, x_in, cond, act_dim):
85
+ for key, val in cond.items():
86
+ x_in[:, key, act_dim:] = val.clone()
87
+ return x_in
88
+
89
+ def run_diffusion(self, x, conditions, n_guide_steps, scale):
90
+ batch_size = x.shape[0]
91
+ y = None
92
+ for i in tqdm.tqdm(self.scheduler.timesteps):
93
+ # create batch of timesteps to pass into model
94
+ timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
95
+ for _ in range(n_guide_steps):
96
+ with torch.enable_grad():
97
+ x.requires_grad_()
98
+
99
+ # permute to match dimension for pre-trained models
100
+ y = self.value_function(x.permute(0, 2, 1), timesteps).sample
101
+ grad = torch.autograd.grad([y.sum()], [x])[0]
102
+
103
+ posterior_variance = self.scheduler._get_variance(i)
104
+ model_std = torch.exp(0.5 * posterior_variance)
105
+ grad = model_std * grad
106
+
107
+ grad[timesteps < 2] = 0
108
+ x = x.detach()
109
+ x = x + scale * grad
110
+ x = self.reset_x0(x, conditions, self.action_dim)
111
+
112
+ prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
113
+
114
+ # TODO: verify deprecation of this kwarg
115
+ x = self.scheduler.step(prev_x, i, x)["prev_sample"]
116
+
117
+ # apply conditions to the trajectory (set the initial state)
118
+ x = self.reset_x0(x, conditions, self.action_dim)
119
+ x = self.to_torch(x)
120
+ return x, y
121
+
122
+ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
123
+ # normalize the observations and create batch dimension
124
+ obs = self.normalize(obs, "observations")
125
+ obs = obs[None].repeat(batch_size, axis=0)
126
+
127
+ conditions = {0: self.to_torch(obs)}
128
+ shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
129
+
130
+ # generate initial noise and apply our conditions (to make the trajectories start at current state)
131
+ x1 = randn_tensor(shape, device=self.unet.device)
132
+ x = self.reset_x0(x1, conditions, self.action_dim)
133
+ x = self.to_torch(x)
134
+
135
+ # run the diffusion process
136
+ x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
137
+
138
+ # sort output trajectories by value
139
+ sorted_idx = y.argsort(0, descending=True).squeeze()
140
+ sorted_values = x[sorted_idx]
141
+ actions = sorted_values[:, :, : self.action_dim]
142
+ actions = actions.detach().cpu().numpy()
143
+ denorm_actions = self.de_normalize(actions, key="actions")
144
+
145
+ # select the action with the highest value
146
+ if y is not None:
147
+ selected_index = 0
148
+ else:
149
+ # if we didn't run value guiding, select a random action
150
+ selected_index = np.random.randint(0, batch_size)
151
+
152
+ denorm_actions = denorm_actions[selected_index, 0]
153
+ return denorm_actions
diffusers3/image_processor.py ADDED
@@ -0,0 +1,1103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import warnings
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from PIL import Image, ImageFilter, ImageOps
24
+
25
+ from .configuration_utils import ConfigMixin, register_to_config
26
+ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
27
+
28
+
29
+ PipelineImageInput = Union[
30
+ PIL.Image.Image,
31
+ np.ndarray,
32
+ torch.Tensor,
33
+ List[PIL.Image.Image],
34
+ List[np.ndarray],
35
+ List[torch.Tensor],
36
+ ]
37
+
38
+ PipelineDepthInput = PipelineImageInput
39
+
40
+
41
+ def is_valid_image(image):
42
+ return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
43
+
44
+
45
+ def is_valid_image_imagelist(images):
46
+ # check if the image input is one of the supported formats for image and image list:
47
+ # it can be either one of below 3
48
+ # (1) a 4d pytorch tensor or numpy array,
49
+ # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
50
+ # (3) a list of valid image
51
+ if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
52
+ return True
53
+ elif is_valid_image(images):
54
+ return True
55
+ elif isinstance(images, list):
56
+ return all(is_valid_image(image) for image in images)
57
+ return False
58
+
59
+
60
+ class VaeImageProcessor(ConfigMixin):
61
+ """
62
+ Image processor for VAE.
63
+
64
+ Args:
65
+ do_resize (`bool`, *optional*, defaults to `True`):
66
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
67
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
68
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
69
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
70
+ resample (`str`, *optional*, defaults to `lanczos`):
71
+ Resampling filter to use when resizing the image.
72
+ do_normalize (`bool`, *optional*, defaults to `True`):
73
+ Whether to normalize the image to [-1,1].
74
+ do_binarize (`bool`, *optional*, defaults to `False`):
75
+ Whether to binarize the image to 0/1.
76
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
77
+ Whether to convert the images to RGB format.
78
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
79
+ Whether to convert the images to grayscale format.
80
+ """
81
+
82
+ config_name = CONFIG_NAME
83
+
84
+ @register_to_config
85
+ def __init__(
86
+ self,
87
+ do_resize: bool = True,
88
+ vae_scale_factor: int = 8,
89
+ vae_latent_channels: int = 4,
90
+ resample: str = "lanczos",
91
+ do_normalize: bool = True,
92
+ do_binarize: bool = False,
93
+ do_convert_rgb: bool = False,
94
+ do_convert_grayscale: bool = False,
95
+ ):
96
+ super().__init__()
97
+ if do_convert_rgb and do_convert_grayscale:
98
+ raise ValueError(
99
+ "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
100
+ " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
101
+ " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
102
+ )
103
+
104
+ @staticmethod
105
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
106
+ """
107
+ Convert a numpy image or a batch of images to a PIL image.
108
+ """
109
+ if images.ndim == 3:
110
+ images = images[None, ...]
111
+ images = (images * 255).round().astype("uint8")
112
+ if images.shape[-1] == 1:
113
+ # special case for grayscale (single channel) images
114
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
115
+ else:
116
+ pil_images = [Image.fromarray(image) for image in images]
117
+
118
+ return pil_images
119
+
120
+ @staticmethod
121
+ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
122
+ """
123
+ Convert a PIL image or a list of PIL images to NumPy arrays.
124
+ """
125
+ if not isinstance(images, list):
126
+ images = [images]
127
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
128
+ images = np.stack(images, axis=0)
129
+
130
+ return images
131
+
132
+ @staticmethod
133
+ def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
134
+ """
135
+ Convert a NumPy image to a PyTorch tensor.
136
+ """
137
+ if images.ndim == 3:
138
+ images = images[..., None]
139
+
140
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
141
+ return images
142
+
143
+ @staticmethod
144
+ def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
145
+ """
146
+ Convert a PyTorch tensor to a NumPy image.
147
+ """
148
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
149
+ return images
150
+
151
+ @staticmethod
152
+ def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
153
+ """
154
+ Normalize an image array to [-1,1].
155
+ """
156
+ return 2.0 * images - 1.0
157
+
158
+ @staticmethod
159
+ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
160
+ """
161
+ Denormalize an image array to [0,1].
162
+ """
163
+ return (images / 2 + 0.5).clamp(0, 1)
164
+
165
+ @staticmethod
166
+ def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
167
+ """
168
+ Converts a PIL image to RGB format.
169
+ """
170
+ image = image.convert("RGB")
171
+
172
+ return image
173
+
174
+ @staticmethod
175
+ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
176
+ """
177
+ Converts a PIL image to grayscale format.
178
+ """
179
+ image = image.convert("L")
180
+
181
+ return image
182
+
183
+ @staticmethod
184
+ def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
185
+ """
186
+ Applies Gaussian blur to an image.
187
+ """
188
+ image = image.filter(ImageFilter.GaussianBlur(blur_factor))
189
+
190
+ return image
191
+
192
+ @staticmethod
193
+ def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
194
+ """
195
+ Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
196
+ ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
197
+ processing are 512x512, the region will be expanded to 128x128.
198
+
199
+ Args:
200
+ mask_image (PIL.Image.Image): Mask image.
201
+ width (int): Width of the image to be processed.
202
+ height (int): Height of the image to be processed.
203
+ pad (int, optional): Padding to be added to the crop region. Defaults to 0.
204
+
205
+ Returns:
206
+ tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
207
+ matches the original aspect ratio.
208
+ """
209
+
210
+ mask_image = mask_image.convert("L")
211
+ mask = np.array(mask_image)
212
+
213
+ # 1. find a rectangular region that contains all masked ares in an image
214
+ h, w = mask.shape
215
+ crop_left = 0
216
+ for i in range(w):
217
+ if not (mask[:, i] == 0).all():
218
+ break
219
+ crop_left += 1
220
+
221
+ crop_right = 0
222
+ for i in reversed(range(w)):
223
+ if not (mask[:, i] == 0).all():
224
+ break
225
+ crop_right += 1
226
+
227
+ crop_top = 0
228
+ for i in range(h):
229
+ if not (mask[i] == 0).all():
230
+ break
231
+ crop_top += 1
232
+
233
+ crop_bottom = 0
234
+ for i in reversed(range(h)):
235
+ if not (mask[i] == 0).all():
236
+ break
237
+ crop_bottom += 1
238
+
239
+ # 2. add padding to the crop region
240
+ x1, y1, x2, y2 = (
241
+ int(max(crop_left - pad, 0)),
242
+ int(max(crop_top - pad, 0)),
243
+ int(min(w - crop_right + pad, w)),
244
+ int(min(h - crop_bottom + pad, h)),
245
+ )
246
+
247
+ # 3. expands crop region to match the aspect ratio of the image to be processed
248
+ ratio_crop_region = (x2 - x1) / (y2 - y1)
249
+ ratio_processing = width / height
250
+
251
+ if ratio_crop_region > ratio_processing:
252
+ desired_height = (x2 - x1) / ratio_processing
253
+ desired_height_diff = int(desired_height - (y2 - y1))
254
+ y1 -= desired_height_diff // 2
255
+ y2 += desired_height_diff - desired_height_diff // 2
256
+ if y2 >= mask_image.height:
257
+ diff = y2 - mask_image.height
258
+ y2 -= diff
259
+ y1 -= diff
260
+ if y1 < 0:
261
+ y2 -= y1
262
+ y1 -= y1
263
+ if y2 >= mask_image.height:
264
+ y2 = mask_image.height
265
+ else:
266
+ desired_width = (y2 - y1) * ratio_processing
267
+ desired_width_diff = int(desired_width - (x2 - x1))
268
+ x1 -= desired_width_diff // 2
269
+ x2 += desired_width_diff - desired_width_diff // 2
270
+ if x2 >= mask_image.width:
271
+ diff = x2 - mask_image.width
272
+ x2 -= diff
273
+ x1 -= diff
274
+ if x1 < 0:
275
+ x2 -= x1
276
+ x1 -= x1
277
+ if x2 >= mask_image.width:
278
+ x2 = mask_image.width
279
+
280
+ return x1, y1, x2, y2
281
+
282
+ def _resize_and_fill(
283
+ self,
284
+ image: PIL.Image.Image,
285
+ width: int,
286
+ height: int,
287
+ ) -> PIL.Image.Image:
288
+ """
289
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
290
+ the image within the dimensions, filling empty with data from image.
291
+
292
+ Args:
293
+ image: The image to resize.
294
+ width: The width to resize the image to.
295
+ height: The height to resize the image to.
296
+ """
297
+
298
+ ratio = width / height
299
+ src_ratio = image.width / image.height
300
+
301
+ src_w = width if ratio < src_ratio else image.width * height // image.height
302
+ src_h = height if ratio >= src_ratio else image.height * width // image.width
303
+
304
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
305
+ res = Image.new("RGB", (width, height))
306
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
307
+
308
+ if ratio < src_ratio:
309
+ fill_height = height // 2 - src_h // 2
310
+ if fill_height > 0:
311
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
312
+ res.paste(
313
+ resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
314
+ box=(0, fill_height + src_h),
315
+ )
316
+ elif ratio > src_ratio:
317
+ fill_width = width // 2 - src_w // 2
318
+ if fill_width > 0:
319
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
320
+ res.paste(
321
+ resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
322
+ box=(fill_width + src_w, 0),
323
+ )
324
+
325
+ return res
326
+
327
+ def _resize_and_crop(
328
+ self,
329
+ image: PIL.Image.Image,
330
+ width: int,
331
+ height: int,
332
+ ) -> PIL.Image.Image:
333
+ """
334
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
335
+ the image within the dimensions, cropping the excess.
336
+
337
+ Args:
338
+ image: The image to resize.
339
+ width: The width to resize the image to.
340
+ height: The height to resize the image to.
341
+ """
342
+ ratio = width / height
343
+ src_ratio = image.width / image.height
344
+
345
+ src_w = width if ratio > src_ratio else image.width * height // image.height
346
+ src_h = height if ratio <= src_ratio else image.height * width // image.width
347
+
348
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
349
+ res = Image.new("RGB", (width, height))
350
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
351
+ return res
352
+
353
+ def resize(
354
+ self,
355
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
356
+ height: int,
357
+ width: int,
358
+ resize_mode: str = "default", # "default", "fill", "crop"
359
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
360
+ """
361
+ Resize image.
362
+
363
+ Args:
364
+ image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
365
+ The image input, can be a PIL image, numpy array or pytorch tensor.
366
+ height (`int`):
367
+ The height to resize to.
368
+ width (`int`):
369
+ The width to resize to.
370
+ resize_mode (`str`, *optional*, defaults to `default`):
371
+ The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
372
+ within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
373
+ will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
374
+ then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
375
+ the image to fit within the specified width and height, maintaining the aspect ratio, and then center
376
+ the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
377
+ supported for PIL image input.
378
+
379
+ Returns:
380
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
381
+ The resized image.
382
+ """
383
+ if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
384
+ raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
385
+ if isinstance(image, PIL.Image.Image):
386
+ if resize_mode == "default":
387
+ image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
388
+ elif resize_mode == "fill":
389
+ image = self._resize_and_fill(image, width, height)
390
+ elif resize_mode == "crop":
391
+ image = self._resize_and_crop(image, width, height)
392
+ else:
393
+ raise ValueError(f"resize_mode {resize_mode} is not supported")
394
+
395
+ elif isinstance(image, torch.Tensor):
396
+ image = torch.nn.functional.interpolate(
397
+ image,
398
+ size=(height, width),
399
+ )
400
+ elif isinstance(image, np.ndarray):
401
+ image = self.numpy_to_pt(image)
402
+ image = torch.nn.functional.interpolate(
403
+ image,
404
+ size=(height, width),
405
+ )
406
+ image = self.pt_to_numpy(image)
407
+ return image
408
+
409
+ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
410
+ """
411
+ Create a mask.
412
+
413
+ Args:
414
+ image (`PIL.Image.Image`):
415
+ The image input, should be a PIL image.
416
+
417
+ Returns:
418
+ `PIL.Image.Image`:
419
+ The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
420
+ """
421
+ image[image < 0.5] = 0
422
+ image[image >= 0.5] = 1
423
+
424
+ return image
425
+
426
+ def get_default_height_width(
427
+ self,
428
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
429
+ height: Optional[int] = None,
430
+ width: Optional[int] = None,
431
+ ) -> Tuple[int, int]:
432
+ """
433
+ This function return the height and width that are downscaled to the next integer multiple of
434
+ `vae_scale_factor`.
435
+
436
+ Args:
437
+ image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
438
+ The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
439
+ shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
440
+ have shape `[batch, channel, height, width]`.
441
+ height (`int`, *optional*, defaults to `None`):
442
+ The height in preprocessed image. If `None`, will use the height of `image` input.
443
+ width (`int`, *optional*`, defaults to `None`):
444
+ The width in preprocessed. If `None`, will use the width of the `image` input.
445
+ """
446
+
447
+ if height is None:
448
+ if isinstance(image, PIL.Image.Image):
449
+ height = image.height
450
+ elif isinstance(image, torch.Tensor):
451
+ height = image.shape[2]
452
+ else:
453
+ height = image.shape[1]
454
+
455
+ if width is None:
456
+ if isinstance(image, PIL.Image.Image):
457
+ width = image.width
458
+ elif isinstance(image, torch.Tensor):
459
+ width = image.shape[3]
460
+ else:
461
+ width = image.shape[2]
462
+
463
+ width, height = (
464
+ x - x % self.config.vae_scale_factor for x in (width, height)
465
+ ) # resize to integer multiple of vae_scale_factor
466
+
467
+ return height, width
468
+
469
+ def preprocess(
470
+ self,
471
+ image: PipelineImageInput,
472
+ height: Optional[int] = None,
473
+ width: Optional[int] = None,
474
+ resize_mode: str = "default", # "default", "fill", "crop"
475
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
476
+ ) -> torch.Tensor:
477
+ """
478
+ Preprocess the image input.
479
+
480
+ Args:
481
+ image (`pipeline_image_input`):
482
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
483
+ supported formats.
484
+ height (`int`, *optional*, defaults to `None`):
485
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
486
+ height.
487
+ width (`int`, *optional*`, defaults to `None`):
488
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
489
+ resize_mode (`str`, *optional*, defaults to `default`):
490
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
491
+ the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
492
+ resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
493
+ center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
494
+ image to fit within the specified width and height, maintaining the aspect ratio, and then center the
495
+ image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
496
+ supported for PIL image input.
497
+ crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
498
+ The crop coordinates for each image in the batch. If `None`, will not crop the image.
499
+ """
500
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
501
+
502
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
503
+ if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
504
+ if isinstance(image, torch.Tensor):
505
+ # if image is a pytorch tensor could have 2 possible shapes:
506
+ # 1. batch x height x width: we should insert the channel dimension at position 1
507
+ # 2. channel x height x width: we should insert batch dimension at position 0,
508
+ # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
509
+ # for simplicity, we insert a dimension of size 1 at position 1 for both cases
510
+ image = image.unsqueeze(1)
511
+ else:
512
+ # if it is a numpy array, it could have 2 possible shapes:
513
+ # 1. batch x height x width: insert channel dimension on last position
514
+ # 2. height x width x channel: insert batch dimension on first position
515
+ if image.shape[-1] == 1:
516
+ image = np.expand_dims(image, axis=0)
517
+ else:
518
+ image = np.expand_dims(image, axis=-1)
519
+
520
+ if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
521
+ warnings.warn(
522
+ "Passing `image` as a list of 4d np.ndarray is deprecated."
523
+ "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
524
+ FutureWarning,
525
+ )
526
+ image = np.concatenate(image, axis=0)
527
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
528
+ warnings.warn(
529
+ "Passing `image` as a list of 4d torch.Tensor is deprecated."
530
+ "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
531
+ FutureWarning,
532
+ )
533
+ image = torch.cat(image, axis=0)
534
+
535
+ if not is_valid_image_imagelist(image):
536
+ raise ValueError(
537
+ f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
538
+ )
539
+ if not isinstance(image, list):
540
+ image = [image]
541
+
542
+ if isinstance(image[0], PIL.Image.Image):
543
+ if crops_coords is not None:
544
+ image = [i.crop(crops_coords) for i in image]
545
+ if self.config.do_resize:
546
+ height, width = self.get_default_height_width(image[0], height, width)
547
+ image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
548
+ if self.config.do_convert_rgb:
549
+ image = [self.convert_to_rgb(i) for i in image]
550
+ elif self.config.do_convert_grayscale:
551
+ image = [self.convert_to_grayscale(i) for i in image]
552
+ image = self.pil_to_numpy(image) # to np
553
+ image = self.numpy_to_pt(image) # to pt
554
+
555
+ elif isinstance(image[0], np.ndarray):
556
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
557
+
558
+ image = self.numpy_to_pt(image)
559
+
560
+ height, width = self.get_default_height_width(image, height, width)
561
+ if self.config.do_resize:
562
+ image = self.resize(image, height, width)
563
+
564
+ elif isinstance(image[0], torch.Tensor):
565
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
566
+
567
+ if self.config.do_convert_grayscale and image.ndim == 3:
568
+ image = image.unsqueeze(1)
569
+
570
+ channel = image.shape[1]
571
+ # don't need any preprocess if the image is latents
572
+ if channel == self.config.vae_latent_channels:
573
+ return image
574
+
575
+ height, width = self.get_default_height_width(image, height, width)
576
+ if self.config.do_resize:
577
+ image = self.resize(image, height, width)
578
+
579
+ # expected range [0,1], normalize to [-1,1]
580
+ do_normalize = self.config.do_normalize
581
+ if do_normalize and image.min() < 0:
582
+ warnings.warn(
583
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
584
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
585
+ FutureWarning,
586
+ )
587
+ do_normalize = False
588
+ if do_normalize:
589
+ image = self.normalize(image)
590
+
591
+ if self.config.do_binarize:
592
+ image = self.binarize(image)
593
+
594
+ return image
595
+
596
+ def postprocess(
597
+ self,
598
+ image: torch.Tensor,
599
+ output_type: str = "pil",
600
+ do_denormalize: Optional[List[bool]] = None,
601
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
602
+ """
603
+ Postprocess the image output from tensor to `output_type`.
604
+
605
+ Args:
606
+ image (`torch.Tensor`):
607
+ The image input, should be a pytorch tensor with shape `B x C x H x W`.
608
+ output_type (`str`, *optional*, defaults to `pil`):
609
+ The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
610
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
611
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
612
+ `VaeImageProcessor` config.
613
+
614
+ Returns:
615
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
616
+ The postprocessed image.
617
+ """
618
+ if not isinstance(image, torch.Tensor):
619
+ raise ValueError(
620
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
621
+ )
622
+ if output_type not in ["latent", "pt", "np", "pil"]:
623
+ deprecation_message = (
624
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
625
+ "`pil`, `np`, `pt`, `latent`"
626
+ )
627
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
628
+ output_type = "np"
629
+
630
+ if output_type == "latent":
631
+ return image
632
+
633
+ if do_denormalize is None:
634
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
635
+
636
+ image = torch.stack(
637
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
638
+ )
639
+
640
+ if output_type == "pt":
641
+ return image
642
+
643
+ image = self.pt_to_numpy(image)
644
+
645
+ if output_type == "np":
646
+ return image
647
+
648
+ if output_type == "pil":
649
+ return self.numpy_to_pil(image)
650
+
651
+ def apply_overlay(
652
+ self,
653
+ mask: PIL.Image.Image,
654
+ init_image: PIL.Image.Image,
655
+ image: PIL.Image.Image,
656
+ crop_coords: Optional[Tuple[int, int, int, int]] = None,
657
+ ) -> PIL.Image.Image:
658
+ """
659
+ overlay the inpaint output to the original image
660
+ """
661
+
662
+ width, height = image.width, image.height
663
+
664
+ init_image = self.resize(init_image, width=width, height=height)
665
+ mask = self.resize(mask, width=width, height=height)
666
+
667
+ init_image_masked = PIL.Image.new("RGBa", (width, height))
668
+ init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
669
+ init_image_masked = init_image_masked.convert("RGBA")
670
+
671
+ if crop_coords is not None:
672
+ x, y, x2, y2 = crop_coords
673
+ w = x2 - x
674
+ h = y2 - y
675
+ base_image = PIL.Image.new("RGBA", (width, height))
676
+ image = self.resize(image, height=h, width=w, resize_mode="crop")
677
+ base_image.paste(image, (x, y))
678
+ image = base_image.convert("RGB")
679
+
680
+ image = image.convert("RGBA")
681
+ image.alpha_composite(init_image_masked)
682
+ image = image.convert("RGB")
683
+
684
+ return image
685
+
686
+
687
+ class VaeImageProcessorLDM3D(VaeImageProcessor):
688
+ """
689
+ Image processor for VAE LDM3D.
690
+
691
+ Args:
692
+ do_resize (`bool`, *optional*, defaults to `True`):
693
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
694
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
695
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
696
+ resample (`str`, *optional*, defaults to `lanczos`):
697
+ Resampling filter to use when resizing the image.
698
+ do_normalize (`bool`, *optional*, defaults to `True`):
699
+ Whether to normalize the image to [-1,1].
700
+ """
701
+
702
+ config_name = CONFIG_NAME
703
+
704
+ @register_to_config
705
+ def __init__(
706
+ self,
707
+ do_resize: bool = True,
708
+ vae_scale_factor: int = 8,
709
+ resample: str = "lanczos",
710
+ do_normalize: bool = True,
711
+ ):
712
+ super().__init__()
713
+
714
+ @staticmethod
715
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
716
+ """
717
+ Convert a NumPy image or a batch of images to a PIL image.
718
+ """
719
+ if images.ndim == 3:
720
+ images = images[None, ...]
721
+ images = (images * 255).round().astype("uint8")
722
+ if images.shape[-1] == 1:
723
+ # special case for grayscale (single channel) images
724
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
725
+ else:
726
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
727
+
728
+ return pil_images
729
+
730
+ @staticmethod
731
+ def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
732
+ """
733
+ Convert a PIL image or a list of PIL images to NumPy arrays.
734
+ """
735
+ if not isinstance(images, list):
736
+ images = [images]
737
+
738
+ images = [np.array(image).astype(np.float32) / (2**16 - 1) for image in images]
739
+ images = np.stack(images, axis=0)
740
+ return images
741
+
742
+ @staticmethod
743
+ def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
744
+ """
745
+ Args:
746
+ image: RGB-like depth image
747
+
748
+ Returns: depth map
749
+
750
+ """
751
+ return image[:, :, 1] * 2**8 + image[:, :, 2]
752
+
753
+ def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
754
+ """
755
+ Convert a NumPy depth image or a batch of images to a PIL image.
756
+ """
757
+ if images.ndim == 3:
758
+ images = images[None, ...]
759
+ images_depth = images[:, :, :, 3:]
760
+ if images.shape[-1] == 6:
761
+ images_depth = (images_depth * 255).round().astype("uint8")
762
+ pil_images = [
763
+ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
764
+ ]
765
+ elif images.shape[-1] == 4:
766
+ images_depth = (images_depth * 65535.0).astype(np.uint16)
767
+ pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
768
+ else:
769
+ raise Exception("Not supported")
770
+
771
+ return pil_images
772
+
773
+ def postprocess(
774
+ self,
775
+ image: torch.Tensor,
776
+ output_type: str = "pil",
777
+ do_denormalize: Optional[List[bool]] = None,
778
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
779
+ """
780
+ Postprocess the image output from tensor to `output_type`.
781
+
782
+ Args:
783
+ image (`torch.Tensor`):
784
+ The image input, should be a pytorch tensor with shape `B x C x H x W`.
785
+ output_type (`str`, *optional*, defaults to `pil`):
786
+ The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
787
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
788
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
789
+ `VaeImageProcessor` config.
790
+
791
+ Returns:
792
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
793
+ The postprocessed image.
794
+ """
795
+ if not isinstance(image, torch.Tensor):
796
+ raise ValueError(
797
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
798
+ )
799
+ if output_type not in ["latent", "pt", "np", "pil"]:
800
+ deprecation_message = (
801
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
802
+ "`pil`, `np`, `pt`, `latent`"
803
+ )
804
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
805
+ output_type = "np"
806
+
807
+ if do_denormalize is None:
808
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
809
+
810
+ image = torch.stack(
811
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
812
+ )
813
+
814
+ image = self.pt_to_numpy(image)
815
+
816
+ if output_type == "np":
817
+ if image.shape[-1] == 6:
818
+ image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
819
+ else:
820
+ image_depth = image[:, :, :, 3:]
821
+ return image[:, :, :, :3], image_depth
822
+
823
+ if output_type == "pil":
824
+ return self.numpy_to_pil(image), self.numpy_to_depth(image)
825
+ else:
826
+ raise Exception(f"This type {output_type} is not supported")
827
+
828
+ def preprocess(
829
+ self,
830
+ rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
831
+ depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
832
+ height: Optional[int] = None,
833
+ width: Optional[int] = None,
834
+ target_res: Optional[int] = None,
835
+ ) -> torch.Tensor:
836
+ """
837
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
838
+ """
839
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
840
+
841
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
842
+ if self.config.do_convert_grayscale and isinstance(rgb, (torch.Tensor, np.ndarray)) and rgb.ndim == 3:
843
+ raise Exception("This is not yet supported")
844
+
845
+ if isinstance(rgb, supported_formats):
846
+ rgb = [rgb]
847
+ depth = [depth]
848
+ elif not (isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)):
849
+ raise ValueError(
850
+ f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}"
851
+ )
852
+
853
+ if isinstance(rgb[0], PIL.Image.Image):
854
+ if self.config.do_convert_rgb:
855
+ raise Exception("This is not yet supported")
856
+ # rgb = [self.convert_to_rgb(i) for i in rgb]
857
+ # depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth
858
+ if self.config.do_resize or target_res:
859
+ height, width = self.get_default_height_width(rgb[0], height, width) if not target_res else target_res
860
+ rgb = [self.resize(i, height, width) for i in rgb]
861
+ depth = [self.resize(i, height, width) for i in depth]
862
+ rgb = self.pil_to_numpy(rgb) # to np
863
+ rgb = self.numpy_to_pt(rgb) # to pt
864
+
865
+ depth = self.depth_pil_to_numpy(depth) # to np
866
+ depth = self.numpy_to_pt(depth) # to pt
867
+
868
+ elif isinstance(rgb[0], np.ndarray):
869
+ rgb = np.concatenate(rgb, axis=0) if rgb[0].ndim == 4 else np.stack(rgb, axis=0)
870
+ rgb = self.numpy_to_pt(rgb)
871
+ height, width = self.get_default_height_width(rgb, height, width)
872
+ if self.config.do_resize:
873
+ rgb = self.resize(rgb, height, width)
874
+
875
+ depth = np.concatenate(depth, axis=0) if rgb[0].ndim == 4 else np.stack(depth, axis=0)
876
+ depth = self.numpy_to_pt(depth)
877
+ height, width = self.get_default_height_width(depth, height, width)
878
+ if self.config.do_resize:
879
+ depth = self.resize(depth, height, width)
880
+
881
+ elif isinstance(rgb[0], torch.Tensor):
882
+ raise Exception("This is not yet supported")
883
+ # rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0)
884
+
885
+ # if self.config.do_convert_grayscale and rgb.ndim == 3:
886
+ # rgb = rgb.unsqueeze(1)
887
+
888
+ # channel = rgb.shape[1]
889
+
890
+ # height, width = self.get_default_height_width(rgb, height, width)
891
+ # if self.config.do_resize:
892
+ # rgb = self.resize(rgb, height, width)
893
+
894
+ # depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0)
895
+
896
+ # if self.config.do_convert_grayscale and depth.ndim == 3:
897
+ # depth = depth.unsqueeze(1)
898
+
899
+ # channel = depth.shape[1]
900
+ # # don't need any preprocess if the image is latents
901
+ # if depth == 4:
902
+ # return rgb, depth
903
+
904
+ # height, width = self.get_default_height_width(depth, height, width)
905
+ # if self.config.do_resize:
906
+ # depth = self.resize(depth, height, width)
907
+ # expected range [0,1], normalize to [-1,1]
908
+ do_normalize = self.config.do_normalize
909
+ if rgb.min() < 0 and do_normalize:
910
+ warnings.warn(
911
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
912
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]",
913
+ FutureWarning,
914
+ )
915
+ do_normalize = False
916
+
917
+ if do_normalize:
918
+ rgb = self.normalize(rgb)
919
+ depth = self.normalize(depth)
920
+
921
+ if self.config.do_binarize:
922
+ rgb = self.binarize(rgb)
923
+ depth = self.binarize(depth)
924
+
925
+ return rgb, depth
926
+
927
+
928
+ class IPAdapterMaskProcessor(VaeImageProcessor):
929
+ """
930
+ Image processor for IP Adapter image masks.
931
+
932
+ Args:
933
+ do_resize (`bool`, *optional*, defaults to `True`):
934
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
935
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
936
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
937
+ resample (`str`, *optional*, defaults to `lanczos`):
938
+ Resampling filter to use when resizing the image.
939
+ do_normalize (`bool`, *optional*, defaults to `False`):
940
+ Whether to normalize the image to [-1,1].
941
+ do_binarize (`bool`, *optional*, defaults to `True`):
942
+ Whether to binarize the image to 0/1.
943
+ do_convert_grayscale (`bool`, *optional*, defaults to be `True`):
944
+ Whether to convert the images to grayscale format.
945
+
946
+ """
947
+
948
+ config_name = CONFIG_NAME
949
+
950
+ @register_to_config
951
+ def __init__(
952
+ self,
953
+ do_resize: bool = True,
954
+ vae_scale_factor: int = 8,
955
+ resample: str = "lanczos",
956
+ do_normalize: bool = False,
957
+ do_binarize: bool = True,
958
+ do_convert_grayscale: bool = True,
959
+ ):
960
+ super().__init__(
961
+ do_resize=do_resize,
962
+ vae_scale_factor=vae_scale_factor,
963
+ resample=resample,
964
+ do_normalize=do_normalize,
965
+ do_binarize=do_binarize,
966
+ do_convert_grayscale=do_convert_grayscale,
967
+ )
968
+
969
+ @staticmethod
970
+ def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int):
971
+ """
972
+ Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
973
+ aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
974
+
975
+ Args:
976
+ mask (`torch.Tensor`):
977
+ The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
978
+ batch_size (`int`):
979
+ The batch size.
980
+ num_queries (`int`):
981
+ The number of queries.
982
+ value_embed_dim (`int`):
983
+ The dimensionality of the value embeddings.
984
+
985
+ Returns:
986
+ `torch.Tensor`:
987
+ The downsampled mask tensor.
988
+
989
+ """
990
+ o_h = mask.shape[1]
991
+ o_w = mask.shape[2]
992
+ ratio = o_w / o_h
993
+ mask_h = int(math.sqrt(num_queries / ratio))
994
+ mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0)
995
+ mask_w = num_queries // mask_h
996
+
997
+ mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)
998
+
999
+ # Repeat batch_size times
1000
+ if mask_downsample.shape[0] < batch_size:
1001
+ mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
1002
+
1003
+ mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1)
1004
+
1005
+ downsampled_area = mask_h * mask_w
1006
+ # If the output image and the mask do not have the same aspect ratio, tensor shapes will not match
1007
+ # Pad tensor if downsampled_mask.shape[1] is smaller than num_queries
1008
+ if downsampled_area < num_queries:
1009
+ warnings.warn(
1010
+ "The aspect ratio of the mask does not match the aspect ratio of the output image. "
1011
+ "Please update your masks or adjust the output size for optimal performance.",
1012
+ UserWarning,
1013
+ )
1014
+ mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0)
1015
+ # Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries
1016
+ if downsampled_area > num_queries:
1017
+ warnings.warn(
1018
+ "The aspect ratio of the mask does not match the aspect ratio of the output image. "
1019
+ "Please update your masks or adjust the output size for optimal performance.",
1020
+ UserWarning,
1021
+ )
1022
+ mask_downsample = mask_downsample[:, :num_queries]
1023
+
1024
+ # Repeat last dimension to match SDPA output shape
1025
+ mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat(
1026
+ 1, 1, value_embed_dim
1027
+ )
1028
+
1029
+ return mask_downsample
1030
+
1031
+
1032
+ class PixArtImageProcessor(VaeImageProcessor):
1033
+ """
1034
+ Image processor for PixArt image resize and crop.
1035
+
1036
+ Args:
1037
+ do_resize (`bool`, *optional*, defaults to `True`):
1038
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
1039
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
1040
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
1041
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
1042
+ resample (`str`, *optional*, defaults to `lanczos`):
1043
+ Resampling filter to use when resizing the image.
1044
+ do_normalize (`bool`, *optional*, defaults to `True`):
1045
+ Whether to normalize the image to [-1,1].
1046
+ do_binarize (`bool`, *optional*, defaults to `False`):
1047
+ Whether to binarize the image to 0/1.
1048
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
1049
+ Whether to convert the images to RGB format.
1050
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
1051
+ Whether to convert the images to grayscale format.
1052
+ """
1053
+
1054
+ @register_to_config
1055
+ def __init__(
1056
+ self,
1057
+ do_resize: bool = True,
1058
+ vae_scale_factor: int = 8,
1059
+ resample: str = "lanczos",
1060
+ do_normalize: bool = True,
1061
+ do_binarize: bool = False,
1062
+ do_convert_grayscale: bool = False,
1063
+ ):
1064
+ super().__init__(
1065
+ do_resize=do_resize,
1066
+ vae_scale_factor=vae_scale_factor,
1067
+ resample=resample,
1068
+ do_normalize=do_normalize,
1069
+ do_binarize=do_binarize,
1070
+ do_convert_grayscale=do_convert_grayscale,
1071
+ )
1072
+
1073
+ @staticmethod
1074
+ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
1075
+ """Returns binned height and width."""
1076
+ ar = float(height / width)
1077
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
1078
+ default_hw = ratios[closest_ratio]
1079
+ return int(default_hw[0]), int(default_hw[1])
1080
+
1081
+ @staticmethod
1082
+ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
1083
+ orig_height, orig_width = samples.shape[2], samples.shape[3]
1084
+
1085
+ # Check if resizing is needed
1086
+ if orig_height != new_height or orig_width != new_width:
1087
+ ratio = max(new_height / orig_height, new_width / orig_width)
1088
+ resized_width = int(orig_width * ratio)
1089
+ resized_height = int(orig_height * ratio)
1090
+
1091
+ # Resize
1092
+ samples = F.interpolate(
1093
+ samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
1094
+ )
1095
+
1096
+ # Center Crop
1097
+ start_x = (resized_width - new_width) // 2
1098
+ end_x = start_x + new_width
1099
+ start_y = (resized_height - new_height) // 2
1100
+ end_y = start_y + new_height
1101
+ samples = samples[:, :, start_y:end_y, start_x:end_x]
1102
+
1103
+ return samples
diffusers3/loaders/__init__.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, deprecate
4
+ from ..utils.import_utils import is_peft_available, is_torch_available, is_transformers_available
5
+
6
+
7
+ def text_encoder_lora_state_dict(text_encoder):
8
+ deprecate(
9
+ "text_encoder_load_state_dict in `models`",
10
+ "0.27.0",
11
+ "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
12
+ )
13
+ state_dict = {}
14
+
15
+ for name, module in text_encoder_attn_modules(text_encoder):
16
+ for k, v in module.q_proj.lora_linear_layer.state_dict().items():
17
+ state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
18
+
19
+ for k, v in module.k_proj.lora_linear_layer.state_dict().items():
20
+ state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
21
+
22
+ for k, v in module.v_proj.lora_linear_layer.state_dict().items():
23
+ state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
24
+
25
+ for k, v in module.out_proj.lora_linear_layer.state_dict().items():
26
+ state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
27
+
28
+ return state_dict
29
+
30
+
31
+ if is_transformers_available():
32
+
33
+ def text_encoder_attn_modules(text_encoder):
34
+ deprecate(
35
+ "text_encoder_attn_modules in `models`",
36
+ "0.27.0",
37
+ "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
38
+ )
39
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection
40
+
41
+ attn_modules = []
42
+
43
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
44
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
45
+ name = f"text_model.encoder.layers.{i}.self_attn"
46
+ mod = layer.self_attn
47
+ attn_modules.append((name, mod))
48
+ else:
49
+ raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
50
+
51
+ return attn_modules
52
+
53
+
54
+ _import_structure = {}
55
+
56
+ if is_torch_available():
57
+ _import_structure["single_file_model"] = ["FromOriginalModelMixin"]
58
+
59
+ _import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
60
+ _import_structure["utils"] = ["AttnProcsLayers"]
61
+ if is_transformers_available():
62
+ _import_structure["single_file"] = ["FromSingleFileMixin"]
63
+ _import_structure["lora_pipeline"] = [
64
+ "AmusedLoraLoaderMixin",
65
+ "StableDiffusionLoraLoaderMixin",
66
+ "SD3LoraLoaderMixin",
67
+ "StableDiffusionXLLoraLoaderMixin",
68
+ "LoraLoaderMixin",
69
+ "FluxLoraLoaderMixin",
70
+ ]
71
+ _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
72
+ _import_structure["ip_adapter"] = ["IPAdapterMixin"]
73
+
74
+ _import_structure["peft"] = ["PeftAdapterMixin"]
75
+
76
+
77
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
78
+ if is_torch_available():
79
+ from .single_file_model import FromOriginalModelMixin
80
+ from .unet import UNet2DConditionLoadersMixin
81
+ from .utils import AttnProcsLayers
82
+
83
+ if is_transformers_available():
84
+ from .ip_adapter import IPAdapterMixin
85
+ from .lora_pipeline import (
86
+ AmusedLoraLoaderMixin,
87
+ FluxLoraLoaderMixin,
88
+ LoraLoaderMixin,
89
+ SD3LoraLoaderMixin,
90
+ StableDiffusionLoraLoaderMixin,
91
+ StableDiffusionXLLoraLoaderMixin,
92
+ )
93
+ from .single_file import FromSingleFileMixin
94
+ from .textual_inversion import TextualInversionLoaderMixin
95
+
96
+ from .peft import PeftAdapterMixin
97
+ else:
98
+ import sys
99
+
100
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diffusers3/loaders/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.91 kB). View file
 
diffusers3/loaders/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (2.93 kB). View file
 
diffusers3/loaders/__pycache__/ip_adapter.cpython-310.pyc ADDED
Binary file (11.6 kB). View file
 
diffusers3/loaders/__pycache__/ip_adapter.cpython-38.pyc ADDED
Binary file (11.5 kB). View file
 
diffusers3/loaders/__pycache__/lora_base.cpython-310.pyc ADDED
Binary file (22.7 kB). View file
 
diffusers3/loaders/__pycache__/lora_base.cpython-38.pyc ADDED
Binary file (22.8 kB). View file
 
diffusers3/loaders/__pycache__/lora_conversion_utils.cpython-310.pyc ADDED
Binary file (14.9 kB). View file
 
diffusers3/loaders/__pycache__/lora_conversion_utils.cpython-38.pyc ADDED
Binary file (15.3 kB). View file
 
diffusers3/loaders/__pycache__/lora_pipeline.cpython-310.pyc ADDED
Binary file (53.1 kB). View file
 
diffusers3/loaders/__pycache__/lora_pipeline.cpython-38.pyc ADDED
Binary file (57.8 kB). View file
 
diffusers3/loaders/__pycache__/peft.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
diffusers3/loaders/__pycache__/peft.cpython-38.pyc ADDED
Binary file (12.9 kB). View file
 
diffusers3/loaders/__pycache__/single_file.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
diffusers3/loaders/__pycache__/single_file.cpython-38.pyc ADDED
Binary file (14.2 kB). View file
 
diffusers3/loaders/__pycache__/single_file_model.cpython-310.pyc ADDED
Binary file (9.82 kB). View file
 
diffusers3/loaders/__pycache__/single_file_model.cpython-38.pyc ADDED
Binary file (9.82 kB). View file
 
diffusers3/loaders/__pycache__/single_file_utils.cpython-310.pyc ADDED
Binary file (53.4 kB). View file
 
diffusers3/loaders/__pycache__/single_file_utils.cpython-38.pyc ADDED
Binary file (54.7 kB). View file