DCI-VTON / app.py
venbab's picture
Update app.py
f128013 verified
import os
from pathlib import Path
import gradio as gr
import numpy as np
from rembg import remove
from PIL import Image, ImageFilter
import shutil
print(f"[DCI] ENABLE_DCI={os.getenv('ENABLE_DCI','0')=='1'}")
ROOT = "/tmp/dci_onepair"
IMG_DIR = f"{ROOT}/test_img"
CLOTH_DIR = f"{ROOT}/test_color"
PAIRS = f"{ROOT}/test_pairs.txt"
def prepare_onepair_dataset(user_img_path: str, cloth_img_path: str):
# Reset temp dataset
if os.path.exists(ROOT):
shutil.rmtree(ROOT)
# Expected DCI-VTON structure
img_dir = os.path.join(ROOT, "test", "image")
cloth_dir = os.path.join(ROOT, "test", "cloth")
cm_dir = os.path.join(ROOT, "test", "cloth-mask")
os.makedirs(img_dir, exist_ok=True)
os.makedirs(cloth_dir, exist_ok=True)
os.makedirs(cm_dir, exist_ok=True)
user_name = "person.jpg" # person image filename
cloth_name = "cloth.jpg" # garment image filename
# Write pairs: CLOTH FIRST, then PERSON (as CPDataset expects)
with open(os.path.join(ROOT, "test_pairs.txt"), "w") as f:
f.write(f"{cloth_name} {user_name}\n")
# Save images in expected places
shutil.copy(user_img_path, os.path.join(img_dir, user_name))
shutil.copy(cloth_img_path, os.path.join(cloth_dir, cloth_name))
# --- Compatibility shim ---
# Some codepaths also try to read image/cloth.jpg; make a duplicate.
shutil.copy(user_img_path, os.path.join(img_dir, cloth_name))
# Reuse the garment as a coarse cloth-mask (you can replace with a real mask later)
shutil.copy(cloth_img_path, os.path.join(cm_dir, cloth_name))
# Debug print
print("PAIR LINE:", open(os.path.join(ROOT, "test_pairs.txt")).read().strip())
for p in ["test/image", "test/cloth", "test/cloth-mask"]:
full = os.path.join(ROOT, p)
print(p, "=>", os.listdir(full) if os.path.exists(full) else "MISSING")
return ROOT
def _lazy_imports():
from dci_vton_infer import DciVtonPredictor
from preprocess_onepair import build_temp_dataset
return DciVtonPredictor, build_temp_dataset
try:
from spaces import GPU
except Exception:
def GPU(*a,**kw):
def deco(fn): return fn
return deco
TITLE="💙 venbab DCI-VTON Try-On (Preview)"
DESC=("Upload a person image and a garment image. • CPU = preview blend "
"• GPU = build one-pair dataset & run real DCI-VTON via test.py.")
def _pil_to_rgba(im): return im if im.mode=="RGBA" else im.convert("RGBA")
def auto_mask_torso(human,top,bot,feather):
w,h=human.size; y1=int(h*top); y2=int(h*bot)
mask=Image.new("L",(w,h),0); band=Image.new("L",(w,max(1,y2-y1)),255); mask.paste(band,(0,y1))
if feather>0: mask=mask.filter(ImageFilter.GaussianBlur(radius=feather))
return mask
def blend_preview(human,garment,use_mask,feather,top,bot,fit,blend):
if human is None or garment is None: return None
h=_pil_to_rgba(human); g=_pil_to_rgba(garment)
hw,hh=h.size; gw,gh=g.size
fr={"Slim (75%)":0.75,"Relaxed (85%)":0.85,"Wide (95%)":0.95}.get(fit,0.85)
tw=int(hw*fr); sc=tw/max(1,gw); th=int(gh*sc)
g=g.resize((tw,th),Image.BICUBIC)
y1=int(hh*top); y2=int(hh*bot); thh=max(1,y2-y1)
x=(hw-tw)//2; y=y1+(thh-th)//2
ov=Image.new("RGBA",(hw,hh),(0,0,0,0)); ov.paste(g,(x,y),g)
mask=auto_mask_torso(h,top,bot,feather) if use_mask else Image.new("L",(hw,hh),255)
return Image.blend(h,Image.composite(ov,h,mask),max(0,min(1,blend))).convert("RGB")
_PREDICTOR=None
@GPU
def tryon_gpu(person_img, garment_img, use_mask, feather, top, bot, fit, blend):
global _PREDICTOR
DciVtonPredictor, _ = _lazy_imports() # we won't use build_temp_dataset anymore
if _PREDICTOR is None:
_PREDICTOR = DciVtonPredictor(device="cuda")
# 1) save incoming PIL images to temp files
tmp_user = "/tmp/dci_user.jpg"
tmp_cloth = "/tmp/dci_cloth.jpg"
person_img.convert("RGB").save(tmp_user, quality=95)
garment_img.convert("RGB").save(tmp_cloth, quality=95)
# 2) build the one-pair dataset (creates test_img/, test_color/, test_pairs.txt)
dataroot = prepare_onepair_dataset(tmp_user, tmp_cloth)
mask = auto_mask_torso(person_img, top, bot, feather) if use_mask else None
return _PREDICTOR.predict(
person_img.convert("RGB"),
garment_img.convert("RGB"),
mask_img=mask,
cfg=dict(dataroot=dataroot, fit=fit, blend=blend, torso=(top, bot))
)
def tryon_dispatch(person,garment,use_mask,feather,top,bot,fit,blend):
if person is None or garment is None: return None
if os.getenv("ENABLE_DCI","0")=="1":
return tryon_gpu(person,garment,use_mask,feather,top,bot,fit,blend)
return blend_preview(person,garment,use_mask,feather,top,bot,fit,blend)
CSS="""
.gradio-container{max-width:1200px!important;margin:auto;}
.gr-button.primary{background:linear-gradient(90deg,#6366f1,#8b5cf6);}
"""
with gr.Blocks(title=TITLE,theme=gr.themes.Soft(),css=CSS) as demo:
gr.Markdown(f"# {TITLE}"); gr.Markdown(DESC)
with gr.Row():
with gr.Column(scale=1):
human=gr.Image(label="Human",type="pil",height=480)
use=gr.Checkbox(value=True,label="Auto-mask torso")
feather=gr.Slider(0,30,value=10,step=1,label="Feather (px)")
ttop=gr.Slider(0.05,0.45,value=0.30,step=0.01,label="Torso top")
tbot=gr.Slider(0.50,0.90,value=0.68,step=0.01,label="Torso bottom")
with gr.Column(scale=1):
garment=gr.Image(label="Garment",type="pil",height=480)
fit=gr.Radio(["Slim (75%)","Relaxed (85%)","Wide (95%)"],value="Relaxed (85%)",label="Fit width")
blend=gr.Slider(0.2,1.5,value=0.9,step=0.05,label="Blend strength")
with gr.Column(scale=1):
output=gr.Image(label="Result",type="pil",height=480)
btn=gr.Button("Try-On",variant="primary")
btn.click(fn=tryon_dispatch,
inputs=[human,garment,use,feather,ttop,tbot,fit,blend],
outputs=output)
if __name__=="__main__": demo.launch()