venbab commited on
Commit
16d8232
·
verified ·
1 Parent(s): 26fbdc0

Update dci_vton_infer.py

Browse files
Files changed (1) hide show
  1. dci_vton_infer.py +90 -56
dci_vton_infer.py CHANGED
@@ -1,99 +1,133 @@
 
 
 
1
  from __future__ import annotations
2
  import os, glob, subprocess, tempfile
3
  from pathlib import Path
4
- from typing import Optional, Dict
 
5
  from PIL import Image, ImageFilter
6
  from huggingface_hub import hf_hub_download
7
- import numpy as np
8
 
9
- def _pil_to_rgba(im: Image.Image): return im if im.mode=="RGBA" else im.convert("RGBA")
10
 
11
- def _auto_mask_torso(human: Image.Image, top:float, bot:float, feather:int):
12
- w,h=human.size; y1=int(h*top); y2=int(h*bot)
13
- mask=Image.new("L",(w,h),0); band=Image.new("L",(w,max(1,y2-y1)),255); mask.paste(band,(0,y1))
14
- if feather>0: mask=mask.filter(ImageFilter.GaussianBlur(radius=feather))
 
 
 
 
 
 
 
 
 
15
  return mask
16
 
17
- def _quick_blend(person,garment,mask,fit,blend,torso):
18
- human=_pil_to_rgba(person); garment=_pil_to_rgba(garment)
19
- hw,hh=human.size; gw,gh=garment.size
20
- fit_ratio={"Slim (75%)":0.75,"Relaxed (85%)":0.85,"Wide (95%)":0.95}.get(fit,0.85)
21
- tw=int(hw*fit_ratio); scale=tw/max(1,gw); th=int(gh*scale)
22
- garment=garment.resize((tw,th),Image.BICUBIC)
23
- y1=int(hh*torso[0]); y2=int(hh*torso[1]); torso_h=max(1,y2-y1)
24
- x=(hw-tw)//2; y=y1+(torso_h-th)//2
25
- overlay=Image.new("RGBA",(hw,hh),(0,0,0,0)); overlay.paste(garment,(x,y),garment)
26
- if mask is None: mask=Image.new("L",(hw,hh),255)
27
- alpha=max(0,min(1,float(blend)))
28
- return Image.blend(human,Image.composite(overlay,human,mask),alpha).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
 
 
 
 
 
 
 
 
 
 
30
  class DciVtonPredictor:
31
- def __init__(self, device="cuda"):
32
- self.device=device; self.ready=False
33
- repo="venbab/dci-vton-weights"
 
 
34
  print("[DCI] downloading viton512_v2.ckpt …")
35
- self.viton_ckpt=hf_hub_download(repo,filename="viton512_v2.ckpt")
36
  print("[DCI] downloading warp_viton.pth …")
37
- self.warp_pth=hf_hub_download(repo,filename="warp_viton.pth")
38
- self.ready=True; print(f"[DCI] backend ready (device={device})")
39
 
40
- def predict(
 
 
 
41
  self,
42
  person_img: Image.Image,
43
  garment_img: Image.Image,
44
  mask_img: Optional[Image.Image] = None,
45
  cfg: Optional[Dict] = None
46
  ) -> Image.Image:
47
-
 
 
 
 
 
 
48
  cfg = cfg or {}
49
- fit = cfg.get("fit", "Relaxed (85%)")
50
  blend = float(cfg.get("blend", 0.9))
51
  torso = tuple(cfg.get("torso", (0.30, 0.68)))
52
  dataroot = cfg.get("dataroot")
53
 
54
- # If no dataset path, we must preview.
55
  if not dataroot:
56
- print("[DCI] PREVIEW: no dataroot provided.")
57
  return _quick_blend(person_img, garment_img, mask_img, fit, blend, torso)
58
 
59
- # Try REAL DCI (test.py)
60
- print(f"[DCI] REAL DCI: dataroot={dataroot}")
61
  try:
