VISTA / feat_file.py
ssoxye's picture
Clean Space repo (code only) + gradio app
689a987
import os
import argparse
from dataclasses import dataclass
from typing import Optional
import torch
from diffusers import UniPCMultistepScheduler
from diffusers3.models.controlnet import ControlNetModel
from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import (
StableDiffusionXLControlNetImg2ImgPipeline,
)
from ip_adapter import IPAdapterXL
import cv2
import numpy as np
import imageio
from PIL import Image, ImageOps
from transformers import pipeline
from preprocess.simple_extractor import run as run_simple_extractor
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
image_encoder_path = "models/image_encoder"
ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin"
controlnet_path = "diffusers/controlnet-depth-sdxl-1.0"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float32
DEBUG_SAVE = False
DEFAULT_STEPS = 40
# =========================
# Global resize params (요구사항 반영)
# - person 원본을 height=1024로 맞춘 뒤의 (H,W)를 전역으로 사용
# =========================
H: Optional[int] = None # 항상 1024
W: Optional[int] = None # aspect 유지로 계산된 width
def compute_hw_from_person(person_path: str):
"""
person 원본 이미지 기준:
- height가 정확히 1024가 되도록 스케일
- aspect ratio 유지
=> H=1024, W=round(orig_w * (1024/orig_h))
"""
img = cv2.imread(person_path)
if img is None:
raise FileNotFoundError(f"cv2.imread failed: {person_path} (exists={os.path.exists(person_path)})")
orig_h, orig_w = img.shape[:2]
scale = 1024.0 / float(orig_h)
new_h = 1024
new_w = int(round(orig_w * scale))
return new_h, new_w
controlnet = ControlNetModel.from_pretrained(
controlnet_path,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float32,
).to(device)
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
base_model_path,
controlnet=controlnet,
use_safetensors=True,
torch_dtype=torch.float32,
add_watermarker=False,
).to(device)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_attention_slicing()
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception:
pass
depth_estimator = pipeline("depth-estimation")
@dataclass
class Paths:
person_path: str
depth_path: str
style_path: str
output_path: str
def _ensure_exists(path: str, name: str):
if not os.path.exists(path):
raise FileNotFoundError(f"{name} not found: {path}")
def apply_parsing_white_mask_to_person_cv2(
person_pil: Image.Image,
parsing_img: Image.Image
) -> np.ndarray:
person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8)
mask = np.array(parsing_img.convert("L"), dtype=np.uint8)
white_mask = mask == 255
result_rgb = np.full_like(person_rgb, 255, dtype=np.uint8)
result_rgb[white_mask] = person_rgb[white_mask]
result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
return result_bgr
def _imread_or_raise(path: str, flags=None):
img = cv2.imread(path, flags) if flags is not None else cv2.imread(path)
if img is None:
raise FileNotFoundError(f"cv2.imread failed: {path} (exists={os.path.exists(path)})")
return img
def invert_sketch_area(sketch_area: Image.Image) -> Image.Image:
gray = sketch_area.convert("L")
arr = np.array(gray, dtype=np.uint8)
inverted = 255 - arr
return Image.fromarray(inverted, mode="L")
def merge_white_regions_or(
parsing_img: Image.Image,
sketch_area: Image.Image
) -> Image.Image:
p_img = parsing_img.convert("L")
s_img = sketch_area.convert("L")
p = np.array(p_img, dtype=np.uint8)
s = np.array(s_img, dtype=np.uint8)
merged = np.where(
(p == 255) | (s == 255),
255,
0
).astype(np.uint8)
return merged
def preprocess_mask(mask: np.ndarray) -> Image.Image:
# padding 목표 width는 요구사항대로 "항상 1024" 고정 (원본 그대로)
height, width = mask.shape
total_padding = 1024 - width
left_padding = total_padding // 2
right_padding = total_padding - left_padding
padded_mask = cv2.copyMakeBorder(
mask, 0, 0, left_padding, right_padding,
borderType=cv2.BORDER_CONSTANT,
value=0,
)
kernel = np.ones((17, 17), np.uint8)
dilated_mask = cv2.dilate(padded_mask, kernel, iterations=1)
if DEBUG_SAVE:
cv2.imwrite("padded_mask.png", padded_mask)
cv2.imwrite("padded_mask_dilated.png", dilated_mask)
return Image.fromarray(dilated_mask)
def make_depth(depth_path: str) -> Image.Image:
global H, W
if H is None or W is None:
raise RuntimeError("Global H/W not set. Call run_one() first.")
depth_img = _imread_or_raise(depth_path, 0)
inverted_depth = cv2.bitwise_not(depth_img)
contours, _ = cv2.findContours(inverted_depth, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
filled_depth = inverted_depth.copy()
cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
# ✅ resize는 전역 (W,H)
filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
height, width = filled_depth.shape
total_padding = 1024 - width
left_padding = total_padding // 2
right_padding = total_padding - left_padding
padded_depth = cv2.copyMakeBorder(
filled_depth, 0, 0, left_padding, right_padding,
borderType=cv2.BORDER_CONSTANT,
value=0,
)
inverted_image = ImageOps.invert(Image.fromarray(padded_depth))
with torch.inference_mode():
image_depth = depth_estimator(inverted_image)["depth"]
if DEBUG_SAVE:
image_depth.save("depth.png")
return image_depth
def fill_sketch_from_image_path_to_pil(
image_path: str,
threshold: int = 127,
) -> Image.Image:
global H, W
if H is None or W is None:
raise RuntimeError("Global H/W not set. Call run_one() first.")
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
if img is None:
raise ValueError(f"이미지를 불러올 수 없습니다: {image_path}")
# ✅ resize는 전역 (W,H)
img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST)
_, binary = cv2.threshold(
img,
threshold,
255,
cv2.THRESH_BINARY_INV
)
contours, _ = cv2.findContours(
binary,
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE
)
result = np.full_like(img, 255, dtype=np.uint8)
cv2.drawContours(
result,
contours,
contourIdx=-1,
color=0,
thickness=-1
)
pil_image = Image.fromarray(result, mode="L")
return pil_image
def center_crop_lr_to_768x1024(arr: np.ndarray) -> np.ndarray:
# 원본 유지
h, w = arr.shape[:2]
target_w, target_h = 700, 1024
if h != target_h:
arr = cv2.resize(arr, (w, target_h), interpolation=cv2.INTER_AREA)
h, w = arr.shape[:2]
if w < target_w:
pad = (target_w - w) // 2
arr = cv2.copyMakeBorder(arr, 0, 0, pad, target_w - w - pad, cv2.BORDER_REFLECT_101)
h, w = arr.shape[:2]
left = (w - target_w) // 2
right = left + target_w
return arr[:, left:right]
def save_cropped(imgs, out_path: str):
np_imgs = [np.asarray(im) for im in imgs]
cropped = [center_crop_lr_to_768x1024(x) for x in np_imgs]
out = np.concatenate(cropped, axis=1)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
imageio.imsave(out_path, out)
def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
global H, W
category = 'Upper-clothes'
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
person_path_abs = os.path.abspath(
os.path.join(PROJECT_ROOT, paths.person_path)
)
# ✅ 전역 H/W 세팅: person 원본을 height=1024로 맞췄을 때의 (H,W)
H, W = compute_hw_from_person(paths.person_path)
print('person_path_abs: ', person_path_abs)
print(f'[global] H={H}, W={W} (from person scaled to height=1024)')
res = run_simple_extractor(
category=category,
input_path=person_path_abs,
model_restore="./preprocess/ckpts/exp-schp-201908301523-atr.pth"
)
parsing_img = res["images"][0] if res["images"] else None
if parsing_img is None:
raise RuntimeError("run_simple_extractor returned no parsing images.")
sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path)
sketch_area_inv = invert_sketch_area(sketch_area)
merged_img = merge_white_regions_or(parsing_img, sketch_area_inv)
mask_pil = preprocess_mask(merged_img)
_ensure_exists(paths.person_path, "person_path")
_ensure_exists(paths.depth_path, "depth_path")
_ensure_exists(paths.style_path, "style_path")
# =========================
# person: resize는 (W,H)
# padding 목표 width=1024는 원본 그대로
# =========================
person_bgr = _imread_or_raise(paths.person_path)
person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA)
if DEBUG_SAVE:
cv2.imwrite("person.png", person_bgr)
target_width = 1024 # ✅ 고정
padding = (target_width - person_bgr.shape[1]) // 2
padded_person = cv2.copyMakeBorder(
person_bgr,
top=0, bottom=0,
left=padding, right=padding,
borderType=cv2.BORDER_CONSTANT,
value=[255, 255, 255],
)
person_rgb = cv2.cvtColor(padded_person, cv2.COLOR_BGR2RGB)
person_pil = Image.fromarray(person_rgb)
depth_map = make_depth(paths.depth_path)
# =========================
# garment: 원본 로직 유지 (다만 parsing/mask 크기 맞추려고 아래에서 resize (W,H) 적용)
# =========================
personn = Image.open(paths.person_path)
garment_ = apply_parsing_white_mask_to_person_cv2(
personn,
parsing_img
)
garment_rgb = cv2.cvtColor(garment_, cv2.COLOR_BGR2RGB)
# ✅ (중요) garment_는 원본 person 크기일 수 있으니 전역 (W,H)로 맞춘 뒤 padding
garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA)
garment_rgb = cv2.copyMakeBorder(
garment_rgb,
top=0, bottom=0,
left=padding, right=padding,
borderType=cv2.BORDER_CONSTANT,
value=[255, 255, 255],
)
garment_pil = Image.fromarray(garment_rgb)
if DEBUG_SAVE:
garment_pil.save('./garment_pil.png')
# =========================
# garment mask: resize는 (W,H), padding 목표 width=1024 고정
# =========================
garment_mask_bgr = np.array(parsing_img.convert("L"), dtype=np.uint8)
garment_mask_bgr = cv2.resize(garment_mask_bgr, (W, H), interpolation=cv2.INTER_AREA)
# 원본 코드의 실수였던 BGR2RGB를 그대로 두면 에러 가능성이 있어서,
# 여기만 "GRAY2RGB"로 안전하게 바꿔줌 (입력 shape이 2D라 BGR2RGB는 예외 발생 가능)
garment_mask_rgb = cv2.cvtColor(garment_mask_bgr, cv2.COLOR_GRAY2RGB)
garment_mask_rgb = cv2.copyMakeBorder(
garment_mask_rgb,
top=0, bottom=0,
left=padding, right=padding,
borderType=cv2.BORDER_CONSTANT,
value=[0, 0, 0],
)
garment_mask_pil = Image.fromarray(garment_mask_rgb)
if DEBUG_SAVE:
garment_mask_pil.save("garment_mask.png")
# =========================
# IPAdapterXL 생성/호출: 원본 그대로 유지 (여기서 에러 나면 안 됨)
# =========================
ip_model = IPAdapterXL(
pipe,
image_encoder_path,
ip_ckpt,
device,
mask_pil,
person_pil,
content_scale=0.3,
style_scale=0.5,
garment_images=garment_pil,
garment_mask=garment_mask_pil,
)
style_img = Image.open(paths.style_path)
person_pil.save('./person_pil.png')
mask_pil.save('./mask_pil.png')
garment_pil.save('./garment_pil.png')
garment_mask_pil.save('./garment_mask_pil.png')
with torch.inference_mode():
images = ip_model.generate(
pil_image=style_img,
image=person_pil,
control_image=depth_map,
strength=1.0,
num_samples=1,
num_inference_steps=int(steps),
shape_prompt="",
prompt=prompt or "",
num=0,
scale=None, # ✅ 원본 그대로 (set_scale 관련 에러 방지 핵심)
controlnet_conditioning_scale=0.7,
guidance_scale=7.5,
)
save_cropped(images, paths.output_path)
print(f"Saved: {paths.output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FEAT pipeline (single sample, file paths)")
parser.add_argument(
"--person-path",
type=str,
default="./DATA_input/Garment/person/1_048392_0.jpg",
)
parser.add_argument(
"--depth-path",
type=str,
default="./DATA_input/Garment/sketch/1_048392_0.png",
)
parser.add_argument(
"--style-path",
type=str,
default="./DATA_input/Garment/style/1_00.jpg",
)
parser.add_argument(
"--output-path",
type=str,
default="./00.png",
)
parser.add_argument("--prompt", type=str, default="upper garment", help="single prompt string (optional)")
parser.add_argument("--steps", type=int, default=DEFAULT_STEPS)
parser.add_argument("--debug-save", action="store_true", help="save debug intermediate images (slow)")
args = parser.parse_args()
DEBUG_SAVE = bool(args.debug_save)
paths = Paths(
person_path=args.person_path,
depth_path=args.depth_path,
style_path=args.style_path,
output_path=args.output_path,
)
run_one(paths, prompt=args.prompt, steps=args.steps)