| import os |
| import argparse |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import torch |
| from diffusers import UniPCMultistepScheduler |
| from diffusers3.models.controlnet import ControlNetModel |
| from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import ( |
| StableDiffusionXLControlNetImg2ImgPipeline, |
| ) |
| from ip_adapter import IPAdapterXL |
|
|
| import cv2 |
| import numpy as np |
| import imageio |
| from PIL import Image, ImageOps |
| from transformers import pipeline |
| from preprocess.simple_extractor import run as run_simple_extractor |
|
|
|
|
| base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" |
| image_encoder_path = "models/image_encoder" |
| ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin" |
| controlnet_path = "diffusers/controlnet-depth-sdxl-1.0" |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| dtype = torch.float32 |
|
|
| DEBUG_SAVE = False |
| DEFAULT_STEPS = 40 |
|
|
| |
| |
| |
| |
| H: Optional[int] = None |
| W: Optional[int] = None |
|
|
|
|
| def compute_hw_from_person(person_path: str): |
| """ |
| person ์๋ณธ ์ด๋ฏธ์ง ๊ธฐ์ค: |
| - height๊ฐ ์ ํํ 1024๊ฐ ๋๋๋ก ์ค์ผ์ผ |
| - aspect ratio ์ ์ง |
| => H=1024, W=round(orig_w * (1024/orig_h)) |
| """ |
| img = cv2.imread(person_path) |
| if img is None: |
| raise FileNotFoundError(f"cv2.imread failed: {person_path} (exists={os.path.exists(person_path)})") |
|
|
| orig_h, orig_w = img.shape[:2] |
| scale = 1024.0 / float(orig_h) |
| new_h = 1024 |
| new_w = int(round(orig_w * scale)) |
| return new_h, new_w |
|
|
|
|
| controlnet = ControlNetModel.from_pretrained( |
| controlnet_path, |
| variant="fp16", |
| use_safetensors=True, |
| torch_dtype=torch.float32, |
| ).to(device) |
|
|
| pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( |
| base_model_path, |
| controlnet=controlnet, |
| use_safetensors=True, |
| torch_dtype=torch.float32, |
| add_watermarker=False, |
| ).to(device) |
|
|
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
| pipe.enable_attention_slicing() |
|
|
| try: |
| pipe.enable_xformers_memory_efficient_attention() |
| except Exception: |
| pass |
|
|
| depth_estimator = pipeline("depth-estimation") |
|
|
|
|
| @dataclass |
| class Paths: |
| person_path: str |
| depth_path: str |
| style_path: str |
| output_path: str |
|
|
|
|
| def _ensure_exists(path: str, name: str): |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"{name} not found: {path}") |
|
|
|
|
| def apply_parsing_white_mask_to_person_cv2( |
| person_pil: Image.Image, |
| parsing_img: Image.Image |
| ) -> np.ndarray: |
| person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8) |
| mask = np.array(parsing_img.convert("L"), dtype=np.uint8) |
| white_mask = mask == 255 |
| result_rgb = np.full_like(person_rgb, 255, dtype=np.uint8) |
| result_rgb[white_mask] = person_rgb[white_mask] |
| result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) |
| return result_bgr |
|
|
|
|
| def _imread_or_raise(path: str, flags=None): |
| img = cv2.imread(path, flags) if flags is not None else cv2.imread(path) |
| if img is None: |
| raise FileNotFoundError(f"cv2.imread failed: {path} (exists={os.path.exists(path)})") |
| return img |
|
|
|
|
| def invert_sketch_area(sketch_area: Image.Image) -> Image.Image: |
| gray = sketch_area.convert("L") |
| arr = np.array(gray, dtype=np.uint8) |
| inverted = 255 - arr |
| return Image.fromarray(inverted, mode="L") |
|
|
|
|
| def merge_white_regions_or( |
| parsing_img: Image.Image, |
| sketch_area: Image.Image |
| ) -> Image.Image: |
| p_img = parsing_img.convert("L") |
| s_img = sketch_area.convert("L") |
|
|
| p = np.array(p_img, dtype=np.uint8) |
| s = np.array(s_img, dtype=np.uint8) |
|
|
| merged = np.where( |
| (p == 255) | (s == 255), |
| 255, |
| 0 |
| ).astype(np.uint8) |
|
|
| return merged |
|
|
|
|
| def preprocess_mask(mask: np.ndarray) -> Image.Image: |
| |
| height, width = mask.shape |
| total_padding = 1024 - width |
| left_padding = total_padding // 2 |
| right_padding = total_padding - left_padding |
|
|
| padded_mask = cv2.copyMakeBorder( |
| mask, 0, 0, left_padding, right_padding, |
| borderType=cv2.BORDER_CONSTANT, |
| value=0, |
| ) |
|
|
| kernel = np.ones((17, 17), np.uint8) |
| dilated_mask = cv2.dilate(padded_mask, kernel, iterations=1) |
|
|
| if DEBUG_SAVE: |
| cv2.imwrite("padded_mask.png", padded_mask) |
| cv2.imwrite("padded_mask_dilated.png", dilated_mask) |
|
|
| return Image.fromarray(dilated_mask) |
|
|
|
|
| def make_depth(depth_path: str) -> Image.Image: |
| global H, W |
| if H is None or W is None: |
| raise RuntimeError("Global H/W not set. Call run_one() first.") |
|
|
| depth_img = _imread_or_raise(depth_path, 0) |
|
|
| inverted_depth = cv2.bitwise_not(depth_img) |
| contours, _ = cv2.findContours(inverted_depth, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
| filled_depth = inverted_depth.copy() |
| cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED) |
|
|
| |
| filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA) |
|
|
| height, width = filled_depth.shape |
| total_padding = 1024 - width |
| left_padding = total_padding // 2 |
| right_padding = total_padding - left_padding |
|
|
| padded_depth = cv2.copyMakeBorder( |
| filled_depth, 0, 0, left_padding, right_padding, |
| borderType=cv2.BORDER_CONSTANT, |
| value=0, |
| ) |
|
|
| inverted_image = ImageOps.invert(Image.fromarray(padded_depth)) |
|
|
| with torch.inference_mode(): |
| image_depth = depth_estimator(inverted_image)["depth"] |
|
|
| if DEBUG_SAVE: |
| image_depth.save("depth.png") |
|
|
| return image_depth |
|
|
|
|
| def fill_sketch_from_image_path_to_pil( |
| image_path: str, |
| threshold: int = 127, |
| ) -> Image.Image: |
| global H, W |
| if H is None or W is None: |
| raise RuntimeError("Global H/W not set. Call run_one() first.") |
|
|
| img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) |
| if img is None: |
| raise ValueError(f"์ด๋ฏธ์ง๋ฅผ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค: {image_path}") |
|
|
| |
| img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST) |
|
|
| _, binary = cv2.threshold( |
| img, |
| threshold, |
| 255, |
| cv2.THRESH_BINARY_INV |
| ) |
|
|
| contours, _ = cv2.findContours( |
| binary, |
| cv2.RETR_EXTERNAL, |
| cv2.CHAIN_APPROX_SIMPLE |
| ) |
|
|
| result = np.full_like(img, 255, dtype=np.uint8) |
|
|
| cv2.drawContours( |
| result, |
| contours, |
| contourIdx=-1, |
| color=0, |
| thickness=-1 |
| ) |
|
|
| pil_image = Image.fromarray(result, mode="L") |
| return pil_image |
|
|
|
|
| def center_crop_lr_to_768x1024(arr: np.ndarray) -> np.ndarray: |
| |
| h, w = arr.shape[:2] |
| target_w, target_h = 700, 1024 |
| if h != target_h: |
| arr = cv2.resize(arr, (w, target_h), interpolation=cv2.INTER_AREA) |
| h, w = arr.shape[:2] |
| if w < target_w: |
| pad = (target_w - w) // 2 |
| arr = cv2.copyMakeBorder(arr, 0, 0, pad, target_w - w - pad, cv2.BORDER_REFLECT_101) |
| h, w = arr.shape[:2] |
| left = (w - target_w) // 2 |
| right = left + target_w |
| return arr[:, left:right] |
|
|
|
|
| def save_cropped(imgs, out_path: str): |
| np_imgs = [np.asarray(im) for im in imgs] |
| cropped = [center_crop_lr_to_768x1024(x) for x in np_imgs] |
| out = np.concatenate(cropped, axis=1) |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) |
| imageio.imsave(out_path, out) |
|
|
|
|
| def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS): |
| global H, W |
|
|
| category = 'Upper-clothes' |
| PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) |
|
|
| person_path_abs = os.path.abspath( |
| os.path.join(PROJECT_ROOT, paths.person_path) |
| ) |
|
|
| |
| H, W = compute_hw_from_person(paths.person_path) |
|
|
| print('person_path_abs: ', person_path_abs) |
| print(f'[global] H={H}, W={W} (from person scaled to height=1024)') |
|
|
| res = run_simple_extractor( |
| category=category, |
| input_path=person_path_abs, |
| model_restore="./preprocess/ckpts/exp-schp-201908301523-atr.pth" |
| ) |
|
|
| parsing_img = res["images"][0] if res["images"] else None |
| if parsing_img is None: |
| raise RuntimeError("run_simple_extractor returned no parsing images.") |
|
|
| sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path) |
| sketch_area_inv = invert_sketch_area(sketch_area) |
| merged_img = merge_white_regions_or(parsing_img, sketch_area_inv) |
|
|
| mask_pil = preprocess_mask(merged_img) |
|
|
| _ensure_exists(paths.person_path, "person_path") |
| _ensure_exists(paths.depth_path, "depth_path") |
| _ensure_exists(paths.style_path, "style_path") |
|
|
| |
| |
| |
| |
| person_bgr = _imread_or_raise(paths.person_path) |
| person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA) |
| if DEBUG_SAVE: |
| cv2.imwrite("person.png", person_bgr) |
|
|
| target_width = 1024 |
| padding = (target_width - person_bgr.shape[1]) // 2 |
| padded_person = cv2.copyMakeBorder( |
| person_bgr, |
| top=0, bottom=0, |
| left=padding, right=padding, |
| borderType=cv2.BORDER_CONSTANT, |
| value=[255, 255, 255], |
| ) |
| person_rgb = cv2.cvtColor(padded_person, cv2.COLOR_BGR2RGB) |
| person_pil = Image.fromarray(person_rgb) |
|
|
| depth_map = make_depth(paths.depth_path) |
|
|
| |
| |
| |
| personn = Image.open(paths.person_path) |
|
|
| garment_ = apply_parsing_white_mask_to_person_cv2( |
| personn, |
| parsing_img |
| ) |
|
|
| garment_rgb = cv2.cvtColor(garment_, cv2.COLOR_BGR2RGB) |
|
|
| |
| garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA) |
|
|
| garment_rgb = cv2.copyMakeBorder( |
| garment_rgb, |
| top=0, bottom=0, |
| left=padding, right=padding, |
| borderType=cv2.BORDER_CONSTANT, |
| value=[255, 255, 255], |
| ) |
| garment_pil = Image.fromarray(garment_rgb) |
|
|
| if DEBUG_SAVE: |
| garment_pil.save('./garment_pil.png') |
|
|
| |
| |
| |
| garment_mask_bgr = np.array(parsing_img.convert("L"), dtype=np.uint8) |
| garment_mask_bgr = cv2.resize(garment_mask_bgr, (W, H), interpolation=cv2.INTER_AREA) |
|
|
| |
| |
| garment_mask_rgb = cv2.cvtColor(garment_mask_bgr, cv2.COLOR_GRAY2RGB) |
|
|
| garment_mask_rgb = cv2.copyMakeBorder( |
| garment_mask_rgb, |
| top=0, bottom=0, |
| left=padding, right=padding, |
| borderType=cv2.BORDER_CONSTANT, |
| value=[0, 0, 0], |
| ) |
| garment_mask_pil = Image.fromarray(garment_mask_rgb) |
|
|
| if DEBUG_SAVE: |
| garment_mask_pil.save("garment_mask.png") |
|
|
| |
| |
| |
| ip_model = IPAdapterXL( |
| pipe, |
| image_encoder_path, |
| ip_ckpt, |
| device, |
| mask_pil, |
| person_pil, |
| content_scale=0.3, |
| style_scale=0.5, |
| garment_images=garment_pil, |
| garment_mask=garment_mask_pil, |
| ) |
|
|
| style_img = Image.open(paths.style_path) |
| |
| person_pil.save('./person_pil.png') |
| mask_pil.save('./mask_pil.png') |
| garment_pil.save('./garment_pil.png') |
| garment_mask_pil.save('./garment_mask_pil.png') |
|
|
| with torch.inference_mode(): |
| images = ip_model.generate( |
| pil_image=style_img, |
| image=person_pil, |
| control_image=depth_map, |
| strength=1.0, |
| num_samples=1, |
| num_inference_steps=int(steps), |
| shape_prompt="", |
| prompt=prompt or "", |
| num=0, |
| scale=None, |
| controlnet_conditioning_scale=0.7, |
| guidance_scale=7.5, |
| ) |
|
|
| save_cropped(images, paths.output_path) |
| print(f"Saved: {paths.output_path}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="FEAT pipeline (single sample, file paths)") |
|
|
| parser.add_argument( |
| "--person-path", |
| type=str, |
| default="./DATA_input/Garment/person/1_048392_0.jpg", |
| ) |
| parser.add_argument( |
| "--depth-path", |
| type=str, |
| default="./DATA_input/Garment/sketch/1_048392_0.png", |
| ) |
| parser.add_argument( |
| "--style-path", |
| type=str, |
| default="./DATA_input/Garment/style/1_00.jpg", |
| ) |
| parser.add_argument( |
| "--output-path", |
| type=str, |
| default="./00.png", |
| ) |
|
|
| parser.add_argument("--prompt", type=str, default="upper garment", help="single prompt string (optional)") |
| parser.add_argument("--steps", type=int, default=DEFAULT_STEPS) |
| parser.add_argument("--debug-save", action="store_true", help="save debug intermediate images (slow)") |
| args = parser.parse_args() |
|
|
| DEBUG_SAVE = bool(args.debug_save) |
|
|
| paths = Paths( |
| person_path=args.person_path, |
| depth_path=args.depth_path, |
| style_path=args.style_path, |
| output_path=args.output_path, |
| ) |
|
|
| run_one(paths, prompt=args.prompt, steps=args.steps) |
|
|
|
|
|
|