62
  outdir = Path(tempfile.mkdtemp(prefix="dci_out_"))
63
  cmd = [
64
- "python",
65
- "test.py",
66
  "--config", "dci_vton/configs/viton512_v2.yaml",
67
  "--ckpt", self.viton_ckpt,
68
  "--dataroot", str(dataroot),
69
- "--H", "512",
70
- "--W", "512",
71
  "--n_samples", "1",
72
  "--ddim_steps", "30",
73
  "--outdir", str(outdir),
74
  ]
75
  print("[DCI] running:", " ".join(cmd))
76
- p = subprocess.run(cmd, capture_output=True, text=True)
77
- print("[DCI] test.py exit code:", p.returncode)
78
- if p.stdout:
79
- print("[DCI][stdout]\n", p.stdout)
80
- if p.stderr:
81
- print("[DCI][stderr]\n", p.stderr)
82
-
83
- if p.returncode != 0:
84
- raise RuntimeError(f"test.py failed with code {p.returncode}")
85
-
86
- res = sorted(glob.glob(str(outdir / "result" / "*.png")))
87
- if not res:
88
- raise RuntimeError("No result image produced by test.py.")
89
 
90
- print("[DCI] SUCCESS: returning test.py result", res[0])
91
- return Image.open(res[0]).convert("RGB")
 
 
 
92
 
93
  except Exception as e:
94
- strict = os.getenv("DCI_STRICT", "0") == "1"
95
- print("[DCI] ERROR running test.py:", repr(e))
96
- if strict:
97
- raise
98
- print("[DCI] FALLBACK to preview blend.")
99
  return _quick_blend(person_img, garment_img, mask_img, fit, blend, torso)
 
1
+ # dci_vton_infer.py
2
+ # Try real DCI-VTON via test.py; otherwise fall back to a neat preview overlay.
3
+
4
  from __future__ import annotations
5
  import os, glob, subprocess, tempfile
6
  from pathlib import Path
7
+ from typing import Optional, Dict, Tuple
8
+
9
  from PIL import Image, ImageFilter
10
  from huggingface_hub import hf_hub_download
 
11
 
 
12
 
13
+ # ---------- tiny preview helpers ----------
14
+ def _to_rgba(im: Image.Image) -> Image.Image:
15
+ return im if im.mode == "RGBA" else im.convert("RGBA")
16
+
17
+ def _auto_mask_torso(human: Image.Image, top_rel: float, bot_rel: float, feather: int) -> Image.Image:
18
+ w, h = human.size
19
+ y_top = int(h * max(0.0, min(1.0, float(top_rel))))
20
+ y_bot = int(h * max(0.0, min(1.0, float(bot_rel))))
21
+ mask = Image.new("L", (w, h), 0)
22
+ band = Image.new("L", (w, max(1, y_bot - y_top)), 255)
23
+ mask.paste(band, (0, y_top))
24
+ if feather > 0:
25
+ mask = mask.filter(ImageFilter.GaussianBlur(radius=feather))
26
  return mask
27
 
28
+ def _quick_blend(
29
+ person: Image.Image,
30
+ garment: Image.Image,
31
+ mask_img: Optional[Image.Image],
32
+ fit_width: str,
33
+ blend_strength: float,
34
+ torso: Tuple[float, float]
35
+ ) -> Image.Image:
36
+ human = _to_rgba(person)
37
+ cloth = _to_rgba(garment)
38
+
39
+ hw, hh = human.size
40
+ gw, gh = cloth.size
41
+
42
+ fit_ratio = {"Slim (75%)": 0.75, "Relaxed (85%)": 0.85, "Wide (95%)": 0.95}.get(fit_width, 0.85)
43
+ target_w = int(hw * fit_ratio)
44
+ scale = target_w / max(1, gw)
45
+ target_h = int(gh * scale)
46
+ cloth = cloth.resize((target_w, target_h), Image.BICUBIC)
47
+
48
+ top_rel, bot_rel = torso
49
+ y_top_full = int(hh * top_rel)
50
+ y_bot_full = int(hh * bot_rel)
51
+ torso_h = max(1, y_bot_full - y_top_full)
52
+
53
+ x_left = (hw - target_w) // 2
54
+ y_top = y_top_full + (torso_h - target_h) // 2
55
+
56
+ overlay = Image.new("RGBA", (hw, hh), (0, 0, 0, 0))
57
+ if cloth.mode != "RGBA":
58
+ cloth = cloth.convert("RGBA")
59
+ overlay.paste(cloth, (x_left, y_top), cloth)
60
 
