Spaces:
Running on Zero
Running on Zero
Fix status updates: use gr.Progress instead of yield generator
Browse files
app.py
CHANGED
|
@@ -19,7 +19,6 @@ DATA_DIR = "/data" if os.path.exists("/data") else "/tmp"
|
|
| 19 |
OUTPUT_DIR = os.path.join(DATA_DIR, "outputs")
|
| 20 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 21 |
|
| 22 |
-
# Point HF cache to persistent storage so model downloads survive restarts
|
| 23 |
os.environ["HF_HOME"] = os.path.join(DATA_DIR, "hf_cache")
|
| 24 |
os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(DATA_DIR, "hf_cache", "hub")
|
| 25 |
|
|
@@ -42,12 +41,12 @@ def _make_mask(size: int, cloth_type: str) -> Image.Image:
|
|
| 42 |
d.rectangle([int(size*.10), int(size*.18), int(size*.90), int(size*.65)], fill=255)
|
| 43 |
elif cloth_type == "lower":
|
| 44 |
d.rectangle([int(size*.05), int(size*.55), int(size*.95), int(size*1.0)], fill=255)
|
| 45 |
-
else:
|
| 46 |
d.rectangle([int(size*.05), int(size*.15), int(size*.95), int(size*1.0)], fill=255)
|
| 47 |
return mask
|
| 48 |
|
| 49 |
# ---------------------------------------------------------------------------
|
| 50 |
-
# GPU-decorated inference
|
| 51 |
# ---------------------------------------------------------------------------
|
| 52 |
_pipe = None
|
| 53 |
|
|
@@ -57,13 +56,13 @@ def _run_inference(person: Image.Image, garment: Image.Image, mask: Image.Image,
|
|
| 57 |
global _pipe
|
| 58 |
if _pipe is None:
|
| 59 |
from diffusers import PaintByExamplePipeline
|
| 60 |
-
print("Loading Paint-by-Example
|
| 61 |
_pipe = PaintByExamplePipeline.from_pretrained(
|
| 62 |
"Fantasy-Studio/Paint-by-Example",
|
| 63 |
torch_dtype=torch.float16,
|
| 64 |
).to("cuda")
|
| 65 |
_pipe.set_progress_bar_config(disable=True)
|
| 66 |
-
print("Pipeline ready
|
| 67 |
|
| 68 |
rng = torch.Generator(device="cuda")
|
| 69 |
rng.manual_seed(int(seed) if seed != -1 else torch.randint(0, 2**32, (1,)).item())
|
|
@@ -79,7 +78,7 @@ def _run_inference(person: Image.Image, garment: Image.Image, mask: Image.Image,
|
|
| 79 |
return result.images
|
| 80 |
|
| 81 |
# ---------------------------------------------------------------------------
|
| 82 |
-
#
|
| 83 |
# ---------------------------------------------------------------------------
|
| 84 |
def run_tryon(
|
| 85 |
person_image: Image.Image,
|
|
@@ -88,11 +87,12 @@ def run_tryon(
|
|
| 88 |
num_steps: int,
|
| 89 |
guidance_scale: float,
|
| 90 |
seed: int,
|
|
|
|
| 91 |
):
|
| 92 |
if person_image is None or garment_image is None:
|
| 93 |
raise gr.Error("Please upload both a person photo and a garment image.")
|
| 94 |
|
| 95 |
-
|
| 96 |
|
| 97 |
person = _fit_to_square(person_image)
|
| 98 |
garment = _fit_to_square(garment_image)
|
|
@@ -100,6 +100,8 @@ def run_tryon(
|
|
| 100 |
|
| 101 |
output_images = _run_inference(person, garment, mask, num_steps, guidance_scale, seed)
|
| 102 |
|
|
|
|
|
|
|
| 103 |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 104 |
saved_paths = []
|
| 105 |
for i, img in enumerate(output_images):
|
|
@@ -107,7 +109,8 @@ def run_tryon(
|
|
| 107 |
img.save(path, format="PNG")
|
| 108 |
saved_paths.append(path)
|
| 109 |
|
| 110 |
-
|
|
|
|
| 111 |
|
| 112 |
# ---------------------------------------------------------------------------
|
| 113 |
# Gradio UI
|
|
@@ -116,9 +119,8 @@ with gr.Blocks(title="Virtual Try-On", theme=gr.themes.Soft()) as demo:
|
|
| 116 |
gr.Markdown(
|
| 117 |
"# 👗 Virtual Try-On\n"
|
| 118 |
"Upload a **person photo** and a **garment image**, select the type, then click **Try On**.\n\n"
|
| 119 |
-
"> Runs
|
| 120 |
-
">
|
| 121 |
-
"> **First run:** ~2-3 min (model download). **Subsequent runs:** ~15-30s."
|
| 122 |
)
|
| 123 |
|
| 124 |
with gr.Row():
|
|
|
|
| 19 |
OUTPUT_DIR = os.path.join(DATA_DIR, "outputs")
|
| 20 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 21 |
|
|
|
|
| 22 |
os.environ["HF_HOME"] = os.path.join(DATA_DIR, "hf_cache")
|
| 23 |
os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(DATA_DIR, "hf_cache", "hub")
|
| 24 |
|
|
|
|
| 41 |
d.rectangle([int(size*.10), int(size*.18), int(size*.90), int(size*.65)], fill=255)
|
| 42 |
elif cloth_type == "lower":
|
| 43 |
d.rectangle([int(size*.05), int(size*.55), int(size*.95), int(size*1.0)], fill=255)
|
| 44 |
+
else:
|
| 45 |
d.rectangle([int(size*.05), int(size*.15), int(size*.95), int(size*1.0)], fill=255)
|
| 46 |
return mask
|
| 47 |
|
| 48 |
# ---------------------------------------------------------------------------
|
| 49 |
+
# GPU-decorated inference
|
| 50 |
# ---------------------------------------------------------------------------
|
| 51 |
_pipe = None
|
| 52 |
|
|
|
|
| 56 |
global _pipe
|
| 57 |
if _pipe is None:
|
| 58 |
from diffusers import PaintByExamplePipeline
|
| 59 |
+
print("Loading Paint-by-Example (~5 GB, first run only)…")
|
| 60 |
_pipe = PaintByExamplePipeline.from_pretrained(
|
| 61 |
"Fantasy-Studio/Paint-by-Example",
|
| 62 |
torch_dtype=torch.float16,
|
| 63 |
).to("cuda")
|
| 64 |
_pipe.set_progress_bar_config(disable=True)
|
| 65 |
+
print("Pipeline ready.")
|
| 66 |
|
| 67 |
rng = torch.Generator(device="cuda")
|
| 68 |
rng.manual_seed(int(seed) if seed != -1 else torch.randint(0, 2**32, (1,)).item())
|
|
|
|
| 78 |
return result.images
|
| 79 |
|
| 80 |
# ---------------------------------------------------------------------------
|
| 81 |
+
# Gradio inference — uses gr.Progress for live status updates
|
| 82 |
# ---------------------------------------------------------------------------
|
| 83 |
def run_tryon(
|
| 84 |
person_image: Image.Image,
|
|
|
|
| 87 |
num_steps: int,
|
| 88 |
guidance_scale: float,
|
| 89 |
seed: int,
|
| 90 |
+
progress=gr.Progress(track_tqdm=True),
|
| 91 |
):
|
| 92 |
if person_image is None or garment_image is None:
|
| 93 |
raise gr.Error("Please upload both a person photo and a garment image.")
|
| 94 |
|
| 95 |
+
progress(0, desc="⏳ Requesting GPU + loading model (first run ~3 min)…")
|
| 96 |
|
| 97 |
person = _fit_to_square(person_image)
|
| 98 |
garment = _fit_to_square(garment_image)
|
|
|
|
| 100 |
|
| 101 |
output_images = _run_inference(person, garment, mask, num_steps, guidance_scale, seed)
|
| 102 |
|
| 103 |
+
progress(0.9, desc="💾 Saving result…")
|
| 104 |
+
|
| 105 |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 106 |
saved_paths = []
|
| 107 |
for i, img in enumerate(output_images):
|
|
|
|
| 109 |
img.save(path, format="PNG")
|
| 110 |
saved_paths.append(path)
|
| 111 |
|
| 112 |
+
progress(1.0, desc="✅ Done!")
|
| 113 |
+
return output_images, saved_paths, "✅ Done! Download your result below."
|
| 114 |
|
| 115 |
# ---------------------------------------------------------------------------
|
| 116 |
# Gradio UI
|
|
|
|
| 119 |
gr.Markdown(
|
| 120 |
"# 👗 Virtual Try-On\n"
|
| 121 |
"Upload a **person photo** and a **garment image**, select the type, then click **Try On**.\n\n"
|
| 122 |
+
"> Runs on **Hugging Face ZeroGPU** (free A10G) — no local GPU needed. \n"
|
| 123 |
+
"> **First run:** ~2-3 min (model download ~5 GB). **Subsequent runs:** ~15-30s."
|
|
|
|
| 124 |
)
|
| 125 |
|
| 126 |
with gr.Row():
|