ArmanRV commited on
Commit
bf8aada
·
verified ·
1 Parent(s): 6a3ad1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -56
app.py CHANGED
@@ -2,7 +2,7 @@
2
  import os
3
  import time
4
  import tempfile
5
- from typing import Optional, Tuple, List
6
 
7
  import gradio as gr
8
  from PIL import Image
@@ -17,7 +17,7 @@ API_NAME = "/tryon"
17
 
18
  # ----------------------------
19
  # Auth for company demo (no HF accounts needed)
20
- # Set these in HF Space Secrets:
21
  # DEMO_USER=RVtest
22
  # DEMO_PASS=rv2026
23
  # ----------------------------
@@ -34,10 +34,10 @@ ALLOWED_EXTS = (".png", ".jpg", ".jpeg", ".webp")
34
 
35
  def list_garments() -> List[str]:
36
  try:
37
- files = []
38
- for f in os.listdir(GARMENT_DIR):
39
- if f.lower().endswith(ALLOWED_EXTS) and not f.startswith("."):
40
- files.append(f)
41
  files.sort()
42
  return files
43
  except Exception:
@@ -61,7 +61,6 @@ def load_garment_pil(filename: str) -> Optional[Image.Image]:
61
 
62
 
63
  def build_gallery_items(files: List[str]):
64
- # (image, caption). Caption empty = cleaner UI
65
  return [(garment_path(f), "") for f in files]
66
 
67
 
@@ -70,7 +69,6 @@ def build_gallery_items(files: List[str]):
70
  # ----------------------------
71
  HF_TOKEN = os.getenv("HF_TOKEN", "")
72
  print("HF_TOKEN set:", bool(HF_TOKEN), "len:", len(HF_TOKEN) if HF_TOKEN else 0)
73
-
74
  if HF_TOKEN:
75
  try:
76
  login(token=HF_TOKEN, add_to_git_credential=False)
@@ -81,50 +79,14 @@ else:
81
  print("HF login: skipped (no token in env)")
82
 
83
 
84
- # ----------------------------
85
- # Client caching
86
- # ----------------------------
87
- _client: Optional[Client] = None
88
-
89
-
90
- def reset_client():
91
- global _client
92
- _client = None
93
-
94
-
95
- def get_client() -> Client:
96
- """
97
- gradio_client differs by version. Newer versions support hf_token=...
98
- Older versions don't. We fallback gracefully.
99
- """
100
- global _client
101
- if _client is None:
102
- try:
103
- if HF_TOKEN:
104
- _client = Client(SPACE, hf_token=HF_TOKEN) # may raise TypeError on older versions
105
- else:
106
- _client = Client(SPACE)
107
- except TypeError:
108
- _client = Client(SPACE)
109
- return _client
110
-
111
-
112
  # ----------------------------
113
  # Helpers
114
  # ----------------------------
115
- def clamp_int(x, lo, hi):
116
- try:
117
- x = int(x)
118
- except Exception:
119
- x = lo
120
- return max(lo, min(hi, x))
121
-
122
-
123
  def save_pil_temp(pil_img: Image.Image, suffix: str = ".png") -> str:
124
  f = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
125
  path = f.name
126
  f.close()
127
- pil_img.save(path, format="PNG") # no resize/compress
128
  return path
129
 
130
 
@@ -145,10 +107,33 @@ def allow_call(min_interval_sec: float = 3.0) -> Tuple[bool, str]:
145
  return True, ""
146
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  # ----------------------------
149
  # Core inference (remote call)
150
  # ----------------------------
151
- def tryon_remote(person_pil, garment_filename):
152
  ok, msg = allow_call(3.0)
153
  if not ok:
154
  return None, msg
@@ -178,7 +163,7 @@ def tryon_remote(person_pil, garment_filename):
178
 
179
  for attempt in range(1, 7):
180
  try:
181
- client = get_client()
182
 
183
  result = client.predict(
184
  dict={"background": handle_file(p_path), "layers": [], "composite": None},
@@ -206,7 +191,6 @@ def tryon_remote(person_pil, garment_filename):
206
  or "read operation timed out" in msg_l
207
  or "timed out" in msg_l
208
  )
209
-
210
  is_busy = (
211
  "too many requests" in msg_l
212
  or "queue" in msg_l
@@ -214,14 +198,19 @@ def tryon_remote(person_pil, garment_filename):
214
  or "overloaded" in msg_l
215
  or "capacity" in msg_l
216
  )
 
217
 
218
- is_expired = "expired zerogpu proxy token" in msg_l or "zerogpu proxy token" in msg_l
219
-
220
- if is_timeout or is_busy or is_expired:
221
- reset_client()
222
  time.sleep(4.0 * attempt)
223
  continue
224
 
 
 
 
 
 
 
225
  time.sleep(1.2 * attempt)
226
 
227
  tail = str(last_err)[:240] if last_err else "unknown error"
@@ -297,21 +286,18 @@ with gr.Blocks(title="Virtual Try-On Rendez-vous", css=CUSTOM_CSS) as demo:
297
  with gr.Column():
298
  out = gr.Image(label="Результат", type="pil", height=760)
299
 
300
- # Update selection on click
301
  garment_gallery.select(
302
  fn=on_gallery_select,
303
  inputs=[garment_files_state],
304
  outputs=[selected_garment_state, selected_label],
305
  )
306
 
