ArmanRV commited on
Commit
6a3ad1d
·
verified ·
1 Parent(s): 0b7f5fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -66
app.py CHANGED
@@ -5,170 +5,290 @@ import tempfile
5
  from typing import Optional, Tuple, List
6
 
7
  import gradio as gr
8
- import spaces
9
  from PIL import Image
10
  from gradio_client import Client, handle_file
11
  from huggingface_hub import login
12
 
 
 
 
13
  SPACE = "yisol/IDM-VTON"
14
  API_NAME = "/tryon"
15
 
16
  # ----------------------------
17
- # Auth (from HF Secrets)
 
 
 
18
  # ----------------------------
19
  DEMO_USER = os.getenv("DEMO_USER", "").strip()
20
  DEMO_PASS = os.getenv("DEMO_PASS", "").strip()
21
  APP_AUTH = (DEMO_USER, DEMO_PASS) if (DEMO_USER and DEMO_PASS) else None
22
 
23
  # ----------------------------
24
- # Garment catalog
25
  # ----------------------------
26
  GARMENT_DIR = "garments"
27
  ALLOWED_EXTS = (".png", ".jpg", ".jpeg", ".webp")
28
 
 
29
  def list_garments() -> List[str]:
30
  try:
31
- return sorted(
32
- f for f in os.listdir(GARMENT_DIR)
33
- if f.lower().endswith(ALLOWED_EXTS) and not f.startswith(".")
34
- )
 
 
35
  except Exception:
36
  return []
37
 
 
38
  def garment_path(filename: str) -> str:
39
  return os.path.join(GARMENT_DIR, filename)
40
 
 
41
  def load_garment_pil(filename: str) -> Optional[Image.Image]:
 
 
42
  path = garment_path(filename)
43
- if os.path.exists(path):
 
 
44
  return Image.open(path).convert("RGB")
45
- return None
 
 
46
 
47
  def build_gallery_items(files: List[str]):
 
48
  return [(garment_path(f), "") for f in files]
49
 
 
50
  # ----------------------------
51
- # HF login (optional)
52
  # ----------------------------
53
  HF_TOKEN = os.getenv("HF_TOKEN", "")
 
 
54
  if HF_TOKEN:
55
  try:
56
  login(token=HF_TOKEN, add_to_git_credential=False)
57
- except:
58
- pass
 
 
 
 
59
 
 
 
 
60
  _client: Optional[Client] = None
61
 
 
62
  def reset_client():
63
  global _client
64
  _client = None
65
 
 
66
  def get_client() -> Client:
 
 
 
 
67
  global _client
68
  if _client is None:
69
  try:
70
- _client = Client(SPACE)
 
 
 
71
  except TypeError:
72
  _client = Client(SPACE)
73
  return _client
74
 
75
- def save_pil_temp(pil_img: Image.Image) -> str:
76
- f = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
 
 
 
 
 
 
 
 
 
 
 
 
77
  path = f.name
78
  f.close()
79
- pil_img.save(path, format="PNG")
80
  return path
81
 
 
82
  # ----------------------------
83
- # Rate limit
 
84
  # ----------------------------
85
  _last_call_ts = 0.0
86
 
 
87
  def allow_call(min_interval_sec: float = 3.0) -> Tuple[bool, str]:
88
  global _last_call_ts
89
  now = time.time()
90
  if now - _last_call_ts < min_interval_sec:
91
- wait = min_interval_sec - (now - _last_call_ts)
92
- return False, f"⏳ Подождите {wait:.1f} сек."
93
  _last_call_ts = now
94
  return True, ""
95
 
 
96
  # ----------------------------
97
- # Try-on
98
  # ----------------------------
99
- @spaces.GPU(duration=20)
100
- def tryon_remote(person_pil, selected_filename):
101
  ok, msg = allow_call(3.0)
102
  if not ok:
103
  return None, msg
104
 
105
  if person_pil is None:
106
  return None, "❌ Загрузите фото человека"
107
- if not selected_filename:
108
- return None, "❌ Выберите одежду"
109
 
110
- garment_pil = load_garment_pil(selected_filename)
 
 
 
111
  if garment_pil is None:
112
- return None, "❌ Ошибка загрузки одежды"
 
 
 
 
 
 
 
113
 
114
  p_path = save_pil_temp(person_pil)
115
  g_path = save_pil_temp(garment_pil)
116
 
117
  try:
118
- client = get_client()
119
-
120
- result = client.predict(
121
- dict={"background": handle_file(p_path), "layers": [], "composite": None},
122
- garm_img=handle_file(g_path),
123
- garment_des="a photo of a garment",
124
- is_checked=True,
125
- is_checked_crop=True,
126
- denoise_steps=25,
127
- seed=42,
128
- api_name=API_NAME,
129
- )
130
 