61
+ if mask_img is None:
62
+ mask_img = Image.new("L", (hw, hh), 255)
63
+
64
+ alpha = max(0.0, min(1.0, float(blend_strength)))
65
+ mixed = Image.composite(overlay, human, mask_img)
66
+ out = Image.blend(human, mixed, alpha)
67
+ return out.convert("RGB")
68
+
69
+
70
+ # ---------- main predictor ----------
71
  class DciVtonPredictor:
72
+ def __init__(self, device: str = "cuda"):
73
+ self.device = device
74
+ self.ready = False
75
+
76
+ # download weights (for when you call real test.py)
77
  print("[DCI] downloading viton512_v2.ckpt …")
78
+ self.viton_ckpt = hf_hub_download(repo_id="venbab/dci-vton-weights", filename="viton512_v2.ckpt")
79
  print("[DCI] downloading warp_viton.pth …")
80
+ self.warp_pth = hf_hub_download(repo_id="venbab/dci-vton-weights", filename="warp_viton.pth")
 
81
 
82
+ self.ready = True
83
+ print(f"[DCI] backend ready (device={self.device})")
84
+
85
+ def predict(
86
  self,
87
  person_img: Image.Image,
88
  garment_img: Image.Image,
89
  mask_img: Optional[Image.Image] = None,
90
  cfg: Optional[Dict] = None
91
  ) -> Image.Image:
92
+ """
93
+ cfg keys supported:
94
+ - dataroot: str | None → if set, we try test.py; else do preview blend
95
+ - fit: str ("Slim (75%)" | "Relaxed (85%)" | "Wide (95%)")
96
+ - blend: float (0..1+)
97
+ - torso: (top_rel, bot_rel)
98
+ """
99
  cfg = cfg or {}
100
+ fit = cfg.get("fit", "Relaxed (85%)")
101
  blend = float(cfg.get("blend", 0.9))
102
  torso = tuple(cfg.get("torso", (0.30, 0.68)))
103
  dataroot = cfg.get("dataroot")
104
 
105
+ # If no dataset was prepared, just show preview overlay.
106
  if not dataroot:
 
107
  return _quick_blend(person_img, garment_img, mask_img, fit, blend, torso)
108
 
109
+ # Try to run the repo's test.py
 
110
  try:
111
  outdir = Path(tempfile.mkdtemp(prefix="dci_out_"))
112
  cmd = [
113
+ "python", "test.py",
 
114
  "--config", "dci_vton/configs/viton512_v2.yaml",
115
  "--ckpt", self.viton_ckpt,
116
  "--dataroot", str(dataroot),
117
+ "--H", "512", "--W", "512",
 
118
  "--n_samples", "1",
119
  "--ddim_steps", "30",
120
  "--outdir", str(outdir),
121
  ]
122
  print("[DCI] running:", " ".join(cmd))
123
+ subprocess.run(cmd, check=True)
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ res_dir = outdir / "result"
126
+ pngs = sorted(glob.glob(str(res_dir / "*.png")))
127
+ if not pngs:
128
+ raise RuntimeError("No result produced by test.py")
129
+ return Image.open(pngs[0]).convert("RGB")
130
 
131
  except Exception as e:
132
+ print("[DCI] test.py failed → preview fallback. Reason:", repr(e))
 
 
 
 
133
  return _quick_blend(person_img, garment_img, mask_img, fit, blend, torso)