|
|
import os |
|
|
import sys |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ROOT = os.path.dirname(os.path.abspath(__file__)) |
|
|
if ROOT not in sys.path: |
|
|
sys.path.insert(0, ROOT) |
|
|
|
|
|
print("[BOOT] ROOT =", ROOT, flush=True) |
|
|
print("[BOOT] sys.path[:5] =", sys.path[:5], flush=True) |
|
|
|
|
|
import tempfile |
|
|
from dataclasses import dataclass |
|
|
from functools import lru_cache |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import imageio |
|
|
from PIL import Image, ImageOps |
|
|
from transformers import pipeline |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
import diffusers3 |
|
|
print("[BOOT] diffusers3 loaded from:", getattr(diffusers3, "__file__", "<?>"), flush=True) |
|
|
|
|
|
from diffusers import UniPCMultistepScheduler, AutoencoderKL, UNet2DConditionModel |
|
|
|
|
|
from diffusers3.models.controlnet import ControlNetModel |
|
|
from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import ( |
|
|
StableDiffusionXLControlNetImg2ImgPipeline, |
|
|
) |
|
|
from ip_adapter import IPAdapterXL |
|
|
|
|
|
|
|
|
from preprocess.simple_extractor import run as run_simple_extractor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" |
|
|
CONTROLNET_ID = "diffusers/controlnet-depth-sdxl-1.0" |
|
|
|
|
|
|
|
|
ASSETS_REPO = os.getenv("ASSETS_REPO", "soye/VISTA_assets") |
|
|
ASSETS_REPO_TYPE = "dataset" |
|
|
|
|
|
depth_estimator = pipeline("depth-estimation") |
|
|
|
|
|
|
|
|
def asset_path(relpath: str) -> str: |
|
|
return hf_hub_download( |
|
|
repo_id=ASSETS_REPO, |
|
|
repo_type=ASSETS_REPO_TYPE, |
|
|
filename=relpath, |
|
|
) |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def get_assets(): |
|
|
print("[ASSETS] Downloading assets from:", ASSETS_REPO, flush=True) |
|
|
|
|
|
image_encoder_weight = asset_path("image_encoder/model.safetensors") |
|
|
_ = asset_path("image_encoder/config.json") |
|
|
image_encoder_dir = os.path.dirname(image_encoder_weight) |
|
|
|
|
|
ip_ckpt = asset_path("ip_adapter/ip-adapter_sdxl_vit-h.bin") |
|
|
schp_ckpt = asset_path("preprocess_ckpts/exp-schp-201908301523-atr.pth") |
|
|
|
|
|
print("[ASSETS] image_encoder_dir =", image_encoder_dir, flush=True) |
|
|
print("[ASSETS] ip_ckpt =", ip_ckpt, flush=True) |
|
|
print("[ASSETS] schp_ckpt =", schp_ckpt, flush=True) |
|
|
return image_encoder_dir, ip_ckpt, schp_ckpt |
|
|
|
|
|
|
|
|
DEFAULT_STEPS = 40 |
|
|
DEBUG_SAVE = False |
|
|
|
|
|
H: Optional[int] = None |
|
|
W: Optional[int] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Paths: |
|
|
person_path: str |
|
|
depth_path: str |
|
|
style_path: str |
|
|
output_path: str |
|
|
|
|
|
|
|
|
def _imread_or_raise(path: str, flag=cv2.IMREAD_COLOR): |
|
|
img = cv2.imread(path, flag) |
|
|
if img is None: |
|
|
raise FileNotFoundError(f"cv2.imread failed: {path} (exists={os.path.exists(path)})") |
|
|
return img |
|
|
|
|
|
def apply_parsing_white_mask_to_person_cv2( |
|
|
person_pil: Image.Image, |
|
|
parsing_img: Image.Image |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
person_pil(RGB) 크기에 parsing_img(L) 마스크를 맞춰서 |
|
|
흰색(255) 영역만 person을 남기고 나머지는 흰색 배경으로 만드는 함수. |
|
|
|
|
|
- parsing_img는 person 크기에 반드시 맞춰야 함 (NEAREST) |
|
|
""" |
|
|
person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8) |
|
|
|
|
|
|
|
|
mask = np.array(parsing_img.convert("L"), dtype=np.uint8) |
|
|
|
|
|
|
|
|
if mask.shape[0] != person_rgb.shape[0] or mask.shape[1] != person_rgb.shape[1]: |
|
|
mask = cv2.resize(mask, (person_rgb.shape[1], person_rgb.shape[0]), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
white_mask = (mask == 255) |
|
|
|
|
|
result_rgb = np.full_like(person_rgb, 255, dtype=np.uint8) |
|
|
result_rgb[white_mask] = person_rgb[white_mask] |
|
|
|
|
|
result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) |
|
|
return result_bgr |
|
|
|
|
|
|
|
|
|
|
|
def compute_hw_from_person(person_path: str): |
|
|
img = _imread_or_raise(person_path) |
|
|
orig_h, orig_w = img.shape[:2] |
|
|
scale = 1024.0 / float(orig_h) |
|
|
new_h = 1024 |
|
|
new_w = int(round(orig_w * scale)) |
|
|
if new_w > 1024: |
|
|
new_w = 1024 |
|
|
return new_h, new_w |
|
|
|
|
|
|
|
|
def invert_sketch_area(sketch_pil: Image.Image) -> Image.Image: |
|
|
return ImageOps.invert(sketch_pil.convert("L")).convert("RGB") |
|
|
|
|
|
|
|
|
def fill_sketch_from_image_path_to_pil(image_path: str) -> Image.Image: |
|
|
global H, W |
|
|
if H is None or W is None: |
|
|
raise RuntimeError("Global H/W not set.") |
|
|
img = _imread_or_raise(image_path, cv2.IMREAD_GRAYSCALE) |
|
|
img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST) |
|
|
_, binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV) |
|
|
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
filled = np.zeros_like(binary) |
|
|
cv2.drawContours(filled, contours, -1, 255, thickness=cv2.FILLED) |
|
|
filled_rgb = cv2.cvtColor(filled, cv2.COLOR_GRAY2RGB) |
|
|
return Image.fromarray(filled_rgb) |
|
|
|
|
|
|
|
|
def merge_white_regions_or(img1: Image.Image, img2: Image.Image) -> Image.Image: |
|
|
a = np.array(img1.convert("RGB"), dtype=np.uint8) |
|
|
b = np.array(img2.convert("RGB"), dtype=np.uint8) |
|
|
white_a = np.all(a == 255, axis=-1) |
|
|
white_b = np.all(b == 255, axis=-1) |
|
|
out = a.copy() |
|
|
out[white_a | white_b] = 255 |
|
|
return Image.fromarray(out) |
|
|
|
|
|
|
|
|
def preprocess_mask(mask_img: Image.Image) -> Image.Image: |
|
|
global H, W |
|
|
m = np.array(mask_img.convert("L"), dtype=np.uint8) |
|
|
|
|
|
if (H is not None) and (W is not None): |
|
|
m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
_, m = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY) |
|
|
|
|
|
target_width = 1024 |
|
|
h, w = m.shape[:2] |
|
|
|
|
|
if w < target_width: |
|
|
total_padding = target_width - w |
|
|
left_padding = total_padding // 2 |
|
|
right_padding = total_padding - left_padding |
|
|
m = cv2.copyMakeBorder( |
|
|
m, |
|
|
top=0, bottom=0, |
|
|
left=left_padding, right=right_padding, |
|
|
borderType=cv2.BORDER_CONSTANT, |
|
|
value=0, |
|
|
) |
|
|
elif w > target_width: |
|
|
left = (w - target_width) // 2 |
|
|
m = m[:, left:left + target_width] |
|
|
|
|
|
kernel = np.ones((17, 17), np.uint8) |
|
|
m = cv2.dilate(m, kernel, iterations=1) |
|
|
|
|
|
if DEBUG_SAVE: |
|
|
cv2.imwrite("mask_final_1024.png", m) |
|
|
|
|
|
return Image.fromarray(m, mode="L").convert("RGB") |
|
|
|
|
|
def make_depth(depth_path: str) -> Image.Image: |
|
|
global H, W |
|
|
if H is None or W is None: |
|
|
raise RuntimeError("Global H/W not set. Call run_one() first.") |
|
|
|
|
|
depth_img = _imread_or_raise(depth_path, 0) |
|
|
|
|
|
inverted_depth = cv2.bitwise_not(depth_img) |
|
|
contours, _ = cv2.findContours(inverted_depth, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
filled_depth = inverted_depth.copy() |
|
|
cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED) |
|
|
|
|
|
|
|
|
filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA) |
|
|
|
|
|
height, width = filled_depth.shape |
|
|
total_padding = 1024 - width |
|
|
left_padding = total_padding // 2 |
|
|
right_padding = total_padding - left_padding |
|
|
|
|
|
padded_depth = cv2.copyMakeBorder( |
|
|
filled_depth, 0, 0, left_padding, right_padding, |
|
|
borderType=cv2.BORDER_CONSTANT, |
|
|
value=0, |
|
|
) |
|
|
|
|
|
inverted_image = ImageOps.invert(Image.fromarray(padded_depth)) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
image_depth = depth_estimator(inverted_image)["depth"] |
|
|
|
|
|
if DEBUG_SAVE: |
|
|
image_depth.save("depth.png") |
|
|
|
|
|
return image_depth |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def center_crop_lr_to_768x1024(arr: np.ndarray) -> np.ndarray: |
|
|
target_h, target_w = 1024, 768 |
|
|
h, w = arr.shape[:2] |
|
|
if h != target_h: |
|
|
arr = cv2.resize(arr, (w, target_h), interpolation=cv2.INTER_AREA) |
|
|
h, w = arr.shape[:2] |
|
|
if w < target_w: |
|
|
pad = (target_w - w) // 2 |
|
|
arr = cv2.copyMakeBorder(arr, 0, 0, pad, pad, cv2.BORDER_CONSTANT, value=[255, 255, 255]) |
|
|
w = arr.shape[1] |
|
|
left = (w - target_w) // 2 |
|
|
return arr[:, left:left + target_w] |
|
|
|
|
|
|
|
|
def save_cropped(imgs, out_path: str): |
|
|
np_imgs = [np.asarray(im) for im in imgs] |
|
|
cropped = [center_crop_lr_to_768x1024(x) for x in np_imgs] |
|
|
out = np.concatenate(cropped, axis=1) |
|
|
os.makedirs(os.path.dirname(out_path), exist_ok=True) |
|
|
imageio.imsave(out_path, out) |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float32 |
|
|
|
|
|
print(f"[PIPE] device={device}, dtype={dtype}", flush=True) |
|
|
|
|
|
controlnet = ControlNetModel.from_pretrained( |
|
|
CONTROLNET_ID, |
|
|
torch_dtype=dtype, |
|
|
use_safetensors=True, |
|
|
).to(device) |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained( |
|
|
BASE_MODEL_ID, |
|
|
subfolder="vae", |
|
|
torch_dtype=dtype, |
|
|
use_safetensors=True, |
|
|
).to(device) |
|
|
|
|
|
unet = UNet2DConditionModel.from_pretrained( |
|
|
BASE_MODEL_ID, |
|
|
subfolder="unet", |
|
|
torch_dtype=dtype, |
|
|
use_safetensors=True, |
|
|
).to(device) |
|
|
|
|
|
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( |
|
|
BASE_MODEL_ID, |
|
|
controlnet=controlnet, |
|
|
vae=vae, |
|
|
unet=unet, |
|
|
torch_dtype=dtype, |
|
|
use_safetensors=True, |
|
|
add_watermarker=False, |
|
|
).to(device) |
|
|
|
|
|
if device == "cuda": |
|
|
try: |
|
|
pipe.vae.to(dtype=dtype) |
|
|
if hasattr(pipe.vae, "config") and hasattr(pipe.vae.config, "force_upcast"): |
|
|
pipe.vae.config.force_upcast = False |
|
|
except Exception as e: |
|
|
print("[PIPE] VAE dtype cast failed:", repr(e), flush=True) |
|
|
|
|
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
|
|
pipe.enable_attention_slicing() |
|
|
try: |
|
|
pipe.enable_xformers_memory_efficient_attention() |
|
|
except Exception as e: |
|
|
print("[PIPE] xformers not enabled:", repr(e), flush=True) |
|
|
|
|
|
return pipe, device, dtype |
|
|
|
|
|
|
|
|
def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS): |
|
|
""" |
|
|
Returns: |
|
|
images(list[PIL]), mask_pil(PIL), depth_map(PIL), person_pil(PIL), garment_pil(PIL), garment_mask_pil(PIL) |
|
|
""" |
|
|
global H, W |
|
|
pipe, device, _dtype = get_pipe_and_device() |
|
|
image_encoder_dir, ip_ckpt, schp_ckpt = get_assets() |
|
|
|
|
|
H, W = compute_hw_from_person(paths.person_path) |
|
|
|
|
|
res = run_simple_extractor( |
|
|
category="Upper-clothes", |
|
|
input_path=os.path.abspath(paths.person_path), |
|
|
model_restore=schp_ckpt, |
|
|
) |
|
|
parsing_img = res["images"][0] if res.get("images") else None |
|
|
if parsing_img is None: |
|
|
raise RuntimeError("run_simple_extractor returned no parsing images.") |
|
|
|
|
|
sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path) |
|
|
merged_img = merge_white_regions_or(parsing_img, sketch_area) |
|
|
mask_pil = preprocess_mask(merged_img) |
|
|
|
|
|
person_bgr = _imread_or_raise(paths.person_path) |
|
|
person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA) |
|
|
|
|
|
target_width = 1024 |
|
|
cur_w = person_bgr.shape[1] |
|
|
if cur_w < target_width: |
|
|
total = target_width - cur_w |
|
|
left = total // 2 |
|
|
right = total - left |
|
|
padded_person = cv2.copyMakeBorder( |
|
|
person_bgr, 0, 0, left, right, |
|
|
borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255] |
|
|
) |
|
|
elif cur_w > target_width: |
|
|
left = (cur_w - target_width) // 2 |
|
|
padded_person = person_bgr[:, left:left + target_width] |
|
|
else: |
|
|
padded_person = person_bgr |
|
|
|
|
|
person_rgb = cv2.cvtColor(padded_person, cv2.COLOR_BGR2RGB) |
|
|
person_pil = Image.fromarray(person_rgb) |
|
|
|
|
|
depth_map = make_depth(paths.depth_path) |
|
|
|
|
|
|
|
|
|
|
|
personn = Image.open(paths.person_path) |
|
|
|
|
|
garment_ = apply_parsing_white_mask_to_person_cv2( |
|
|
personn, |
|
|
parsing_img |
|
|
) |
|
|
|
|
|
garment_rgb = cv2.cvtColor(garment_, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA) |
|
|
|
|
|
target_width = 1024 |
|
|
padding = (target_width - person_bgr.shape[1]) // 2 |
|
|
|
|
|
garment_rgb = cv2.copyMakeBorder( |
|
|
garment_rgb, |
|
|
top=0, bottom=0, |
|
|
left=padding, right=padding, |
|
|
borderType=cv2.BORDER_CONSTANT, |
|
|
value=[255, 255, 255], |
|
|
) |
|
|
garment_pil = Image.fromarray(garment_rgb) |
|
|
|
|
|
gm = np.array(parsing_img.convert("L"), dtype=np.uint8) |
|
|
gm = cv2.resize(gm, (W, H), interpolation=cv2.INTER_AREA) |
|
|
gm = cv2.cvtColor(gm, cv2.COLOR_GRAY2RGB) |
|
|
cur_w2 = gm.shape[1] |
|
|
if cur_w2 < target_width: |
|
|
total = target_width - cur_w2 |
|
|
left = total // 2 |
|
|
right = total - left |
|
|
gm = cv2.copyMakeBorder(gm, 0, 0, left, right, cv2.BORDER_CONSTANT, value=[0, 0, 0]) |
|
|
elif cur_w2 > target_width: |
|
|
left2 = (cur_w2 - target_width) // 2 |
|
|
gm = gm[:, left2:left2 + target_width] |
|
|
garment_mask_pil = Image.fromarray(gm) |
|
|
|
|
|
|
|
|
print( |
|
|
"[SIZE] person:", person_pil.size, |
|
|
"mask:", mask_pil.size, |
|
|
"depth:", depth_map.size, |
|
|
"garment:", garment_pil.size, |
|
|
"gmask:", garment_mask_pil.size, |
|
|
flush=True |
|
|
) |
|
|
|
|
|
ip_model = IPAdapterXL( |
|
|
pipe, |
|
|
image_encoder_dir, |
|
|
ip_ckpt, |
|
|
device, |
|
|
mask_pil, |
|
|
person_pil, |
|
|
content_scale=0.3, |
|
|
style_scale=0.5, |
|
|
garment_images=garment_pil, |
|
|
garment_mask=garment_mask_pil, |
|
|
) |
|
|
|
|
|
if device == "cuda": |
|
|
pipe.to(dtype=torch.float32) |
|
|
try: |
|
|
for _, proc in pipe.unet.attn_processors.items(): |
|
|
proc.to(dtype=torch.float32) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
style_img = Image.open(paths.style_path).convert("RGB") |
|
|
|
|
|
with torch.inference_mode(): |
|
|
images = ip_model.generate( |
|
|
pil_image=style_img, |
|
|
image=person_pil, |
|
|
control_image=depth_map, |
|
|
strength=1.0, |
|
|
num_samples=1, |
|
|
num_inference_steps=int(steps), |
|
|
shape_prompt="", |
|
|
prompt=prompt or "", |
|
|
num=0, |
|
|
scale=None, |
|
|
controlnet_conditioning_scale=0.7, |
|
|
guidance_scale=7.5, |
|
|
) |
|
|
|
|
|
save_cropped(images, paths.output_path) |
|
|
|
|
|
return images, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil |
|
|
|
|
|
|
|
|
def set_seed(seed: int): |
|
|
if seed is None or seed < 0: |
|
|
return |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
|
def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed): |
|
|
print("[UI] infer_web called", flush=True) |
|
|
if person_fp is None or sketch_fp is None or style_fp is None: |
|
|
raise gr.Error("person / sketch / style 이미지를 모두 업로드해야 합니다.") |
|
|
|
|
|
set_seed(int(seed) if seed is not None else -1) |
|
|
|
|
|
tmp_dir = tempfile.mkdtemp(prefix="vista_demo_") |
|
|
out_path = os.path.join(tmp_dir, "result.png") |
|
|
|
|
|
paths = Paths( |
|
|
person_path=person_fp, |
|
|
depth_path=sketch_fp, |
|
|
style_path=style_fp, |
|
|
output_path=out_path, |
|
|
) |
|
|
|
|
|
_, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil = run_one( |
|
|
paths, prompt=prompt, steps=int(steps) |
|
|
) |
|
|
|
|
|
out_img = Image.open(out_path).convert("RGB") |
|
|
return out_img, out_path, mask_pil, depth_map, person_pil, garment_pil, garment_mask_pil |
|
|
|
|
|
|
|
|
with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo: |
|
|
gr.Markdown("## VISTA Demo\nperson / sketch(guide) / style 입력으로 결과를 생성합니다.") |
|
|
|
|
|
with gr.Row(): |
|
|
person_in = gr.Image(label="Person Image", type="filepath") |
|
|
sketch_in = gr.Image(label="Sketch / Guide (depth_path)", type="filepath") |
|
|
style_in = gr.Image(label="Style Image", type="filepath") |
|
|
|
|
|
with gr.Row(): |
|
|
prompt_in = gr.Textbox(label="Prompt", value="upper garment", lines=2) |
|
|
steps_in = gr.Slider(1, 80, value=DEFAULT_STEPS, step=1, label="Steps") |
|
|
seed_in = gr.Number(label="Seed (-1=random)", value=-1, precision=0) |
|
|
|
|
|
run_btn = gr.Button("Run") |
|
|
|
|
|
out_img = gr.Image(label="Output", type="pil") |
|
|
out_file = gr.File(label="Download result.png") |
|
|
|
|
|
gr.Markdown("### Debug Visualizations (mask/depth/etc)") |
|
|
with gr.Row(): |
|
|
dbg_mask = gr.Image(label="mask_pil", type="pil") |
|
|
dbg_depth = gr.Image(label="depth_map", type="pil") |
|
|
|
|
|
with gr.Row(): |
|
|
dbg_person = gr.Image(label="person_pil", type="pil") |
|
|
dbg_garment = gr.Image(label="garment_pil", type="pil") |
|
|
dbg_gmask = gr.Image(label="garment_mask_pil", type="pil") |
|
|
|
|
|
run_btn.click( |
|
|
fn=infer_web, |
|
|
inputs=[person_in, sketch_in, style_in, prompt_in, steps_in, seed_in], |
|
|
outputs=[out_img, out_file, dbg_mask, dbg_depth, dbg_person, dbg_garment, dbg_gmask], |
|
|
) |
|
|
|
|
|
demo.queue() |
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|
|
|
|