ArmanRV commited on
Commit
0b7f5fb
·
verified ·
1 Parent(s): 90f2241

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -203
app.py CHANGED
@@ -5,63 +5,58 @@ import tempfile
5
  from typing import Optional, Tuple, List
6
 
7
  import gradio as gr
8
- import spaces # ZeroGPU runtime requires at least one @spaces.GPU-decorated function
9
  from PIL import Image
10
  from gradio_client import Client, handle_file
11
  from huggingface_hub import login
12
 
13
- # ----------------------------
14
- # Remote Space (IDM-VTON)
15
- # ----------------------------
16
  SPACE = "yisol/IDM-VTON"
17
  API_NAME = "/tryon"
18
 
19
  # ----------------------------
20
- # Demo auth (no HF accounts needed)
21
- # Set these in HF Space Secrets:
22
- # DEMO_USER=companydemo
23
- # DEMO_PASS=your-strong-password
24
  # ----------------------------
25
  DEMO_USER = os.getenv("DEMO_USER", "").strip()
26
  DEMO_PASS = os.getenv("DEMO_PASS", "").strip()
27
  APP_AUTH = (DEMO_USER, DEMO_PASS) if (DEMO_USER and DEMO_PASS) else None
28
 
29
  # ----------------------------
30
- # Garment catalog folder in repo
31
  # ----------------------------
32
  GARMENT_DIR = "garments"
33
  ALLOWED_EXTS = (".png", ".jpg", ".jpeg", ".webp")
34
 
35
  def list_garments() -> List[str]:
36
  try:
37
- files = []
38
- for f in os.listdir(GARMENT_DIR):
39
- if f.lower().endswith(ALLOWED_EXTS) and not f.startswith("."):
40
- files.append(f)
41
- files.sort()
42
- return files
43
  except Exception:
44
  return []
45
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # ----------------------------
47
- # HF token (optional)
48
  # ----------------------------
49
  HF_TOKEN = os.getenv("HF_TOKEN", "")
50
-
51
- print("HF_TOKEN set:", bool(HF_TOKEN), "len:", len(HF_TOKEN) if HF_TOKEN else 0)
52
-
53
  if HF_TOKEN:
54
  try:
55
  login(token=HF_TOKEN, add_to_git_credential=False)
56
- print("HF login: OK")
57
- except Exception as e:
58
- print("HF login: FAILED:", str(e)[:200])
59
- else:
60
- print("HF login: skipped (no token in env)")
61
 
62
- # ----------------------------
63
- # Client caching
64
- # ----------------------------
65
  _client: Optional[Client] = None
66
 
67
  def reset_client():
@@ -69,53 +64,23 @@ def reset_client():
69
  _client = None
70
 
71
  def get_client() -> Client:
72
- """
73
- gradio_client differs by version. Newer versions support hf_token=...
74
- Older versions don't. We fallback gracefully.
75
- """
76
  global _client
77
  if _client is None:
78
  try:
79
- if HF_TOKEN:
80
- _client = Client(SPACE, hf_token=HF_TOKEN) # may raise TypeError on older versions
81
- else:
82
- _client = Client(SPACE)
83
  except TypeError:
84
  _client = Client(SPACE)
85
  return _client
86
 
87
- # ----------------------------
88
- # Helpers
89
- # ----------------------------
90
- def clamp_int(x, lo, hi):
91
- try:
92
- x = int(x)
93
- except Exception:
94
- x = lo
95
- return max(lo, min(hi, x))
96
-
97
- def save_pil_temp(pil_img: Image.Image, suffix: str = ".png") -> str:
98
- f = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
99
  path = f.name
100
  f.close()
101
- pil_img.save(path, format="PNG") # no resize/compress
102
  return path
103
 
104
- def load_garment_pil(filename: str) -> Optional[Image.Image]:
105
- if not filename:
106
- return None
107
- path = os.path.join(GARMENT_DIR, filename)
108
- if not os.path.exists(path):
109
- return None
110
- try:
111
- return Image.open(path).convert("RGB")
112
- except Exception:
113
- return None
114
-
115
  # ----------------------------
116
- # Simple global rate limit (anti spam)
117
- # NOTE: this is a global limiter across all users for this Space.
118
- # For internal demo it's usually enough. Adjust interval as needed.
119
  # ----------------------------
120
  _last_call_ts = 0.0
121
 
@@ -123,204 +88,116 @@ def allow_call(min_interval_sec: float = 3.0) -> Tuple[bool, str]:
123
  global _last_call_ts
124
  now = time.time()
125
  if now - _last_call_ts < min_interval_sec:
126
- wait = max(0.0, min_interval_sec - (now - _last_call_ts))
127
- return False, f"⏳ Слишком часто. Подождите {wait:.1f} сек."
128
  _last_call_ts = now
