Nandha2017 commited on
Commit
1cfca86
·
verified ·
1 Parent(s): 0aa8990

Fix status updates: use gr.Progress instead of yield generator

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -19,7 +19,6 @@ DATA_DIR = "/data" if os.path.exists("/data") else "/tmp"
19
  OUTPUT_DIR = os.path.join(DATA_DIR, "outputs")
20
  os.makedirs(OUTPUT_DIR, exist_ok=True)
21
 
22
- # Point HF cache to persistent storage so model downloads survive restarts
23
  os.environ["HF_HOME"] = os.path.join(DATA_DIR, "hf_cache")
24
  os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(DATA_DIR, "hf_cache", "hub")
25
 
@@ -42,12 +41,12 @@ def _make_mask(size: int, cloth_type: str) -> Image.Image:
42
  d.rectangle([int(size*.10), int(size*.18), int(size*.90), int(size*.65)], fill=255)
43
  elif cloth_type == "lower":
44
  d.rectangle([int(size*.05), int(size*.55), int(size*.95), int(size*1.0)], fill=255)
45
- else: # overall / dress
46
  d.rectangle([int(size*.05), int(size*.15), int(size*.95), int(size*1.0)], fill=255)
47
  return mask
48
 
49
  # ---------------------------------------------------------------------------
50
- # GPU-decorated inference — plain return, no yield (spaces.GPU doesn't stream)
51
  # ---------------------------------------------------------------------------
52
  _pipe = None
53
 
