DCI-VTON / preprocess_onepair.py
venbab's picture
Update preprocess_onepair.py
40bf3f2 verified
# preprocess_onepair.py
from __future__ import annotations
import io, os, shutil
from pathlib import Path
from typing import Tuple
import numpy as np
from PIL import Image
import cv2
from rembg import remove
import mediapipe as mp
_mp_seg = mp.solutions.selfie_segmentation.SelfieSegmentation(model_selection=1)
def _to_pil(img) -> Image.Image:
if isinstance(img, Image.Image): return img
if isinstance(img, (str, os.PathLike)): return Image.open(img).convert("RGB")
if isinstance(img, bytes): return Image.open(io.BytesIO(img)).convert("RGB")
raise TypeError("Unsupported image type")
def _resize_pad(im: Image.Image, size: Tuple[int,int]=(512,512)) -> Image.Image:
w,h = im.size; tw,th = size
scale = min(tw/w, th/h)
nw,nh = int(w*scale), int(h*scale)
im2 = im.resize((nw,nh), Image.BICUBIC)
canvas = Image.new("RGB",(tw,th),(255,255,255))
canvas.paste(im2,((tw-nw)//2,(th-nh)//2))
return canvas
def _cloth_edge(garment_rgb: Image.Image) -> Image.Image:
arr = np.array(garment_rgb)
cut = remove(arr)
alpha = cut[:,:,3] if cut.shape[2]==4 else np.ones(arr.shape[:2],dtype=np.uint8)*255
edge = np.zeros_like(alpha,dtype=np.uint8); edge[alpha>10]=255
return Image.fromarray(edge)
def _human_mask(human_rgb: Image.Image) -> Image.Image:
arr = np.array(human_rgb)
res = _mp_seg.process(cv2.cvtColor(arr, cv2.COLOR_RGB2BGR))
mask = (res.segmentation_mask>0.5).astype(np.uint8)*255
return Image.fromarray(mask)
def build_temp_dataset(person_img, garment_img, root: Path|str) -> str:
root = Path(root); test_root = root/"test"
if test_root.exists(): shutil.rmtree(test_root)
for sub in ["image","cloth","edge","image-parse","pose","warp_feat"]:
(test_root/sub).mkdir(parents=True,exist_ok=True)
person_pil=_resize_pad(_to_pil(person_img)); garment_pil=_resize_pad(_to_pil(garment_img))
person_name="user_0001.jpg"; cloth_name="cloth_0001.jpg"
person_pil.save(test_root/"image"/person_name,quality=95)
garment_pil.save(test_root/"cloth"/cloth_name,quality=95)
edge_pil=_cloth_edge(garment_pil).convert("L").resize((512,512),Image.NEAREST)
edge_pil.save(test_root/"edge"/cloth_name.replace(".jpg",".png"))
parse_pil=_human_mask(person_pil).convert("L")
parse_pil.save(test_root/"image-parse"/person_name.replace(".jpg",".png"))
dummy=np.zeros((512,512,3),dtype=np.uint8)
Image.fromarray(dummy).save(test_root/"pose"/person_name.replace(".jpg","_keypoints.png"))
feat=np.zeros((256,256,3),dtype=np.uint8)
Image.fromarray(feat).save(test_root/"warp_feat"/f"{person_name[:-4]}_{cloth_name[:-4]}.png")
with open(test_root/"pairs.txt","w") as f: f.write(f"{person_name} {cloth_name}\n")
return str(root)