ArmanRV's picture
Update app.py
7e4f185 verified
raw
history blame
8.78 kB
# -*- coding: utf-8 -*-
import os
import time
import tempfile
from typing import Optional, Tuple
import gradio as gr
import spaces # ZeroGPU runtime requires at least one @spaces.GPU-decorated function
from PIL import Image
from gradio_client import Client, handle_file
from huggingface_hub import login
# ----------------------------
# Remote Space (IDM-VTON)
# ----------------------------
SPACE = "yisol/IDM-VTON"
API_NAME = "/tryon"
# ----------------------------
# Auth for company demo (no HF accounts needed)
# Set these in HF Space Secrets:
# DEMO_USER=companydemo
# DEMO_PASS=your-strong-password
# ----------------------------
DEMO_USER = os.getenv("DEMO_USER", "").strip()
DEMO_PASS = os.getenv("DEMO_PASS", "").strip()
APP_AUTH = (DEMO_USER, DEMO_PASS) if (DEMO_USER and DEMO_PASS) else None
# ----------------------------
# HF token (optional)
# ----------------------------
HF_TOKEN = os.getenv("HF_TOKEN", "")
print("HF_TOKEN set:", bool(HF_TOKEN), "len:", len(HF_TOKEN) if HF_TOKEN else 0)
if HF_TOKEN:
try:
login(token=HF_TOKEN, add_to_git_credential=False)
print("HF login: OK")
except Exception as e:
print("HF login: FAILED:", str(e)[:200])
else:
print("HF login: skipped (no token in env)")
# ----------------------------
# Client caching
# ----------------------------
_client: Optional[Client] = None
def reset_client():
global _client
_client = None
def get_client() -> Client:
"""
gradio_client differs by version. Newer versions support hf_token=...
Older versions don't. We fallback gracefully.
"""
global _client
if _client is None:
try:
if HF_TOKEN:
_client = Client(SPACE, hf_token=HF_TOKEN) # may raise TypeError on older versions
else:
_client = Client(SPACE)
except TypeError:
_client = Client(SPACE)
return _client
# ----------------------------
# Helpers
# ----------------------------
def clamp_int(x, lo, hi):
try:
x = int(x)
except Exception:
x = lo
return max(lo, min(hi, x))
def save_pil_temp(pil_img: Image.Image, suffix: str = ".png") -> str:
f = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
path = f.name
f.close()
pil_img.save(path, format="PNG") # no resize/compress
return path
# ----------------------------
# Simple global rate limit (anti spam)
# NOTE: this is a global limiter across all users for this Space.
# For internal demo it's usually enough. Adjust interval as needed.
# ----------------------------
_last_call_ts = 0.0
def allow_call(min_interval_sec: float = 3.0) -> Tuple[bool, str]:
global _last_call_ts
now = time.time()
if now - _last_call_ts < min_interval_sec:
wait = max(0.0, min_interval_sec - (now - _last_call_ts))
return False, f"⏳ Слишком часто. Подождите {wait:.1f} сек."
_last_call_ts = now
return True, ""
# ----------------------------
# Core inference (remote call)
# ZeroGPU: keep duration LOW to avoid burning quota
# ----------------------------
@spaces.GPU(duration=20)
def tryon_remote(person_pil, garment_pil, garment_desc, auto_mask, crop_center, denoise_steps, seed):
# anti spam
ok, msg = allow_call(3.0)
if not ok:
return None, msg
if person_pil is None:
return None, "❌ Загрузите фото человека"
if garment_pil is None:
return None, "❌ Загрузите одежду"
denoise_steps = clamp_int(denoise_steps, 10, 40)
seed = clamp_int(seed, 0, 999999)
garment_desc = (garment_desc or "").strip()
if not garment_desc:
garment_desc = "a photo of a garment"
p_path = save_pil_temp(person_pil)
g_path = save_pil_temp(garment_pil)
try:
last_err = None
for attempt in range(1, 7):
# Status per attempt (helps user confidence)
# We'll update status by returning it only on success/final fail,
# but we can still encode attempt info in the final error message.
try:
client = get_client()
result = client.predict(
dict={"background": handle_file(p_path), "layers": [], "composite": None},
garm_img=handle_file(g_path),
garment_des=garment_desc,
is_checked=bool(auto_mask),
is_checked_crop=bool(crop_center),
denoise_steps=int(denoise_steps),
seed=int(seed),
api_name=API_NAME,
)
if isinstance(result, (list, tuple)):
result = result[0]
out = Image.open(result).convert("RGB")
return out, f"✅ Готово (steps={denoise_steps}, seed={seed}, crop={crop_center})"
except Exception as e:
last_err = e
msg = str(e)
msg_l = msg.lower()
is_timeout = (
"write operation timed out" in msg_l
or "read operation timed out" in msg_l
or "timed out" in msg_l
)
is_busy = (
"too many requests" in msg_l
or "queue" in msg_l
or "too busy" in msg_l
or "overloaded" in msg_l
or "capacity" in msg_l
or "zerogpu" in msg_l
)
is_expired = "expired zerogpu proxy token" in msg_l or "zerogpu proxy token" in msg_l
# Make retry reason human-friendly
if is_expired:
reason = "истёк токен ZeroGPU"
elif is_busy:
reason = "очередь/перегрузка"
elif is_timeout:
reason = "таймаут сети"
else:
reason = "ошибка"
# Retry on known transient errors
if is_timeout or is_busy or is_expired:
reset_client()
time.sleep(4.0 * attempt)
continue
# Unknown error: short backoff and continue retrying a few times anyway
time.sleep(1.2 * attempt)
# Final fail
tail = str(last_err)[:240] if last_err else "unknown error"
return None, f"❌ Ошибка Space после 6 попыток: {tail}"
finally:
for path in (p_path, g_path):
try:
os.remove(path)
except Exception:
pass
def reset_ui():
return None, None, "", "Ожидание...", None
CUSTOM_CSS = """
footer {display:none !important;}
#api-info {display:none !important;}
div[class*="footer"] {display:none !important;}
button[aria-label="Settings"] {display:none !important;}
"""
with gr.Blocks(title="Virtual Try-On Rendez-vous", css=CUSTOM_CSS) as demo:
gr.Markdown("# Virtual Try-On Rendez-vous")
with gr.Row():
with gr.Column():
person = gr.Image(label="Фото человека", type="pil", height=420)
garment = gr.Image(label="Одежда", type="pil", height=320)
garment_desc = gr.Textbox(label="Описание одежды", value="")
with gr.Accordion("Настройки", open=False):
auto_mask = gr.Checkbox(label="Auto-mask (Space)", value=True)
crop_center = gr.Checkbox(label="Crop по центру", value=True)
denoise_steps = gr.Slider(10, 40, value=25, step=1, label="Denoise steps")
seed = gr.Slider(0, 999999, value=42, step=1, label="Seed")
with gr.Row():
run = gr.Button("Примерить", variant="primary")
reset = gr.Button("Сбросить", variant="secondary")
status = gr.Textbox(value="Ожидание...", interactive=False)
with gr.Column():
out = gr.Image(label="Результат", type="pil", height=760)
# Better user feedback during processing (shows spinner and disables button)
run.click(
fn=tryon_remote,
inputs=[person, garment, garment_desc, auto_mask, crop_center, denoise_steps, seed],
outputs=[out, status],
)
reset.click(
fn=reset_ui,
inputs=[],
outputs=[person, garment, garment_desc, status, out],
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=False,
ssr_mode=False,
auth=APP_AUTH, # ✅ login/password gate
)