131
- if isinstance(result, (list, tuple)):
132
- result = result[0]
133
-
134
- out = Image.open(result).convert("RGB")
135
- return out, "✅ Готово"
136
-
137
- except Exception as e:
138
- reset_client()
139
- return None, f"❌ Ошибка: {str(e)[:200]}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  finally:
142
  for path in (p_path, g_path):
143
  try:
144
  os.remove(path)
145
- except:
146
  pass
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  # ----------------------------
149
  # UI
150
  # ----------------------------
151
  CUSTOM_CSS = """
152
  footer {display:none !important;}
 
 
 
153
  """
154
 
155
- files = list_garments()
156
- items = build_gallery_items(files)
157
 
158
  with gr.Blocks(title="Virtual Try-On Rendez-vous", css=CUSTOM_CSS) as demo:
159
  gr.Markdown("# Virtual Try-On Rendez-vous")
160
 
161
- selected_state = gr.State(None)
 
162
 
163
  with gr.Row():
164
  with gr.Column():
165
  person = gr.Image(label="Фото человека", type="pil", height=420)
166
 
 
 
 
 
167
  garment_gallery = gr.Gallery(
168
- label="Выберите одежду",
169
- value=items,
170
  columns=4,
171
- height=300
 
172
  )
173
 
174
  run = gr.Button("Примерить", variant="primary")
@@ -177,27 +297,33 @@ with gr.Blocks(title="Virtual Try-On Rendez-vous", css=CUSTOM_CSS) as demo:
177
  with gr.Column():
178
  out = gr.Image(label="Результат", type="pil", height=760)
179
 
180
- def on_select(files, evt: gr.SelectData):
181
- if not files:
182
- return None
183
- return files[evt.index]
184
-
185
  garment_gallery.select(
186
- fn=on_select,
187
- inputs=[gr.State(files)],
188
- outputs=[selected_state]
 
 
 
 
 
 
 
189
  )
190
 
 
191
  run.click(
192
  fn=tryon_remote,
193
- inputs=[person, selected_state],
194
- outputs=[out, status]
195
  )
196
 
197
  if __name__ == "__main__":
198
  demo.launch(
199
  server_name="0.0.0.0",
200
  server_port=7860,
201
- auth=APP_AUTH,
202
- share=False
 
 
203
  )
 
5
  from typing import Optional, Tuple, List
6
 
7
  import gradio as gr
 
8
  from PIL import Image
9
  from gradio_client import Client, handle_file
10
  from huggingface_hub import login
11
 
12
+ # ----------------------------
13
+ # Remote Space (IDM-VTON)
14
+ # ----------------------------
15
  SPACE = "yisol/IDM-VTON"
16
  API_NAME = "/tryon"
17
 
18
  # ----------------------------
19
+ # Auth for company demo (no HF accounts needed)
20
+ # Set these in HF Space Secrets:
21
+ # DEMO_USER=RVtest
22
+ # DEMO_PASS=rv2026
23
  # ----------------------------
24
  DEMO_USER = os.getenv("DEMO_USER", "").strip()
25
  DEMO_PASS = os.getenv("DEMO_PASS", "").strip()
26
  APP_AUTH = (DEMO_USER, DEMO_PASS) if (DEMO_USER and DEMO_PASS) else None
27
 
28
  # ----------------------------
29
+ # Garment catalog folder in repo
30
  # ----------------------------
31
  GARMENT_DIR = "garments"
32
  ALLOWED_EXTS = (".png", ".jpg", ".jpeg", ".webp")
33
 
34
+
35
  def list_garments() -> List[str]:
36
  try:
37
+ files = []
38
+ for f in os.listdir(GARMENT_DIR):
39
+ if f.lower().endswith(ALLOWED_EXTS) and not f.startswith("."):
40
+ files.append(f)
41
+ files.sort()
42
+ return files
43
  except Exception:
44
  return []
45
 
46
+
47
  def garment_path(filename: str) -> str:
48
  return os.path.join(GARMENT_DIR, filename)
49
 
50
+
51
  def load_garment_pil(filename: str) -> Optional[Image.Image]:
52
+ if not filename:
53
+ return None
54
  path = garment_path(filename)
55
+ if not os.path.exists(path):
56
+ return None
57
+ try:
58
  return Image.open(path).convert("RGB")
59
+ except Exception:
60
+ return None
61
+
62
 
63
  def build_gallery_items(files: List[str]):
