Fix diffusers3 import + lazy assets
Browse files
app.py
CHANGED
|
@@ -1,4 +1,16 @@
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import tempfile
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from functools import lru_cache
|
|
@@ -14,6 +26,11 @@ from PIL import Image, ImageOps
|
|
| 14 |
from huggingface_hub import hf_hub_download
|
| 15 |
|
| 16 |
from diffusers import UniPCMultistepScheduler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
from diffusers3.models.controlnet import ControlNetModel
|
| 18 |
from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import (
|
| 19 |
StableDiffusionXLControlNetImg2ImgPipeline,
|
|
@@ -30,7 +47,7 @@ from preprocess.simple_extractor import run as run_simple_extractor
|
|
| 30 |
BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 31 |
CONTROLNET_ID = "diffusers/controlnet-depth-sdxl-1.0"
|
| 32 |
|
| 33 |
-
# 네 assets dataset repo (
|
| 34 |
ASSETS_REPO = os.getenv("ASSETS_REPO", "soye/VISTA_assets")
|
| 35 |
ASSETS_REPO_TYPE = "dataset" # dataset repo로 올렸으니
|
| 36 |
|
|
@@ -50,14 +67,29 @@ def asset_path(relpath: str) -> str:
|
|
| 50 |
)
|
| 51 |
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
IP_CKPT = asset_path("ip_adapter/ip-adapter_sdxl_vit-h.bin")
|
| 60 |
-
SCHP_CKPT = asset_path("preprocess_ckpts/exp-schp-201908301523-atr.pth")
|
| 61 |
|
| 62 |
DEFAULT_STEPS = 40
|
| 63 |
DEBUG_SAVE = False
|
|
@@ -84,10 +116,8 @@ def _imread_or_raise(path: str, flag=cv2.IMREAD_COLOR):
|
|
| 84 |
|
| 85 |
def compute_hw_from_person(person_path: str):
|
| 86 |
"""
|
| 87 |
-
네 코드와 동일 개념:
|
| 88 |
- height=1024 고정, aspect 유지로 W 계산
|
| 89 |
-
|
| 90 |
-
demo 안정성 위해 W를 1024로 cap.
|
| 91 |
"""
|
| 92 |
img = _imread_or_raise(person_path)
|
| 93 |
orig_h, orig_w = img.shape[:2]
|
|
@@ -186,6 +216,8 @@ def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, s
|
|
| 186 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 187 |
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 188 |
|
|
|
|
|
|
|
| 189 |
cn_kwargs = dict(torch_dtype=dtype, use_safetensors=True)
|
| 190 |
if dtype == torch.float16:
|
| 191 |
cn_kwargs["variant"] = "fp16"
|
|
@@ -203,8 +235,8 @@ def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, s
|
|
| 203 |
pipe.enable_attention_slicing()
|
| 204 |
try:
|
| 205 |
pipe.enable_xformers_memory_efficient_attention()
|
| 206 |
-
except Exception:
|
| 207 |
-
|
| 208 |
|
| 209 |
return pipe, device, dtype
|
| 210 |
|
|
@@ -213,13 +245,16 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
|
|
| 213 |
global H, W
|
| 214 |
pipe, device, _dtype = get_pipe_and_device()
|
| 215 |
|
|
|
|
|
|
|
|
|
|
| 216 |
H, W = compute_hw_from_person(paths.person_path)
|
| 217 |
|
| 218 |
# parsing extractor (원본 호출 형태 유지)
|
| 219 |
res = run_simple_extractor(
|
| 220 |
category="Upper-clothes",
|
| 221 |
input_path=os.path.abspath(paths.person_path),
|
| 222 |
-
model_restore=
|
| 223 |
)
|
| 224 |
parsing_img = res["images"][0] if res.get("images") else None
|
| 225 |
if parsing_img is None:
|
|
@@ -253,8 +288,7 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
|
|
| 253 |
|
| 254 |
depth_map = make_depth(paths.depth_path)
|
| 255 |
|
| 256 |
-
# garment / garment_mask
|
| 257 |
-
# 여기서는 parsing_img 기반 사이즈/패딩만 맞춰서 전달
|
| 258 |
garment_pil = person_pil.copy()
|
| 259 |
|
| 260 |
gm = np.array(parsing_img.convert("L"), dtype=np.uint8)
|
|
@@ -271,8 +305,8 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
|
|
| 271 |
|
| 272 |
ip_model = IPAdapterXL(
|
| 273 |
pipe,
|
| 274 |
-
|
| 275 |
-
|
| 276 |
device,
|
| 277 |
mask_pil,
|
| 278 |
person_pil,
|
|
@@ -358,4 +392,4 @@ with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
|
|
| 358 |
|
| 359 |
demo.queue()
|
| 360 |
if __name__ == "__main__":
|
| 361 |
-
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
| 1 |
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
# ---------------------------------------------------------
|
| 5 |
+
# 0) Make sure local packages (diffusers3, preprocess, etc.) are importable on HF Spaces
|
| 6 |
+
# ---------------------------------------------------------
|
| 7 |
+
ROOT = os.path.dirname(os.path.abspath(__file__))
|
| 8 |
+
if ROOT not in sys.path:
|
| 9 |
+
sys.path.insert(0, ROOT)
|
| 10 |
+
|
| 11 |
+
print("[BOOT] ROOT =", ROOT, flush=True)
|
| 12 |
+
print("[BOOT] sys.path[:5] =", sys.path[:5], flush=True)
|
| 13 |
+
|
| 14 |
import tempfile
|
| 15 |
from dataclasses import dataclass
|
| 16 |
from functools import lru_cache
|
|
|
|
| 26 |
from huggingface_hub import hf_hub_download
|
| 27 |
|
| 28 |
from diffusers import UniPCMultistepScheduler
|
| 29 |
+
|
| 30 |
+
# Show where diffusers3 is imported from (helps diagnose import collisions on Spaces)
|
| 31 |
+
import diffusers3
|
| 32 |
+
print("[BOOT] diffusers3 loaded from:", getattr(diffusers3, "__file__", "<?>"), flush=True)
|
| 33 |
+
|
| 34 |
from diffusers3.models.controlnet import ControlNetModel
|
| 35 |
from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import (
|
| 36 |
StableDiffusionXLControlNetImg2ImgPipeline,
|
|
|
|
| 47 |
BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 48 |
CONTROLNET_ID = "diffusers/controlnet-depth-sdxl-1.0"
|
| 49 |
|
| 50 |
+
# 네 assets dataset repo (가중치 저장소)
|
| 51 |
ASSETS_REPO = os.getenv("ASSETS_REPO", "soye/VISTA_assets")
|
| 52 |
ASSETS_REPO_TYPE = "dataset" # dataset repo로 올렸으니
|
| 53 |
|
|
|
|
| 67 |
)
|
| 68 |
|
| 69 |
|
| 70 |
+
@lru_cache(maxsize=1)
|
| 71 |
+
def get_assets():
|
| 72 |
+
"""
|
| 73 |
+
Lazily downloads required assets on first use.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
(image_encoder_dir, ip_ckpt_path, schp_ckpt_path)
|
| 77 |
+
"""
|
| 78 |
+
print("[ASSETS] Downloading assets from:", ASSETS_REPO, flush=True)
|
| 79 |
+
|
| 80 |
+
# Image encoder folder is needed by IPAdapterXL
|
| 81 |
+
image_encoder_weight = asset_path("image_encoder/model.safetensors")
|
| 82 |
+
_ = asset_path("image_encoder/config.json") # ensure config exists locally
|
| 83 |
+
image_encoder_dir = os.path.dirname(image_encoder_weight)
|
| 84 |
+
|
| 85 |
+
ip_ckpt = asset_path("ip_adapter/ip-adapter_sdxl_vit-h.bin")
|
| 86 |
+
schp_ckpt = asset_path("preprocess_ckpts/exp-schp-201908301523-atr.pth")
|
| 87 |
+
|
| 88 |
+
print("[ASSETS] image_encoder_dir =", image_encoder_dir, flush=True)
|
| 89 |
+
print("[ASSETS] ip_ckpt =", ip_ckpt, flush=True)
|
| 90 |
+
print("[ASSETS] schp_ckpt =", schp_ckpt, flush=True)
|
| 91 |
+
return image_encoder_dir, ip_ckpt, schp_ckpt
|
| 92 |
|
|
|
|
|
|
|
| 93 |
|
| 94 |
DEFAULT_STEPS = 40
|
| 95 |
DEBUG_SAVE = False
|
|
|
|
| 116 |
|
| 117 |
def compute_hw_from_person(person_path: str):
|
| 118 |
"""
|
|
|
|
| 119 |
- height=1024 고정, aspect 유지로 W 계산
|
| 120 |
+
- demo 안정성 위해 W를 1024로 cap.
|
|
|
|
| 121 |
"""
|
| 122 |
img = _imread_or_raise(person_path)
|
| 123 |
orig_h, orig_w = img.shape[:2]
|
|
|
|
| 216 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 217 |
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 218 |
|
| 219 |
+
print(f"[PIPE] device={device}, dtype={dtype}", flush=True)
|
| 220 |
+
|
| 221 |
cn_kwargs = dict(torch_dtype=dtype, use_safetensors=True)
|
| 222 |
if dtype == torch.float16:
|
| 223 |
cn_kwargs["variant"] = "fp16"
|
|
|
|
| 235 |
pipe.enable_attention_slicing()
|
| 236 |
try:
|
| 237 |
pipe.enable_xformers_memory_efficient_attention()
|
| 238 |
+
except Exception as e:
|
| 239 |
+
print("[PIPE] xformers not enabled:", repr(e), flush=True)
|
| 240 |
|
| 241 |
return pipe, device, dtype
|
| 242 |
|
|
|
|
| 245 |
global H, W
|
| 246 |
pipe, device, _dtype = get_pipe_and_device()
|
| 247 |
|
| 248 |
+
# lazy assets download here (NOT at import time)
|
| 249 |
+
image_encoder_dir, ip_ckpt, schp_ckpt = get_assets()
|
| 250 |
+
|
| 251 |
H, W = compute_hw_from_person(paths.person_path)
|
| 252 |
|
| 253 |
# parsing extractor (원본 호출 형태 유지)
|
| 254 |
res = run_simple_extractor(
|
| 255 |
category="Upper-clothes",
|
| 256 |
input_path=os.path.abspath(paths.person_path),
|
| 257 |
+
model_restore=schp_ckpt,
|
| 258 |
)
|
| 259 |
parsing_img = res["images"][0] if res.get("images") else None
|
| 260 |
if parsing_img is None:
|
|
|
|
| 288 |
|
| 289 |
depth_map = make_depth(paths.depth_path)
|
| 290 |
|
| 291 |
+
# garment / garment_mask
|
|
|
|
| 292 |
garment_pil = person_pil.copy()
|
| 293 |
|
| 294 |
gm = np.array(parsing_img.convert("L"), dtype=np.uint8)
|
|
|
|
| 305 |
|
| 306 |
ip_model = IPAdapterXL(
|
| 307 |
pipe,
|
| 308 |
+
image_encoder_dir,
|
| 309 |
+
ip_ckpt,
|
| 310 |
device,
|
| 311 |
mask_pil,
|
| 312 |
person_pil,
|
|
|
|
| 392 |
|
| 393 |
demo.queue()
|
| 394 |
if __name__ == "__main__":
|
| 395 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|