|
|
import os |
|
|
import argparse |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from diffusers import UniPCMultistepScheduler |
|
|
from diffusers3.models.controlnet import ControlNetModel |
|
|
from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import ( |
|
|
StableDiffusionXLControlNetImg2ImgPipeline, |
|
|
) |
|
|
from ip_adapter import IPAdapterXL |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import imageio |
|
|
from PIL import Image, ImageOps |
|
|
from transformers import pipeline |
|
|
from preprocess.simple_extractor import run as run_simple_extractor |
|
|
|
|
|
|
|
|
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" |
|
|
image_encoder_path = "models/image_encoder" |
|
|
ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin" |
|
|
controlnet_path = "diffusers/controlnet-depth-sdxl-1.0" |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float32 |
|
|
|
|
|
DEBUG_SAVE = False |
|
|
DEFAULT_STEPS = 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
H: Optional[int] = None |
|
|
W: Optional[int] = None |
|
|
|
|
|
|
|
|
def compute_hw_from_person(person_path: str): |
|
|
""" |
|
|
person 원본 이미지 기준: |
|
|
- height가 정확히 1024가 되도록 스케일 |
|
|
- aspect ratio 유지 |
|
|
=> H=1024, W=round(orig_w * (1024/orig_h)) |
|
|
""" |
|
|
img = cv2.imread(person_path) |
|
|
if img is None: |
|
|
raise FileNotFoundError(f"cv2.imread failed: {person_path} (exists={os.path.exists(person_path)})") |
|
|
|
|
|
orig_h, orig_w = img.shape[:2] |
|
|
scale = 1024.0 / float(orig_h) |
|
|
new_h = 1024 |
|
|
new_w = int(round(orig_w * scale)) |
|
|
return new_h, new_w |
|
|
|
|
|
|
|
|
controlnet = ControlNetModel.from_pretrained( |
|
|
controlnet_path, |
|
|
variant="fp16", |
|
|
use_safetensors=True, |
|
|
torch_dtype=torch.float32, |
|
|
).to(device) |
|
|
|
|
|
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( |
|
|
base_model_path, |
|
|
controlnet=controlnet, |
|
|
use_safetensors=True, |
|
|
torch_dtype=torch.float32, |
|
|
add_watermarker=False, |
|
|
).to(device) |
|
|
|
|
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
|
|
pipe.enable_attention_slicing() |
|
|
|
|
|
try: |
|
|
pipe.enable_xformers_memory_efficient_attention() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
depth_estimator = pipeline("depth-estimation") |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Paths: |
|
|
person_path: str |
|
|
depth_path: str |
|
|
style_path: str |
|
|
output_path: str |
|
|
|
|
|
|
|
|
def _ensure_exists(path: str, name: str): |
|
|
if not os.path.exists(path): |
|
|
raise FileNotFoundError(f"{name} not found: {path}") |
|
|
|
|
|
|
|
|
def apply_parsing_white_mask_to_person_cv2( |
|
|
person_pil: Image.Image, |
|
|
parsing_img: Image.Image |
|
|
) -> np.ndarray: |
|
|
person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8) |
|
|
mask = np.array(parsing_img.convert("L"), dtype=np.uint8) |
|
|
white_mask = mask == 255 |
|
|
result_rgb = np.full_like(person_rgb, 255, dtype=np.uint8) |
|
|
result_rgb[white_mask] = person_rgb[white_mask] |
|
|
result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) |
|
|
return result_bgr |
|
|
|
|
|
|
|
|
def _imread_or_raise(path: str, flags=None): |
|
|
img = cv2.imread(path, flags) if flags is not None else cv2.imread(path) |
|
|
if img is None: |
|
|
raise FileNotFoundError(f"cv2.imread failed: {path} (exists={os.path.exists(path)})") |
|
|
return img |
|
|
|
|
|
|
|
|
def invert_sketch_area(sketch_area: Image.Image) -> Image.Image: |
|
|
gray = sketch_area.convert("L") |
|
|
arr = np.array(gray, dtype=np.uint8) |
|
|
inverted = 255 - arr |
|
|
return Image.fromarray(inverted, mode="L") |
|
|
|
|
|
|
|
|
def merge_white_regions_or( |
|
|
parsing_img: Image.Image, |
|
|
sketch_area: Image.Image |
|
|
) -> Image.Image: |
|
|
p_img = parsing_img.convert("L") |
|
|
s_img = sketch_area.convert("L") |
|
|
|
|
|
p = np.array(p_img, dtype=np.uint8) |
|
|
s = np.array(s_img, dtype=np.uint8) |
|
|
|
|
|
merged = np.where( |
|
|
(p == 255) | (s == 255), |
|
|
255, |
|
|
0 |
|
|
).astype(np.uint8) |
|
|
|
|
|
return merged |
|
|
|
|
|
|
|
|
def preprocess_mask(mask: np.ndarray) -> Image.Image: |
|
|
|
|
|
height, width = mask.shape |
|
|
total_padding = 1024 - width |
|
|
left_padding = total_padding // 2 |
|
|
right_padding = total_padding - left_padding |
|
|
|
|
|
padded_mask = cv2.copyMakeBorder( |
|
|
mask, 0, 0, left_padding, right_padding, |
|
|
borderType=cv2.BORDER_CONSTANT, |
|
|
value=0, |
|
|
) |
|
|
|
|
|
kernel = np.ones((17, 17), np.uint8) |
|
|
dilated_mask = cv2.dilate(padded_mask, kernel, iterations=1) |
|
|
|
|
|
if DEBUG_SAVE: |
|
|
cv2.imwrite("padded_mask.png", padded_mask) |
|
|
cv2.imwrite("padded_mask_dilated.png", dilated_mask) |
|
|
|
|
|
return Image.fromarray(dilated_mask) |
|
|
|
|
|
|
|
|
def make_depth(depth_path: str) -> Image.Image: |
|
|
global H, W |
|
|
if H is None or W is None: |
|
|
raise RuntimeError("Global H/W not set. Call run_one() first.") |
|
|
|
|
|
depth_img = _imread_or_raise(depth_path, 0) |
|
|
|
|
|
inverted_depth = cv2.bitwise_not(depth_img) |
|
|
contours, _ = cv2.findContours(inverted_depth, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
filled_depth = inverted_depth.copy() |
|
|
cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED) |
|
|
|
|
|
|
|
|
filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA) |
|
|
|
|
|
height, width = filled_depth.shape |
|
|
total_padding = 1024 - width |
|
|
left_padding = total_padding // 2 |
|
|
right_padding = total_padding - left_padding |
|
|
|
|
|
padded_depth = cv2.copyMakeBorder( |
|
|
filled_depth, 0, 0, left_padding, right_padding, |
|
|
borderType=cv2.BORDER_CONSTANT, |
|
|
value=0, |
|
|
) |
|
|
|
|
|
inverted_image = ImageOps.invert(Image.fromarray(padded_depth)) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
image_depth = depth_estimator(inverted_image)["depth"] |
|
|
|
|
|
if DEBUG_SAVE: |
|
|
image_depth.save("depth.png") |
|
|
|
|
|
return image_depth |
|
|
|
|
|
|
|
|
def fill_sketch_from_image_path_to_pil( |
|
|
image_path: str, |
|
|
threshold: int = 127, |
|
|
) -> Image.Image: |
|
|
global H, W |
|
|
if H is None or W is None: |
|
|
raise RuntimeError("Global H/W not set. Call run_one() first.") |
|
|
|
|
|
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) |
|
|
if img is None: |
|
|
raise ValueError(f"이미지를 불러올 수 없습니다: {image_path}") |
|
|
|
|
|
|
|
|
img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
_, binary = cv2.threshold( |
|
|
img, |
|
|
threshold, |
|
|
255, |
|
|
cv2.THRESH_BINARY_INV |
|
|
) |
|
|
|
|
|
contours, _ = cv2.findContours( |
|
|
binary, |
|
|
cv2.RETR_EXTERNAL, |
|
|
cv2.CHAIN_APPROX_SIMPLE |
|
|
) |
|
|
|
|
|
result = np.full_like(img, 255, dtype=np.uint8) |
|
|
|
|
|
cv2.drawContours( |
|
|
result, |
|
|
contours, |
|
|
contourIdx=-1, |
|
|
color=0, |
|
|
thickness=-1 |
|
|
) |
|
|
|
|
|
pil_image = Image.fromarray(result, mode="L") |
|
|
return pil_image |
|
|
|
|
|
|
|
|
def center_crop_lr_to_768x1024(arr: np.ndarray) -> np.ndarray: |
|
|
|
|
|
h, w = arr.shape[:2] |
|
|
target_w, target_h = 700, 1024 |
|
|
if h != target_h: |
|
|
arr = cv2.resize(arr, (w, target_h), interpolation=cv2.INTER_AREA) |
|
|
h, w = arr.shape[:2] |
|
|
if w < target_w: |
|
|
pad = (target_w - w) // 2 |
|
|
arr = cv2.copyMakeBorder(arr, 0, 0, pad, target_w - w - pad, cv2.BORDER_REFLECT_101) |
|
|
h, w = arr.shape[:2] |
|
|
left = (w - target_w) // 2 |
|
|
right = left + target_w |
|
|
return arr[:, left:right] |
|
|
|
|
|
|
|
|
def save_cropped(imgs, out_path: str): |
|
|
np_imgs = [np.asarray(im) for im in imgs] |
|
|
cropped = [center_crop_lr_to_768x1024(x) for x in np_imgs] |
|
|
out = np.concatenate(cropped, axis=1) |
|
|
os.makedirs(os.path.dirname(out_path), exist_ok=True) |
|
|
imageio.imsave(out_path, out) |
|
|
|
|
|
|
|
|
def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS): |
|
|
global H, W |
|
|
|
|
|
category = 'Upper-clothes' |
|
|
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
person_path_abs = os.path.abspath( |
|
|
os.path.join(PROJECT_ROOT, paths.person_path) |
|
|
) |
|
|
|
|
|
|
|
|
H, W = compute_hw_from_person(paths.person_path) |
|
|
|
|
|
print('person_path_abs: ', person_path_abs) |
|
|
print(f'[global] H={H}, W={W} (from person scaled to height=1024)') |
|
|
|
|
|
res = run_simple_extractor( |
|
|
category=category, |
|
|
input_path=person_path_abs, |
|
|
model_restore="./preprocess/ckpts/exp-schp-201908301523-atr.pth" |
|
|
) |
|
|
|
|
|
parsing_img = res["images"][0] if res["images"] else None |
|
|
if parsing_img is None: |
|
|
raise RuntimeError("run_simple_extractor returned no parsing images.") |
|
|
|
|
|
sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path) |
|
|
sketch_area_inv = invert_sketch_area(sketch_area) |
|
|
merged_img = merge_white_regions_or(parsing_img, sketch_area_inv) |
|
|
|
|
|
mask_pil = preprocess_mask(merged_img) |
|
|
|
|
|
_ensure_exists(paths.person_path, "person_path") |
|
|
_ensure_exists(paths.depth_path, "depth_path") |
|
|
_ensure_exists(paths.style_path, "style_path") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
person_bgr = _imread_or_raise(paths.person_path) |
|
|
person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA) |
|
|
if DEBUG_SAVE: |
|
|
cv2.imwrite("person.png", person_bgr) |
|
|
|
|
|
target_width = 1024 |
|
|
padding = (target_width - person_bgr.shape[1]) // 2 |
|
|
padded_person = cv2.copyMakeBorder( |
|
|
person_bgr, |
|
|
top=0, bottom=0, |
|
|
left=padding, right=padding, |
|
|
borderType=cv2.BORDER_CONSTANT, |
|
|
value=[255, 255, 255], |
|
|
) |
|
|
person_rgb = cv2.cvtColor(padded_person, cv2.COLOR_BGR2RGB) |
|
|
person_pil = Image.fromarray(person_rgb) |
|
|
|
|
|
depth_map = make_depth(paths.depth_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
personn = Image.open(paths.person_path) |
|
|
|
|
|
garment_ = apply_parsing_white_mask_to_person_cv2( |
|
|
personn, |
|
|
parsing_img |
|
|
) |
|
|
|
|
|
garment_rgb = cv2.cvtColor(garment_, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA) |
|
|
|
|
|
garment_rgb = cv2.copyMakeBorder( |
|
|
garment_rgb, |
|
|
top=0, bottom=0, |
|
|
left=padding, right=padding, |
|
|
borderType=cv2.BORDER_CONSTANT, |
|
|
value=[255, 255, 255], |
|
|
) |
|
|
garment_pil = Image.fromarray(garment_rgb) |
|
|
|
|
|
if DEBUG_SAVE: |
|
|
garment_pil.save('./garment_pil.png') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
garment_mask_bgr = np.array(parsing_img.convert("L"), dtype=np.uint8) |
|
|
garment_mask_bgr = cv2.resize(garment_mask_bgr, (W, H), interpolation=cv2.INTER_AREA) |
|
|
|
|
|
|
|
|
|
|
|
garment_mask_rgb = cv2.cvtColor(garment_mask_bgr, cv2.COLOR_GRAY2RGB) |
|
|
|
|
|
garment_mask_rgb = cv2.copyMakeBorder( |
|
|
garment_mask_rgb, |
|
|
top=0, bottom=0, |
|
|
left=padding, right=padding, |
|
|
borderType=cv2.BORDER_CONSTANT, |
|
|
value=[0, 0, 0], |
|
|
) |
|
|
garment_mask_pil = Image.fromarray(garment_mask_rgb) |
|
|
|
|
|
if DEBUG_SAVE: |
|
|
garment_mask_pil.save("garment_mask.png") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ip_model = IPAdapterXL( |
|
|
pipe, |
|
|
image_encoder_path, |
|
|
ip_ckpt, |
|
|
device, |
|
|
mask_pil, |
|
|
person_pil, |
|
|
content_scale=0.3, |
|
|
style_scale=0.5, |
|
|
garment_images=garment_pil, |
|
|
garment_mask=garment_mask_pil, |
|
|
) |
|
|
|
|
|
style_img = Image.open(paths.style_path) |
|
|
|
|
|
person_pil.save('./person_pil.png') |
|
|
mask_pil.save('./mask_pil.png') |
|
|
garment_pil.save('./garment_pil.png') |
|
|
garment_mask_pil.save('./garment_mask_pil.png') |
|
|
|
|
|
with torch.inference_mode(): |
|
|
images = ip_model.generate( |
|
|
pil_image=style_img, |
|
|
image=person_pil, |
|
|
control_image=depth_map, |
|
|
strength=1.0, |
|
|
num_samples=1, |
|
|
num_inference_steps=int(steps), |
|
|
shape_prompt="", |
|
|
prompt=prompt or "", |
|
|
num=0, |
|
|
scale=None, |
|
|
controlnet_conditioning_scale=0.7, |
|
|
guidance_scale=7.5, |
|
|
) |
|
|
|
|
|
save_cropped(images, paths.output_path) |
|
|
print(f"Saved: {paths.output_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="FEAT pipeline (single sample, file paths)") |
|
|
|
|
|
parser.add_argument( |
|
|
"--person-path", |
|
|
type=str, |
|
|
default="./DATA_input/Garment/person/1_048392_0.jpg", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--depth-path", |
|
|
type=str, |
|
|
default="./DATA_input/Garment/sketch/1_048392_0.png", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--style-path", |
|
|
type=str, |
|
|
default="./DATA_input/Garment/style/1_00.jpg", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-path", |
|
|
type=str, |
|
|
default="./00.png", |
|
|
) |
|
|
|
|
|
parser.add_argument("--prompt", type=str, default="upper garment", help="single prompt string (optional)") |
|
|
parser.add_argument("--steps", type=int, default=DEFAULT_STEPS) |
|
|
parser.add_argument("--debug-save", action="store_true", help="save debug intermediate images (slow)") |
|
|
args = parser.parse_args() |
|
|
|
|
|
DEBUG_SAVE = bool(args.debug_save) |
|
|
|
|
|
paths = Paths( |
|
|
person_path=args.person_path, |
|
|
depth_path=args.depth_path, |
|
|
style_path=args.style_path, |
|
|
output_path=args.output_path, |
|
|
) |
|
|
|
|
|
run_one(paths, prompt=args.prompt, steps=args.steps) |
|
|
|
|
|
|
|
|
|