| import os |
| import sys |
| import glob |
|
|
| |
| |
| |
| ROOT = os.path.dirname(os.path.abspath(__file__)) |
| if ROOT not in sys.path: |
| sys.path.insert(0, ROOT) |
|
|
| print("[BOOT] ROOT =", ROOT, flush=True) |
| print("[BOOT] sys.path[:5] =", sys.path[:5], flush=True) |
|
|
| import tempfile |
| from dataclasses import dataclass |
| from functools import lru_cache |
| from typing import Optional, Tuple, List, Dict |
|
|
| import gradio as gr |
| import torch |
| import numpy as np |
| import cv2 |
| import imageio |
| from PIL import Image, ImageOps |
| from transformers import pipeline |
| from huggingface_hub import hf_hub_download |
|
|
| import diffusers3 |
| print("[BOOT] diffusers3 loaded from:", getattr(diffusers3, "__file__", "<?>"), flush=True) |
|
|
| from diffusers import UniPCMultistepScheduler, AutoencoderKL, UNet2DConditionModel |
| from diffusers3.models.controlnet import ControlNetModel |
| from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import ( |
| StableDiffusionXLControlNetImg2ImgPipeline, |
| ) |
| from ip_adapter import IPAdapterXL |
|
|
| |
| from preprocess.simple_extractor import run as run_simple_extractor |
|
|
|
|
| |
| |
| |
| BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" |
| CONTROLNET_ID = "diffusers/controlnet-depth-sdxl-1.0" |
|
|
| |
| ASSETS_REPO = os.getenv("ASSETS_REPO", "soye/VISTA_assets") |
| ASSETS_REPO_TYPE = "dataset" |
|
|
| depth_estimator = pipeline("depth-estimation") |
|
|
|
|
| def asset_path(relpath: str) -> str: |
| return hf_hub_download( |
| repo_id=ASSETS_REPO, |
| repo_type=ASSETS_REPO_TYPE, |
| filename=relpath, |
| ) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_assets(): |
| print("[ASSETS] Downloading assets from:", ASSETS_REPO, flush=True) |
|
|
| image_encoder_weight = asset_path("image_encoder/model.safetensors") |
| _ = asset_path("image_encoder/config.json") |
| image_encoder_dir = os.path.dirname(image_encoder_weight) |
|
|
| ip_ckpt = asset_path("ip_adapter/ip-adapter_sdxl_vit-h.bin") |
| schp_ckpt = asset_path("preprocess_ckpts/exp-schp-201908301523-atr.pth") |
|
|
| print("[ASSETS] image_encoder_dir =", image_encoder_dir, flush=True) |
| print("[ASSETS] ip_ckpt =", ip_ckpt, flush=True) |
| print("[ASSETS] schp_ckpt =", schp_ckpt, flush=True) |
| return image_encoder_dir, ip_ckpt, schp_ckpt |
|
|
|
|
| |
| |
| |
| def _is_image_file(p: str) -> bool: |
| ext = os.path.splitext(p.lower())[1] |
| return ext in (".png", ".jpg", ".jpeg", ".webp") |
|
|
|
|
| def build_ui_example_lists(root_dir: str = ROOT) -> Dict[str, List[str]]: |
| """ |
| Returns dict of example filepaths: |
| - persons: [{root}/examples/person/*] |
| - styles : [{root}/examples/style/*] |
| - sketches: [{root}/examples/sketch/*] (optional) |
| """ |
| person_dir = os.path.join(root_dir, "examples", "person") |
| style_dir = os.path.join(root_dir, "examples", "style") |
| sketch_dir = os.path.join(root_dir, "examples", "sketch") |
|
|
| persons = [p for p in sorted(glob.glob(os.path.join(person_dir, "*"))) if _is_image_file(p)] |
| styles = [p for p in sorted(glob.glob(os.path.join(style_dir, "*"))) if _is_image_file(p)] |
| sketches = [p for p in sorted(glob.glob(os.path.join(sketch_dir, "*"))) if _is_image_file(p)] |
|
|
| return {"persons": persons, "styles": styles, "sketches": sketches} |
|
|
|
|
| DEFAULT_STEPS = 40 |
| DEBUG_SAVE = False |
|
|
| H: Optional[int] = None |
| W: Optional[int] = None |
|
|
|
|
| @dataclass |
| class Paths: |
| person_path: str |
| depth_path: Optional[str] |
| style_path: Optional[str] |
| output_path: str |
|
|
|
|
| def _imread_or_raise(path: str, flag=cv2.IMREAD_COLOR): |
| img = cv2.imread(path, flag) |
| if img is None: |
| raise FileNotFoundError(f"cv2.imread failed: {path} (exists={os.path.exists(path)})") |
| return img |
|
|
|
|
| def _pad_or_crop_to_width_np(arr: np.ndarray, target_width: int, pad_value): |
| """ |
| arr: HxWxC or HxW |
| target_width로 center crop 또는 좌/우 padding(비대칭 포함)해서 정확히 맞춤. |
| """ |
| if arr.ndim not in (2, 3): |
| raise ValueError(f"arr must be 2D or 3D, got shape={arr.shape}") |
|
|
| h = arr.shape[0] |
| w = arr.shape[1] |
|
|
| if w == target_width: |
| return arr |
|
|
| if w > target_width: |
| left = (w - target_width) // 2 |
| return arr[:, left:left + target_width] if arr.ndim == 2 else arr[:, left:left + target_width, :] |
|
|
| |
| total = target_width - w |
| left = total // 2 |
| right = total - left |
|
|
| if arr.ndim == 2: |
| return cv2.copyMakeBorder( |
| arr, 0, 0, left, right, |
| borderType=cv2.BORDER_CONSTANT, |
| value=pad_value, |
| ) |
| else: |
| return cv2.copyMakeBorder( |
| arr, 0, 0, left, right, |
| borderType=cv2.BORDER_CONSTANT, |
| value=pad_value, |
| ) |
|
|
|
|
| def apply_parsing_white_mask_to_person_cv2( |
| person_pil: Image.Image, |
| parsing_img: Image.Image |
| ) -> np.ndarray: |
| person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8) |
| mask = np.array(parsing_img.convert("L"), dtype=np.uint8) |
|
|
| if mask.shape[:2] != person_rgb.shape[:2]: |
| mask = cv2.resize(mask, (person_rgb.shape[1], person_rgb.shape[0]), interpolation=cv2.INTER_NEAREST) |
|
|
| white_mask = (mask == 255) |
| result_rgb = np.full_like(person_rgb, 255, dtype=np.uint8) |
| result_rgb[white_mask] = person_rgb[white_mask] |
| result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) |
| return result_bgr |
|
|
|
|
| def remove_small_white_components( |
| parsing_img: Image.Image, |
| *, |
| white_threshold: int = 128, |
| min_white_area: int = 150, |
| use_open: bool = False, |
| open_ksize: int = 3, |
| morph_iters: int = 1, |
| ) -> Image.Image: |
| """ |
| - 흰색(=foreground)으로 이진화 |
| - connected components로 '작은 흰색 덩어리'만 제거 |
| - (옵션) OPEN을 아주 약하게 적용해 작은 점/가시 제거 (흰색이 늘어나는 CLOSE는 사용 X) |
| """ |
| if not isinstance(parsing_img, Image.Image): |
| raise TypeError("parsing_img must be a PIL.Image.Image") |
|
|
| arr = np.array(parsing_img.convert("L"), dtype=np.uint8) |
| mask = np.where(arr >= int(white_threshold), 255, 0).astype(np.uint8) |
|
|
| |
| num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) |
| keep = np.zeros_like(mask) |
| for lab in range(1, num_labels): |
| area = int(stats[lab, cv2.CC_STAT_AREA]) |
| if area >= int(min_white_area): |
| keep[labels == lab] = 255 |
| mask = keep |
|
|
| |
| if use_open and int(open_ksize) > 1: |
| k = int(open_ksize) |
| if k % 2 == 0: |
| k += 1 |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=int(morph_iters)) |
|
|
| return Image.fromarray(mask, mode="L") |
|
|
|
|
| def compute_hw_from_person(person_path: str): |
| img = _imread_or_raise(person_path) |
| orig_h, orig_w = img.shape[:2] |
| scale = 1024.0 / float(orig_h) |
| new_h = 1024 |
| new_w = int(round(orig_w * scale)) |
| if new_w > 1024: |
| new_w = 1024 |
| return new_h, new_w |
|
|
|
|
| def fill_sketch_from_image_path_to_pil(image_path: str) -> Image.Image: |
| global H, W |
| if H is None or W is None: |
| raise RuntimeError("Global H/W not set.") |
| img = _imread_or_raise(image_path, cv2.IMREAD_GRAYSCALE) |
| img = cv2.bitwise_not(img) |
| img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST) |
| _, binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV) |
| contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| filled = np.zeros_like(binary) |
| cv2.drawContours(filled, contours, -1, 255, thickness=cv2.FILLED) |
| filled_rgb = cv2.cvtColor(filled, cv2.COLOR_GRAY2RGB) |
| return Image.fromarray(filled_rgb) |
|
|
|
|
| def _resize_pil_nearest(img: Image.Image, size_wh: Tuple[int, int], *, force_mode: Optional[str] = None) -> Image.Image: |
| """ |
| Resize PIL image to (W,H) using INTER_NEAREST (safe for masks). |
| size_wh: (width, height) |
| """ |
| w, h = int(size_wh[0]), int(size_wh[1]) |
| if force_mode is not None: |
| img = img.convert(force_mode) |
| arr = np.array(img, dtype=np.uint8) |
|
|
| if arr.ndim == 2: |
| resized = cv2.resize(arr, (w, h), interpolation=cv2.INTER_NEAREST) |
| return Image.fromarray(resized, mode="L") |
| elif arr.ndim == 3 and arr.shape[2] == 3: |
| resized = cv2.resize(arr, (w, h), interpolation=cv2.INTER_NEAREST) |
| return Image.fromarray(resized, mode="RGB") |
| else: |
| raise ValueError(f"Unsupported image array shape: {arr.shape}") |
|
|
|
|
| def merge_white_regions_or(img1: Image.Image, img2: Image.Image) -> Image.Image: |
| a = np.array(img1.convert("RGB"), dtype=np.uint8) |
| b = np.array(img2.convert("RGB"), dtype=np.uint8) |
|
|
| |
| if a.shape[:2] != b.shape[:2]: |
| b = cv2.resize(b, (a.shape[1], a.shape[0]), interpolation=cv2.INTER_NEAREST) |
|
|
| white_a = np.all(a == 255, axis=-1) |
| white_b = np.all(b == 255, axis=-1) |
| out = a.copy() |
| out[white_a | white_b] = 255 |
| return Image.fromarray(out, mode="RGB") |
|
|
|
|
| def preprocess_mask(mask_img: Image.Image) -> Image.Image: |
| global H, W |
| m = np.array(mask_img.convert("L"), dtype=np.uint8) |
|
|
| if (H is not None) and (W is not None): |
| m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) |
|
|
| _, m = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY) |
|
|
| target_width = 1024 |
| m = _pad_or_crop_to_width_np(m, target_width, pad_value=0) |
|
|
| kernel = np.ones((12, 12), np.uint8) |
| m = cv2.dilate(m, kernel, iterations=1) |
|
|
| if DEBUG_SAVE: |
| cv2.imwrite("mask_final_1024.png", m) |
|
|
| return Image.fromarray(m, mode="L").convert("RGB") |
|
|
|
|
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
|
|
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
|
|
|
|
| def make_depth(depth_path: str) -> Image.Image: |
| global H, W |
| if H is None or W is None: |
| raise RuntimeError("Global H/W not set. Call run_one() first.") |
|
|
| depth_img = _imread_or_raise(depth_path, 0) |
|
|
| |
| _, depth_bin = cv2.threshold(depth_img, 127, 255, cv2.THRESH_BINARY) |
|
|
| |
| contours, _ = cv2.findContours(depth_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| filled_depth = np.zeros_like(depth_bin) |
| cv2.drawContours(filled_depth, contours, -1, 255, thickness=cv2.FILLED) |
|
|
| |
| filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_NEAREST) |
|
|
| |
| _, filled_depth = cv2.threshold(filled_depth, 127, 255, cv2.THRESH_BINARY) |
|
|
| filled_depth = _pad_or_crop_to_width_np(filled_depth, 1024, pad_value=0) |
| |
| |
|
|
| |
| erode_ksize = 5 |
| erode_iters = 1 |
| if erode_ksize is not None and erode_ksize > 1 and erode_iters > 0: |
| if erode_ksize % 2 == 0: |
| erode_ksize += 1 |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_ksize, erode_ksize)) |
| filled_depth = cv2.erode(filled_depth, kernel, iterations=erode_iters) |
| |
| _, filled_depth = cv2.threshold(filled_depth, 127, 255, cv2.THRESH_BINARY) |
| |
|
|
| inverted_image = ImageOps.invert(Image.fromarray(filled_depth)) |
| |
|
|
| with torch.inference_mode(): |
| image_depth = depth_estimator(inverted_image)["depth"] |
| |
|
|
| return image_depth |
|
|
|
|
|
|
| def _edges_from_parsing(parsing_img: Image.Image) -> np.ndarray: |
| m = np.array(parsing_img.convert("L"), dtype=np.uint8) |
| _, m_bin = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY) |
| edges = cv2.Canny(m_bin, 50, 150) |
| edges = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=1) |
| return edges.astype(np.uint8) |
|
|
|
|
| def make_depth_from_parsing_edges(parsing_img: Image.Image) -> Image.Image: |
| global H, W |
| if H is None or W is None: |
| raise RuntimeError("Global H/W not set. Call run_one() first.") |
|
|
| depth_img = _edges_from_parsing(parsing_img) |
| contours, _ = cv2.findContours(depth_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
| filled_depth = depth_img.copy() |
| cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED) |
|
|
| filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA) |
| filled_depth = _pad_or_crop_to_width_np(filled_depth, 1024, pad_value=0) |
|
|
| inverted_image = ImageOps.invert(Image.fromarray(filled_depth)) |
|
|
| with torch.inference_mode(): |
| image_depth = depth_estimator(inverted_image)["depth"] |
|
|
| if DEBUG_SAVE: |
| image_depth.save("depth.png") |
|
|
| return image_depth |
|
|
|
|
| def center_crop_lr_to_768x1024(arr: np.ndarray) -> np.ndarray: |
| target_h, target_w = 1024, 768 |
| h, w = arr.shape[:2] |
| if h != target_h: |
| arr = cv2.resize(arr, (w, target_h), interpolation=cv2.INTER_AREA) |
| h, w = arr.shape[:2] |
| if w < target_w: |
| pad = (target_w - w) // 2 |
| arr = cv2.copyMakeBorder(arr, 0, 0, pad, pad, cv2.BORDER_CONSTANT, value=[255, 255, 255]) |
| w = arr.shape[1] |
| left = (w - target_w) // 2 |
| return arr[:, left:left + target_w] |
|
|
|
|
| def save_cropped(imgs, out_path: str): |
| np_imgs = [np.asarray(im) for im in imgs] |
| cropped = [center_crop_lr_to_768x1024(x) for x in np_imgs] |
| out = np.concatenate(cropped, axis=1) |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) |
| imageio.imsave(out_path, out) |
| |
| def _read_hw(path: str) -> Tuple[int, int]: |
| img = _imread_or_raise(path) |
| h, w = img.shape[:2] |
| return h, w |
|
|
|
|
| def _center_crop_lr_to_aspect(arr: np.ndarray, target_aspect: float, *, pad_value=255) -> np.ndarray: |
| """ |
| arr: HxWxC (RGB) or HxW |
| target_aspect = target_w / target_h |
| - 높이(H)는 유지 |
| - 좌/우를 동일 비율로 crop해서 target_aspect에 맞춤 |
| - 만약 현재 폭이 부족하면 좌/우 padding으로 맞춤 |
| """ |
| if arr.ndim == 2: |
| arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB) |
|
|
| h, w = arr.shape[:2] |
| if h <= 0 or w <= 0: |
| raise ValueError(f"Invalid image shape: {arr.shape}") |
|
|
| desired_w = int(round(h * float(target_aspect))) |
| if desired_w <= 0: |
| desired_w = 1 |
|
|
| |
| if w >= desired_w: |
| left = (w - desired_w) // 2 |
| right = left + desired_w |
| return arr[:, left:right] |
|
|
| |
| total = desired_w - w |
| left_pad = total // 2 |
| right_pad = total - left_pad |
| return cv2.copyMakeBorder( |
| arr, |
| 0, 0, |
| left_pad, right_pad, |
| borderType=cv2.BORDER_CONSTANT, |
| value=[pad_value, pad_value, pad_value], |
| ) |
|
|
|
|
| def save_output_match_person(imgs, out_path: str, person_path: str): |
| """ |
| - 출력 imgs(보통 길이 1)를 person 원본 비율에 맞게 좌/우 center-crop |
| - person 원본 (W,H)로 resize |
| - (imgs가 여러 장이면) 처리 후 가로로 concat해서 저장 |
| """ |
| person_h, person_w = _read_hw(person_path) |
| target_aspect = float(person_w) / float(person_h) |
|
|
| np_imgs = [] |
| for im in imgs: |
| if isinstance(im, Image.Image): |
| arr = np.asarray(im.convert("RGB"), dtype=np.uint8) |
| else: |
| |
| arr = np.asarray(im, dtype=np.uint8) |
| if arr.ndim == 2: |
| arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB) |
|
|
| cropped = _center_crop_lr_to_aspect(arr, target_aspect, pad_value=255) |
| resized = cv2.resize(cropped, (person_w, person_h), interpolation=cv2.INTER_AREA) |
| np_imgs.append(resized) |
|
|
| out = np.concatenate(np_imgs, axis=1) |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) |
| imageio.imsave(out_path, out) |
|
|
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| dtype = torch.float32 |
|
|
| print(f"[PIPE] device={device}, dtype={dtype}", flush=True) |
|
|
| controlnet = ControlNetModel.from_pretrained( |
| CONTROLNET_ID, |
| torch_dtype=dtype, |
| use_safetensors=True, |
| ).to(device) |
|
|
| vae = AutoencoderKL.from_pretrained( |
| BASE_MODEL_ID, |
| subfolder="vae", |
| torch_dtype=dtype, |
| use_safetensors=True, |
| ).to(device) |
|
|
| unet = UNet2DConditionModel.from_pretrained( |
| BASE_MODEL_ID, |
| subfolder="unet", |
| torch_dtype=dtype, |
| use_safetensors=True, |
| ).to(device) |
|
|
| pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( |
| BASE_MODEL_ID, |
| controlnet=controlnet, |
| vae=vae, |
| unet=unet, |
| torch_dtype=dtype, |
| use_safetensors=True, |
| add_watermarker=False, |
| ).to(device) |
|
|
| if device == "cuda": |
| try: |
| pipe.vae.to(dtype=dtype) |
| if hasattr(pipe.vae, "config") and hasattr(pipe.vae.config, "force_upcast"): |
| pipe.vae.config.force_upcast = False |
| except Exception as e: |
| print("[PIPE] VAE dtype cast failed:", repr(e), flush=True) |
|
|
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
| pipe.enable_attention_slicing() |
| try: |
| pipe.enable_xformers_memory_efficient_attention() |
| except Exception as e: |
| print("[PIPE] xformers not enabled:", repr(e), flush=True) |
|
|
| return pipe, device, dtype |
|
|
|
|
| |
| _UI_TO_EXTRACTOR_CATEGORY = { |
| "Upper-body": "Upper-cloth", |
| "Lower-body": "Bottom", |
| "Dress": "Dress", |
| } |
|
|
|
|
| def _has_valid_file(path: Optional[str]) -> bool: |
| return ( |
| path is not None |
| and isinstance(path, str) |
| and len(path) > 0 |
| and os.path.exists(path) |
| ) |
|
|
|
|
| def _resolve_content_style_scales(style_present: bool, prompt_present: bool) -> Tuple[float, float]: |
| """ |
| 요구사항: |
| - style image 없으면: (0.0, 0.0) |
| - prompt 없으면: (0.4, 0.6) |
| - 둘 다 있으면: 기존 유지 (0.3, 0.5) |
| """ |
| if not style_present: |
| return 0.0, 0.0 |
| if not prompt_present: |
| return 0.35, 0.65 |
| return 0.25, 0.5 |
|
|
|
|
| def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str = "Dress"): |
| global H, W |
| pipe, device, _dtype = get_pipe_and_device() |
| image_encoder_dir, ip_ckpt, schp_ckpt = get_assets() |
|
|
| H, W = compute_hw_from_person(paths.person_path) |
|
|
| extractor_category = _UI_TO_EXTRACTOR_CATEGORY.get(category, "Dress") |
|
|
| res = run_simple_extractor( |
| category=extractor_category, |
| input_path=os.path.abspath(paths.person_path), |
| model_restore=schp_ckpt, |
| ) |
| parsing_img = res["images"][0] if res.get("images") else None |
| if parsing_img is None: |
| raise RuntimeError("run_simple_extractor returned no parsing images.") |
|
|
| parsing_img = remove_small_white_components( |
| parsing_img, |
| white_threshold=128, |
| min_white_area=150, |
| use_open=False, |
| ) |
|
|
| |
| if parsing_img.size != (W, H): |
| parsing_img = _resize_pil_nearest(parsing_img, (W, H), force_mode="L") |
|
|
| use_depth_path = _has_valid_file(paths.depth_path) |
|
|
| if use_depth_path: |
| sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path) |
| else: |
| sketch_area = parsing_img.convert("RGB") |
|
|
| merged_img = merge_white_regions_or(parsing_img, sketch_area) |
| mask_pil = preprocess_mask(merged_img) |
|
|
| |
| person_bgr = _imread_or_raise(paths.person_path) |
| person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA) |
| person_bgr = _pad_or_crop_to_width_np(person_bgr, 1024, pad_value=[255, 255, 255]) |
| person_rgb = cv2.cvtColor(person_bgr, cv2.COLOR_BGR2RGB) |
| person_pil = Image.fromarray(person_rgb) |
|
|
| |
| if use_depth_path: |
| depth_map = make_depth(paths.depth_path) |
| else: |
| depth_map = make_depth_from_parsing_edges(parsing_img) |
|
|
| |
| personn = Image.open(paths.person_path).convert("RGB") |
| garment_bgr = apply_parsing_white_mask_to_person_cv2(personn, parsing_img) |
| garment_rgb = cv2.cvtColor(garment_bgr, cv2.COLOR_BGR2RGB) |
| garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA) |
| garment_rgb = _pad_or_crop_to_width_np(garment_rgb, 1024, pad_value=[255, 255, 255]) |
| garment_pil = Image.fromarray(garment_rgb) |
|
|
| |
| gm = np.array(parsing_img.convert("L"), dtype=np.uint8) |
| gm = cv2.resize(gm, (W, H), interpolation=cv2.INTER_NEAREST) |
| gm = cv2.cvtColor(gm, cv2.COLOR_GRAY2RGB) |
| gm = _pad_or_crop_to_width_np(gm, 1024, pad_value=[0, 0, 0]) |
| garment_mask_pil = Image.fromarray(gm) |
|
|
| |
| style_present = _has_valid_file(paths.style_path) |
| prompt_present = (prompt is not None) and (str(prompt).strip() != "") |
| content_scale, style_scale = _resolve_content_style_scales(style_present, prompt_present) |
|
|
| print( |
| "[SIZE] person:", person_pil.size, |
| "mask:", mask_pil.size, |
| "depth:", depth_map.size, |
| "garment:", garment_pil.size, |
| "gmask:", garment_mask_pil.size, |
| "ui_category:", category, |
| "extractor_category:", extractor_category, |
| "style_present:", style_present, |
| "prompt_present:", prompt_present, |
| "content_scale:", content_scale, |
| "style_scale:", style_scale, |
| flush=True |
| ) |
|
|
| ip_model = IPAdapterXL( |
| pipe, |
| image_encoder_dir, |
| ip_ckpt, |
| device, |
| mask_pil, |
| person_pil, |
| content_scale=content_scale, |
| style_scale=style_scale, |
| garment_images=garment_pil, |
| garment_mask=garment_mask_pil, |
| ) |
|
|
| if device == "cuda": |
| pipe.to(dtype=torch.float32) |
| try: |
| for _, proc in pipe.unet.attn_processors.items(): |
| proc.to(dtype=torch.float32) |
| except Exception: |
| pass |
|
|
| |
| if style_present: |
| style_img = Image.open(paths.style_path).convert("RGB") |
| else: |
| |
| style_img = garment_pil |
|
|
| |
| if prompt is not None and str(prompt).strip() != "": |
| prompt = extractor_category + " with " + str(prompt).strip() |
| else: |
| prompt = extractor_category |
|
|
| with torch.inference_mode(): |
| images = ip_model.generate( |
| pil_image=style_img, |
| image=person_pil, |
| control_image=depth_map, |
| strength=1.0, |
| num_samples=1, |
| num_inference_steps=int(steps), |
| shape_prompt="", |
| prompt=prompt or "", |
| num=0, |
| scale=None, |
| controlnet_conditioning_scale=0.7, |
| guidance_scale=7.5, |
| ) |
|
|
| |
| |
| save_output_match_person(images, paths.output_path, paths.person_path) |
| return images, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil |
|
|
|
|
|
|
| def set_seed(seed: int): |
| if seed is None or seed < 0: |
| return |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed, category): |
| print("[UI] infer_web called", flush=True) |
|
|
| |
| if person_fp is None: |
| raise gr.Error("person 이미지는 필수입니다. (style/sketch는 선택)") |
|
|
| if category not in ("Upper-body", "Lower-body", "Dress"): |
| raise gr.Error(f"Invalid category: {category}") |
|
|
| set_seed(int(seed) if seed is not None else -1) |
|
|
| tmp_dir = tempfile.mkdtemp(prefix="vista_demo_") |
| out_path = os.path.join(tmp_dir, "result.png") |
|
|
| paths = Paths( |
| person_path=person_fp, |
| depth_path=sketch_fp, |
| style_path=style_fp, |
| output_path=out_path, |
| ) |
|
|
| _, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil = run_one( |
| paths, prompt=prompt, steps=int(steps), category=category |
| ) |
|
|
| out_img = Image.open(out_path).convert("RGB") |
| return out_img, out_path, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil |
|
|
|
|
| with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo: |
| |
| |
|
|
| gr.Markdown( |
| "첫 실행은 모델 로딩 때문에 시간이 오래 걸릴 수 있습니다.<br>" |
| "category랑 try-on 하려는 옷 종류를 꼭 맞춰주세요.", |
| elem_classes="tight_md", |
| ) |
|
|
| category_toggle = gr.Radio( |
| choices=["Dress", "Upper-body", "Lower-body"], |
| value="Dress", |
| label="Category", |
| interactive=True, |
| ) |
|
|
| |
| ex = build_ui_example_lists(ROOT) |
| person_examples = [[p] for p in ex["persons"]] |
| style_examples = [[p] for p in ex["styles"]] |
| sketch_examples = [[p] for p in ex["sketches"]] |
|
|
| |
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| person_in = gr.Image(label="Person Image (required)", type="filepath") |
| if person_examples: |
| gr.Markdown("#### Examples") |
| gr.Examples( |
| examples=person_examples, |
| inputs=[person_in], |
| examples_per_page=8, |
| ) |
|
|
| |
| with gr.Column(scale=1): |
| style_in = gr.Image(label="Style Image (optional)", type="filepath") |
| if style_examples: |
| gr.Markdown("#### Examples") |
| gr.Examples( |
| examples=style_examples, |
| inputs=[style_in], |
| examples_per_page=8, |
| ) |
|
|
| |
| with gr.Column(scale=1): |
| out_img = gr.Image(label="Output", type="pil") |
|
|
| with gr.Accordion("Sketch / Guide (optional)", open=False): |
| sketch_in = gr.Image( |
| label="Sketch / Guide (person과 같은 번호로 매칭하세요: person 1 ↔ sketch 1). 스케치는 person 인체와 정렬되어야 합니다.", |
| type="filepath", |
| ) |
| if sketch_examples: |
| gr.Markdown("#### Examples") |
| gr.Examples( |
| examples=sketch_examples, |
| inputs=[sketch_in], |
| examples_per_page=8, |
| ) |
|
|
| with gr.Row(): |
| prompt_in = gr.Textbox( |
| label="Prompt", |
| value="", |
| placeholder="ex) crystal, lace, button, …", |
| lines=2, |
| ) |
| steps_in = gr.Slider(1, 80, value=DEFAULT_STEPS, step=1, label="Steps") |
| seed_in = gr.Number(label="Seed (-1=random)", value=-1, precision=0) |
|
|
| run_btn = gr.Button("Run") |
| out_file = gr.File(label="Download result.png") |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| run_btn.click( |
| fn=infer_web, |
| inputs=[person_in, sketch_in, style_in, prompt_in, steps_in, seed_in, category_toggle], |
| outputs=[out_img, out_file], |
| ) |
|
|
| demo.queue() |
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|
|