Update dci_vton_infer.py
Browse files- dci_vton_infer.py +90 -56
dci_vton_infer.py
CHANGED
|
@@ -1,99 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
import os, glob, subprocess, tempfile
|
| 3 |
from pathlib import Path
|
| 4 |
-
from typing import Optional, Dict
|
|
|
|
| 5 |
from PIL import Image, ImageFilter
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
-
import numpy as np
|
| 8 |
|
| 9 |
-
def _pil_to_rgba(im: Image.Image): return im if im.mode=="RGBA" else im.convert("RGBA")
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
return mask
|
| 16 |
|
| 17 |
-
def _quick_blend(
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
class DciVtonPredictor:
|
| 31 |
-
def __init__(self, device="cuda"):
|
| 32 |
-
self.device
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
print("[DCI] downloading viton512_v2.ckpt …")
|
| 35 |
-
self.viton_ckpt=hf_hub_download(
|
| 36 |
print("[DCI] downloading warp_viton.pth …")
|
| 37 |
-
self.warp_pth=hf_hub_download(
|
| 38 |
-
self.ready=True; print(f"[DCI] backend ready (device={device})")
|
| 39 |
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
| 41 |
self,
|
| 42 |
person_img: Image.Image,
|
| 43 |
garment_img: Image.Image,
|
| 44 |
mask_img: Optional[Image.Image] = None,
|
| 45 |
cfg: Optional[Dict] = None
|
| 46 |
) -> Image.Image:
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
cfg = cfg or {}
|
| 49 |
-
fit
|
| 50 |
blend = float(cfg.get("blend", 0.9))
|
| 51 |
torso = tuple(cfg.get("torso", (0.30, 0.68)))
|
| 52 |
dataroot = cfg.get("dataroot")
|
| 53 |
|
| 54 |
-
# If no dataset
|
| 55 |
if not dataroot:
|
| 56 |
-
print("[DCI] PREVIEW: no dataroot provided.")
|
| 57 |
return _quick_blend(person_img, garment_img, mask_img, fit, blend, torso)
|
| 58 |
|
| 59 |
-
# Try
|
| 60 |
-
print(f"[DCI] REAL DCI: dataroot={dataroot}")
|
| 61 |
try:
|
| 62 |
outdir = Path(tempfile.mkdtemp(prefix="dci_out_"))
|
| 63 |
cmd = [
|
| 64 |
-
"python",
|
| 65 |
-
"test.py",
|
| 66 |
"--config", "dci_vton/configs/viton512_v2.yaml",
|
| 67 |
"--ckpt", self.viton_ckpt,
|
| 68 |
"--dataroot", str(dataroot),
|
| 69 |
-
"--H", "512",
|
| 70 |
-
"--W", "512",
|
| 71 |
"--n_samples", "1",
|
| 72 |
"--ddim_steps", "30",
|
| 73 |
"--outdir", str(outdir),
|
| 74 |
]
|
| 75 |
print("[DCI] running:", " ".join(cmd))
|
| 76 |
-
|
| 77 |
-
print("[DCI] test.py exit code:", p.returncode)
|
| 78 |
-
if p.stdout:
|
| 79 |
-
print("[DCI][stdout]\n", p.stdout)
|
| 80 |
-
if p.stderr:
|
| 81 |
-
print("[DCI][stderr]\n", p.stderr)
|
| 82 |
-
|
| 83 |
-
if p.returncode != 0:
|
| 84 |
-
raise RuntimeError(f"test.py failed with code {p.returncode}")
|
| 85 |
-
|
| 86 |
-
res = sorted(glob.glob(str(outdir / "result" / "*.png")))
|
| 87 |
-
if not res:
|
| 88 |
-
raise RuntimeError("No result image produced by test.py.")
|
| 89 |
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
except Exception as e:
|
| 94 |
-
|
| 95 |
-
print("[DCI] ERROR running test.py:", repr(e))
|
| 96 |
-
if strict:
|
| 97 |
-
raise
|
| 98 |
-
print("[DCI] FALLBACK to preview blend.")
|
| 99 |
return _quick_blend(person_img, garment_img, mask_img, fit, blend, torso)
|
|
|
|
| 1 |
+
# dci_vton_infer.py
|
| 2 |
+
# Try real DCI-VTON via test.py; otherwise fall back to a neat preview overlay.
|
| 3 |
+
|
| 4 |
from __future__ import annotations
|
| 5 |
import os, glob, subprocess, tempfile
|
| 6 |
from pathlib import Path
|
| 7 |
+
from typing import Optional, Dict, Tuple
|
| 8 |
+
|
| 9 |
from PIL import Image, ImageFilter
|
| 10 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 11 |
|
|
|
|
| 12 |
|
| 13 |
+
# ---------- tiny preview helpers ----------
|
| 14 |
+
def _to_rgba(im: Image.Image) -> Image.Image:
|
| 15 |
+
return im if im.mode == "RGBA" else im.convert("RGBA")
|
| 16 |
+
|
| 17 |
+
def _auto_mask_torso(human: Image.Image, top_rel: float, bot_rel: float, feather: int) -> Image.Image:
|
| 18 |
+
w, h = human.size
|
| 19 |
+
y_top = int(h * max(0.0, min(1.0, float(top_rel))))
|
| 20 |
+
y_bot = int(h * max(0.0, min(1.0, float(bot_rel))))
|
| 21 |
+
mask = Image.new("L", (w, h), 0)
|
| 22 |
+
band = Image.new("L", (w, max(1, y_bot - y_top)), 255)
|
| 23 |
+
mask.paste(band, (0, y_top))
|
| 24 |
+
if feather > 0:
|
| 25 |
+
mask = mask.filter(ImageFilter.GaussianBlur(radius=feather))
|
| 26 |
return mask
|
| 27 |
|
| 28 |
+
def _quick_blend(
|
| 29 |
+
person: Image.Image,
|
| 30 |
+
garment: Image.Image,
|
| 31 |
+
mask_img: Optional[Image.Image],
|
| 32 |
+
fit_width: str,
|
| 33 |
+
blend_strength: float,
|
| 34 |
+
torso: Tuple[float, float]
|
| 35 |
+
) -> Image.Image:
|
| 36 |
+
human = _to_rgba(person)
|
| 37 |
+
cloth = _to_rgba(garment)
|
| 38 |
+
|
| 39 |
+
hw, hh = human.size
|
| 40 |
+
gw, gh = cloth.size
|
| 41 |
+
|
| 42 |
+
fit_ratio = {"Slim (75%)": 0.75, "Relaxed (85%)": 0.85, "Wide (95%)": 0.95}.get(fit_width, 0.85)
|
| 43 |
+
target_w = int(hw * fit_ratio)
|
| 44 |
+
scale = target_w / max(1, gw)
|
| 45 |
+
target_h = int(gh * scale)
|
| 46 |
+
cloth = cloth.resize((target_w, target_h), Image.BICUBIC)
|
| 47 |
+
|
| 48 |
+
top_rel, bot_rel = torso
|
| 49 |
+
y_top_full = int(hh * top_rel)
|
| 50 |
+
y_bot_full = int(hh * bot_rel)
|
| 51 |
+
torso_h = max(1, y_bot_full - y_top_full)
|
| 52 |
+
|
| 53 |
+
x_left = (hw - target_w) // 2
|
| 54 |
+
y_top = y_top_full + (torso_h - target_h) // 2
|
| 55 |
+
|
| 56 |
+
overlay = Image.new("RGBA", (hw, hh), (0, 0, 0, 0))
|
| 57 |
+
if cloth.mode != "RGBA":
|
| 58 |
+
cloth = cloth.convert("RGBA")
|
| 59 |
+
overlay.paste(cloth, (x_left, y_top), cloth)
|
| 60 |
|
| 61 |
+
if mask_img is None:
|
| 62 |
+
mask_img = Image.new("L", (hw, hh), 255)
|
| 63 |
+
|
| 64 |
+
alpha = max(0.0, min(1.0, float(blend_strength)))
|
| 65 |
+
mixed = Image.composite(overlay, human, mask_img)
|
| 66 |
+
out = Image.blend(human, mixed, alpha)
|
| 67 |
+
return out.convert("RGB")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ---------- main predictor ----------
|
| 71 |
class DciVtonPredictor:
|
| 72 |
+
def __init__(self, device: str = "cuda"):
|
| 73 |
+
self.device = device
|
| 74 |
+
self.ready = False
|
| 75 |
+
|
| 76 |
+
# download weights (for when you call real test.py)
|
| 77 |
print("[DCI] downloading viton512_v2.ckpt …")
|
| 78 |
+
self.viton_ckpt = hf_hub_download(repo_id="venbab/dci-vton-weights", filename="viton512_v2.ckpt")
|
| 79 |
print("[DCI] downloading warp_viton.pth …")
|
| 80 |
+
self.warp_pth = hf_hub_download(repo_id="venbab/dci-vton-weights", filename="warp_viton.pth")
|
|
|
|
| 81 |
|
| 82 |
+
self.ready = True
|
| 83 |
+
print(f"[DCI] backend ready (device={self.device})")
|
| 84 |
+
|
| 85 |
+
def predict(
|
| 86 |
self,
|
| 87 |
person_img: Image.Image,
|
| 88 |
garment_img: Image.Image,
|
| 89 |
mask_img: Optional[Image.Image] = None,
|
| 90 |
cfg: Optional[Dict] = None
|
| 91 |
) -> Image.Image:
|
| 92 |
+
"""
|
| 93 |
+
cfg keys supported:
|
| 94 |
+
- dataroot: str | None → if set, we try test.py; else do preview blend
|
| 95 |
+
- fit: str ("Slim (75%)" | "Relaxed (85%)" | "Wide (95%)")
|
| 96 |
+
- blend: float (0..1+)
|
| 97 |
+
- torso: (top_rel, bot_rel)
|
| 98 |
+
"""
|
| 99 |
cfg = cfg or {}
|
| 100 |
+
fit = cfg.get("fit", "Relaxed (85%)")
|
| 101 |
blend = float(cfg.get("blend", 0.9))
|
| 102 |
torso = tuple(cfg.get("torso", (0.30, 0.68)))
|
| 103 |
dataroot = cfg.get("dataroot")
|
| 104 |
|
| 105 |
+
# If no dataset was prepared, just show preview overlay.
|
| 106 |
if not dataroot:
|
|
|
|
| 107 |
return _quick_blend(person_img, garment_img, mask_img, fit, blend, torso)
|
| 108 |
|
| 109 |
+
# Try to run the repo's test.py
|
|
|
|
| 110 |
try:
|
| 111 |
outdir = Path(tempfile.mkdtemp(prefix="dci_out_"))
|
| 112 |
cmd = [
|
| 113 |
+
"python", "test.py",
|
|
|
|
| 114 |
"--config", "dci_vton/configs/viton512_v2.yaml",
|
| 115 |
"--ckpt", self.viton_ckpt,
|
| 116 |
"--dataroot", str(dataroot),
|
| 117 |
+
"--H", "512", "--W", "512",
|
|
|
|
| 118 |
"--n_samples", "1",
|
| 119 |
"--ddim_steps", "30",
|
| 120 |
"--outdir", str(outdir),
|
| 121 |
]
|
| 122 |
print("[DCI] running:", " ".join(cmd))
|
| 123 |
+
subprocess.run(cmd, check=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
+
res_dir = outdir / "result"
|
| 126 |
+
pngs = sorted(glob.glob(str(res_dir / "*.png")))
|
| 127 |
+
if not pngs:
|
| 128 |
+
raise RuntimeError("No result produced by test.py")
|
| 129 |
+
return Image.open(pngs[0]).convert("RGB")
|
| 130 |
|
| 131 |
except Exception as e:
|
| 132 |
+
print("[DCI] test.py failed → preview fallback. Reason:", repr(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
return _quick_blend(person_img, garment_img, mask_img, fit, blend, torso)
|