64
+ # (image, caption). Caption empty = cleaner UI
65
  return [(garment_path(f), "") for f in files]
66
 
67
+
68
  # ----------------------------
69
+ # HF token (optional)
70
  # ----------------------------
71
  HF_TOKEN = os.getenv("HF_TOKEN", "")
72
+ print("HF_TOKEN set:", bool(HF_TOKEN), "len:", len(HF_TOKEN) if HF_TOKEN else 0)
73
+
74
  if HF_TOKEN:
75
  try:
76
  login(token=HF_TOKEN, add_to_git_credential=False)
77
+ print("HF login: OK")
78
+ except Exception as e:
79
+ print("HF login: FAILED:", str(e)[:200])
80
+ else:
81
+ print("HF login: skipped (no token in env)")
82
+
83
 
84
+ # ----------------------------
85
+ # Client caching
86
+ # ----------------------------
87
  _client: Optional[Client] = None
88
 
89
+
90
  def reset_client():
91
  global _client
92
  _client = None
93
 
94
+
95
  def get_client() -> Client:
96
+ """
97
+ gradio_client differs by version. Newer versions support hf_token=...
98
+ Older versions don't. We fallback gracefully.
99
+ """
100
  global _client
101
  if _client is None:
102
  try:
103
+ if HF_TOKEN:
104
+ _client = Client(SPACE, hf_token=HF_TOKEN) # may raise TypeError on older versions
105
+ else:
106
+ _client = Client(SPACE)
107
  except TypeError:
108
  _client = Client(SPACE)
109
  return _client
110
 
111
+
112
+ # ----------------------------
113
+ # Helpers
114
+ # ----------------------------
115
+ def clamp_int(x, lo, hi):
116
+ try:
117
+ x = int(x)
118
+ except Exception:
119
+ x = lo
120
+ return max(lo, min(hi, x))
121
+
122
+
123
+ def save_pil_temp(pil_img: Image.Image, suffix: str = ".png") -> str:
124
+ f = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
125
  path = f.name
126
  f.close()
127
+ pil_img.save(path, format="PNG") # no resize/compress
128
  return path
129
 
130
+
131
  # ----------------------------
132
+ # Simple global rate limit (anti spam)
133
+ # NOTE: global across all users. Good enough for internal demo.
134
  # ----------------------------
135
  _last_call_ts = 0.0
136
 
137
+
138
  def allow_call(min_interval_sec: float = 3.0) -> Tuple[bool, str]:
139
  global _last_call_ts
140
  now = time.time()
141
  if now - _last_call_ts < min_interval_sec:
142
+ wait = max(0.0, min_interval_sec - (now - _last_call_ts))
143
+ return False, f"⏳ Слишком часто. Подождите {wait:.1f} сек."
144
  _last_call_ts = now
145
  return True, ""
146
 
147
+
148
  # ----------------------------
149
+ # Core inference (remote call)
150
  # ----------------------------
151
+ def tryon_remote(person_pil, garment_filename):
 
152
  ok, msg = allow_call(3.0)
153
  if not ok:
154
  return None, msg
155
 
156
  if person_pil is None:
157
  return None, "❌ Загрузите фото человека"
 
 
158
 
159
+ if not garment_filename:
160
+ return None, "❌ Выберите одежду (кликните на превью)"
161
+
162
+ garment_pil = load_garment_pil(garment_filename)
163
  if garment_pil is None:
164
+ return None, "❌ Не удалось загрузить выбранную одежду (проверьте garments/)"
165
+
166
+ # Fixed params for simple demo
167
+ garment_desc = "a photo of a garment"
168
+ auto_mask = True
169
+ crop_center = True
170
+ denoise_steps = 25
171
+ seed = 42
172
 
173
  p_path = save_pil_temp(person_pil)
174
  g_path = save_pil_temp(garment_pil)
175
 
176
  try:
177
+ last_err = None
 
 
 
 
 
 
 
 
 
 
 
178
 