129
  return True, ""
130
 
131
  # ----------------------------
132
- # Core inference (remote call)
133
- # ZeroGPU: keep duration LOW to avoid burning quota
134
  # ----------------------------
135
  @spaces.GPU(duration=20)
136
- def tryon_remote(person_pil, garment_filename, garment_desc, auto_mask, crop_center, denoise_steps, seed):
137
  ok, msg = allow_call(3.0)
138
  if not ok:
139
  return None, msg
140
 
141
  if person_pil is None:
142
  return None, "❌ Загрузите фото человека"
 
 
143
 
144
- if not garment_filename:
145
- return None, "❌ Выберите одежду из списка"
146
-
147
- garment_pil = load_garment_pil(garment_filename)
148
  if garment_pil is None:
149
- return None, "❌ Не удалось загрузить выбранную одежду (проверьте файл в папке garments/)"
150
-
151
- denoise_steps = clamp_int(denoise_steps, 10, 40)
152
- seed = clamp_int(seed, 0, 999999)
153
-
154
- garment_desc = (garment_desc or "").strip()
155
- if not garment_desc:
156
- garment_desc = "a photo of a garment"
157
 
158
  p_path = save_pil_temp(person_pil)
159
  g_path = save_pil_temp(garment_pil)
160
 
161
  try:
162
- last_err = None
163
 
164
- for attempt in range(1, 7):
165
- try:
166
- client = get_client()
167
-
168
- result = client.predict(
169
- dict={"background": handle_file(p_path), "layers": [], "composite": None},
170
- garm_img=handle_file(g_path),
171
- garment_des=garment_desc,
172
- is_checked=bool(auto_mask),
173
- is_checked_crop=bool(crop_center),
174
- denoise_steps=int(denoise_steps),
175
- seed=int(seed),
176
- api_name=API_NAME,
177
- )
178
-
179
- if isinstance(result, (list, tuple)):
180
- result = result[0]
181
-
182
- out = Image.open(result).convert("RGB")
183
- return out, f" Готово (одежда: {garment_filename}, steps={denoise_steps}, seed={seed})"
184
-
185
- except Exception as e:
186
- last_err = e
187
- msg_l = str(e).lower()
188
-
189
- is_timeout = (
190
- "write operation timed out" in msg_l
191
- or "read operation timed out" in msg_l
192
- or "timed out" in msg_l
193
- )
194
-
195
- is_busy = (
196
- "too many requests" in msg_l
197
- or "queue" in msg_l
198
- or "too busy" in msg_l
199
- or "overloaded" in msg_l
200
- or "capacity" in msg_l
201
- or "zerogpu" in msg_l
202
- )
203
-
204
- is_expired = "expired zerogpu proxy token" in msg_l or "zerogpu proxy token" in msg_l
205
-
206
- if is_timeout or is_busy or is_expired:
207
- reset_client()
208
- time.sleep(4.0 * attempt)
209
- continue
210
-
211
- time.sleep(1.2 * attempt)
212
-
213
- tail = str(last_err)[:240] if last_err else "unknown error"
214
- return None, f"❌ Ошибка Space после 6 попыток: {tail}"
215
 
216
  finally:
217
  for path in (p_path, g_path):
218
  try:
219
  os.remove(path)
220
- except Exception:
221
  pass
222
 
223
- # ----------------------------
224
- # UI helpers
225
- # ----------------------------
226
- def refresh_garments():
227
- files = list_garments()
228
- value = files[0] if files else None
229
- preview = load_garment_pil(value) if value else None
230
- status = "✅ Каталог обновлён" if files else "⚠️ В папке garments/ пока нет изображений"
231
- return gr.Dropdown(choices=files, value=value), preview, status
232
-
233
- def on_select_garment(filename: str):
234
- return load_garment_pil(filename)
235
-
236
- def reset_ui():
237
- files = list_garments()
238
- value = files[0] if files else None
239
- preview = load_garment_pil(value) if value else None
240
- return None, value, preview, "", "Ожидание...", None
241
-
242
  # ----------------------------
243
  # UI
244
  # ----------------------------
245
  CUSTOM_CSS = """
246
  footer {display:none !important;}
247
- #api-info {display:none !important;}
248
- div[class*="footer"] {display:none !important;}
249
- button[aria-label="Settings"] {display:none !important;}
250
  """
251
 
252
- initial_files = list_garments()
253
- initial_value = initial_files[0] if initial_files else None
254
- initial_preview = load_garment_pil(initial_value) if initial_value else None
255
 
256
  with gr.Blocks(title="Virtual Try-On Rendez-vous", css=CUSTOM_CSS) as demo:
