import os import sys import glob # --------------------------------------------------------- # 0) Make sure local packages (diffusers3, preprocess, etc.) are importable on HF Spaces # --------------------------------------------------------- ROOT = os.path.dirname(os.path.abspath(__file__)) if ROOT not in sys.path: sys.path.insert(0, ROOT) print("[BOOT] ROOT =", ROOT, flush=True) print("[BOOT] sys.path[:5] =", sys.path[:5], flush=True) import tempfile from dataclasses import dataclass from functools import lru_cache from typing import Optional, Tuple, List, Dict import gradio as gr import torch import numpy as np import cv2 import imageio from PIL import Image, ImageOps from transformers import pipeline from huggingface_hub import hf_hub_download import diffusers3 print("[BOOT] diffusers3 loaded from:", getattr(diffusers3, "__file__", ""), flush=True) from diffusers import UniPCMultistepScheduler, AutoencoderKL, UNet2DConditionModel from diffusers3.models.controlnet import ControlNetModel from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import ( StableDiffusionXLControlNetImg2ImgPipeline, ) from ip_adapter import IPAdapterXL # extractor from preprocess.simple_extractor import run as run_simple_extractor # ========================= # HF Hub repo ids # ========================= BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" CONTROLNET_ID = "diffusers/controlnet-depth-sdxl-1.0" # assets dataset repo ASSETS_REPO = os.getenv("ASSETS_REPO", "soye/VISTA_assets") ASSETS_REPO_TYPE = "dataset" depth_estimator = pipeline("depth-estimation") def asset_path(relpath: str) -> str: return hf_hub_download( repo_id=ASSETS_REPO, repo_type=ASSETS_REPO_TYPE, filename=relpath, ) @lru_cache(maxsize=1) def get_assets(): print("[ASSETS] Downloading assets from:", ASSETS_REPO, flush=True) image_encoder_weight = asset_path("image_encoder/model.safetensors") _ = asset_path("image_encoder/config.json") image_encoder_dir = os.path.dirname(image_encoder_weight) ip_ckpt = asset_path("ip_adapter/ip-adapter_sdxl_vit-h.bin") schp_ckpt = asset_path("preprocess_ckpts/exp-schp-201908301523-atr.pth") print("[ASSETS] image_encoder_dir =", image_encoder_dir, flush=True) print("[ASSETS] ip_ckpt =", ip_ckpt, flush=True) print("[ASSETS] schp_ckpt =", schp_ckpt, flush=True) return image_encoder_dir, ip_ckpt, schp_ckpt # ========================= # Example assets for Gradio UI (✅ 분리형) # ========================= def _is_image_file(p: str) -> bool: ext = os.path.splitext(p.lower())[1] return ext in (".png", ".jpg", ".jpeg", ".webp") def build_ui_example_lists(root_dir: str = ROOT) -> Dict[str, List[str]]: """ Returns dict of example filepaths: - persons: [{root}/examples/person/*] - styles : [{root}/examples/style/*] - sketches: [{root}/examples/sketch/*] (optional) """ person_dir = os.path.join(root_dir, "examples", "person") style_dir = os.path.join(root_dir, "examples", "style") sketch_dir = os.path.join(root_dir, "examples", "sketch") persons = [p for p in sorted(glob.glob(os.path.join(person_dir, "*"))) if _is_image_file(p)] styles = [p for p in sorted(glob.glob(os.path.join(style_dir, "*"))) if _is_image_file(p)] sketches = [p for p in sorted(glob.glob(os.path.join(sketch_dir, "*"))) if _is_image_file(p)] return {"persons": persons, "styles": styles, "sketches": sketches} DEFAULT_STEPS = 40 DEBUG_SAVE = False H: Optional[int] = None W: Optional[int] = None @dataclass class Paths: person_path: str depth_path: Optional[str] # sketch(guide) optional style_path: Optional[str] # ✅ style optional (변경) output_path: str def _imread_or_raise(path: str, flag=cv2.IMREAD_COLOR): img = cv2.imread(path, flag) if img is None: raise FileNotFoundError(f"cv2.imread failed: {path} (exists={os.path.exists(path)})") return img def _pad_or_crop_to_width_np(arr: np.ndarray, target_width: int, pad_value): """ arr: HxWxC or HxW target_width로 center crop 또는 좌/우 padding(비대칭 포함)해서 정확히 맞춤. """ if arr.ndim not in (2, 3): raise ValueError(f"arr must be 2D or 3D, got shape={arr.shape}") h = arr.shape[0] w = arr.shape[1] if w == target_width: return arr if w > target_width: left = (w - target_width) // 2 return arr[:, left:left + target_width] if arr.ndim == 2 else arr[:, left:left + target_width, :] # w < target_width: pad total = target_width - w left = total // 2 right = total - left # ✅ remainder를 오른쪽이 먹어서 항상 정확히 target_width if arr.ndim == 2: return cv2.copyMakeBorder( arr, 0, 0, left, right, borderType=cv2.BORDER_CONSTANT, value=pad_value, ) else: return cv2.copyMakeBorder( arr, 0, 0, left, right, borderType=cv2.BORDER_CONSTANT, value=pad_value, ) def apply_parsing_white_mask_to_person_cv2( person_pil: Image.Image, parsing_img: Image.Image ) -> np.ndarray: person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8) mask = np.array(parsing_img.convert("L"), dtype=np.uint8) if mask.shape[:2] != person_rgb.shape[:2]: mask = cv2.resize(mask, (person_rgb.shape[1], person_rgb.shape[0]), interpolation=cv2.INTER_NEAREST) white_mask = (mask == 255) result_rgb = np.full_like(person_rgb, 255, dtype=np.uint8) result_rgb[white_mask] = person_rgb[white_mask] result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) return result_bgr def remove_small_white_components( parsing_img: Image.Image, *, white_threshold: int = 128, min_white_area: int = 150, use_open: bool = False, open_ksize: int = 3, morph_iters: int = 1, ) -> Image.Image: """ - 흰색(=foreground)으로 이진화 - connected components로 '작은 흰색 덩어리'만 제거 - (옵션) OPEN을 아주 약하게 적용해 작은 점/가시 제거 (흰색이 늘어나는 CLOSE는 사용 X) """ if not isinstance(parsing_img, Image.Image): raise TypeError("parsing_img must be a PIL.Image.Image") arr = np.array(parsing_img.convert("L"), dtype=np.uint8) mask = np.where(arr >= int(white_threshold), 255, 0).astype(np.uint8) # 1) 작은 흰색 연결요소 제거 num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) keep = np.zeros_like(mask) for lab in range(1, num_labels): area = int(stats[lab, cv2.CC_STAT_AREA]) if area >= int(min_white_area): keep[labels == lab] = 255 mask = keep # 2) (옵션) OPEN: 작은 흰 점/가시 제거 + 경계 약간 정리 (흰색 증가 방향 아님) if use_open and int(open_ksize) > 1: k = int(open_ksize) if k % 2 == 0: k += 1 kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=int(morph_iters)) return Image.fromarray(mask, mode="L") def compute_hw_from_person(person_path: str): img = _imread_or_raise(person_path) orig_h, orig_w = img.shape[:2] scale = 1024.0 / float(orig_h) new_h = 1024 new_w = int(round(orig_w * scale)) if new_w > 1024: new_w = 1024 return new_h, new_w def fill_sketch_from_image_path_to_pil(image_path: str) -> Image.Image: global H, W if H is None or W is None: raise RuntimeError("Global H/W not set.") img = _imread_or_raise(image_path, cv2.IMREAD_GRAYSCALE) img = cv2.bitwise_not(img) img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST) _, binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV) contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) filled = np.zeros_like(binary) cv2.drawContours(filled, contours, -1, 255, thickness=cv2.FILLED) filled_rgb = cv2.cvtColor(filled, cv2.COLOR_GRAY2RGB) return Image.fromarray(filled_rgb) def _resize_pil_nearest(img: Image.Image, size_wh: Tuple[int, int], *, force_mode: Optional[str] = None) -> Image.Image: """ Resize PIL image to (W,H) using INTER_NEAREST (safe for masks). size_wh: (width, height) """ w, h = int(size_wh[0]), int(size_wh[1]) if force_mode is not None: img = img.convert(force_mode) arr = np.array(img, dtype=np.uint8) if arr.ndim == 2: resized = cv2.resize(arr, (w, h), interpolation=cv2.INTER_NEAREST) return Image.fromarray(resized, mode="L") elif arr.ndim == 3 and arr.shape[2] == 3: resized = cv2.resize(arr, (w, h), interpolation=cv2.INTER_NEAREST) return Image.fromarray(resized, mode="RGB") else: raise ValueError(f"Unsupported image array shape: {arr.shape}") def merge_white_regions_or(img1: Image.Image, img2: Image.Image) -> Image.Image: a = np.array(img1.convert("RGB"), dtype=np.uint8) b = np.array(img2.convert("RGB"), dtype=np.uint8) # ✅ safety: make shapes identical to avoid numpy broadcasting error if a.shape[:2] != b.shape[:2]: b = cv2.resize(b, (a.shape[1], a.shape[0]), interpolation=cv2.INTER_NEAREST) white_a = np.all(a == 255, axis=-1) white_b = np.all(b == 255, axis=-1) out = a.copy() out[white_a | white_b] = 255 return Image.fromarray(out, mode="RGB") def preprocess_mask(mask_img: Image.Image) -> Image.Image: global H, W m = np.array(mask_img.convert("L"), dtype=np.uint8) if (H is not None) and (W is not None): m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) _, m = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY) target_width = 1024 m = _pad_or_crop_to_width_np(m, target_width, pad_value=0) kernel = np.ones((12, 12), np.uint8) m = cv2.dilate(m, kernel, iterations=1) if DEBUG_SAVE: cv2.imwrite("mask_final_1024.png", m) return Image.fromarray(m, mode="L").convert("RGB") # def make_depth(depth_path: str) -> Image.Image: # global H, W # if H is None or W is None: # raise RuntimeError("Global H/W not set. Call run_one() first.") # depth_img = _imread_or_raise(depth_path, 0) # grayscale # # (선택) 입력이 완전한 0/255가 아니라면 이진화로 고정 # _, depth_bin = cv2.threshold(depth_img, 127, 255, cv2.THRESH_BINARY) # # 컨투어 채우기가 "두꺼워 보임"의 원인일 수도 있어, 유지/제거 선택 가능 # # 1) 채우기 유지 (holes 메우는 목적이라면) # contours, _ = cv2.findContours(depth_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # filled_depth = np.zeros_like(depth_bin) # cv2.drawContours(filled_depth, contours, -1, 255, thickness=cv2.FILLED) # # 2) 채우기 제거하고 싶으면 위 3줄 대신 이걸 사용: # # filled_depth = depth_bin # # ✅ 마스크 리사이즈는 NEAREST (경계 번짐/팽창 느낌 방지) # filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_NEAREST) # # (선택) 리사이즈 후에도 0/255 강제 # _, filled_depth = cv2.threshold(filled_depth, 127, 255, cv2.THRESH_BINARY) # filled_depth = _pad_or_crop_to_width_np(filled_depth, 1024, pad_value=0) # inverted_image = ImageOps.invert(Image.fromarray(filled_depth)) # with torch.inference_mode(): # image_depth = depth_estimator(inverted_image)["depth"] # if DEBUG_SAVE: # image_depth.save("depth.png") # return image_depth def make_depth(depth_path: str) -> Image.Image: global H, W if H is None or W is None: raise RuntimeError("Global H/W not set. Call run_one() first.") depth_img = _imread_or_raise(depth_path, 0) # grayscale # (선택) 입력이 완전한 0/255가 아니라면 이진화로 고정 _, depth_bin = cv2.threshold(depth_img, 127, 255, cv2.THRESH_BINARY) # 컨투어 채우기 (holes 메우는 목적) contours, _ = cv2.findContours(depth_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) filled_depth = np.zeros_like(depth_bin) cv2.drawContours(filled_depth, contours, -1, 255, thickness=cv2.FILLED) # ✅ 마스크 리사이즈는 NEAREST filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_NEAREST) # (선택) 리사이즈 후에도 0/255 강제 _, filled_depth = cv2.threshold(filled_depth, 127, 255, cv2.THRESH_BINARY) filled_depth = _pad_or_crop_to_width_np(filled_depth, 1024, pad_value=0) # ✅ 여기서 침식(팽창의 반대): 흰색 영역을 조금 줄임 erode_ksize = 5 # 3/5/7... (클수록 더 많이 줄어듦) erode_iters = 1 # 1~2 추천 if erode_ksize is not None and erode_ksize > 1 and erode_iters > 0: if erode_ksize % 2 == 0: erode_ksize += 1 kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_ksize, erode_ksize)) filled_depth = cv2.erode(filled_depth, kernel, iterations=erode_iters) # 안전하게 다시 이진화 _, filled_depth = cv2.threshold(filled_depth, 127, 255, cv2.THRESH_BINARY) inverted_image = ImageOps.invert(Image.fromarray(filled_depth)) with torch.inference_mode(): image_depth = depth_estimator(inverted_image)["depth"] return image_depth def _edges_from_parsing(parsing_img: Image.Image) -> np.ndarray: m = np.array(parsing_img.convert("L"), dtype=np.uint8) _, m_bin = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY) edges = cv2.Canny(m_bin, 50, 150) edges = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=1) return edges.astype(np.uint8) def make_depth_from_parsing_edges(parsing_img: Image.Image) -> Image.Image: global H, W if H is None or W is None: raise RuntimeError("Global H/W not set. Call run_one() first.") depth_img = _edges_from_parsing(parsing_img) contours, _ = cv2.findContours(depth_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) filled_depth = depth_img.copy() cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED) filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA) filled_depth = _pad_or_crop_to_width_np(filled_depth, 1024, pad_value=0) inverted_image = ImageOps.invert(Image.fromarray(filled_depth)) with torch.inference_mode(): image_depth = depth_estimator(inverted_image)["depth"] if DEBUG_SAVE: image_depth.save("depth.png") return image_depth def center_crop_lr_to_768x1024(arr: np.ndarray) -> np.ndarray: target_h, target_w = 1024, 768 h, w = arr.shape[:2] if h != target_h: arr = cv2.resize(arr, (w, target_h), interpolation=cv2.INTER_AREA) h, w = arr.shape[:2] if w < target_w: pad = (target_w - w) // 2 arr = cv2.copyMakeBorder(arr, 0, 0, pad, pad, cv2.BORDER_CONSTANT, value=[255, 255, 255]) w = arr.shape[1] left = (w - target_w) // 2 return arr[:, left:left + target_w] def save_cropped(imgs, out_path: str): np_imgs = [np.asarray(im) for im in imgs] cropped = [center_crop_lr_to_768x1024(x) for x in np_imgs] out = np.concatenate(cropped, axis=1) os.makedirs(os.path.dirname(out_path), exist_ok=True) imageio.imsave(out_path, out) def _read_hw(path: str) -> Tuple[int, int]: img = _imread_or_raise(path) # BGR h, w = img.shape[:2] return h, w def _center_crop_lr_to_aspect(arr: np.ndarray, target_aspect: float, *, pad_value=255) -> np.ndarray: """ arr: HxWxC (RGB) or HxW target_aspect = target_w / target_h - 높이(H)는 유지 - 좌/우를 동일 비율로 crop해서 target_aspect에 맞춤 - 만약 현재 폭이 부족하면 좌/우 padding으로 맞춤 """ if arr.ndim == 2: arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB) h, w = arr.shape[:2] if h <= 0 or w <= 0: raise ValueError(f"Invalid image shape: {arr.shape}") desired_w = int(round(h * float(target_aspect))) if desired_w <= 0: desired_w = 1 # 폭이 충분하면 좌/우 crop if w >= desired_w: left = (w - desired_w) // 2 right = left + desired_w return arr[:, left:right] # 폭이 부족하면 좌/우 padding (요청은 crop이지만 안전장치) total = desired_w - w left_pad = total // 2 right_pad = total - left_pad return cv2.copyMakeBorder( arr, 0, 0, left_pad, right_pad, borderType=cv2.BORDER_CONSTANT, value=[pad_value, pad_value, pad_value], ) def save_output_match_person(imgs, out_path: str, person_path: str): """ - 출력 imgs(보통 길이 1)를 person 원본 비율에 맞게 좌/우 center-crop - person 원본 (W,H)로 resize - (imgs가 여러 장이면) 처리 후 가로로 concat해서 저장 """ person_h, person_w = _read_hw(person_path) target_aspect = float(person_w) / float(person_h) np_imgs = [] for im in imgs: if isinstance(im, Image.Image): arr = np.asarray(im.convert("RGB"), dtype=np.uint8) else: # 혹시 numpy가 들어오는 경우 대비 arr = np.asarray(im, dtype=np.uint8) if arr.ndim == 2: arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB) cropped = _center_crop_lr_to_aspect(arr, target_aspect, pad_value=255) resized = cv2.resize(cropped, (person_w, person_h), interpolation=cv2.INTER_AREA) np_imgs.append(resized) out = np.concatenate(np_imgs, axis=1) # imgs가 1장이면 그대로 os.makedirs(os.path.dirname(out_path), exist_ok=True) imageio.imsave(out_path, out) @lru_cache(maxsize=1) def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]: device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float32 print(f"[PIPE] device={device}, dtype={dtype}", flush=True) controlnet = ControlNetModel.from_pretrained( CONTROLNET_ID, torch_dtype=dtype, use_safetensors=True, ).to(device) vae = AutoencoderKL.from_pretrained( BASE_MODEL_ID, subfolder="vae", torch_dtype=dtype, use_safetensors=True, ).to(device) unet = UNet2DConditionModel.from_pretrained( BASE_MODEL_ID, subfolder="unet", torch_dtype=dtype, use_safetensors=True, ).to(device) pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( BASE_MODEL_ID, controlnet=controlnet, vae=vae, unet=unet, torch_dtype=dtype, use_safetensors=True, add_watermarker=False, ).to(device) if device == "cuda": try: pipe.vae.to(dtype=dtype) if hasattr(pipe.vae, "config") and hasattr(pipe.vae.config, "force_upcast"): pipe.vae.config.force_upcast = False except Exception as e: print("[PIPE] VAE dtype cast failed:", repr(e), flush=True) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_attention_slicing() try: pipe.enable_xformers_memory_efficient_attention() except Exception as e: print("[PIPE] xformers not enabled:", repr(e), flush=True) return pipe, device, dtype # UI 표기 → 내부 extractor category 문자열 매핑 _UI_TO_EXTRACTOR_CATEGORY = { "Upper-body": "Upper-cloth", "Lower-body": "Bottom", "Dress": "Dress", } def _has_valid_file(path: Optional[str]) -> bool: return ( path is not None and isinstance(path, str) and len(path) > 0 and os.path.exists(path) ) def _resolve_content_style_scales(style_present: bool, prompt_present: bool) -> Tuple[float, float]: """ 요구사항: - style image 없으면: (0.0, 0.0) - prompt 없으면: (0.4, 0.6) - 둘 다 있으면: 기존 유지 (0.3, 0.5) """ if not style_present: return 0.0, 0.0 if not prompt_present: return 0.35, 0.65 return 0.25, 0.5 def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str = "Dress"): global H, W pipe, device, _dtype = get_pipe_and_device() image_encoder_dir, ip_ckpt, schp_ckpt = get_assets() H, W = compute_hw_from_person(paths.person_path) extractor_category = _UI_TO_EXTRACTOR_CATEGORY.get(category, "Dress") res = run_simple_extractor( category=extractor_category, input_path=os.path.abspath(paths.person_path), model_restore=schp_ckpt, ) parsing_img = res["images"][0] if res.get("images") else None if parsing_img is None: raise RuntimeError("run_simple_extractor returned no parsing images.") parsing_img = remove_small_white_components( parsing_img, white_threshold=128, min_white_area=150, # 데이터에 맞게 30~200 사이 조절 use_open=False, ) # ✅ IMPORTANT: extractor output size can differ from (W,H). Align before OR-merge. if parsing_img.size != (W, H): parsing_img = _resize_pil_nearest(parsing_img, (W, H), force_mode="L") use_depth_path = _has_valid_file(paths.depth_path) if use_depth_path: sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path) else: sketch_area = parsing_img.convert("RGB") merged_img = merge_white_regions_or(parsing_img, sketch_area) mask_pil = preprocess_mask(merged_img) # person person_bgr = _imread_or_raise(paths.person_path) person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA) person_bgr = _pad_or_crop_to_width_np(person_bgr, 1024, pad_value=[255, 255, 255]) person_rgb = cv2.cvtColor(person_bgr, cv2.COLOR_BGR2RGB) person_pil = Image.fromarray(person_rgb) # depth if use_depth_path: depth_map = make_depth(paths.depth_path) else: depth_map = make_depth_from_parsing_edges(parsing_img) # garment image (✅ 여기서부터가 핵심: 1024 폭 강제) personn = Image.open(paths.person_path).convert("RGB") garment_bgr = apply_parsing_white_mask_to_person_cv2(personn, parsing_img) garment_rgb = cv2.cvtColor(garment_bgr, cv2.COLOR_BGR2RGB) garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA) garment_rgb = _pad_or_crop_to_width_np(garment_rgb, 1024, pad_value=[255, 255, 255]) garment_pil = Image.fromarray(garment_rgb) # garment mask (✅ 동일하게 1024 맞춤) gm = np.array(parsing_img.convert("L"), dtype=np.uint8) gm = cv2.resize(gm, (W, H), interpolation=cv2.INTER_NEAREST) gm = cv2.cvtColor(gm, cv2.COLOR_GRAY2RGB) gm = _pad_or_crop_to_width_np(gm, 1024, pad_value=[0, 0, 0]) garment_mask_pil = Image.fromarray(gm) # ✅ 조건에 따른 scale 결정 style_present = _has_valid_file(paths.style_path) prompt_present = (prompt is not None) and (str(prompt).strip() != "") content_scale, style_scale = _resolve_content_style_scales(style_present, prompt_present) print( "[SIZE] person:", person_pil.size, "mask:", mask_pil.size, "depth:", depth_map.size, "garment:", garment_pil.size, "gmask:", garment_mask_pil.size, "ui_category:", category, "extractor_category:", extractor_category, "style_present:", style_present, "prompt_present:", prompt_present, "content_scale:", content_scale, "style_scale:", style_scale, flush=True ) ip_model = IPAdapterXL( pipe, image_encoder_dir, ip_ckpt, device, mask_pil, person_pil, content_scale=content_scale, # ✅ 변경 style_scale=style_scale, # ✅ 변경 garment_images=garment_pil, garment_mask=garment_mask_pil, ) if device == "cuda": pipe.to(dtype=torch.float32) try: for _, proc in pipe.unet.attn_processors.items(): proc.to(dtype=torch.float32) except Exception: pass # ✅ style image 없을 때도 generate 입력이 None이 되지 않게 대체 if style_present: style_img = Image.open(paths.style_path).convert("RGB") else: # scale이 0이므로 영향은 없고, 함수 시그니처만 만족시키기 위한 대체값 style_img = garment_pil # prompt 구성은 기존 유지 if prompt is not None and str(prompt).strip() != "": prompt = extractor_category + " with " + str(prompt).strip() else: prompt = extractor_category with torch.inference_mode(): images = ip_model.generate( pil_image=style_img, image=person_pil, control_image=depth_map, strength=1.0, num_samples=1, num_inference_steps=int(steps), shape_prompt="", prompt=prompt or "", num=0, scale=None, controlnet_conditioning_scale=0.7, guidance_scale=7.5, ) # save_cropped(images, paths.output_path) # return images, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil save_output_match_person(images, paths.output_path, paths.person_path) return images, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil def set_seed(seed: int): if seed is None or seed < 0: return np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed, category): print("[UI] infer_web called", flush=True) # ✅ person만 필수, style은 선택 if person_fp is None: raise gr.Error("person 이미지는 필수입니다. (style/sketch는 선택)") if category not in ("Upper-body", "Lower-body", "Dress"): raise gr.Error(f"Invalid category: {category}") set_seed(int(seed) if seed is not None else -1) tmp_dir = tempfile.mkdtemp(prefix="vista_demo_") out_path = os.path.join(tmp_dir, "result.png") paths = Paths( person_path=person_fp, depth_path=sketch_fp, style_path=style_fp, # ✅ None 가능 output_path=out_path, ) _, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil = run_one( paths, prompt=prompt, steps=int(steps), category=category ) out_img = Image.open(out_path).convert("RGB") return out_img, out_path, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo: # gr.Markdown("첫 실행은 모델 로딩 때문에 시간이 오래 걸릴 수 있습니다.") # gr.Markdown("category랑 try-on 하려는 옷 종류를 꼭 맞춰주세요.") gr.Markdown( "첫 실행은 모델 로딩 때문에 시간이 오래 걸릴 수 있습니다.
" "category랑 try-on 하려는 옷 종류를 꼭 맞춰주세요.", elem_classes="tight_md", ) category_toggle = gr.Radio( choices=["Dress", "Upper-body", "Lower-body"], value="Dress", label="Category", interactive=True, ) # ✅ 예시 리스트(분리) ex = build_ui_example_lists(ROOT) person_examples = [[p] for p in ex["persons"]] style_examples = [[p] for p in ex["styles"]] sketch_examples = [[p] for p in ex["sketches"]] # 한 행에 Person / Style / Output with gr.Row(): # -------- Person column -------- with gr.Column(scale=1): person_in = gr.Image(label="Person Image (required)", type="filepath") if person_examples: gr.Markdown("#### Examples") gr.Examples( examples=person_examples, inputs=[person_in], examples_per_page=8, ) # -------- Style column -------- with gr.Column(scale=1): style_in = gr.Image(label="Style Image (optional)", type="filepath") if style_examples: gr.Markdown("#### Examples") gr.Examples( examples=style_examples, inputs=[style_in], examples_per_page=8, ) # -------- Output column -------- with gr.Column(scale=1): out_img = gr.Image(label="Output", type="pil") with gr.Accordion("Sketch / Guide (optional)", open=False): sketch_in = gr.Image( label="Sketch / Guide (person과 같은 번호로 매칭하세요: person 1 ↔ sketch 1). 스케치는 person 인체와 정렬되어야 합니다.", type="filepath", ) if sketch_examples: gr.Markdown("#### Examples") gr.Examples( examples=sketch_examples, inputs=[sketch_in], examples_per_page=8, ) with gr.Row(): prompt_in = gr.Textbox( label="Prompt", value="", placeholder="ex) crystal, lace, button, …", lines=2, ) steps_in = gr.Slider(1, 80, value=DEFAULT_STEPS, step=1, label="Steps") seed_in = gr.Number(label="Seed (-1=random)", value=-1, precision=0) run_btn = gr.Button("Run") out_file = gr.File(label="Download result.png") # gr.Markdown("### Debug Visualizations (mask/depth/etc)") # with gr.Row(): # dbg_mask = gr.Image(label="mask_pil", type="pil") # dbg_depth = gr.Image(label="depth_map", type="pil") # with gr.Row(): # dbg_person = gr.Image(label="person_pil", type="pil") # dbg_garment = gr.Image(label="garment_pil", type="pil") # dbg_gmask = gr.Image(label="garment_mask_pil", type="pil") run_btn.click( fn=infer_web, inputs=[person_in, sketch_in, style_in, prompt_in, steps_in, seed_in, category_toggle], outputs=[out_img, out_file], ) demo.queue() if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)