179
+ for attempt in range(1, 7):
180
+ try:
181
+ client = get_client()
182
+
183
+ result = client.predict(
184
+ dict={"background": handle_file(p_path), "layers": [], "composite": None},
185
+ garm_img=handle_file(g_path),
186
+ garment_des=garment_desc,
187
+ is_checked=bool(auto_mask),
188
+ is_checked_crop=bool(crop_center),
189
+ denoise_steps=int(denoise_steps),
190
+ seed=int(seed),
191
+ api_name=API_NAME,
192
+ )
193
+
194
+ if isinstance(result, (list, tuple)):
195
+ result = result[0]
196
+
197
+ out = Image.open(result).convert("RGB")
198
+ return out, "✅ Готово"
199
+
200
+ except Exception as e:
201
+ last_err = e
202
+ msg_l = str(e).lower()
203
+
204
+ is_timeout = (
205
+ "write operation timed out" in msg_l
206
+ or "read operation timed out" in msg_l
207
+ or "timed out" in msg_l
208
+ )
209
+
210
+ is_busy = (
211
+ "too many requests" in msg_l
212
+ or "queue" in msg_l
213
+ or "too busy" in msg_l
214
+ or "overloaded" in msg_l
215
+ or "capacity" in msg_l
216
+ )
217
+
218
+ is_expired = "expired zerogpu proxy token" in msg_l or "zerogpu proxy token" in msg_l
219
+
220
+ if is_timeout or is_busy or is_expired:
221
+ reset_client()
222
+ time.sleep(4.0 * attempt)
223
+ continue
224
+
225
+ time.sleep(1.2 * attempt)
226
+
227
+ tail = str(last_err)[:240] if last_err else "unknown error"
228
+ return None, f"❌ Ошибка Space после 6 попыток: {tail}"
229
 
230
  finally:
231
  for path in (p_path, g_path):
232
  try:
233
  os.remove(path)
234
+ except Exception:
235
  pass
236
 
237
+
238
+ # ----------------------------
239
+ # UI helpers
240
+ # ----------------------------
241
+ def refresh_catalog():
242
+ files = list_garments()
243
+ items = build_gallery_items(files)
244
+ status = "✅ Каталог обновлён" if files else "⚠️ В папке garments/ пока нет изображений"
245
+ return items, files, None, status
246
+
247
+
248
+ def on_gallery_select(files: List[str], evt: gr.SelectData):
249
+ if not files:
250
+ return None, "⚠️ Каталог пуст"
251
+ try:
252
+ idx = int(evt.index) if evt.index is not None else 0
253
+ idx = max(0, min(idx, len(files) - 1))
254
+ return files[idx], f"👕 Выбрано: {files[idx]}"
255
+ except Exception:
256
+ return None, "⚠️ Не удалось выбрать одежду"
257
+
258
+
259
  # ----------------------------
260
  # UI
261
  # ----------------------------
262
  CUSTOM_CSS = """
263
  footer {display:none !important;}
264
+ #api-info {display:none !important;}
265
+ div[class*="footer"] {display:none !important;}
266
+ button[aria-label="Settings"] {display:none !important;}
267
  """
268
 
269
+ initial_files = list_garments()
270
+ initial_items = build_gallery_items(initial_files)
271
 
272
  with gr.Blocks(title="Virtual Try-On Rendez-vous", css=CUSTOM_CSS) as demo:
273
  gr.Markdown("# Virtual Try-On Rendez-vous")
274
 
275
+ garment_files_state = gr.State(initial_files)
276
+ selected_garment_state = gr.State(None)
277
 
278
  with gr.Row():
279
  with gr.Column():
280
  person = gr.Image(label="Фото человека", type="pil", height=420)
281
 
282
+ with gr.Row():
283
+ refresh_btn = gr.Button("🔄 Обновить каталог", variant="secondary")
284
+ selected_label = gr.Markdown("👕 Выберите одежду, кликнув по превью ниже")
285
+
286
  garment_gallery = gr.Gallery(
287
+ label="Каталог одежды (кликните на превью)",
288
+ value=initial_items,
289
  columns=4,
290
+ height=340,
291
+ allow_preview=True,
292
  )
293
 
294
  run = gr.Button("Примерить", variant="primary")
 
297
  with gr.Column():
298
  out = gr.Image(label="Результат", type="pil", height=760)
299
 
300
+ # Update selection on click
 
 
 
 
301
  garment_gallery.select(
302
+ fn=on_gallery_select,
303
+ inputs=[garment_files_state],
304
+ outputs=[selected_garment_state, selected_label],
305
+ )
306
+
307
+ # Refresh catalog after uploading new garments
308
+ refresh_btn.click(
309
+ fn=refresh_catalog,
310
+ inputs=[],
311
+ outputs=[garment_gallery, garment_files_state, selected_garment_state, status],
312
  )
313
 
314
+ # Run try-on
315
  run.click(
316
  fn=tryon_remote,
317
+ inputs=[person, selected_garment_state],
318
+ outputs=[out, status],
319
  )
320
 
321
  if __name__ == "__main__":
322
  demo.launch(
323
  server_name="0.0.0.0",
324
  server_port=7860,
325
+ share=False,
326
+ debug=False,
327
+ ssr_mode=False,
328
+ auth=APP_AUTH, # ✅ login/password gate
329
  )