import os import argparse from dataclasses import dataclass from typing import Optional import torch from diffusers import UniPCMultistepScheduler from diffusers3.models.controlnet import ControlNetModel from diffusers3.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img_img import ( StableDiffusionXLControlNetImg2ImgPipeline, ) from ip_adapter import IPAdapterXL import cv2 import numpy as np import imageio from PIL import Image, ImageOps from transformers import pipeline from preprocess.simple_extractor import run as run_simple_extractor base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" image_encoder_path = "models/image_encoder" ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin" controlnet_path = "diffusers/controlnet-depth-sdxl-1.0" device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float32 DEBUG_SAVE = False DEFAULT_STEPS = 40 # ========================= # Global resize params (요구사항 반영) # - person 원본을 height=1024로 맞춘 뒤의 (H,W)를 전역으로 사용 # ========================= H: Optional[int] = None # 항상 1024 W: Optional[int] = None # aspect 유지로 계산된 width def compute_hw_from_person(person_path: str): """ person 원본 이미지 기준: - height가 정확히 1024가 되도록 스케일 - aspect ratio 유지 => H=1024, W=round(orig_w * (1024/orig_h)) """ img = cv2.imread(person_path) if img is None: raise FileNotFoundError(f"cv2.imread failed: {person_path} (exists={os.path.exists(person_path)})") orig_h, orig_w = img.shape[:2] scale = 1024.0 / float(orig_h) new_h = 1024 new_w = int(round(orig_w * scale)) return new_h, new_w controlnet = ControlNetModel.from_pretrained( controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=torch.float32, ).to(device) pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( base_model_path, controlnet=controlnet, use_safetensors=True, torch_dtype=torch.float32, add_watermarker=False, ).to(device) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_attention_slicing() try: pipe.enable_xformers_memory_efficient_attention() except Exception: pass depth_estimator = pipeline("depth-estimation") @dataclass class Paths: person_path: str depth_path: str style_path: str output_path: str def _ensure_exists(path: str, name: str): if not os.path.exists(path): raise FileNotFoundError(f"{name} not found: {path}") def apply_parsing_white_mask_to_person_cv2( person_pil: Image.Image, parsing_img: Image.Image ) -> np.ndarray: person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8) mask = np.array(parsing_img.convert("L"), dtype=np.uint8) white_mask = mask == 255 result_rgb = np.full_like(person_rgb, 255, dtype=np.uint8) result_rgb[white_mask] = person_rgb[white_mask] result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) return result_bgr def _imread_or_raise(path: str, flags=None): img = cv2.imread(path, flags) if flags is not None else cv2.imread(path) if img is None: raise FileNotFoundError(f"cv2.imread failed: {path} (exists={os.path.exists(path)})") return img def invert_sketch_area(sketch_area: Image.Image) -> Image.Image: gray = sketch_area.convert("L") arr = np.array(gray, dtype=np.uint8) inverted = 255 - arr return Image.fromarray(inverted, mode="L") def merge_white_regions_or( parsing_img: Image.Image, sketch_area: Image.Image ) -> Image.Image: p_img = parsing_img.convert("L") s_img = sketch_area.convert("L") p = np.array(p_img, dtype=np.uint8) s = np.array(s_img, dtype=np.uint8) merged = np.where( (p == 255) | (s == 255), 255, 0 ).astype(np.uint8) return merged def preprocess_mask(mask: np.ndarray) -> Image.Image: # padding 목표 width는 요구사항대로 "항상 1024" 고정 (원본 그대로) height, width = mask.shape total_padding = 1024 - width left_padding = total_padding // 2 right_padding = total_padding - left_padding padded_mask = cv2.copyMakeBorder( mask, 0, 0, left_padding, right_padding, borderType=cv2.BORDER_CONSTANT, value=0, ) kernel = np.ones((17, 17), np.uint8) dilated_mask = cv2.dilate(padded_mask, kernel, iterations=1) if DEBUG_SAVE: cv2.imwrite("padded_mask.png", padded_mask) cv2.imwrite("padded_mask_dilated.png", dilated_mask) return Image.fromarray(dilated_mask) def make_depth(depth_path: str) -> Image.Image: global H, W if H is None or W is None: raise RuntimeError("Global H/W not set. Call run_one() first.") depth_img = _imread_or_raise(depth_path, 0) inverted_depth = cv2.bitwise_not(depth_img) contours, _ = cv2.findContours(inverted_depth, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) filled_depth = inverted_depth.copy() cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED) # ✅ resize는 전역 (W,H) filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA) height, width = filled_depth.shape total_padding = 1024 - width left_padding = total_padding // 2 right_padding = total_padding - left_padding padded_depth = cv2.copyMakeBorder( filled_depth, 0, 0, left_padding, right_padding, borderType=cv2.BORDER_CONSTANT, value=0, ) inverted_image = ImageOps.invert(Image.fromarray(padded_depth)) with torch.inference_mode(): image_depth = depth_estimator(inverted_image)["depth"] if DEBUG_SAVE: image_depth.save("depth.png") return image_depth def fill_sketch_from_image_path_to_pil( image_path: str, threshold: int = 127, ) -> Image.Image: global H, W if H is None or W is None: raise RuntimeError("Global H/W not set. Call run_one() first.") img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) if img is None: raise ValueError(f"이미지를 불러올 수 없습니다: {image_path}") # ✅ resize는 전역 (W,H) img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST) _, binary = cv2.threshold( img, threshold, 255, cv2.THRESH_BINARY_INV ) contours, _ = cv2.findContours( binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) result = np.full_like(img, 255, dtype=np.uint8) cv2.drawContours( result, contours, contourIdx=-1, color=0, thickness=-1 ) pil_image = Image.fromarray(result, mode="L") return pil_image def center_crop_lr_to_768x1024(arr: np.ndarray) -> np.ndarray: # 원본 유지 h, w = arr.shape[:2] target_w, target_h = 700, 1024 if h != target_h: arr = cv2.resize(arr, (w, target_h), interpolation=cv2.INTER_AREA) h, w = arr.shape[:2] if w < target_w: pad = (target_w - w) // 2 arr = cv2.copyMakeBorder(arr, 0, 0, pad, target_w - w - pad, cv2.BORDER_REFLECT_101) h, w = arr.shape[:2] left = (w - target_w) // 2 right = left + target_w return arr[:, left:right] def save_cropped(imgs, out_path: str): np_imgs = [np.asarray(im) for im in imgs] cropped = [center_crop_lr_to_768x1024(x) for x in np_imgs] out = np.concatenate(cropped, axis=1) os.makedirs(os.path.dirname(out_path), exist_ok=True) imageio.imsave(out_path, out) def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS): global H, W category = 'Upper-clothes' PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) person_path_abs = os.path.abspath( os.path.join(PROJECT_ROOT, paths.person_path) ) # ✅ 전역 H/W 세팅: person 원본을 height=1024로 맞췄을 때의 (H,W) H, W = compute_hw_from_person(paths.person_path) print('person_path_abs: ', person_path_abs) print(f'[global] H={H}, W={W} (from person scaled to height=1024)') res = run_simple_extractor( category=category, input_path=person_path_abs, model_restore="./preprocess/ckpts/exp-schp-201908301523-atr.pth" ) parsing_img = res["images"][0] if res["images"] else None if parsing_img is None: raise RuntimeError("run_simple_extractor returned no parsing images.") sketch_area = fill_sketch_from_image_path_to_pil(paths.depth_path) sketch_area_inv = invert_sketch_area(sketch_area) merged_img = merge_white_regions_or(parsing_img, sketch_area_inv) mask_pil = preprocess_mask(merged_img) _ensure_exists(paths.person_path, "person_path") _ensure_exists(paths.depth_path, "depth_path") _ensure_exists(paths.style_path, "style_path") # ========================= # person: resize는 (W,H) # padding 목표 width=1024는 원본 그대로 # ========================= person_bgr = _imread_or_raise(paths.person_path) person_bgr = cv2.resize(person_bgr, (W, H), interpolation=cv2.INTER_AREA) if DEBUG_SAVE: cv2.imwrite("person.png", person_bgr) target_width = 1024 # ✅ 고정 padding = (target_width - person_bgr.shape[1]) // 2 padded_person = cv2.copyMakeBorder( person_bgr, top=0, bottom=0, left=padding, right=padding, borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255], ) person_rgb = cv2.cvtColor(padded_person, cv2.COLOR_BGR2RGB) person_pil = Image.fromarray(person_rgb) depth_map = make_depth(paths.depth_path) # ========================= # garment: 원본 로직 유지 (다만 parsing/mask 크기 맞추려고 아래에서 resize (W,H) 적용) # ========================= personn = Image.open(paths.person_path) garment_ = apply_parsing_white_mask_to_person_cv2( personn, parsing_img ) garment_rgb = cv2.cvtColor(garment_, cv2.COLOR_BGR2RGB) # ✅ (중요) garment_는 원본 person 크기일 수 있으니 전역 (W,H)로 맞춘 뒤 padding garment_rgb = cv2.resize(garment_rgb, (W, H), interpolation=cv2.INTER_AREA) garment_rgb = cv2.copyMakeBorder( garment_rgb, top=0, bottom=0, left=padding, right=padding, borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255], ) garment_pil = Image.fromarray(garment_rgb) if DEBUG_SAVE: garment_pil.save('./garment_pil.png') # ========================= # garment mask: resize는 (W,H), padding 목표 width=1024 고정 # ========================= garment_mask_bgr = np.array(parsing_img.convert("L"), dtype=np.uint8) garment_mask_bgr = cv2.resize(garment_mask_bgr, (W, H), interpolation=cv2.INTER_AREA) # 원본 코드의 실수였던 BGR2RGB를 그대로 두면 에러 가능성이 있어서, # 여기만 "GRAY2RGB"로 안전하게 바꿔줌 (입력 shape이 2D라 BGR2RGB는 예외 발생 가능) garment_mask_rgb = cv2.cvtColor(garment_mask_bgr, cv2.COLOR_GRAY2RGB) garment_mask_rgb = cv2.copyMakeBorder( garment_mask_rgb, top=0, bottom=0, left=padding, right=padding, borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0], ) garment_mask_pil = Image.fromarray(garment_mask_rgb) if DEBUG_SAVE: garment_mask_pil.save("garment_mask.png") # ========================= # IPAdapterXL 생성/호출: 원본 그대로 유지 (여기서 에러 나면 안 됨) # ========================= ip_model = IPAdapterXL( pipe, image_encoder_path, ip_ckpt, device, mask_pil, person_pil, content_scale=0.3, style_scale=0.5, garment_images=garment_pil, garment_mask=garment_mask_pil, ) style_img = Image.open(paths.style_path) person_pil.save('./person_pil.png') mask_pil.save('./mask_pil.png') garment_pil.save('./garment_pil.png') garment_mask_pil.save('./garment_mask_pil.png') with torch.inference_mode(): images = ip_model.generate( pil_image=style_img, image=person_pil, control_image=depth_map, strength=1.0, num_samples=1, num_inference_steps=int(steps), shape_prompt="", prompt=prompt or "", num=0, scale=None, # ✅ 원본 그대로 (set_scale 관련 에러 방지 핵심) controlnet_conditioning_scale=0.7, guidance_scale=7.5, ) save_cropped(images, paths.output_path) print(f"Saved: {paths.output_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="FEAT pipeline (single sample, file paths)") parser.add_argument( "--person-path", type=str, default="./DATA_input/Garment/person/1_048392_0.jpg", ) parser.add_argument( "--depth-path", type=str, default="./DATA_input/Garment/sketch/1_048392_0.png", ) parser.add_argument( "--style-path", type=str, default="./DATA_input/Garment/style/1_00.jpg", ) parser.add_argument( "--output-path", type=str, default="./00.png", ) parser.add_argument("--prompt", type=str, default="upper garment", help="single prompt string (optional)") parser.add_argument("--steps", type=int, default=DEFAULT_STEPS) parser.add_argument("--debug-save", action="store_true", help="save debug intermediate images (slow)") args = parser.parse_args() DEBUG_SAVE = bool(args.debug_save) paths = Paths( person_path=args.person_path, depth_path=args.depth_path, style_path=args.style_path, output_path=args.output_path, ) run_one(paths, prompt=args.prompt, steps=args.steps)