FEAT / app.py
ssoxye's picture
update UI
800c0e5
import os
import sys
import glob
# ---------------------------------------------------------
# 0) Make sure local packages (diffusers3, preprocess, etc.) are importable on HF Spaces
# ---------------------------------------------------------
ROOT = os.path.dirname(os.path.abspath(__file__))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
print("[BOOT] ROOT =", ROOT, flush=True)
print("[BOOT] sys.path[:5] =", sys.path[:5], flush=True)
import tempfile
from dataclasses import dataclass
from functools import lru_cache
from typing import Optional, Tuple, List, Dict
import gradio as gr
import torch
import numpy as np
import cv2
import imageio
from PIL import Image, ImageOps
from transformers import pipeline
from huggingface_hub import hf_hub_download
import diffusers3
print("[BOOT] diffusers3 loaded from:", getattr(diffusers3, "__file__", "<?>"), flush=True)
from diffusers import UniPCMultistepScheduler, AutoencoderKL, UNet2DConditionModel
from diffusers3.models.controlnet import ControlNetModel
from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import (
StableDiffusionXLControlNetImg2ImgPipeline,
)
from ip_adapter import IPAdapterXL
# extractor
from preprocess.simple_extractor import run as run_simple_extractor
# =========================
# HF Hub repo ids
# =========================
BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
CONTROLNET_ID = "diffusers/controlnet-depth-sdxl-1.0"
# assets dataset repo
ASSETS_REPO = os.getenv("ASSETS_REPO", "soye/VISTA_assets")
ASSETS_REPO_TYPE = "dataset"
depth_estimator = pipeline("depth-estimation")
def asset_path(relpath: str) -> str:
return hf_hub_download(
repo_id=ASSETS_REPO,
repo_type=ASSETS_REPO_TYPE,
filename=relpath,
)
@lru_cache(maxsize=1)
def get_assets():
print("[ASSETS] Downloading assets from:", ASSETS_REPO, flush=True)
image_encoder_weight = asset_path("image_encoder/model.safetensors")
_ = asset_path("image_encoder/config.json")
image_encoder_dir = os.path.dirname(image_encoder_weight)
ip_ckpt = asset_path("ip_adapter/ip-adapter_sdxl_vit-h.bin")
schp_ckpt = asset_path("preprocess_ckpts/exp-schp-201908301523-atr.pth")
print("[ASSETS] image_encoder_dir =", image_encoder_dir, flush=True)
print("[ASSETS] ip_ckpt =", ip_ckpt, flush=True)
print("[ASSETS] schp_ckpt =", schp_ckpt, flush=True)
return image_encoder_dir, ip_ckpt, schp_ckpt
# =========================
# Example assets for Gradio UI (✅ 분리형)
# =========================
def _is_image_file(p: str) -> bool:
ext = os.path.splitext(p.lower())[1]
return ext in (".png", ".jpg", ".jpeg", ".webp")
def build_ui_example_lists(root_dir: str = ROOT) -> Dict[str, List[str]]:
"""
Returns dict of example filepaths:
- persons: [{root}/examples/person/*]
- styles : [{root}/examples/style/*]
- sketches: [{root}/examples/sketch/*] (optional)
"""
person_dir = os.path.join(root_dir, "examples", "person")
style_dir = os.path.join(root_dir, "examples", "style")
sketch_dir = os.path.join(root_dir, "examples", "sketch")
persons = [p for p in sorted(glob.glob(os.path.join(person_dir, "*"))) if _is_image_file(p)]
styles = [p for p in sorted(glob.glob(os.path.join(style_dir, "*"))) if _is_image_file(p)]
sketches = [p for p in sorted(glob.glob(os.path.join(sketch_dir, "*"))) if _is_image_file(p)]
return {"persons": persons, "styles": styles, "sketches": sketches}
DEFAULT_STEPS = 40
DEBUG_SAVE = False
H: Optional[int] = None
W: Optional[int] = None
@dataclass
class Paths:
person_path: str
depth_path: Optional[str] # sketch(guide) optional
style_path: Optional[str] # ✅ style optional (변경)
output_path: str
def _imread_or_raise(path: str, flag=cv2.IMREAD_COLOR):
img = cv2.imread(path, flag)
if img is None:
raise FileNotFoundError(f"cv2.imread failed: {path} (exists={os.path.exists(path)})")
return img
def _pad_or_crop_to_width_np(arr: np.ndarray, target_width: int, pad_value):
"""
arr: HxWxC or HxW
target_width로 center crop 또는 좌/우 padding(비대칭 포함)해서 정확히 맞춤.
"""
if arr.ndim not in (2, 3):
raise ValueError(f"arr must be 2D or 3D, got shape={arr.shape}")
h = arr.shape[0]
w = arr.shape[1]
if w == target_width:
return arr
if w > target_width:
left = (w - target_width) // 2
return arr[:, left:left + target_width] if arr.ndim == 2 else arr[:, left:left + target_width, :]
# w < target_width: pad
total = target_width - w
left = total // 2
right = total - left # ✅ remainder를 오른쪽이 먹어서 항상 정확히 target_width
if arr.ndim == 2:
return cv2.copyMakeBorder(
arr, 0, 0, left, right,
borderType=cv2.BORDER_CONSTANT,
value=pad_value,
)
else:
return cv2.copyMakeBorder(
arr, 0, 0, left, right,
borderType=cv2.BORDER_CONSTANT,
value=pad_value,
)
def apply_parsing_white_mask_to_person_cv2(
person_pil: Image.Image,
parsing_img: Image.Image
) -> np.ndarray:
person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8)
mask = np.array(parsing_img.convert("L"), dtype=np.uint8)
if mask.shape[:2] != person_rgb.shape[:2]:
mask = cv2.resize(mask, (person_rgb.shape[1], person_rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
white_mask = (mask == 255)
result_rgb = np.full_like(person_rgb, 255, dtype=np.uint8)
result_rgb[white_mask] = person_rgb[white_mask]
result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
return result_bgr
def remove_small_white_components(
parsing_img: Image.Image,
*,
white_threshold: int = 128,
min_white_area: int = 150,
use_open: bool = False,
open_ksize: int = 3,
morph_iters: int = 1,
) -> Image.Image:
"""
- 흰색(=foreground)으로 이진화
- connected components로 '작은 흰색 덩어리'만 제거
- (옵션) OPEN을 아주 약하게 적용해 작은 점/가시 제거 (흰색이 늘어나는 CLOSE는 사용 X)
"""
if not isinstance(parsing_img, Image.Image):
raise TypeError("parsing_img must be a PIL.Image.Image")
arr = np.array(parsing_img.convert("L"), dtype=np.uint8)
mask = np.where(arr >= int(white_threshold), 255, 0).astype(np.uint8)
# 1) 작은 흰색 연결요소 제거
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
keep = np.zeros_like(mask)
for lab in range(1, num_labels):
area = int(stats[lab, cv2.CC_STAT_AREA])
if area >= int(min_white_area):
keep[labels == lab] = 255
mask = keep
# 2) (옵션) OPEN: 작은 흰 점/가시 제거 + 경계 약간 정리 (흰색 증가 방향 아님)
if use_open and int(open_ksize) > 1:
k = int(open_ksize)
if k % 2 == 0:
k += 1
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=int(morph_iters))
return Image.fromarray(mask, mode="L")
def compute_hw_from_person(person_path: str):
img = _imread_or_raise(person_path)
orig_h, orig_w = img.shape[:2]
scale = 1024.0 / float(orig_h)
new_h = 1024
new_w = int(round(orig_w * scale))
if new_w > 1024:
new_w = 1024
return new_h, new_w
def fill_sketch_from_image_path_to_pil(image_path: str) -> Image.Image:
global H, W
if H is None or W is None:
raise RuntimeError("Global H/W not set.")
img = _imread_or_raise(image_path, cv2.IMREAD_GRAYSCALE)
img = cv2.bitwise_not(img)
img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST)
_, binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
filled = np.zeros_like(binary)
cv2.drawContours(filled, contours, -1, 255, thickness=cv2.FILLED)
filled_rgb = cv2.cvtColor(filled, cv2.COLOR_GRAY2RGB)
return Image.fromarray(filled_rgb)
def _resize_pil_nearest(img: Image.Image, size_wh: Tuple[int, int], *, force_mode: Optional[str] = None) -> Image.Image:
"""
Resize PIL image to (W,H) using INTER_NEAREST (safe for masks).
size_wh: (width, height)
"""
w, h = int(size_wh[0]), int(size_wh[1])
if force_mode is not None:
img = img.convert(force_mode)
arr = np.array(img, dtype=np.uint8)
if arr.ndim == 2:
resized = cv2.resize(arr, (w, h), interpolation=cv2.INTER_NEAREST)
return Image.fromarray(resized, mode="L")
elif arr.ndim == 3 and arr.shape[2] == 3:
resized = cv2.resize(arr, (w, h), interpolation=cv2.INTER_NEAREST)
return Image.fromarray(resized, mode="RGB")
else:
raise ValueError(f"Unsupported image array shape: {arr.shape}")
def merge_white_regions_or(img1: Image.Image, img2: Image.Image) -> Image.Image:
a = np.array(img1.convert("RGB"), dtype=np.uint8)
b = np.array(img2.convert("RGB"), dtype=np.uint8)
# ✅ safety: make shapes identical to avoid numpy broadcasting error
if a.shape[:2] != b.shape[:2]:
b = cv2.resize(b, (a.shape[1], a.shape[0]), interpolation=cv2.INTER_NEAREST)
white_a = np.all(a == 255, axis=-1)
white_b = np.all(b == 255, axis=-1)
out = a.copy()
out[white_a | white_b] = 255
return Image.fromarray(out, mode="RGB")
def preprocess_mask(mask_img: Image.Image) -> Image.Image:
global H, W
m = np.array(mask_img.convert("L"), dtype=np.uint8)
if (H is not None) and (W is not None):
m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST)
_, m = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY)
target_width = 1024
m = _pad_or_crop_to_width_np(m, target_width, pad_value=0)
kernel = np.ones((12, 12), np.uint8)
m = cv2.dilate(m, kernel, iterations=1)
if DEBUG_SAVE:
cv2.imwrite("mask_final_1024.png", m)
return Image.fromarray(m, mode="L").convert("RGB")
# def make_depth(depth_path: str) -> Image.Image:
# global H, W
# if H is None or W is None:
# raise RuntimeError("Global H/W not set. Call run_one() first.")
# depth_img = _imread_or_raise(depth_path, 0) # grayscale
# # (선택) 입력이 완전한 0/255가 아니라면 이진화로 고정
# _, depth_bin = cv2.threshold(depth_img, 127, 255, cv2.THRESH_BINARY)
# # 컨투어 채우기가 "두꺼워 보임"의 원인일 수도 있어, 유지/제거 선택 가능
# # 1) 채우기 유지 (holes 메우는 목적이라면)
# contours, _ = cv2.findContours(depth_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# filled_depth = np.zeros_like(depth_bin)
# cv2.drawContours(filled_depth, contours, -1, 255, thickness=cv2.FILLED)
# # 2) 채우기 제거하고 싶으면 위 3줄 대신 이걸 사용:
# # filled_depth = depth_bin
# # ✅ 마스크 리사이즈는 NEAREST (경계 번짐/팽창 느낌 방지)
# filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_NEAREST)
# # (선택) 리사이즈 후에도 0/255 강제
# _, filled_depth = cv2.threshold(filled_depth, 127, 255, cv2.THRESH_BINARY)
# filled_depth = _pad_or_crop_to_width_np(filled_depth, 1024, pad_value=0)
# inverted_image = ImageOps.invert(Image.fromarray(filled_depth))
# with torch.inference_mode():
# image_depth = depth_estimator(inverted_image)["depth"]
# if DEBUG_SAVE:
# image_depth.save("depth.png")
# return image_depth
def make_depth(depth_path: str) -> Image.Image:
global H, W
if H is None or W is None:
raise RuntimeError("Global H/W not set. Call run_one() first.")
depth_img = _imread_or_raise(depth_path, 0) # grayscale
# (선택) 입력이 완전한 0/255가 아니라면 이진화로 고정
_, depth_bin = cv2.threshold(depth_img, 127, 255, cv2.THRESH_BINARY)
# 컨투어 채우기 (holes 메우는 목적)
contours, _ = cv2.findContours(depth_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
filled_depth = np.zeros_like(depth_bin)
cv2.drawContours(filled_depth, contours, -1, 255, thickness=cv2.FILLED)
# ✅ 마스크 리사이즈는 NEAREST
filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_NEAREST)
# (선택) 리사이즈 후에도 0/255 강제
_, filled_depth = cv2.threshold(filled_depth, 127, 255, cv2.THRESH_BINARY)
filled_depth = _pad_or_crop_to_width_np(filled_depth, 1024, pad_value=0)
# ✅ 여기서 침식(팽창의 반대): 흰색 영역을 조금 줄임
erode_ksize = 5 # 3/5/7... (클수록 더 많이 줄어듦)
erode_iters = 1 # 1~2 추천
if erode_ksize is not None and erode_ksize > 1 and erode_iters > 0:
if erode_ksize % 2 == 0:
erode_ksize += 1
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_ksize, erode_ksize))
filled_depth = cv2.erode(filled_depth, kernel, iterations=erode_iters)
# 안전하게 다시 이진화
_, filled_depth = cv2.threshold(filled_depth, 127, 255, cv2.THRESH_BINARY)
inverted_image = ImageOps.invert(Image.fromarray(filled_depth))
with torch.inference_mode():
image_depth = depth_estimator(inverted_image)["depth"]
return image_depth
def _edges_from_parsing(parsing_img: Image.Image) -> np.ndarray:
m = np.array(parsing_img.convert("L"), dtype=np.uint8)
_, m_bin = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY)
edges = cv2.Canny(m_bin, 50, 150)
edges = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=1)
return edges.astype(np.uint8)
def make_depth_from_parsing_edges(parsing_img: Image.Image) -> Image.Image:
global H, W
if H is None or W is None:
raise RuntimeError("Global H/W not set. Call run_one() first.")
depth_img = _edges_from_parsing(parsing_img)
contours, _ = cv2.findContours(depth_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
filled_depth = depth_img.copy()
cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
filled_depth = _pad_or_crop_to_width_np(filled_depth, 1024, pad_value=0)
inverted_image = ImageOps.invert(Image.fromarray(filled_depth))
with torch.inference_mode():
image_depth = depth_estimator(inverted_image)["depth"]
if DEBUG_SAVE:
image_depth.save("depth.png")
return image_depth
def center_crop_lr_to_768x1024(arr: np.ndarray) -> np.ndarray:
target_h, target_w = 1024, 768
h, w = arr.shape[:2]
if h != target_h:
arr = cv2.resize(arr, (w, target_h), interpolation=cv2.INTER_AREA)
h, w = arr.shape[:2]
if w < target_w:
pad = (target_w - w) // 2
arr = cv2.copyMakeBorder(arr, 0, 0, pad, pad, cv2.BORDER_CONSTANT, value=[255, 255, 255])
w = arr.shape[1]
left = (w - target_w) // 2
return arr[:, left:left + target_w]
def save_cropped(imgs, out_path: str):
np_imgs = [np.asarray(im) for im in imgs]
cropped = [center_crop_lr_to_768x1024(x) for x in np_imgs]
out = np.concatenate(cropped, axis=1)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
imageio.imsave(out_path, out)
def _read_hw(path: str) -> Tuple[int, int]:
img = _imread_or_raise(path) # BGR
h, w = img.shape[:2]
return h, w
def _center_crop_lr_to_aspect(arr: np.ndarray, target_aspect: float, *, pad_value=255) -> np.ndarray:
"""
arr: HxWxC (RGB) or HxW
target_aspect = target_w / target_h
- 높이(H)는 유지
- 좌/우를 동일 비율로 crop해서 target_aspect에 맞춤
- 만약 현재 폭이 부족하면 좌/우 padding으로 맞춤
"""
if arr.ndim == 2:
arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB)
h, w = arr.shape[:2]
if h <= 0 or w <= 0:
raise ValueError(f"Invalid image shape: {arr.shape}")
desired_w = int(round(h * float(target_aspect)))
if desired_w <= 0:
desired_w = 1
# 폭이 충분하면 좌/우 crop
if w >= desired_w:
left = (w - desired_w) // 2
right = left + desired_w
return arr[:, left:right]
# 폭이 부족하면 좌/우 padding (요청은 crop이지만 안전장치)
total = desired_w - w
left_pad = total // 2
right_pad = total - left_pad
return cv2.copyMakeBorder(
arr,
0, 0,
left_pad, right_pad,
borderType=cv2.BORDER_CONSTANT,
value=[pad_value, pad_value, pad_value],
)
def save_output_match_person(imgs, out_path: str, person_path: str):
"""
- 출력 imgs(보통 길이 1)를 person 원본 비율에 맞게 좌/우 center-crop
- person 원본 (W,H)로 resize
- (imgs가 여러 장이면) 처리 후 가로로 concat해서 저장
"""
person_h, person_w = _read_hw(person_path)
target_aspect = float(person_w) / float(person_h)
np_imgs = []
for im in imgs:
if isinstance(im, Image.Image):
arr = np.asarray(im.convert("RGB"), dtype=np.uint8)
else:
# 혹시 numpy가 들어오는 경우 대비
arr = np.asarray(im, dtype=np.uint8)
if arr.ndim == 2:
arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB)
cropped = _center_crop_lr_to_aspect(arr, target_aspect, pad_value=255)
resized = cv2.resize(cropped, (person_w, person_h), interpolation=cv2.INTER_AREA)
np_imgs.append(resized)
out = np.concatenate(np_imgs, axis=1) # imgs가 1장이면 그대로
os.makedirs(os.path.dirname(out_path), exist_ok=True)
imageio.imsave(out_path, out)
@lru_cache(maxsize=1)
def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]:
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float32
print(f"[PIPE] device={device}, dtype={dtype}", flush=True)
controlnet = ControlNetModel.from_pretrained(
CONTROLNET_ID,
torch_dtype=dtype,
use_safetensors=True,
).to(device)
vae = AutoencoderKL.from_pretrained(
BASE_MODEL_ID,
subfolder="vae",
torch_dtype=dtype,
use_safetensors=True,
).to(device)
unet = UNet2DConditionModel.from_pretrained(
BASE_MODEL_ID,
subfolder="unet",
torch_dtype=dtype,
use_safetensors=True,
).to(device)
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
BASE_MODEL_ID,
controlnet=controlnet,
vae=vae,
unet=unet,
torch_dtype=dtype,
use_safetensors=True,
add_watermarker=False,
).to(device)
if device == "cuda":
try:
pipe.vae.to(dtype=dtype)
if hasattr(pipe.vae, "config") and hasattr(pipe.vae.config, "force_upcast"):
pipe.vae.config.force_upcast = False
except Exception as e:
print("[PIPE] VAE dtype cast failed:", repr(e), flush=True)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_attention_slicing()
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception as e:
print("[PIPE] xformers not enabled:", repr(e), flush=True)
return pipe, device, dtype
# UI 표기 → 내부 extractor category 문자열 매핑
_UI_TO_EXTRACTOR_CATEGORY = {
"Upper-body": "Upper-cloth",
"Lower-body": "Bottom",
"Dress": "Dress",
}
def _has_valid_file(path: Optional[str]) -> bool:
return (
path is not None
and isinstance(path, str)
and len(path) > 0
and os.path.exists(path)
)
def _resolve_content_style_scales(style_present: bool, prompt_present: bool) -> Tuple[float, float]:
"""
요구사항:
- style image 없으면: (0.0, 0.0)
- prompt 없으면: (0.4, 0.6)
- 둘 다 있으면: 기존 유지 (0.3, 0.5)
"""
if not style_present:
return 0.0, 0.0
if not prompt_present:
return 0.35, 0.65
return 0.25, 0.5
def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str = "Dress"):
global H, W
pipe, device, _dtype = get_pipe_and_device()
image_encoder_dir, ip_ckpt, schp_ckpt = get_assets()
H, W = compute_hw_from_person(paths.person_path)
extractor_category = _UI_TO_EXTRACTOR_CATEGORY.get(category, "Dress")
res = run_simple_extractor(
category=extractor_category,
input_path=os.path.abspath(paths.person_path),
model_restore=schp_ckpt,
)
parsing_img = res["images"][0] if res.get("images") else None
if parsing_img is None:
raise RuntimeError("run_simple_extractor returned no parsing images.")
parsing_img = remove_small_white_components(
parsing_img,
white_threshold=128,
min_white_area=150, # 데이터에 맞게 30~200 사이 조절
use_open=False,
)
# ✅ IMPORTANT: extractor output size can differ from (W,H). Align before OR-merge.
if parsing_img.size != (W, H):
parsing_img = _resize_pil_nearest(parsing_img, (W, H), force_mode="L")
use_depth_path = _has_valid_file(paths.depth_path)
if use_depth_path:
sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path)
else:
sketch_area = parsing_img.convert("RGB")
merged_img = merge_white_regions_or(parsing_img, sketch_area)
mask_pil = preprocess_mask(merged_img)
# person
person_bgr = _imread_or_raise(paths.person_path)
person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA)
person_bgr = _pad_or_crop_to_width_np(person_bgr, 1024, pad_value=[255, 255, 255])
person_rgb = cv2.cvtColor(person_bgr, cv2.COLOR_BGR2RGB)
person_pil = Image.fromarray(person_rgb)
# depth
if use_depth_path:
depth_map = make_depth(paths.depth_path)
else:
depth_map = make_depth_from_parsing_edges(parsing_img)
# garment image (✅ 여기서부터가 핵심: 1024 폭 강제)
personn = Image.open(paths.person_path).convert("RGB")
garment_bgr = apply_parsing_white_mask_to_person_cv2(personn, parsing_img)
garment_rgb = cv2.cvtColor(garment_bgr, cv2.COLOR_BGR2RGB)
garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA)
garment_rgb = _pad_or_crop_to_width_np(garment_rgb, 1024, pad_value=[255, 255, 255])
garment_pil = Image.fromarray(garment_rgb)
# garment mask (✅ 동일하게 1024 맞춤)
gm = np.array(parsing_img.convert("L"), dtype=np.uint8)
gm = cv2.resize(gm, (W, H), interpolation=cv2.INTER_NEAREST)
gm = cv2.cvtColor(gm, cv2.COLOR_GRAY2RGB)
gm = _pad_or_crop_to_width_np(gm, 1024, pad_value=[0, 0, 0])
garment_mask_pil = Image.fromarray(gm)
# ✅ 조건에 따른 scale 결정
style_present = _has_valid_file(paths.style_path)
prompt_present = (prompt is not None) and (str(prompt).strip() != "")
content_scale, style_scale = _resolve_content_style_scales(style_present, prompt_present)
print(
"[SIZE] person:", person_pil.size,
"mask:", mask_pil.size,
"depth:", depth_map.size,
"garment:", garment_pil.size,
"gmask:", garment_mask_pil.size,
"ui_category:", category,
"extractor_category:", extractor_category,
"style_present:", style_present,
"prompt_present:", prompt_present,
"content_scale:", content_scale,
"style_scale:", style_scale,
flush=True
)
ip_model = IPAdapterXL(
pipe,
image_encoder_dir,
ip_ckpt,
device,
mask_pil,
person_pil,
content_scale=content_scale, # ✅ 변경
style_scale=style_scale, # ✅ 변경
garment_images=garment_pil,
garment_mask=garment_mask_pil,
)
if device == "cuda":
pipe.to(dtype=torch.float32)
try:
for _, proc in pipe.unet.attn_processors.items():
proc.to(dtype=torch.float32)
except Exception:
pass
# ✅ style image 없을 때도 generate 입력이 None이 되지 않게 대체
if style_present:
style_img = Image.open(paths.style_path).convert("RGB")
else:
# scale이 0이므로 영향은 없고, 함수 시그니처만 만족시키기 위한 대체값
style_img = garment_pil
# prompt 구성은 기존 유지
if prompt is not None and str(prompt).strip() != "":
prompt = extractor_category + " with " + str(prompt).strip()
else:
prompt = extractor_category
with torch.inference_mode():
images = ip_model.generate(
pil_image=style_img,
image=person_pil,
control_image=depth_map,
strength=1.0,
num_samples=1,
num_inference_steps=int(steps),
shape_prompt="",
prompt=prompt or "",
num=0,
scale=None,
controlnet_conditioning_scale=0.7,
guidance_scale=7.5,
)
# save_cropped(images, paths.output_path)
# return images, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil
save_output_match_person(images, paths.output_path, paths.person_path)
return images, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil
def set_seed(seed: int):
if seed is None or seed < 0:
return
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed, category):
print("[UI] infer_web called", flush=True)
# ✅ person만 필수, style은 선택
if person_fp is None:
raise gr.Error("person 이미지는 필수입니다. (style/sketch는 선택)")
if category not in ("Upper-body", "Lower-body", "Dress"):
raise gr.Error(f"Invalid category: {category}")
set_seed(int(seed) if seed is not None else -1)
tmp_dir = tempfile.mkdtemp(prefix="vista_demo_")
out_path = os.path.join(tmp_dir, "result.png")
paths = Paths(
person_path=person_fp,
depth_path=sketch_fp,
style_path=style_fp, # ✅ None 가능
output_path=out_path,
)
_, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil = run_one(
paths, prompt=prompt, steps=int(steps), category=category
)
out_img = Image.open(out_path).convert("RGB")
return out_img, out_path, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil
with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
# gr.Markdown("첫 실행은 모델 로딩 때문에 시간이 오래 걸릴 수 있습니다.")
# gr.Markdown("category랑 try-on 하려는 옷 종류를 꼭 맞춰주세요.")
gr.Markdown(
"첫 실행은 모델 로딩 때문에 시간이 오래 걸릴 수 있습니다.<br>"
"category랑 try-on 하려는 옷 종류를 꼭 맞춰주세요.",
elem_classes="tight_md",
)
category_toggle = gr.Radio(
choices=["Dress", "Upper-body", "Lower-body"],
value="Dress",
label="Category",
interactive=True,
)
# ✅ 예시 리스트(분리)
ex = build_ui_example_lists(ROOT)
person_examples = [[p] for p in ex["persons"]]
style_examples = [[p] for p in ex["styles"]]
sketch_examples = [[p] for p in ex["sketches"]]
# 한 행에 Person / Style / Output
with gr.Row():
# -------- Person column --------
with gr.Column(scale=1):
person_in = gr.Image(label="Person Image (required)", type="filepath")
if person_examples:
gr.Markdown("#### Examples")
gr.Examples(
examples=person_examples,
inputs=[person_in],
examples_per_page=8,
)
# -------- Style column --------
with gr.Column(scale=1):
style_in = gr.Image(label="Style Image (optional)", type="filepath")
if style_examples:
gr.Markdown("#### Examples")
gr.Examples(
examples=style_examples,
inputs=[style_in],
examples_per_page=8,
)
# -------- Output column --------
with gr.Column(scale=1):
out_img = gr.Image(label="Output", type="pil")
with gr.Accordion("Sketch / Guide (optional)", open=False):
sketch_in = gr.Image(
label="Sketch / Guide (person과 같은 번호로 매칭하세요: person 1 ↔ sketch 1). 스케치는 person 인체와 정렬되어야 합니다.",
type="filepath",
)
if sketch_examples:
gr.Markdown("#### Examples")
gr.Examples(
examples=sketch_examples,
inputs=[sketch_in],
examples_per_page=8,
)
with gr.Row():
prompt_in = gr.Textbox(
label="Prompt",
value="",
placeholder="ex) crystal, lace, button, …",
lines=2,
)
steps_in = gr.Slider(1, 80, value=DEFAULT_STEPS, step=1, label="Steps")
seed_in = gr.Number(label="Seed (-1=random)", value=-1, precision=0)
run_btn = gr.Button("Run")
out_file = gr.File(label="Download result.png")
# gr.Markdown("### Debug Visualizations (mask/depth/etc)")
# with gr.Row():
# dbg_mask = gr.Image(label="mask_pil", type="pil")
# dbg_depth = gr.Image(label="depth_map", type="pil")
# with gr.Row():
# dbg_person = gr.Image(label="person_pil", type="pil")
# dbg_garment = gr.Image(label="garment_pil", type="pil")
# dbg_gmask = gr.Image(label="garment_mask_pil", type="pil")
run_btn.click(
fn=infer_web,
inputs=[person_in, sketch_in, style_in, prompt_in, steps_in, seed_in, category_toggle],
outputs=[out_img, out_file],
)
demo.queue()
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)