venbab commited on
Commit
09d2508
·
verified ·
1 Parent(s): 16d8232

Update dci_vton_infer.py

Browse files
Files changed (1) hide show
  1. dci_vton_infer.py +60 -46
dci_vton_infer.py CHANGED
@@ -1,17 +1,12 @@
1
  # dci_vton_infer.py
2
- # Try real DCI-VTON via test.py; otherwise fall back to a neat preview overlay.
3
-
4
  from __future__ import annotations
5
- import os, glob, subprocess, tempfile
6
  from pathlib import Path
7
- from typing import Optional, Dict, Tuple
8
-
9
  from PIL import Image, ImageFilter
10
  from huggingface_hub import hf_hub_download
11
 
12
-
13
- # ---------- tiny preview helpers ----------
14
- def _to_rgba(im: Image.Image) -> Image.Image:
15
  return im if im.mode == "RGBA" else im.convert("RGBA")
16
 
17
  def _auto_mask_torso(human: Image.Image, top_rel: float, bot_rel: float, feather: int) -> Image.Image:
@@ -31,19 +26,18 @@ def _quick_blend(
31
  mask_img: Optional[Image.Image],
32
  fit_width: str,
33
  blend_strength: float,
34
- torso: Tuple[float, float]
35
  ) -> Image.Image:
36
- human = _to_rgba(person)
37
- cloth = _to_rgba(garment)
38
 
39
  hw, hh = human.size
40
- gw, gh = cloth.size
41
-
42
  fit_ratio = {"Slim (75%)": 0.75, "Relaxed (85%)": 0.85, "Wide (95%)": 0.95}.get(fit_width, 0.85)
43
  target_w = int(hw * fit_ratio)
44
  scale = target_w / max(1, gw)
45
  target_h = int(gh * scale)
46
- cloth = cloth.resize((target_w, target_h), Image.BICUBIC)
47
 
48
  top_rel, bot_rel = torso
49
  y_top_full = int(hh * top_rel)
