Update dci_vton_infer.py
Browse files- dci_vton_infer.py +60 -46
dci_vton_infer.py
CHANGED
|
@@ -1,17 +1,12 @@
|
|
| 1 |
# dci_vton_infer.py
|
| 2 |
-
# Try real DCI-VTON via test.py; otherwise fall back to a neat preview overlay.
|
| 3 |
-
|
| 4 |
from __future__ import annotations
|
| 5 |
-
import os, glob, subprocess, tempfile
|
| 6 |
from pathlib import Path
|
| 7 |
-
from typing import Optional, Dict
|
| 8 |
-
|
| 9 |
from PIL import Image, ImageFilter
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
|
| 12 |
-
|
| 13 |
-
# ---------- tiny preview helpers ----------
|
| 14 |
-
def _to_rgba(im: Image.Image) -> Image.Image:
|
| 15 |
return im if im.mode == "RGBA" else im.convert("RGBA")
|
| 16 |
|
| 17 |
def _auto_mask_torso(human: Image.Image, top_rel: float, bot_rel: float, feather: int) -> Image.Image:
|
|
@@ -31,19 +26,18 @@ def _quick_blend(
|
|
| 31 |
mask_img: Optional[Image.Image],
|
| 32 |
fit_width: str,
|
| 33 |
blend_strength: float,
|
| 34 |
-
torso:
|
| 35 |
) -> Image.Image:
|
| 36 |
-
human =
|
| 37 |
-
|
| 38 |
|
| 39 |
hw, hh = human.size
|
| 40 |
-
gw, gh =
|
| 41 |
-
|
| 42 |
fit_ratio = {"Slim (75%)": 0.75, "Relaxed (85%)": 0.85, "Wide (95%)": 0.95}.get(fit_width, 0.85)
|
| 43 |
target_w = int(hw * fit_ratio)
|
| 44 |
scale = target_w / max(1, gw)
|
| 45 |
target_h = int(gh * scale)
|
| 46 |
-
|
| 47 |
|
| 48 |
top_rel, bot_rel = torso
|
| 49 |
y_top_full = int(hh * top_rel)
|
|
@@ -54,9 +48,9 @@ def _quick_blend(
|
|
| 54 |
y_top = y_top_full + (torso_h - target_h) // 2
|
| 55 |
|
| 56 |
overlay = Image.new("RGBA", (hw, hh), (0, 0, 0, 0))
|
| 57 |
-
if
|
| 58 |
-
|
| 59 |
-
overlay.paste(
|
| 60 |
|
| 61 |
if mask_img is None:
|
| 62 |
mask_img = Image.new("L", (hw, hh), 255)
|
|
@@ -66,21 +60,28 @@ def _quick_blend(
|
|
| 66 |
out = Image.blend(human, mixed, alpha)
|
| 67 |
return out.convert("RGB")
|
| 68 |
|
| 69 |
-
|
| 70 |
-
# ---------- main predictor ----------
|
| 71 |
class DciVtonPredictor:
|
| 72 |
def __init__(self, device: str = "cuda"):
|
| 73 |
self.device = device
|
| 74 |
self.ready = False
|
| 75 |
|
| 76 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
print("[DCI] downloading viton512_v2.ckpt …")
|
| 78 |
-
self.viton_ckpt = hf_hub_download(repo_id=
|
| 79 |
print("[DCI] downloading warp_viton.pth …")
|
| 80 |
-
self.warp_pth = hf_hub_download(repo_id=
|
| 81 |
|
| 82 |
-
self.ready = True
|
| 83 |
print(f"[DCI] backend ready (device={self.device})")
|
|
|
|
| 84 |
|
| 85 |
def predict(
|
| 86 |
self,
|
|
@@ -89,43 +90,56 @@ class DciVtonPredictor:
|
|
| 89 |
mask_img: Optional[Image.Image] = None,
|
| 90 |
cfg: Optional[Dict] = None
|
| 91 |
) -> Image.Image:
|
| 92 |
-
|
| 93 |
-
cfg keys supported:
|
| 94 |
-
- dataroot: str | None → if set, we try test.py; else do preview blend
|
| 95 |
-
- fit: str ("Slim (75%)" | "Relaxed (85%)" | "Wide (95%)")
|
| 96 |
-
- blend: float (0..1+)
|
| 97 |
-
- torso: (top_rel, bot_rel)
|
| 98 |
-
"""
|
| 99 |
cfg = cfg or {}
|
| 100 |
-
fit
|
| 101 |
blend = float(cfg.get("blend", 0.9))
|
| 102 |
torso = tuple(cfg.get("torso", (0.30, 0.68)))
|
| 103 |
dataroot = cfg.get("dataroot")
|
| 104 |
|
| 105 |
-
# If
|
| 106 |
if not dataroot:
|
| 107 |
return _quick_blend(person_img, garment_img, mask_img, fit, blend, torso)
|
| 108 |
|
| 109 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
try:
|
| 111 |
-
|
| 112 |
-
cmd = [
|
| 113 |
-
"python", "test.py",
|
| 114 |
-
"--config", "dci_vton/configs/viton512_v2.yaml",
|
| 115 |
-
"--ckpt", self.viton_ckpt,
|
| 116 |
-
"--dataroot", str(dataroot),
|
| 117 |
-
"--H", "512", "--W", "512",
|
| 118 |
-
"--n_samples", "1",
|
| 119 |
-
"--ddim_steps", "30",
|
| 120 |
-
"--outdir", str(outdir),
|
| 121 |
-
]
|
| 122 |
print("[DCI] running:", " ".join(cmd))
|
| 123 |
-
subprocess.run(cmd, check=True)
|
| 124 |
|
| 125 |
res_dir = outdir / "result"
|
| 126 |
pngs = sorted(glob.glob(str(res_dir / "*.png")))
|
| 127 |
if not pngs:
|
| 128 |
-
raise RuntimeError("No result produced by test.py")
|
| 129 |
return Image.open(pngs[0]).convert("RGB")
|
| 130 |
|
| 131 |
except Exception as e:
|
|
|
|
| 1 |
# dci_vton_infer.py
|
|
|
|
|
|
|
| 2 |
from __future__ import annotations
|
| 3 |
+
import os, glob, subprocess, tempfile, sys
|
| 4 |
from pathlib import Path
|
| 5 |
+
from typing import Optional, Dict
|
|
|
|
| 6 |
from PIL import Image, ImageFilter
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
|
| 9 |
+
def _pil_to_rgba(im: Image.Image) -> Image.Image:
|
|
|
|
|
|
|
| 10 |
return im if im.mode == "RGBA" else im.convert("RGBA")
|
| 11 |
|
| 12 |
def _auto_mask_torso(human: Image.Image, top_rel: float, bot_rel: float, feather: int) -> Image.Image:
|
|
|
|
| 26 |
mask_img: Optional[Image.Image],
|
| 27 |
fit_width: str,
|
| 28 |
blend_strength: float,
|
| 29 |
+
torso: tuple[float, float]
|
| 30 |
) -> Image.Image:
|
| 31 |
+
human = _pil_to_rgba(person)
|
| 32 |
+
garment = _pil_to_rgba(garment)
|
| 33 |
|
| 34 |
hw, hh = human.size
|
| 35 |
+
gw, gh = garment.size
|
|
|
|
| 36 |
fit_ratio = {"Slim (75%)": 0.75, "Relaxed (85%)": 0.85, "Wide (95%)": 0.95}.get(fit_width, 0.85)
|
| 37 |
target_w = int(hw * fit_ratio)
|
| 38 |
scale = target_w / max(1, gw)
|
| 39 |
target_h = int(gh * scale)
|
| 40 |
+
garment_resized = garment.resize((target_w, target_h), Image.BICUBIC)
|
| 41 |
|
| 42 |
top_rel, bot_rel = torso
|
| 43 |
y_top_full = int(hh * top_rel)
|
|
|
|
| 48 |
y_top = y_top_full + (torso_h - target_h) // 2
|
| 49 |
|
| 50 |
overlay = Image.new("RGBA", (hw, hh), (0, 0, 0, 0))
|
| 51 |
+
if garment_resized.mode != "RGBA":
|
| 52 |
+
garment_resized = garment_resized.convert("RGBA")
|
| 53 |
+
overlay.paste(garment_resized, (x_left, y_top), garment_resized)
|
| 54 |
|
| 55 |
if mask_img is None:
|
| 56 |
mask_img = Image.new("L", (hw, hh), 255)
|
|
|
|
| 60 |
out = Image.blend(human, mixed, alpha)
|
| 61 |
return out.convert("RGB")
|
| 62 |
|
|
|
|
|
|
|
| 63 |
class DciVtonPredictor:
|
| 64 |
def __init__(self, device: str = "cuda"):
|
| 65 |
self.device = device
|
| 66 |
self.ready = False
|
| 67 |
|
| 68 |
+
# Resolve repo paths
|
| 69 |
+
self.repo_root = Path(__file__).parent.resolve()
|
| 70 |
+
# Prefer dci_vton/test.py; fallback to root if needed
|
| 71 |
+
self.test_py = (self.repo_root / "dci_vton" / "test.py")
|
| 72 |
+
if not self.test_py.exists():
|
| 73 |
+
self.test_py = self.repo_root / "test.py"
|
| 74 |
+
self.config_yaml = (self.repo_root / "dci_vton" / "configs" / "viton512_v2.yaml")
|
| 75 |
+
|
| 76 |
+
# Download weights
|
| 77 |
+
repo_id = "venbab/dci-vton-weights"
|
| 78 |
print("[DCI] downloading viton512_v2.ckpt …")
|
| 79 |
+
self.viton_ckpt = hf_hub_download(repo_id=repo_id, filename="viton512_v2.ckpt")
|
| 80 |
print("[DCI] downloading warp_viton.pth …")
|
| 81 |
+
self.warp_pth = hf_hub_download(repo_id=repo_id, filename="warp_viton.pth")
|
| 82 |
|
|
|
|
| 83 |
print(f"[DCI] backend ready (device={self.device})")
|
| 84 |
+
self.ready = True
|
| 85 |
|
| 86 |
def predict(
|
| 87 |
self,
|
|
|
|
| 90 |
mask_img: Optional[Image.Image] = None,
|
| 91 |
cfg: Optional[Dict] = None
|
| 92 |
) -> Image.Image:
|
| 93 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
cfg = cfg or {}
|
| 95 |
+
fit = cfg.get("fit", "Relaxed (85%)")
|
| 96 |
blend = float(cfg.get("blend", 0.9))
|
| 97 |
torso = tuple(cfg.get("torso", (0.30, 0.68)))
|
| 98 |
dataroot = cfg.get("dataroot")
|
| 99 |
|
| 100 |
+
# If we don't have a dataset root, return preview
|
| 101 |
if not dataroot:
|
| 102 |
return _quick_blend(person_img, garment_img, mask_img, fit, blend, torso)
|
| 103 |
|
| 104 |
+
# Ensure paths exist
|
| 105 |
+
if not self.test_py.exists():
|
| 106 |
+
print(f"[DCI] test.py not found at: {self.test_py}")
|
| 107 |
+
return _quick_blend(person_img, garment_img, mask_img, fit, blend, torso)
|
| 108 |
+
if not self.config_yaml.exists():
|
| 109 |
+
print(f"[DCI] config yaml not found at: {self.config_yaml}")
|
| 110 |
+
return _quick_blend(person_img, garment_img, mask_img, fit, blend, torso)
|
| 111 |
+
|
| 112 |
+
# Build env with proper PYTHONPATH so `ldm/...` imports work
|
| 113 |
+
py_path = os.pathsep.join({
|
| 114 |
+
str(self.repo_root),
|
| 115 |
+
str(self.repo_root / "dci_vton"),
|
| 116 |
+
})
|
| 117 |
+
env = dict(os.environ)
|
| 118 |
+
env["PYTHONPATH"] = py_path + (os.pathsep + env["PYTHONPATH"] if "PYTHONPATH" in env else "")
|
| 119 |
+
|
| 120 |
+
outdir = Path(tempfile.mkdtemp(prefix="dci_out_"))
|
| 121 |
+
cmd = [
|
| 122 |
+
sys.executable,
|
| 123 |
+
str(self.test_py),
|
| 124 |
+
"--config", str(self.config_yaml),
|
| 125 |
+
"--ckpt", str(self.viton_ckpt),
|
| 126 |
+
"--dataroot", str(dataroot),
|
| 127 |
+
"--H", "512",
|
| 128 |
+
"--W", "512",
|
| 129 |
+
"--n_samples", "1",
|
| 130 |
+
"--ddim_steps", "30",
|
| 131 |
+
"--outdir", str(outdir),
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
try:
|
| 135 |
+
print("[DCI] REAL DCI: dataroot=", dataroot)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
print("[DCI] running:", " ".join(cmd))
|
| 137 |
+
subprocess.run(cmd, check=True, env=env, cwd=str(self.repo_root))
|
| 138 |
|
| 139 |
res_dir = outdir / "result"
|
| 140 |
pngs = sorted(glob.glob(str(res_dir / "*.png")))
|
| 141 |
if not pngs:
|
| 142 |
+
raise RuntimeError("No result image produced by test.py")
|
| 143 |
return Image.open(pngs[0]).convert("RGB")
|
| 144 |
|
| 145 |
except Exception as e:
|