import os import sys # --------------------------------------------------------- # 0) Make sure local packages (diffusers3, preprocess, etc.) are importable on HF Spaces # --------------------------------------------------------- ROOT = os.path.dirname(os.path.abspath(__file__)) if ROOT not in sys.path: sys.path.insert(0, ROOT) print("[BOOT] ROOT =", ROOT, flush=True) print("[BOOT] sys.path[:5] =", sys.path[:5], flush=True) import tempfile from dataclasses import dataclass from functools import lru_cache from typing import Optional, Tuple import gradio as gr import torch import numpy as np import cv2 import imageio from PIL import Image, ImageOps from transformers import pipeline from huggingface_hub import hf_hub_download # Show where diffusers3 is imported from (helps diagnose import collisions on Spaces) import diffusers3 print("[BOOT] diffusers3 loaded from:", getattr(diffusers3, "__file__", ""), flush=True) from diffusers import UniPCMultistepScheduler, AutoencoderKL, UNet2DConditionModel from diffusers3.models.controlnet import ControlNetModel from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import ( StableDiffusionXLControlNetImg2ImgPipeline, ) from ip_adapter import IPAdapterXL # extractor from preprocess.simple_extractor import run as run_simple_extractor # ========================= # HF Hub repo ids # ========================= BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" CONTROLNET_ID = "diffusers/controlnet-depth-sdxl-1.0" # assets dataset repo ASSETS_REPO = os.getenv("ASSETS_REPO", "soye/VISTA_assets") ASSETS_REPO_TYPE = "dataset" depth_estimator = pipeline("depth-estimation") def asset_path(relpath: str) -> str: return hf_hub_download( repo_id=ASSETS_REPO, repo_type=ASSETS_REPO_TYPE, filename=relpath, ) @lru_cache(maxsize=1) def get_assets(): print("[ASSETS] Downloading assets from:", ASSETS_REPO, flush=True) image_encoder_weight = asset_path("image_encoder/model.safetensors") _ = asset_path("image_encoder/config.json") image_encoder_dir = os.path.dirname(image_encoder_weight) ip_ckpt = asset_path("ip_adapter/ip-adapter_sdxl_vit-h.bin") schp_ckpt = asset_path("preprocess_ckpts/exp-schp-201908301523-atr.pth") print("[ASSETS] image_encoder_dir =", image_encoder_dir, flush=True) print("[ASSETS] ip_ckpt =", ip_ckpt, flush=True) print("[ASSETS] schp_ckpt =", schp_ckpt, flush=True) return image_encoder_dir, ip_ckpt, schp_ckpt DEFAULT_STEPS = 40 DEBUG_SAVE = False H: Optional[int] = None W: Optional[int] = None @dataclass class Paths: person_path: str depth_path: str style_path: str output_path: str def _imread_or_raise(path: str, flag=cv2.IMREAD_COLOR): img = cv2.imread(path, flag) if img is None: raise FileNotFoundError(f"cv2.imread failed: {path} (exists={os.path.exists(path)})") return img def apply_parsing_white_mask_to_person_cv2( person_pil: Image.Image, parsing_img: Image.Image ) -> np.ndarray: """ person_pil(RGB) 크기에 parsing_img(L) 마스크를 맞춰서 흰색(255) 영역만 person을 남기고 나머지는 흰색 배경으로 만드는 함수. - parsing_img는 person 크기에 반드시 맞춰야 함 (NEAREST) """ person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8) # parsing 마스크 (L) mask = np.array(parsing_img.convert("L"), dtype=np.uint8) # ✅ 핵심: 크기 불일치 해결 (H,W) 맞춤 if mask.shape[0] != person_rgb.shape[0] or mask.shape[1] != person_rgb.shape[1]: mask = cv2.resize(mask, (person_rgb.shape[1], person_rgb.shape[0]), interpolation=cv2.INTER_NEAREST) white_mask = (mask == 255) result_rgb = np.full_like(person_rgb, 255, dtype=np.uint8) result_rgb[white_mask] = person_rgb[white_mask] result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) return result_bgr def compute_hw_from_person(person_path: str): img = _imread_or_raise(person_path) orig_h, orig_w = img.shape[:2] scale = 1024.0 / float(orig_h) new_h = 1024 new_w = int(round(orig_w * scale)) if new_w > 1024: new_w = 1024 return new_h, new_w def invert_sketch_area(sketch_pil: Image.Image) -> Image.Image: return ImageOps.invert(sketch_pil.convert("L")).convert("RGB") def fill_sketch_from_image_path_to_pil(image_path: str) -> Image.Image: global H, W if H is None or W is None: raise RuntimeError("Global H/W not set.") img = _imread_or_raise(image_path, cv2.IMREAD_GRAYSCALE) img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST) _, binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV) contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) filled = np.zeros_like(binary) cv2.drawContours(filled, contours, -1, 255, thickness=cv2.FILLED) filled_rgb = cv2.cvtColor(filled, cv2.COLOR_GRAY2RGB) return Image.fromarray(filled_rgb) def merge_white_regions_or(img1: Image.Image, img2: Image.Image) -> Image.Image: a = np.array(img1.convert("RGB"), dtype=np.uint8) b = np.array(img2.convert("RGB"), dtype=np.uint8) white_a = np.all(a == 255, axis=-1) white_b = np.all(b == 255, axis=-1) out = a.copy() out[white_a | white_b] = 255 return Image.fromarray(out) def preprocess_mask(mask_img: Image.Image) -> Image.Image: global H, W m = np.array(mask_img.convert("L"), dtype=np.uint8) if (H is not None) and (W is not None): m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) _, m = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY) target_width = 1024 h, w = m.shape[:2] if w < target_width: total_padding = target_width - w left_padding = total_padding // 2 right_padding = total_padding - left_padding m = cv2.copyMakeBorder( m, top=0, bottom=0, left=left_padding, right=right_padding, borderType=cv2.BORDER_CONSTANT, value=0, ) elif w > target_width: left = (w - target_width) // 2 m = m[:, left:left + target_width] kernel = np.ones((17, 17), np.uint8) m = cv2.dilate(m, kernel, iterations=1) if DEBUG_SAVE: cv2.imwrite("mask_final_1024.png", m) return Image.fromarray(m, mode="L").convert("RGB") def make_depth(depth_path: str) -> Image.Image: global H, W if H is None or W is None: raise RuntimeError("Global H/W not set. Call run_one() first.") depth_img = _imread_or_raise(depth_path, 0) inverted_depth = cv2.bitwise_not(depth_img) contours, _ = cv2.findContours(inverted_depth, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) filled_depth = inverted_depth.copy() cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED) # ✅ resize는 전역 (W,H) filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA) height, width = filled_depth.shape total_padding = 1024 - width left_padding = total_padding // 2 right_padding = total_padding - left_padding padded_depth = cv2.copyMakeBorder( filled_depth, 0, 0, left_padding, right_padding, borderType=cv2.BORDER_CONSTANT, value=0, ) inverted_image = ImageOps.invert(Image.fromarray(padded_depth)) with torch.inference_mode(): image_depth = depth_estimator(inverted_image)["depth"] if DEBUG_SAVE: image_depth.save("depth.png") return image_depth def center_crop_lr_to_768x1024(arr: np.ndarray) -> np.ndarray: target_h, target_w = 1024, 768 h, w = arr.shape[:2] if h != target_h: arr = cv2.resize(arr, (w, target_h), interpolation=cv2.INTER_AREA) h, w = arr.shape[:2] if w < target_w: pad = (target_w - w) // 2 arr = cv2.copyMakeBorder(arr, 0, 0, pad, pad, cv2.BORDER_CONSTANT, value=[255, 255, 255]) w = arr.shape[1] left = (w - target_w) // 2 return arr[:, left:left + target_w] def save_cropped(imgs, out_path: str): np_imgs = [np.asarray(im) for im in imgs] cropped = [center_crop_lr_to_768x1024(x) for x in np_imgs] out = np.concatenate(cropped, axis=1) os.makedirs(os.path.dirname(out_path), exist_ok=True) imageio.imsave(out_path, out) @lru_cache(maxsize=1) def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]: device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float32 # 현재 너 설정 유지 print(f"[PIPE] device={device}, dtype={dtype}", flush=True) controlnet = ControlNetModel.from_pretrained( CONTROLNET_ID, torch_dtype=dtype, use_safetensors=True, ).to(device) vae = AutoencoderKL.from_pretrained( BASE_MODEL_ID, subfolder="vae", torch_dtype=dtype, use_safetensors=True, ).to(device) unet = UNet2DConditionModel.from_pretrained( BASE_MODEL_ID, subfolder="unet", torch_dtype=dtype, use_safetensors=True, ).to(device) pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( BASE_MODEL_ID, controlnet=controlnet, vae=vae, unet=unet, torch_dtype=dtype, use_safetensors=True, add_watermarker=False, ).to(device) if device == "cuda": try: pipe.vae.to(dtype=dtype) if hasattr(pipe.vae, "config") and hasattr(pipe.vae.config, "force_upcast"): pipe.vae.config.force_upcast = False except Exception as e: print("[PIPE] VAE dtype cast failed:", repr(e), flush=True) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_attention_slicing() try: pipe.enable_xformers_memory_efficient_attention() except Exception as e: print("[PIPE] xformers not enabled:", repr(e), flush=True) return pipe, device, dtype def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS): """ Returns: images(list[PIL]), mask_pil(PIL), depth_map(PIL), person_pil(PIL), garment_pil(PIL), garment_mask_pil(PIL) """ global H, W pipe, device, _dtype = get_pipe_and_device() image_encoder_dir, ip_ckpt, schp_ckpt = get_assets() H, W = compute_hw_from_person(paths.person_path) res = run_simple_extractor( category="Upper-clothes", input_path=os.path.abspath(paths.person_path), model_restore=schp_ckpt, ) parsing_img = res["images"][0] if res.get("images") else None if parsing_img is None: raise RuntimeError("run_simple_extractor returned no parsing images.") sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path) merged_img = merge_white_regions_or(parsing_img, sketch_area) mask_pil = preprocess_mask(merged_img) person_bgr = _imread_or_raise(paths.person_path) person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA) target_width = 1024 cur_w = person_bgr.shape[1] if cur_w < target_width: total = target_width - cur_w left = total // 2 right = total - left padded_person = cv2.copyMakeBorder( person_bgr, 0, 0, left, right, borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255] ) elif cur_w > target_width: left = (cur_w - target_width) // 2 padded_person = person_bgr[:, left:left + target_width] else: padded_person = person_bgr person_rgb = cv2.cvtColor(padded_person, cv2.COLOR_BGR2RGB) person_pil = Image.fromarray(person_rgb) depth_map = make_depth(paths.depth_path) personn = Image.open(paths.person_path) garment_ = apply_parsing_white_mask_to_person_cv2( personn, parsing_img ) garment_rgb = cv2.cvtColor(garment_, cv2.COLOR_BGR2RGB) # ✅ (중요) garment_는 원본 person 크기일 수 있으니 전역 (W,H)로 맞춘 뒤 padding garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA) target_width = 1024 # ✅ 고정 padding = (target_width - person_bgr.shape[1]) // 2 garment_rgb = cv2.copyMakeBorder( garment_rgb, top=0, bottom=0, left=padding, right=padding, borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255], ) garment_pil = Image.fromarray(garment_rgb) gm = np.array(parsing_img.convert("L"), dtype=np.uint8) gm = cv2.resize(gm, (W, H), interpolation=cv2.INTER_AREA) gm = cv2.cvtColor(gm, cv2.COLOR_GRAY2RGB) cur_w2 = gm.shape[1] if cur_w2 < target_width: total = target_width - cur_w2 left = total // 2 right = total - left gm = cv2.copyMakeBorder(gm, 0, 0, left, right, cv2.BORDER_CONSTANT, value=[0, 0, 0]) elif cur_w2 > target_width: left2 = (cur_w2 - target_width) // 2 gm = gm[:, left2:left2 + target_width] garment_mask_pil = Image.fromarray(gm) # --- sanity sizes (optional) print( "[SIZE] person:", person_pil.size, "mask:", mask_pil.size, "depth:", depth_map.size, "garment:", garment_pil.size, "gmask:", garment_mask_pil.size, flush=True ) ip_model = IPAdapterXL( pipe, image_encoder_dir, ip_ckpt, device, mask_pil, person_pil, content_scale=0.3, style_scale=0.5, garment_images=garment_pil, garment_mask=garment_mask_pil, ) if device == "cuda": pipe.to(dtype=torch.float32) try: for _, proc in pipe.unet.attn_processors.items(): proc.to(dtype=torch.float32) except Exception: pass style_img = Image.open(paths.style_path).convert("RGB") with torch.inference_mode(): images = ip_model.generate( pil_image=style_img, image=person_pil, control_image=depth_map, strength=1.0, num_samples=1, num_inference_steps=int(steps), shape_prompt="", prompt=prompt or "", num=0, scale=None, controlnet_conditioning_scale=0.7, guidance_scale=7.5, ) save_cropped(images, paths.output_path) return images, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil def set_seed(seed: int): if seed is None or seed < 0: return np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed): print("[UI] infer_web called", flush=True) if person_fp is None or sketch_fp is None or style_fp is None: raise gr.Error("person / sketch / style 이미지를 모두 업로드해야 합니다.") set_seed(int(seed) if seed is not None else -1) tmp_dir = tempfile.mkdtemp(prefix="vista_demo_") out_path = os.path.join(tmp_dir, "result.png") paths = Paths( person_path=person_fp, depth_path=sketch_fp, style_path=style_fp, output_path=out_path, ) _, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil = run_one( paths, prompt=prompt, steps=int(steps) ) out_img = Image.open(out_path).convert("RGB") return out_img, out_path, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo: gr.Markdown("## VISTA Demo\nperson / sketch(guide) / style 입력으로 결과를 생성합니다.") with gr.Row(): person_in = gr.Image(label="Person Image", type="filepath") sketch_in = gr.Image(label="Sketch / Guide (depth_path)", type="filepath") style_in = gr.Image(label="Style Image", type="filepath") with gr.Row(): prompt_in = gr.Textbox(label="Prompt", value="upper garment", lines=2) steps_in = gr.Slider(1, 80, value=DEFAULT_STEPS, step=1, label="Steps") seed_in = gr.Number(label="Seed (-1=random)", value=-1, precision=0) run_btn = gr.Button("Run") out_img = gr.Image(label="Output", type="pil") out_file = gr.File(label="Download result.png") gr.Markdown("### Debug Visualizations (mask/depth/etc)") with gr.Row(): dbg_mask = gr.Image(label="mask_pil", type="pil") dbg_depth = gr.Image(label="depth_map", type="pil") with gr.Row(): dbg_person = gr.Image(label="person_pil", type="pil") dbg_garment = gr.Image(label="garment_pil", type="pil") dbg_gmask = gr.Image(label="garment_mask_pil", type="pil") run_btn.click( fn=infer_web, inputs=[person_in, sketch_in, style_in, prompt_in, steps_in, seed_in], outputs=[out_img, out_file, dbg_mask, dbg_depth, dbg_person, dbg_garment, dbg_gmask], ) demo.queue() if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)