@@ -54,9 +48,9 @@ def _quick_blend(
54
  y_top = y_top_full + (torso_h - target_h) // 2
55
 
56
  overlay = Image.new("RGBA", (hw, hh), (0, 0, 0, 0))
57
- if cloth.mode != "RGBA":
58
- cloth = cloth.convert("RGBA")
59
- overlay.paste(cloth, (x_left, y_top), cloth)
60
 
61
  if mask_img is None:
62
  mask_img = Image.new("L", (hw, hh), 255)
@@ -66,21 +60,28 @@ def _quick_blend(
66
  out = Image.blend(human, mixed, alpha)
67
  return out.convert("RGB")
68
 
69
-
70
- # ---------- main predictor ----------
71
  class DciVtonPredictor:
72
  def __init__(self, device: str = "cuda"):
73
  self.device = device
74
  self.ready = False
75
 
76
- # download weights (for when you call real test.py)
 
 
 
 
 
 
 
 
 
77
  print("[DCI] downloading viton512_v2.ckpt …")
78
- self.viton_ckpt = hf_hub_download(repo_id="venbab/dci-vton-weights", filename="viton512_v2.ckpt")
79
  print("[DCI] downloading warp_viton.pth …")
80
- self.warp_pth = hf_hub_download(repo_id="venbab/dci-vton-weights", filename="warp_viton.pth")
81
 
82
- self.ready = True
83
  print(f"[DCI] backend ready (device={self.device})")
 
84
 
85
  def predict(
86
  self,
@@ -89,43 +90,56 @@ class DciVtonPredictor:
89
  mask_img: Optional[Image.Image] = None,
90
  cfg: Optional[Dict] = None
91
  ) -> Image.Image:
92
- """
93
- cfg keys supported:
94
- - dataroot: str | None → if set, we try test.py; else do preview blend
95
- - fit: str ("Slim (75%)" | "Relaxed (85%)" | "Wide (95%)")
96
- - blend: float (0..1+)
97
- - torso: (top_rel, bot_rel)
98
- """
99
  cfg = cfg or {}
100
- fit = cfg.get("fit", "Relaxed (85%)")
101
  blend = float(cfg.get("blend", 0.9))
102
  torso = tuple(cfg.get("torso", (0.30, 0.68)))
103
  dataroot = cfg.get("dataroot")
104
 
105
- # If no dataset was prepared, just show preview overlay.
106
  if not dataroot:
107
  return _quick_blend(person_img, garment_img, mask_img, fit, blend, torso)
108
 
109
- # Try to run the repo's test.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  try:
111
- outdir = Path(tempfile.mkdtemp(prefix="dci_out_"))
112
- cmd = [
113
- "python", "test.py",
114
- "--config", "dci_vton/configs/viton512_v2.yaml",
115
- "--ckpt", self.viton_ckpt,
116
- "--dataroot", str(dataroot),
117
- "--H", "512", "--W", "512",
118
- "--n_samples", "1",
119
- "--ddim_steps", "30",
120
- "--outdir", str(outdir),
121
- ]
122
  print("[DCI] running:", " ".join(cmd))
123
- subprocess.run(cmd, check=True)
124
 
125
  res_dir = outdir / "result"
126
  pngs = sorted(glob.glob(str(res_dir / "*.png")))
127
  if not pngs:
128
- raise RuntimeError("No result produced by test.py")
129
  return Image.open(pngs[0]).convert("RGB")
130
 
131
  except Exception as e:
 
1
  # dci_vton_infer.py
 
 
2
  from __future__ import annotations
3
+ import os, glob, subprocess, tempfile, sys
4
  from pathlib import Path
5
+ from typing import Optional, Dict
 
6
  from PIL import Image, ImageFilter
7
  from huggingface_hub import hf_hub_download
8
 
9
+ def _pil_to_rgba(im: Image.Image) -> Image.Image:
 
 
10
  return im if im.mode == "RGBA" else im.convert("RGBA")
11
 
12
  def _auto_mask_torso(human: Image.Image, top_rel: float, bot_rel: float, feather: int) -> Image.Image:
 
26
  mask_img: Optional[Image.Image],
27
  fit_width: str,
28
  blend_strength: float,
29
+ torso: tuple[float, float]
30
  ) -> Image.Image:
31
+ human = _pil_to_rgba(person)
32
+ garment = _pil_to_rgba(garment)
33
 
34
  hw, hh = human.size
35
+ gw, gh = garment.size
 
36
  fit_ratio = {"Slim (75%)": 0.75, "Relaxed (85%)": 0.85, "Wide (95%)": 0.95}.get(fit_width, 0.85)
37
  target_w = int(hw * fit_ratio)
38
  scale = target_w / max(1, gw)
39
  target_h = int(gh * scale)
40
+ garment_resized = garment.resize((target_w, target_h), Image.BICUBIC)
41
 
42
  top_rel, bot_rel = torso
43
  y_top_full = int(hh * top_rel)
 
48
  y_top = y_top_full + (torso_h - target_h) // 2
49
 
50
  overlay = Image.new("RGBA", (hw, hh), (0, 0, 0, 0))
51
+ if garment_resized.mode != "RGBA":
52
+ garment_resized = garment_resized.convert("RGBA")
53
+ overlay.paste(garment_resized, (x_left, y_top), garment_resized)
54
 
55
  if mask_img is None:
56
  mask_img = Image.new("L", (hw, hh), 255)
 
60
  out = Image.blend(human, mixed, alpha)
61
  return out.convert("RGB")
62
 
 
 
63
  class DciVtonPredictor:
64
  def __init__(self, device: str = "cuda"):
65
  self.device = device
66
  self.ready = False
67
 
68
+ # Resolve repo paths
69
+ self.repo_root = Path(__file__).parent.resolve()
70
+ # Prefer dci_vton/test.py; fallback to root if needed
71
+ self.test_py = (self.repo_root / "dci_vton" / "test.py")
72
+ if not self.test_py.exists():
73
+ self.test_py = self.repo_root / "test.py"
74
+ self.config_yaml = (self.repo_root / "dci_vton" / "configs" / "viton512_v2.yaml")
75
+
76
+ # Download weights
77
+ repo_id = "venbab/dci-vton-weights"
78
  print("[DCI] downloading viton512_v2.ckpt …")
79
+ self.viton_ckpt = hf_hub_download(repo_id=repo_id, filename="viton512_v2.ckpt")
80
  print("[DCI] downloading warp_viton.pth …")
81
+ self.warp_pth = hf_hub_download(repo_id=repo_id, filename="warp_viton.pth")
82
 
 
83
  print(f"[DCI] backend ready (device={self.device})")
84
+ self.ready = True
85
 
86
  def predict(
87
  self,
 
90
  mask_img: Optional[Image.Image] = None,
91
  cfg: Optional[Dict] = None
92
  ) -> Image.Image:
93
+
 
 
 
 
 
 
94
  cfg = cfg or {}
95
+ fit = cfg.get("fit", "Relaxed (85%)")
96
  blend = float(cfg.get("blend", 0.9))
97
  torso = tuple(cfg.get("torso", (0.30, 0.68)))
98
  dataroot = cfg.get("dataroot")
99
 
100
+ # If we don't have a dataset root, return preview
101
  if not dataroot:
102
  return _quick_blend(person_img, garment_img, mask_img, fit, blend, torso)
103
 
104
+ # Ensure paths exist
105
+ if not self.test_py.exists():
106
+ print(f"[DCI] test.py not found at: {self.test_py}")
107
+ return _quick_blend(person_img, garment_img, mask_img, fit, blend, torso)
108
+ if not self.config_yaml.exists():
109
+ print(f"[DCI] config yaml not found at: {self.config_yaml}")
110
+ return _quick_blend(person_img, garment_img, mask_img, fit, blend, torso)
111
+
112
+ # Build env with proper PYTHONPATH so `ldm/...` imports work
113
+ py_path = os.pathsep.join({
114
+ str(self.repo_root),
115
+ str(self.repo_root / "dci_vton"),
116
+ })
117
+ env = dict(os.environ)
118
+ env["PYTHONPATH"] = py_path + (os.pathsep + env["PYTHONPATH"] if "PYTHONPATH" in env else "")
119
+
120
+ outdir = Path(tempfile.mkdtemp(prefix="dci_out_"))
121
+ cmd = [
122
+ sys.executable,
123
+ str(self.test_py),
124
+ "--config", str(self.config_yaml),
125
+ "--ckpt", str(self.viton_ckpt),
126
+ "--dataroot", str(dataroot),
127
+ "--H", "512",
128
+ "--W", "512",
129
+ "--n_samples", "1",
130
+ "--ddim_steps", "30",
131
+ "--outdir", str(outdir),
132
+ ]
133
+
134
  try:
135
+ print("[DCI] REAL DCI: dataroot=", dataroot)
 
 
 
 
 
 
 
 
 
 
136
  print("[DCI] running:", " ".join(cmd))
137
+ subprocess.run(cmd, check=True, env=env, cwd=str(self.repo_root))
138
 
139
  res_dir = outdir / "result"
140
  pngs = sorted(glob.glob(str(res_dir / "*.png")))
141
  if not pngs:
142
+ raise RuntimeError("No result image produced by test.py")
143
  return Image.open(pngs[0]).convert("RGB")
144
 
145
  except Exception as e: