ssoxye commited on
Commit
61345be
·
1 Parent(s): da21aff

Fix diffusers3 import + lazy assets

Browse files
Files changed (1) hide show
  1. app.py +53 -19
app.py CHANGED
@@ -1,4 +1,16 @@
1
  import os
 
 
 
 
 
 
 
 
 
 
 
 
2
  import tempfile
3
  from dataclasses import dataclass
4
  from functools import lru_cache
@@ -14,6 +26,11 @@ from PIL import Image, ImageOps
14
  from huggingface_hub import hf_hub_download
15
 
16
  from diffusers import UniPCMultistepScheduler
 
 
 
 
 
17
  from diffusers3.models.controlnet import ControlNetModel
18
  from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import (
19
  StableDiffusionXLControlNetImg2ImgPipeline,
@@ -30,7 +47,7 @@ from preprocess.simple_extractor import run as run_simple_extractor
30
  BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
31
  CONTROLNET_ID = "diffusers/controlnet-depth-sdxl-1.0"
32
 
33
- # 네 assets dataset repo (1번에서 만든 것)
34
  ASSETS_REPO = os.getenv("ASSETS_REPO", "soye/VISTA_assets")
35
  ASSETS_REPO_TYPE = "dataset" # dataset repo로 올렸으니
36
 
@@ -50,14 +67,29 @@ def asset_path(relpath: str) -> str:
50
  )
51
 
52
 