257
  gr.Markdown("# Virtual Try-On Rendez-vous")
258
 
 
 
259
  with gr.Row():
260
  with gr.Column():
261
  person = gr.Image(label="Фото человека", type="pil", height=420)
262
 
263
- with gr.Row():
264
- garment_selector = gr.Dropdown(
265
- choices=initial_files,
266
- value=initial_value,
267
- label="Выберите одежду из каталога (garments/)"
268
- )
269
- refresh_btn = gr.Button("🔄 Обновить каталог", variant="secondary")
270
-
271
- garment_preview = gr.Image(label="Предпросмотр одежды", type="pil", height=320)
272
-
273
- garment_desc = gr.Textbox(label="Описание одежды (опционально)", value="")
274
-
275
- with gr.Accordion("Настройки", open=False):
276
- auto_mask = gr.Checkbox(label="Auto-mask (Space)", value=True)
277
- crop_center = gr.Checkbox(label="Crop по центру", value=True)
278
- denoise_steps = gr.Slider(10, 40, value=25, step=1, label="Denoise steps")
279
- seed = gr.Slider(0, 999999, value=42, step=1, label="Seed")
280
-
281
- with gr.Row():
282
- run = gr.Button("Примерить", variant="primary")
283
- reset = gr.Button("Сбросить", variant="secondary")
284
 
 
285
  status = gr.Textbox(value="Ожидание...", interactive=False)
286
 
287
  with gr.Column():
288
  out = gr.Image(label="Результат", type="pil", height=760)
289
 
290
- # Preview updates
291
- garment_selector.change(
292
- fn=on_select_garment,
293
- inputs=[garment_selector],
294
- outputs=[garment_preview]
295
- )
296
 
297
- # Refresh catalog (after uploading new files)
298
- refresh_btn.click(
299
- fn=refresh_garments,
300
- inputs=[],
301
- outputs=[garment_selector, garment_preview, status]
302
  )
303
 
304
- # Run try-on
305
  run.click(
306
  fn=tryon_remote,
307
- inputs=[person, garment_selector, garment_desc, auto_mask, crop_center, denoise_steps, seed],
308
- outputs=[out, status],
309
- )
310
-
311
- # Reset UI
312
- reset.click(
313
- fn=reset_ui,
314
- inputs=[],
315
- outputs=[person, garment_selector, garment_preview, garment_desc, status, out],
316
  )
317
 
318
  if __name__ == "__main__":
319
  demo.launch(
320
  server_name="0.0.0.0",
321
  server_port=7860,
322
- share=False,
323
- debug=False,
324
- ssr_mode=False,
325
- auth=APP_AUTH, # ✅ login/password gate
326
  )
 
5
  from typing import Optional, Tuple, List
6
 
7
  import gradio as gr
8
+ import spaces
9
  from PIL import Image
10
  from gradio_client import Client, handle_file
11
  from huggingface_hub import login
12
 
 
 
 
13
  SPACE = "yisol/IDM-VTON"
14
  API_NAME = "/tryon"
15
 
16
  # ----------------------------
17
+ # Auth (from HF Secrets)
 
 
 
18
  # ----------------------------
19
  DEMO_USER = os.getenv("DEMO_USER", "").strip()
20
  DEMO_PASS = os.getenv("DEMO_PASS", "").strip()
21
  APP_AUTH = (DEMO_USER, DEMO_PASS) if (DEMO_USER and DEMO_PASS) else None
22
 
23
  # ----------------------------
24
+ # Garment catalog
25
  # ----------------------------
26
  GARMENT_DIR = "garments"
27
  ALLOWED_EXTS = (".png", ".jpg", ".jpeg", ".webp")
28
 
29
  def list_garments() -> List[str]:
30
  try:
31
+ return sorted(
32
+ f for f in os.listdir(GARMENT_DIR)
33
+ if f.lower().endswith(ALLOWED_EXTS) and not f.startswith(".")
34
+ )
 
 
35
  except Exception:
36
  return []
37
 
38
+ def garment_path(filename: str) -> str:
39
+ return os.path.join(GARMENT_DIR, filename)
40
+
41
+ def load_garment_pil(filename: str) -> Optional[Image.Image]:
42
+ path = garment_path(filename)
43
+ if os.path.exists(path):
44
+ return Image.open(path).convert("RGB")
45
+ return None
46
+
47
+ def build_gallery_items(files: List[str]):
48
+ return [(garment_path(f), "") for f in files]
49
+
50
  # ----------------------------
51
+ # HF login (optional)
52
  # ----------------------------
53
  HF_TOKEN = os.getenv("HF_TOKEN", "")
 
 
 
54
  if HF_TOKEN:
55
  try:
56
  login(token=HF_TOKEN, add_to_git_credential=False)
