VISTA / app.py
ssoxye's picture
update garment pil
1a497a0
import os
import sys
# ---------------------------------------------------------
# 0) Make sure local packages (diffusers3, preprocess, etc.) are importable on HF Spaces
# ---------------------------------------------------------
ROOT = os.path.dirname(os.path.abspath(__file__))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
print("[BOOT] ROOT =", ROOT, flush=True)
print("[BOOT] sys.path[:5] =", sys.path[:5], flush=True)
import tempfile
from dataclasses import dataclass
from functools import lru_cache
from typing import Optional, Tuple
import gradio as gr
import torch
import numpy as np
import cv2
import imageio
from PIL import Image, ImageOps
from transformers import pipeline
from huggingface_hub import hf_hub_download
# Show where diffusers3 is imported from (helps diagnose import collisions on Spaces)
import diffusers3
print("[BOOT] diffusers3 loaded from:", getattr(diffusers3, "__file__", "<?>"), flush=True)
from diffusers import UniPCMultistepScheduler, AutoencoderKL, UNet2DConditionModel
from diffusers3.models.controlnet import ControlNetModel
from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import (
StableDiffusionXLControlNetImg2ImgPipeline,
)
from ip_adapter import IPAdapterXL
# extractor
from preprocess.simple_extractor import run as run_simple_extractor
# =========================
# HF Hub repo ids
# =========================
BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
CONTROLNET_ID = "diffusers/controlnet-depth-sdxl-1.0"
# assets dataset repo
ASSETS_REPO = os.getenv("ASSETS_REPO", "soye/VISTA_assets")
ASSETS_REPO_TYPE = "dataset"
depth_estimator = pipeline("depth-estimation")
def asset_path(relpath: str) -> str:
return hf_hub_download(
repo_id=ASSETS_REPO,
repo_type=ASSETS_REPO_TYPE,
filename=relpath,
)
@lru_cache(maxsize=1)
def get_assets():
print("[ASSETS] Downloading assets from:", ASSETS_REPO, flush=True)
image_encoder_weight = asset_path("image_encoder/model.safetensors")
_ = asset_path("image_encoder/config.json")
image_encoder_dir = os.path.dirname(image_encoder_weight)
ip_ckpt = asset_path("ip_adapter/ip-adapter_sdxl_vit-h.bin")
schp_ckpt = asset_path("preprocess_ckpts/exp-schp-201908301523-atr.pth")
print("[ASSETS] image_encoder_dir =", image_encoder_dir, flush=True)
print("[ASSETS] ip_ckpt =", ip_ckpt, flush=True)
print("[ASSETS] schp_ckpt =", schp_ckpt, flush=True)
return image_encoder_dir, ip_ckpt, schp_ckpt
DEFAULT_STEPS = 40
DEBUG_SAVE = False
H: Optional[int] = None
W: Optional[int] = None
@dataclass
class Paths:
person_path: str
depth_path: str
style_path: str
output_path: str
def _imread_or_raise(path: str, flag=cv2.IMREAD_COLOR):
img = cv2.imread(path, flag)
if img is None:
raise FileNotFoundError(f"cv2.imread failed: {path} (exists={os.path.exists(path)})")
return img
def apply_parsing_white_mask_to_person_cv2(
person_pil: Image.Image,
parsing_img: Image.Image
) -> np.ndarray:
"""
person_pil(RGB) 크기에 parsing_img(L) 마스크를 맞춰서
흰색(255) 영역만 person을 남기고 나머지는 흰색 배경으로 만드는 함수.
- parsing_img는 person 크기에 반드시 맞춰야 함 (NEAREST)
"""
person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8)
# parsing 마스크 (L)
mask = np.array(parsing_img.convert("L"), dtype=np.uint8)
# ✅ 핵심: 크기 불일치 해결 (H,W) 맞춤
if mask.shape[0] != person_rgb.shape[0] or mask.shape[1] != person_rgb.shape[1]:
mask = cv2.resize(mask, (person_rgb.shape[1], person_rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
white_mask = (mask == 255)
result_rgb = np.full_like(person_rgb, 255, dtype=np.uint8)
result_rgb[white_mask] = person_rgb[white_mask]
result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
return result_bgr
def compute_hw_from_person(person_path: str):
img = _imread_or_raise(person_path)
orig_h, orig_w = img.shape[:2]
scale = 1024.0 / float(orig_h)
new_h = 1024
new_w = int(round(orig_w * scale))
if new_w > 1024:
new_w = 1024
return new_h, new_w
def invert_sketch_area(sketch_pil: Image.Image) -> Image.Image:
return ImageOps.invert(sketch_pil.convert("L")).convert("RGB")
def fill_sketch_from_image_path_to_pil(image_path: str) -> Image.Image:
global H, W
if H is None or W is None:
raise RuntimeError("Global H/W not set.")
img = _imread_or_raise(image_path, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST)
_, binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
filled = np.zeros_like(binary)
cv2.drawContours(filled, contours, -1, 255, thickness=cv2.FILLED)
filled_rgb = cv2.cvtColor(filled, cv2.COLOR_GRAY2RGB)
return Image.fromarray(filled_rgb)
def merge_white_regions_or(img1: Image.Image, img2: Image.Image) -> Image.Image:
a = np.array(img1.convert("RGB"), dtype=np.uint8)
b = np.array(img2.convert("RGB"), dtype=np.uint8)
white_a = np.all(a == 255, axis=-1)
white_b = np.all(b == 255, axis=-1)
out = a.copy()
out[white_a | white_b] = 255
return Image.fromarray(out)
def preprocess_mask(mask_img: Image.Image) -> Image.Image:
global H, W
m = np.array(mask_img.convert("L"), dtype=np.uint8)
if (H is not None) and (W is not None):
m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST)
_, m = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY)
target_width = 1024
h, w = m.shape[:2]
if w < target_width:
total_padding = target_width - w
left_padding = total_padding // 2
right_padding = total_padding - left_padding
m = cv2.copyMakeBorder(
m,
top=0, bottom=0,
left=left_padding, right=right_padding,
borderType=cv2.BORDER_CONSTANT,
value=0,
)
elif w > target_width:
left = (w - target_width) // 2
m = m[:, left:left + target_width]
kernel = np.ones((17, 17), np.uint8)
m = cv2.dilate(m, kernel, iterations=1)
if DEBUG_SAVE:
cv2.imwrite("mask_final_1024.png", m)
return Image.fromarray(m, mode="L").convert("RGB")
def make_depth(depth_path: str) -> Image.Image:
global H, W
if H is None or W is None:
raise RuntimeError("Global H/W not set. Call run_one() first.")
depth_img = _imread_or_raise(depth_path, 0)
inverted_depth = cv2.bitwise_not(depth_img)
contours, _ = cv2.findContours(inverted_depth, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
filled_depth = inverted_depth.copy()
cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
# ✅ resize는 전역 (W,H)
filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
height, width = filled_depth.shape
total_padding = 1024 - width
left_padding = total_padding // 2
right_padding = total_padding - left_padding
padded_depth = cv2.copyMakeBorder(
filled_depth, 0, 0, left_padding, right_padding,
borderType=cv2.BORDER_CONSTANT,
value=0,
)
inverted_image = ImageOps.invert(Image.fromarray(padded_depth))
with torch.inference_mode():
image_depth = depth_estimator(inverted_image)["depth"]
if DEBUG_SAVE:
image_depth.save("depth.png")
return image_depth
def center_crop_lr_to_768x1024(arr: np.ndarray) -> np.ndarray:
target_h, target_w = 1024, 768
h, w = arr.shape[:2]
if h != target_h:
arr = cv2.resize(arr, (w, target_h), interpolation=cv2.INTER_AREA)
h, w = arr.shape[:2]
if w < target_w:
pad = (target_w - w) // 2
arr = cv2.copyMakeBorder(arr, 0, 0, pad, pad, cv2.BORDER_CONSTANT, value=[255, 255, 255])
w = arr.shape[1]
left = (w - target_w) // 2
return arr[:, left:left + target_w]
def save_cropped(imgs, out_path: str):
np_imgs = [np.asarray(im) for im in imgs]
cropped = [center_crop_lr_to_768x1024(x) for x in np_imgs]
out = np.concatenate(cropped, axis=1)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
imageio.imsave(out_path, out)
@lru_cache(maxsize=1)
def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]:
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float32 # 현재 너 설정 유지
print(f"[PIPE] device={device}, dtype={dtype}", flush=True)
controlnet = ControlNetModel.from_pretrained(
CONTROLNET_ID,
torch_dtype=dtype,
use_safetensors=True,
).to(device)
vae = AutoencoderKL.from_pretrained(
BASE_MODEL_ID,
subfolder="vae",
torch_dtype=dtype,
use_safetensors=True,
).to(device)
unet = UNet2DConditionModel.from_pretrained(
BASE_MODEL_ID,
subfolder="unet",
torch_dtype=dtype,
use_safetensors=True,
).to(device)
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
BASE_MODEL_ID,
controlnet=controlnet,
vae=vae,
unet=unet,
torch_dtype=dtype,
use_safetensors=True,
add_watermarker=False,
).to(device)
if device == "cuda":
try:
pipe.vae.to(dtype=dtype)
if hasattr(pipe.vae, "config") and hasattr(pipe.vae.config, "force_upcast"):
pipe.vae.config.force_upcast = False
except Exception as e:
print("[PIPE] VAE dtype cast failed:", repr(e), flush=True)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_attention_slicing()
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception as e:
print("[PIPE] xformers not enabled:", repr(e), flush=True)
return pipe, device, dtype
def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
"""
Returns:
images(list[PIL]), mask_pil(PIL), depth_map(PIL), person_pil(PIL), garment_pil(PIL), garment_mask_pil(PIL)
"""
global H, W
pipe, device, _dtype = get_pipe_and_device()
image_encoder_dir, ip_ckpt, schp_ckpt = get_assets()
H, W = compute_hw_from_person(paths.person_path)
res = run_simple_extractor(
category="Upper-clothes",
input_path=os.path.abspath(paths.person_path),
model_restore=schp_ckpt,
)
parsing_img = res["images"][0] if res.get("images") else None
if parsing_img is None:
raise RuntimeError("run_simple_extractor returned no parsing images.")
sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path)
merged_img = merge_white_regions_or(parsing_img, sketch_area)
mask_pil = preprocess_mask(merged_img)
person_bgr = _imread_or_raise(paths.person_path)
person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA)
target_width = 1024
cur_w = person_bgr.shape[1]
if cur_w < target_width:
total = target_width - cur_w
left = total // 2
right = total - left
padded_person = cv2.copyMakeBorder(
person_bgr, 0, 0, left, right,
borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255]
)
elif cur_w > target_width:
left = (cur_w - target_width) // 2
padded_person = person_bgr[:, left:left + target_width]
else:
padded_person = person_bgr
person_rgb = cv2.cvtColor(padded_person, cv2.COLOR_BGR2RGB)
person_pil = Image.fromarray(person_rgb)
depth_map = make_depth(paths.depth_path)
personn = Image.open(paths.person_path)
garment_ = apply_parsing_white_mask_to_person_cv2(
personn,
parsing_img
)
garment_rgb = cv2.cvtColor(garment_, cv2.COLOR_BGR2RGB)
# ✅ (중요) garment_는 원본 person 크기일 수 있으니 전역 (W,H)로 맞춘 뒤 padding
garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA)
target_width = 1024 # ✅ 고정
padding = (target_width - person_bgr.shape[1]) // 2
garment_rgb = cv2.copyMakeBorder(
garment_rgb,
top=0, bottom=0,
left=padding, right=padding,
borderType=cv2.BORDER_CONSTANT,
value=[255, 255, 255],
)
garment_pil = Image.fromarray(garment_rgb)
gm = np.array(parsing_img.convert("L"), dtype=np.uint8)
gm = cv2.resize(gm, (W, H), interpolation=cv2.INTER_AREA)
gm = cv2.cvtColor(gm, cv2.COLOR_GRAY2RGB)
cur_w2 = gm.shape[1]
if cur_w2 < target_width:
total = target_width - cur_w2
left = total // 2
right = total - left
gm = cv2.copyMakeBorder(gm, 0, 0, left, right, cv2.BORDER_CONSTANT, value=[0, 0, 0])
elif cur_w2 > target_width:
left2 = (cur_w2 - target_width) // 2
gm = gm[:, left2:left2 + target_width]
garment_mask_pil = Image.fromarray(gm)
# --- sanity sizes (optional)
print(
"[SIZE] person:", person_pil.size,
"mask:", mask_pil.size,
"depth:", depth_map.size,
"garment:", garment_pil.size,
"gmask:", garment_mask_pil.size,
flush=True
)
ip_model = IPAdapterXL(
pipe,
image_encoder_dir,
ip_ckpt,
device,
mask_pil,
person_pil,
content_scale=0.3,
style_scale=0.5,
garment_images=garment_pil,
garment_mask=garment_mask_pil,
)
if device == "cuda":
pipe.to(dtype=torch.float32)
try:
for _, proc in pipe.unet.attn_processors.items():
proc.to(dtype=torch.float32)
except Exception:
pass
style_img = Image.open(paths.style_path).convert("RGB")
with torch.inference_mode():
images = ip_model.generate(
pil_image=style_img,
image=person_pil,
control_image=depth_map,
strength=1.0,
num_samples=1,
num_inference_steps=int(steps),
shape_prompt="",
prompt=prompt or "",
num=0,
scale=None,
controlnet_conditioning_scale=0.7,
guidance_scale=7.5,
)
save_cropped(images, paths.output_path)
return images, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil
def set_seed(seed: int):
if seed is None or seed < 0:
return
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed):
print("[UI] infer_web called", flush=True)
if person_fp is None or sketch_fp is None or style_fp is None:
raise gr.Error("person / sketch / style 이미지를 모두 업로드해야 합니다.")
set_seed(int(seed) if seed is not None else -1)
tmp_dir = tempfile.mkdtemp(prefix="vista_demo_")
out_path = os.path.join(tmp_dir, "result.png")
paths = Paths(
person_path=person_fp,
depth_path=sketch_fp,
style_path=style_fp,
output_path=out_path,
)
_, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil = run_one(
paths, prompt=prompt, steps=int(steps)
)
out_img = Image.open(out_path).convert("RGB")
return out_img, out_path, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil
with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
gr.Markdown("## VISTA Demo\nperson / sketch(guide) / style 입력으로 결과를 생성합니다.")
with gr.Row():
person_in = gr.Image(label="Person Image", type="filepath")
sketch_in = gr.Image(label="Sketch / Guide (depth_path)", type="filepath")
style_in = gr.Image(label="Style Image", type="filepath")
with gr.Row():
prompt_in = gr.Textbox(label="Prompt", value="upper garment", lines=2)
steps_in = gr.Slider(1, 80, value=DEFAULT_STEPS, step=1, label="Steps")
seed_in = gr.Number(label="Seed (-1=random)", value=-1, precision=0)
run_btn = gr.Button("Run")
out_img = gr.Image(label="Output", type="pil")
out_file = gr.File(label="Download result.png")
gr.Markdown("### Debug Visualizations (mask/depth/etc)")
with gr.Row():
dbg_mask = gr.Image(label="mask_pil", type="pil")
dbg_depth = gr.Image(label="depth_map", type="pil")
with gr.Row():
dbg_person = gr.Image(label="person_pil", type="pil")
dbg_garment = gr.Image(label="garment_pil", type="pil")
dbg_gmask = gr.Image(label="garment_mask_pil", type="pil")
run_btn.click(
fn=infer_web,
inputs=[person_in, sketch_in, style_in, prompt_in, steps_in, seed_in],
outputs=[out_img, out_file, dbg_mask, dbg_depth, dbg_person, dbg_garment, dbg_gmask],
)
demo.queue()
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)