53
- # image encoder는 "폴더 경로"가 필요하니,
54
- # model/config를 둘 다 다운로드 후 같은 폴더를 가리키게 함.
55
- _IMAGE_ENCODER_WEIGHT = asset_path("image_encoder/model.safetensors")
56
- _IMAGE_ENCODER_CONFIG = asset_path("image_encoder/config.json")
57
- IMAGE_ENCODER_DIR = os.path.dirname(_IMAGE_ENCODER_WEIGHT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- IP_CKPT = asset_path("ip_adapter/ip-adapter_sdxl_vit-h.bin")
60
- SCHP_CKPT = asset_path("preprocess_ckpts/exp-schp-201908301523-atr.pth")
61
 
62
  DEFAULT_STEPS = 40
63
  DEBUG_SAVE = False
@@ -84,10 +116,8 @@ def _imread_or_raise(path: str, flag=cv2.IMREAD_COLOR):
84
 
85
  def compute_hw_from_person(person_path: str):
86
  """
87
- 네 코드와 동일 개념:
88
  - height=1024 고정, aspect 유지로 W 계산
89
- 단, W가 1024보다 커지면 padding이 음수가 되므로(원본 코드 취약점),
90
- demo 안정성 위해 W를 1024로 cap.
91
  """
92
  img = _imread_or_raise(person_path)
93
  orig_h, orig_w = img.shape[:2]
@@ -186,6 +216,8 @@ def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, s
186
  device = "cuda" if torch.cuda.is_available() else "cpu"
187
  dtype = torch.float16 if device == "cuda" else torch.float32
188
 
 
 
189
  cn_kwargs = dict(torch_dtype=dtype, use_safetensors=True)
190
  if dtype == torch.float16:
191
  cn_kwargs["variant"] = "fp16"
@@ -203,8 +235,8 @@ def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, s
203
  pipe.enable_attention_slicing()
204
  try:
205
  pipe.enable_xformers_memory_efficient_attention()
206
- except Exception:
207
- pass
208
 
209
  return pipe, device, dtype
210
 
@@ -213,13 +245,16 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
213
  global H, W
214
  pipe, device, _dtype = get_pipe_and_device()
215
 
 
 
 
216
  H, W = compute_hw_from_person(paths.person_path)
217
 
218
  # parsing extractor (원본 호출 형태 유지)
219
  res = run_simple_extractor(
220
  category="Upper-clothes",
221
  input_path=os.path.abspath(paths.person_path),
222
- model_restore=SCHP_CKPT,
223
  )
224
  parsing_img = res["images"][0] if res.get("images") else None
225
  if parsing_img is None:
@@ -253,8 +288,7 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
253
 
254
  depth_map = make_depth(paths.depth_path)
255
 
256
- # garment / garment_mask: 너 코드 흐름상 person에서 만들어도 되지만
257
- # 여기서는 parsing_img 기반 사이즈/패딩만 맞춰서 전달
258
  garment_pil = person_pil.copy()
259
 
260
  gm = np.array(parsing_img.convert("L"), dtype=np.uint8)
@@ -271,8 +305,8 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
271
 
272
  ip_model = IPAdapterXL(
273
  pipe,
274
- IMAGE_ENCODER_DIR,
275
- IP_CKPT,
276
  device,
277
  mask_pil,
278
  person_pil,
@@ -358,4 +392,4 @@ with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
358
 
359
  demo.queue()
360
  if __name__ == "__main__":
361
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
2
+ import sys
3
+
4
+ # ---------------------------------------------------------
5
+ # 0) Make sure local packages (diffusers3, preprocess, etc.) are importable on HF Spaces
6
+ # ---------------------------------------------------------
7
+ ROOT = os.path.dirname(os.path.abspath(__file__))
8
+ if ROOT not in sys.path:
9
+ sys.path.insert(0, ROOT)
10
+
11
+ print("[BOOT] ROOT =", ROOT, flush=True)
12
+ print("[BOOT] sys.path[:5] =", sys.path[:5], flush=True)
13
+
14
  import tempfile
15
  from dataclasses import dataclass
16
  from functools import lru_cache
 
26
  from huggingface_hub import hf_hub_download
27
 
28
  from diffusers import UniPCMultistepScheduler
29
+
30
+ # Show where diffusers3 is imported from (helps diagnose import collisions on Spaces)
31
+ import diffusers3
32
+ print("[BOOT] diffusers3 loaded from:", getattr(diffusers3, "__file__", "<?>"), flush=True)
33
+
34
  from diffusers3.models.controlnet import ControlNetModel
35
  from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import (
36
  StableDiffusionXLControlNetImg2ImgPipeline,
 
47
  BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
48
  CONTROLNET_ID = "diffusers/controlnet-depth-sdxl-1.0"
49
 
50
+ # 네 assets dataset repo (가중치 저장소)
51
  ASSETS_REPO = os.getenv("ASSETS_REPO", "soye/VISTA_assets")
52
  ASSETS_REPO_TYPE = "dataset" # dataset repo로 올렸으니
53
 
 
67
  )
68
 
69
 
70
+ @lru_cache(maxsize=1)
71
+ def get_assets():
72
+ """
73
+ Lazily downloads required assets on first use.
74
+
75
+ Returns:
76
+ (image_encoder_dir, ip_ckpt_path, schp_ckpt_path)
77
+ """
78
+ print("[ASSETS] Downloading assets from:", ASSETS_REPO, flush=True)
79
+
80
+ # Image encoder folder is needed by IPAdapterXL
81
+ image_encoder_weight = asset_path("image_encoder/model.safetensors")
82
+ _ = asset_path("image_encoder/config.json") # ensure config exists locally
83
+ image_encoder_dir = os.path.dirname(image_encoder_weight)
84
+
85
+ ip_ckpt = asset_path("ip_adapter/ip-adapter_sdxl_vit-h.bin")
86
+ schp_ckpt = asset_path("preprocess_ckpts/exp-schp-201908301523-atr.pth")
87
+
88
+ print("[ASSETS] image_encoder_dir =", image_encoder_dir, flush=True)
89
+ print("[ASSETS] ip_ckpt =", ip_ckpt, flush=True)
90
+ print("[ASSETS] schp_ckpt =", schp_ckpt, flush=True)
91
+ return image_encoder_dir, ip_ckpt, schp_ckpt
92
 
 
 
93
 
94
  DEFAULT_STEPS = 40
95
  DEBUG_SAVE = False
 
116
 
117
  def compute_hw_from_person(person_path: str):
118
  """
 
119
  - height=1024 고정, aspect 유지로 W 계산
120
+ - demo 안정성 위해 W를 1024로 cap.
 
121
  """
122
  img = _imread_or_raise(person_path)
123
  orig_h, orig_w = img.shape[:2]
 
216
  device = "cuda" if torch.cuda.is_available() else "cpu"
217
  dtype = torch.float16 if device == "cuda" else torch.float32
218
 
219
+ print(f"[PIPE] device={device}, dtype={dtype}", flush=True)
220
+
221
  cn_kwargs = dict(torch_dtype=dtype, use_safetensors=True)
222
  if dtype == torch.float16:
223
  cn_kwargs["variant"] = "fp16"
 
235
  pipe.enable_attention_slicing()
236
  try:
237
  pipe.enable_xformers_memory_efficient_attention()
238
+ except Exception as e:
239
+ print("[PIPE] xformers not enabled:", repr(e), flush=True)
240
 
241
  return pipe, device, dtype
242
 
 
245
  global H, W
246
  pipe, device, _dtype = get_pipe_and_device()
247
 
248
+ # lazy assets download here (NOT at import time)
249
+ image_encoder_dir, ip_ckpt, schp_ckpt = get_assets()
250
+
251
  H, W = compute_hw_from_person(paths.person_path)
252
 
253
  # parsing extractor (원본 호출 형태 유지)
254
  res = run_simple_extractor(
255
  category="Upper-clothes",
256
  input_path=os.path.abspath(paths.person_path),
257
+ model_restore=schp_ckpt,
258
  )
259
  parsing_img = res["images"][0] if res.get("images") else None
260
  if parsing_img is None:
 
288
 
289
  depth_map = make_depth(paths.depth_path)
290
 
291
+ # garment / garment_mask
 
292
  garment_pil = person_pil.copy()
293
 
294
  gm = np.array(parsing_img.convert("L"), dtype=np.uint8)
 
305
 
306
  ip_model = IPAdapterXL(
307
  pipe,
308
+ image_encoder_dir,
309
+ ip_ckpt,
310
  device,
311
  mask_pil,
312
  person_pil,
 
392
 
393
  demo.queue()
394
  if __name__ == "__main__":
395
+ demo.launch(server_name="0.0.0.0", server_port=7860)