307
- # Refresh catalog after uploading new garments
308
  refresh_btn.click(
309
  fn=refresh_catalog,
310
  inputs=[],
311
  outputs=[garment_gallery, garment_files_state, selected_garment_state, status],
312
  )
313
 
314
- # Run try-on
315
  run.click(
316
  fn=tryon_remote,
317
  inputs=[person, selected_garment_state],
@@ -325,5 +311,5 @@ if __name__ == "__main__":
325
  share=False,
326
  debug=False,
327
  ssr_mode=False,
328
- auth=APP_AUTH, # ✅ login/password gate
329
  )
 
2
  import os
3
  import time
4
  import tempfile
5
+ from typing import Optional, Tuple, List, Dict
6
 
7
  import gradio as gr
8
  from PIL import Image
 
17
 
18
  # ----------------------------
19
  # Auth for company demo (no HF accounts needed)
20
+ # Secrets:
21
  # DEMO_USER=RVtest
22
  # DEMO_PASS=rv2026
23
  # ----------------------------
 
34
 
35
  def list_garments() -> List[str]:
36
  try:
37
+ files = [
38
+ f for f in os.listdir(GARMENT_DIR)
39
+ if f.lower().endswith(ALLOWED_EXTS) and not f.startswith(".")
40
+ ]
41
  files.sort()
42
  return files
43
  except Exception:
 
61
 
62
 
63
  def build_gallery_items(files: List[str]):
 
64
  return [(garment_path(f), "") for f in files]
65
 
66
 
 
69
  # ----------------------------
70
  HF_TOKEN = os.getenv("HF_TOKEN", "")
71
  print("HF_TOKEN set:", bool(HF_TOKEN), "len:", len(HF_TOKEN) if HF_TOKEN else 0)
 
72
  if HF_TOKEN:
73
  try:
74
  login(token=HF_TOKEN, add_to_git_credential=False)
 
79
  print("HF login: skipped (no token in env)")
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  # ----------------------------
83
  # Helpers
84
  # ----------------------------
 
 
 
 
 
 
 
 
85
  def save_pil_temp(pil_img: Image.Image, suffix: str = ".png") -> str:
86
  f = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
87
  path = f.name
88
  f.close()
89
+ pil_img.save(path, format="PNG")
90
  return path
91
 
92
 
 
107
  return True, ""
108
 
109
 
110
+ def make_client_from_request(request: gr.Request) -> Client:
111
+ """
112
+ IMPORTANT for ZeroGPU Spaces:
113
+ Forward X-IP-Token so the downstream ZeroGPU Space applies rate limits/quota
114
+ per user correctly (instead of treating calls as unauthenticated).
115
+ """
116
+ headers: Dict[str, str] = {}
117
+ try:
118
+ # Gradio normalizes headers to lowercase keys
119
+ x_ip_token = request.headers.get("x-ip-token")
120
+ if x_ip_token:
121
+ headers["x-ip-token"] = x_ip_token
122
+ except Exception:
123
+ pass
124
+
125
+ # Some gradio_client versions accept headers=..., some may not. Fallback safely.
126
+ try:
127
+ return Client(SPACE, headers=headers) if headers else Client(SPACE)
128
+ except TypeError:
129
+ # older client: no headers kwarg
130
+ return Client(SPACE)
131
+
132
+
133
  # ----------------------------
134
  # Core inference (remote call)
135
  # ----------------------------
136
+ def tryon_remote(person_pil, garment_filename, request: gr.Request):
137
  ok, msg = allow_call(3.0)
138
  if not ok:
139
  return None, msg
 
163
 
164
  for attempt in range(1, 7):
165
  try:
166
+ client = make_client_from_request(request)
167
 
168
  result = client.predict(
169
  dict={"background": handle_file(p_path), "layers": [], "composite": None},
 
191
  or "read operation timed out" in msg_l
192
  or "timed out" in msg_l
193
  )
 
194
  is_busy = (
195
  "too many requests" in msg_l
196
  or "queue" in msg_l
 
198
  or "overloaded" in msg_l
199
  or "capacity" in msg_l
200
  )
201
+ is_quota = "quota" in msg_l and "zerogpu" in msg_l
202
 
203
+ # Retry on transient issues; quota will likely not improve immediately
204
+ if is_timeout or is_busy:
 
 
205
  time.sleep(4.0 * attempt)
206
  continue
207
 
208
+ if is_quota:
209
+ return None, (
210
+ "⚠️ Лимит ZeroGPU на стороне модели исчерпан для текущего пользователя.\n"
211
+ "Попробуйте позже или используйте меньше попыток подряд."
212
+ )
213
+
214
  time.sleep(1.2 * attempt)
215
 
216
  tail = str(last_err)[:240] if last_err else "unknown error"
 
286
  with gr.Column():
287
  out = gr.Image(label="Результат", type="pil", height=760)
288
 
 
289
  garment_gallery.select(
290
  fn=on_gallery_select,
291
  inputs=[garment_files_state],
292
  outputs=[selected_garment_state, selected_label],
293
  )
294
 
 
295
  refresh_btn.click(
296
  fn=refresh_catalog,
297
  inputs=[],
298
  outputs=[garment_gallery, garment_files_state, selected_garment_state, status],
299
  )
300
 
 
301
  run.click(
302
  fn=tryon_remote,
303
  inputs=[person, selected_garment_state],
 
311
  share=False,
312
  debug=False,
313
  ssr_mode=False,
314
+ auth=APP_AUTH,
315
  )