update UI
Browse files
app.py
CHANGED
|
@@ -715,6 +715,7 @@
|
|
| 715 |
# if __name__ == "__main__":
|
| 716 |
# demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 717 |
|
|
|
|
| 718 |
import os
|
| 719 |
import sys
|
| 720 |
|
|
@@ -743,7 +744,6 @@ from transformers import pipeline
|
|
| 743 |
|
| 744 |
from huggingface_hub import hf_hub_download
|
| 745 |
|
| 746 |
-
# Show where diffusers3 is imported from (helps diagnose import collisions on Spaces)
|
| 747 |
import diffusers3
|
| 748 |
print("[BOOT] diffusers3 loaded from:", getattr(diffusers3, "__file__", "<?>"), flush=True)
|
| 749 |
|
|
@@ -823,10 +823,6 @@ def apply_parsing_white_mask_to_person_cv2(
|
|
| 823 |
person_pil: Image.Image,
|
| 824 |
parsing_img: Image.Image
|
| 825 |
) -> np.ndarray:
|
| 826 |
-
"""
|
| 827 |
-
person_pil(RGB) ํฌ๊ธฐ์ parsing_img(L) ๋ง์คํฌ๋ฅผ ๋ง์ถฐ์
|
| 828 |
-
ํฐ์(255) ์์ญ๋ง person์ ๋จ๊ธฐ๊ณ ๋๋จธ์ง๋ ํฐ์ ๋ฐฐ๊ฒฝ์ผ๋ก ๋ง๋๋ ํจ์.
|
| 829 |
-
"""
|
| 830 |
person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8)
|
| 831 |
|
| 832 |
mask = np.array(parsing_img.convert("L"), dtype=np.uint8)
|
|
@@ -843,11 +839,6 @@ def apply_parsing_white_mask_to_person_cv2(
|
|
| 843 |
return result_bgr
|
| 844 |
|
| 845 |
|
| 846 |
-
from typing import Optional, Tuple
|
| 847 |
-
import numpy as np
|
| 848 |
-
from PIL import Image
|
| 849 |
-
|
| 850 |
-
|
| 851 |
def clean_and_smooth_parsing_mask(
|
| 852 |
parsing_img: Image.Image,
|
| 853 |
*,
|
|
@@ -858,10 +849,6 @@ def clean_and_smooth_parsing_mask(
|
|
| 858 |
morph_iters: int = 1,
|
| 859 |
blur_ksize: int = 0,
|
| 860 |
) -> Image.Image:
|
| 861 |
-
"""
|
| 862 |
-
Clean small white blobs and smooth boundaries on a grayscale (0/255) PIL mask.
|
| 863 |
-
White=foreground, Black=background.
|
| 864 |
-
"""
|
| 865 |
if not isinstance(parsing_img, Image.Image):
|
| 866 |
raise TypeError("parsing_img must be a PIL.Image.Image")
|
| 867 |
|
|
@@ -870,11 +857,6 @@ def clean_and_smooth_parsing_mask(
|
|
| 870 |
|
| 871 |
mask = np.where(arr >= white_threshold, 255, 0).astype(np.uint8)
|
| 872 |
|
| 873 |
-
try:
|
| 874 |
-
import cv2
|
| 875 |
-
except ImportError as e:
|
| 876 |
-
raise ImportError("This function requires opencv-python (cv2).") from e
|
| 877 |
-
|
| 878 |
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
|
| 879 |
|
| 880 |
keep = np.zeros_like(mask)
|
|
@@ -1031,10 +1013,10 @@ def make_depth_from_parsing_edges(parsing_img: Image.Image) -> Image.Image:
|
|
| 1031 |
|
| 1032 |
depth_img = _edges_from_parsing(parsing_img)
|
| 1033 |
|
| 1034 |
-
|
| 1035 |
-
contours, _ = cv2.findContours(
|
| 1036 |
|
| 1037 |
-
filled_depth =
|
| 1038 |
cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
|
| 1039 |
|
| 1040 |
filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
|
|
@@ -1138,10 +1120,17 @@ def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, s
|
|
| 1138 |
return pipe, device, dtype
|
| 1139 |
|
| 1140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1141 |
def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str = "Dress"):
|
| 1142 |
"""
|
| 1143 |
-
|
| 1144 |
-
images(list[PIL]), mask_pil(PIL), depth_map(PIL), person_pil(PIL), garment_pil(PIL), garment_mask_pil(PIL)
|
| 1145 |
"""
|
| 1146 |
global H, W
|
| 1147 |
pipe, device, _dtype = get_pipe_and_device()
|
|
@@ -1149,9 +1138,11 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str
|
|
| 1149 |
|
| 1150 |
H, W = compute_hw_from_person(paths.person_path)
|
| 1151 |
|
| 1152 |
-
# โ
UI
|
|
|
|
|
|
|
| 1153 |
res = run_simple_extractor(
|
| 1154 |
-
category=
|
| 1155 |
input_path=os.path.abspath(paths.person_path),
|
| 1156 |
model_restore=schp_ckpt,
|
| 1157 |
)
|
|
@@ -1246,7 +1237,8 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str
|
|
| 1246 |
"depth:", depth_map.size,
|
| 1247 |
"garment:", garment_pil.size,
|
| 1248 |
"gmask:", garment_mask_pil.size,
|
| 1249 |
-
"
|
|
|
|
| 1250 |
flush=True
|
| 1251 |
)
|
| 1252 |
|
|
@@ -1309,7 +1301,6 @@ def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed, category):
|
|
| 1309 |
if person_fp is None or style_fp is None:
|
| 1310 |
raise gr.Error("person / style ์ด๋ฏธ์ง๋ ํ์์
๋๋ค. (sketch๋ ์ ํ)")
|
| 1311 |
|
| 1312 |
-
# โ
category๋ UI ๋ผ๋์ค์์ ๋ค์ด์ค๋ฉฐ ๊ธฐ๋ณธ๊ฐ์ "Dress"
|
| 1313 |
if category not in ("Upper-body", "Lower-body", "Dress"):
|
| 1314 |
raise gr.Error(f"Invalid category: {category}")
|
| 1315 |
|
|
@@ -1336,23 +1327,22 @@ def infer_web(person_fp, sketch_fp, style_fp, prompt, steps, seed, category):
|
|
| 1336 |
with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
|
| 1337 |
gr.Markdown("## VISTA Demo\nperson / style ํ์, sketch(guide)๋ ์ ํ์
๋๋ค.")
|
| 1338 |
|
| 1339 |
-
# โ
|
| 1340 |
category_toggle = gr.Radio(
|
| 1341 |
-
choices=["Upper-body", "Lower-body"
|
| 1342 |
value="Dress",
|
| 1343 |
label="Category",
|
| 1344 |
interactive=True,
|
| 1345 |
)
|
| 1346 |
|
|
|
|
| 1347 |
with gr.Row():
|
| 1348 |
person_in = gr.Image(label="Person Image (required)", type="filepath")
|
| 1349 |
style_in = gr.Image(label="Style Image (required)", type="filepath")
|
|
|
|
| 1350 |
|
| 1351 |
with gr.Accordion("Sketch / Guide (optional)", open=False):
|
| 1352 |
-
sketch_in = gr.Image(
|
| 1353 |
-
label="Sketch / Guide",
|
| 1354 |
-
type="filepath"
|
| 1355 |
-
)
|
| 1356 |
|
| 1357 |
with gr.Row():
|
| 1358 |
prompt_in = gr.Textbox(label="Prompt", value="upper garment", lines=2)
|
|
@@ -1361,7 +1351,7 @@ with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
|
|
| 1361 |
|
| 1362 |
run_btn = gr.Button("Run")
|
| 1363 |
|
| 1364 |
-
|
| 1365 |
out_file = gr.File(label="Download result.png")
|
| 1366 |
|
| 1367 |
gr.Markdown("### Debug Visualizations (mask/depth/etc)")
|
|
@@ -1384,3 +1374,5 @@ demo.queue()
|
|
| 1384 |
if __name__ == "__main__":
|
| 1385 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 1386 |
|
|
|
|
|
|
|
|
|
| 715 |
# if __name__ == "__main__":
|
| 716 |
# demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 717 |
|
| 718 |
+
|
| 719 |
import os
|
| 720 |
import sys
|
| 721 |
|
|
|
|
| 744 |
|
| 745 |
from huggingface_hub import hf_hub_download
|
| 746 |
|
|
|
|
| 747 |
import diffusers3
|
| 748 |
print("[BOOT] diffusers3 loaded from:", getattr(diffusers3, "__file__", "<?>"), flush=True)
|
| 749 |
|
|
|
|
| 823 |
person_pil: Image.Image,
|
| 824 |
parsing_img: Image.Image
|
| 825 |
) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 826 |
person_rgb = np.array(person_pil.convert("RGB"), dtype=np.uint8)
|
| 827 |
|
| 828 |
mask = np.array(parsing_img.convert("L"), dtype=np.uint8)
|
|
|
|
| 839 |
return result_bgr
|
| 840 |
|
| 841 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 842 |
def clean_and_smooth_parsing_mask(
|
| 843 |
parsing_img: Image.Image,
|
| 844 |
*,
|
|
|
|
| 849 |
morph_iters: int = 1,
|
| 850 |
blur_ksize: int = 0,
|
| 851 |
) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 852 |
if not isinstance(parsing_img, Image.Image):
|
| 853 |
raise TypeError("parsing_img must be a PIL.Image.Image")
|
| 854 |
|
|
|
|
| 857 |
|
| 858 |
mask = np.where(arr >= white_threshold, 255, 0).astype(np.uint8)
|
| 859 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 860 |
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
|
| 861 |
|
| 862 |
keep = np.zeros_like(mask)
|
|
|
|
| 1013 |
|
| 1014 |
depth_img = _edges_from_parsing(parsing_img)
|
| 1015 |
|
| 1016 |
+
# inverted_depth = cv2.bitwise_not(depth_img)
|
| 1017 |
+
contours, _ = cv2.findContours(depth_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 1018 |
|
| 1019 |
+
filled_depth = depth_img.copy()
|
| 1020 |
cv2.drawContours(filled_depth, contours, -1, (255), thickness=cv2.FILLED)
|
| 1021 |
|
| 1022 |
filled_depth = cv2.resize(filled_depth, (W, H), interpolation=cv2.INTER_AREA)
|
|
|
|
| 1120 |
return pipe, device, dtype
|
| 1121 |
|
| 1122 |
|
| 1123 |
+
# โ
UI ํ๊ธฐ โ ๋ด๋ถ extractor category ๋ฌธ์์ด ๋งคํ
|
| 1124 |
+
_UI_TO_EXTRACTOR_CATEGORY = {
|
| 1125 |
+
"Upper-body": "Upper-cloth",
|
| 1126 |
+
"Lower-body": "Bottom",
|
| 1127 |
+
"Dress": "Dress",
|
| 1128 |
+
}
|
| 1129 |
+
|
| 1130 |
+
|
| 1131 |
def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS, category: str = "Dress"):
|
| 1132 |
"""
|
| 1133 |
+
category: UI์์ ๋์ด์ค๋ ๊ฐ(Upper-body/Lower-body/Dress)
|
|
|
|
| 1134 |
"""
|
| 1135 |
global H, W
|
| 1136 |
pipe, device, _dtype = get_pipe_and_device()
|
|
|
|
| 1138 |
|
| 1139 |
H, W = compute_hw_from_person(paths.person_path)
|
| 1140 |
|
| 1141 |
+
# โ
UI category๋ฅผ extractor๊ฐ ๊ธฐ๋ํ๋ ๋ฌธ์์ด๋ก ๋ณํ
|
| 1142 |
+
extractor_category = _UI_TO_EXTRACTOR_CATEGORY.get(category, "Dress")
|
| 1143 |
+
|
| 1144 |
res = run_simple_extractor(
|
| 1145 |
+
category=extractor_category,
|
| 1146 |
input_path=os.path.abspath(paths.person_path),
|
| 1147 |
model_restore=schp_ckpt,
|
| 1148 |
)
|
|
|
|
| 1237 |
"depth:", depth_map.size,
|
| 1238 |
"garment:", garment_pil.size,
|
| 1239 |
"gmask:", garment_mask_pil.size,
|
| 1240 |
+
"ui_category:", category,
|
| 1241 |
+
"extractor_category:", extractor_category,
|
| 1242 |
flush=True
|
| 1243 |
)
|
| 1244 |
|
|
|
|
| 1301 |
if person_fp is None or style_fp is None:
|
| 1302 |
raise gr.Error("person / style ์ด๋ฏธ์ง๋ ํ์์
๋๋ค. (sketch๋ ์ ํ)")
|
| 1303 |
|
|
|
|
| 1304 |
if category not in ("Upper-body", "Lower-body", "Dress"):
|
| 1305 |
raise gr.Error(f"Invalid category: {category}")
|
| 1306 |
|
|
|
|
| 1327 |
with gr.Blocks(title="VISTA Demo (HF Spaces)") as demo:
|
| 1328 |
gr.Markdown("## VISTA Demo\nperson / style ํ์, sketch(guide)๋ ์ ํ์
๋๋ค.")
|
| 1329 |
|
| 1330 |
+
# โ
UI ํ๊ธฐ๋ Upper-body/Lower-body/Dress ์ ์ง (๊ธฐ๋ณธ Dress)
|
| 1331 |
category_toggle = gr.Radio(
|
| 1332 |
+
choices=["Dress", "Upper-body", "Lower-body"],
|
| 1333 |
value="Dress",
|
| 1334 |
label="Category",
|
| 1335 |
interactive=True,
|
| 1336 |
)
|
| 1337 |
|
| 1338 |
+
# โ
ํ ํ์ Person / Style / Output ๋ฐฐ์น
|
| 1339 |
with gr.Row():
|
| 1340 |
person_in = gr.Image(label="Person Image (required)", type="filepath")
|
| 1341 |
style_in = gr.Image(label="Style Image (required)", type="filepath")
|
| 1342 |
+
out_img = gr.Image(label="Output", type="pil")
|
| 1343 |
|
| 1344 |
with gr.Accordion("Sketch / Guide (optional)", open=False):
|
| 1345 |
+
sketch_in = gr.Image(label="Sketch / Guide", type="filepath")
|
|
|
|
|
|
|
|
|
|
| 1346 |
|
| 1347 |
with gr.Row():
|
| 1348 |
prompt_in = gr.Textbox(label="Prompt", value="upper garment", lines=2)
|
|
|
|
| 1351 |
|
| 1352 |
run_btn = gr.Button("Run")
|
| 1353 |
|
| 1354 |
+
# ํ์ผ ๋ค์ด๋ก๋๋ Output ์๋(๋ค์ ํ)์ ๋๋ ๊ฒ ์ผ๋ฐ์ ์ผ๋ก ๋ณด๊ธฐ ์ข์
|
| 1355 |
out_file = gr.File(label="Download result.png")
|
| 1356 |
|
| 1357 |
gr.Markdown("### Debug Visualizations (mask/depth/etc)")
|
|
|
|
| 1374 |
if __name__ == "__main__":
|
| 1375 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 1376 |
|
| 1377 |
+
|
| 1378 |
+
|