@@ -57,13 +56,13 @@ def _run_inference(person: Image.Image, garment: Image.Image, mask: Image.Image,
57
  global _pipe
58
  if _pipe is None:
59
  from diffusers import PaintByExamplePipeline
60
- print("Loading Paint-by-Example pipeline (~5 GB, first run only)…")
61
  _pipe = PaintByExamplePipeline.from_pretrained(
62
  "Fantasy-Studio/Paint-by-Example",
63
  torch_dtype=torch.float16,
64
  ).to("cuda")
65
  _pipe.set_progress_bar_config(disable=True)
66
- print("Pipeline ready on CUDA.")
67
 
68
  rng = torch.Generator(device="cuda")
69
  rng.manual_seed(int(seed) if seed != -1 else torch.randint(0, 2**32, (1,)).item())
@@ -79,7 +78,7 @@ def _run_inference(person: Image.Image, garment: Image.Image, mask: Image.Image,
79
  return result.images
80
 
81
  # ---------------------------------------------------------------------------
82
- # Outer generatorhandles status streaming to Gradio UI
83
  # ---------------------------------------------------------------------------
84
  def run_tryon(
85
  person_image: Image.Image,
@@ -88,11 +87,12 @@ def run_tryon(
88
  num_steps: int,
89
  guidance_scale: float,
90
  seed: int,
 
91
  ):
92
  if person_image is None or garment_image is None:
93
  raise gr.Error("Please upload both a person photo and a garment image.")
94
 
95
- yield [], [], "⏳ Requesting GPU + loading model (first run ~3 min, then ~30s)"
96
 
97
  person = _fit_to_square(person_image)
98
  garment = _fit_to_square(garment_image)
@@ -100,6 +100,8 @@ def run_tryon(
100
 
101
  output_images = _run_inference(person, garment, mask, num_steps, guidance_scale, seed)
102
 
 
 
103
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
104
  saved_paths = []
105
  for i, img in enumerate(output_images):
@@ -107,7 +109,8 @@ def run_tryon(
107
  img.save(path, format="PNG")
108
  saved_paths.append(path)
109
 
110
- yield output_images, saved_paths, "✅ Done! Download your result below."
 
111
 
112
  # ---------------------------------------------------------------------------
113
  # Gradio UI
@@ -116,9 +119,8 @@ with gr.Blocks(title="Virtual Try-On", theme=gr.themes.Soft()) as demo:
116
  gr.Markdown(
117
  "# 👗 Virtual Try-On\n"
118
  "Upload a **person photo** and a **garment image**, select the type, then click **Try On**.\n\n"
119
- "> Runs entirely on **Hugging Face ZeroGPU** (free A10G) — no local GPU needed. \n"
120
- "> Models download once to HF persistent storage. Images save to your device via the Download button.\n\n"
121
- "> **First run:** ~2-3 min (model download). **Subsequent runs:** ~15-30s."
122
  )
123
 
124
  with gr.Row():
 
19
  OUTPUT_DIR = os.path.join(DATA_DIR, "outputs")
20
  os.makedirs(OUTPUT_DIR, exist_ok=True)
21
 
 
22
  os.environ["HF_HOME"] = os.path.join(DATA_DIR, "hf_cache")
23
  os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(DATA_DIR, "hf_cache", "hub")
24
 
 
41
  d.rectangle([int(size*.10), int(size*.18), int(size*.90), int(size*.65)], fill=255)
42
  elif cloth_type == "lower":
43
  d.rectangle([int(size*.05), int(size*.55), int(size*.95), int(size*1.0)], fill=255)
44
+ else:
45
  d.rectangle([int(size*.05), int(size*.15), int(size*.95), int(size*1.0)], fill=255)
46
  return mask
47
 
48
  # ---------------------------------------------------------------------------
49
+ # GPU-decorated inference
50
  # ---------------------------------------------------------------------------
51
  _pipe = None
52
 
 
56
  global _pipe
57
  if _pipe is None:
58
  from diffusers import PaintByExamplePipeline
59
+ print("Loading Paint-by-Example (~5 GB, first run only)…")
60
  _pipe = PaintByExamplePipeline.from_pretrained(
61
  "Fantasy-Studio/Paint-by-Example",
62
  torch_dtype=torch.float16,
63
  ).to("cuda")
64
  _pipe.set_progress_bar_config(disable=True)
65
+ print("Pipeline ready.")
66
 
67
  rng = torch.Generator(device="cuda")
68
  rng.manual_seed(int(seed) if seed != -1 else torch.randint(0, 2**32, (1,)).item())
 
78
  return result.images
79
 
80
  # ---------------------------------------------------------------------------
81
+ # Gradio inferenceuses gr.Progress for live status updates
82
  # ---------------------------------------------------------------------------
83
  def run_tryon(
84
  person_image: Image.Image,
 
87
  num_steps: int,
88
  guidance_scale: float,
89
  seed: int,
90
+ progress=gr.Progress(track_tqdm=True),
91
  ):
92
  if person_image is None or garment_image is None:
93
  raise gr.Error("Please upload both a person photo and a garment image.")
94
 
95
+ progress(0, desc="⏳ Requesting GPU + loading model (first run ~3 min)")
96
 
97
  person = _fit_to_square(person_image)
98
  garment = _fit_to_square(garment_image)
 
100
 
101
  output_images = _run_inference(person, garment, mask, num_steps, guidance_scale, seed)
102
 
103
+ progress(0.9, desc="💾 Saving result…")
104
+
105
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
106
  saved_paths = []
107
  for i, img in enumerate(output_images):
 
109
  img.save(path, format="PNG")
110
  saved_paths.append(path)
111
 
112
+ progress(1.0, desc="✅ Done!")
113
+ return output_images, saved_paths, "✅ Done! Download your result below."
114
 
115
  # ---------------------------------------------------------------------------
116
  # Gradio UI
 
119
  gr.Markdown(
120
  "# 👗 Virtual Try-On\n"
121
  "Upload a **person photo** and a **garment image**, select the type, then click **Try On**.\n\n"
122
+ "> Runs on **Hugging Face ZeroGPU** (free A10G) — no local GPU needed. \n"
123
+ "> **First run:** ~2-3 min (model download ~5 GB). **Subsequent runs:** ~15-30s."
 
124
  )
125
 
126
  with gr.Row():