57
+ except:
58
+ pass
 
 
 
59
 
 
 
 
60
  _client: Optional[Client] = None
61
 
62
  def reset_client():
 
64
  _client = None
65
 
66
  def get_client() -> Client:
 
 
 
 
67
  global _client
68
  if _client is None:
69
  try:
70
+ _client = Client(SPACE)
 
 
 
71
  except TypeError:
72
  _client = Client(SPACE)
73
  return _client
74
 
75
+ def save_pil_temp(pil_img: Image.Image) -> str:
76
+ f = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
 
 
 
 
 
 
 
 
 
 
77
  path = f.name
78
  f.close()
79
+ pil_img.save(path, format="PNG")
80
  return path
81
 
 
 
 
 
 
 
 
 
 
 
 
82
  # ----------------------------
83
+ # Rate limit
 
 
84
  # ----------------------------
85
  _last_call_ts = 0.0
86
 
 
88
  global _last_call_ts
89
  now = time.time()
90
  if now - _last_call_ts < min_interval_sec:
91
+ wait = min_interval_sec - (now - _last_call_ts)
92
+ return False, f"⏳ Подождите {wait:.1f} сек."
93
  _last_call_ts = now
94
  return True, ""
95
 
96
  # ----------------------------
97
+ # Try-on
 
98
  # ----------------------------
99
  @spaces.GPU(duration=20)
100
+ def tryon_remote(person_pil, selected_filename):
101
  ok, msg = allow_call(3.0)
102
  if not ok:
103
  return None, msg
104
 
105
  if person_pil is None:
106
  return None, "❌ Загрузите фото человека"
107
+ if not selected_filename:
108
+ return None, "❌ Выберите одежду"
109
 
110
+ garment_pil = load_garment_pil(selected_filename)
 
 
 
111
  if garment_pil is None:
112
+ return None, "❌ Ошибка загрузки одежды"
 
 
 
 
 
 
 
113
 
114
  p_path = save_pil_temp(person_pil)
115
  g_path = save_pil_temp(garment_pil)
116
 
117
  try:
118
+ client = get_client()
119
 
120
+ result = client.predict(
121
+ dict={"background": handle_file(p_path), "layers": [], "composite": None},
122
+ garm_img=handle_file(g_path),
123
+ garment_des="a photo of a garment",
124
+ is_checked=True,
125
+ is_checked_crop=True,
126
+ denoise_steps=25,
127
+ seed=42,
128
+ api_name=API_NAME,
129
+ )
130
+
131
+ if isinstance(result, (list, tuple)):
132
+ result = result[0]
133
+
134
+ out = Image.open(result).convert("RGB")
135
+ return out, "✅ Готово"
136
+
137
+ except Exception as e:
138
+ reset_client()
139
+ return None, f" Ошибка: {str(e)[:200]}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  finally:
142
  for path in (p_path, g_path):
143
  try:
144
  os.remove(path)
145
+ except:
146
  pass
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  # ----------------------------
149
  # UI
150
  # ----------------------------
151
  CUSTOM_CSS = """
152
  footer {display:none !important;}
 
 
 
153
  """
154
 
155
+ files = list_garments()
156
+ items = build_gallery_items(files)
 
157
 
158
  with gr.Blocks(title="Virtual Try-On Rendez-vous", css=CUSTOM_CSS) as demo:
159
  gr.Markdown("# Virtual Try-On Rendez-vous")
160
 
161
+ selected_state = gr.State(None)
162
+
163
  with gr.Row():
164
  with gr.Column():
165
  person = gr.Image(label="Фото человека", type="pil", height=420)
166
 
167
+ garment_gallery = gr.Gallery(
168
+ label="Выберите одежду",
169
+ value=items,
170
+ columns=4,
171
+ height=300
172
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ run = gr.Button("Примерить", variant="primary")
175
  status = gr.Textbox(value="Ожидание...", interactive=False)
176
 
177
  with gr.Column():
178
  out = gr.Image(label="Результат", type="pil", height=760)
179
 
180
+ def on_select(files, evt: gr.SelectData):
181
+ if not files:
182
+ return None
183
+ return files[evt.index]
 
 
184
 
185
+ garment_gallery.select(
186
+ fn=on_select,
187
+ inputs=[gr.State(files)],
188
+ outputs=[selected_state]
 
189
  )
190
 
 
191
  run.click(
192
  fn=tryon_remote,
193
+ inputs=[person, selected_state],
194
+ outputs=[out, status]
 
 
 
 
 
 
 
195
  )
196
 
197
  if __name__ == "__main__":
198
  demo.launch(
199
  server_name="0.0.0.0",
200
  server_port=7860,
201
+ auth=APP_AUTH,
202
+ share=False
 
 
203
  )