ArmanRV's picture
Update app.py
bf8aada verified
# -*- coding: utf-8 -*-
import os
import time
import tempfile
from typing import Optional, Tuple, List, Dict
import gradio as gr
from PIL import Image
from gradio_client import Client, handle_file
from huggingface_hub import login
# ----------------------------
# Remote Space (IDM-VTON)
# ----------------------------
SPACE = "yisol/IDM-VTON"
API_NAME = "/tryon"
# ----------------------------
# Auth for company demo (no HF accounts needed)
# Secrets:
# DEMO_USER=RVtest
# DEMO_PASS=rv2026
# ----------------------------
DEMO_USER = os.getenv("DEMO_USER", "").strip()
DEMO_PASS = os.getenv("DEMO_PASS", "").strip()
APP_AUTH = (DEMO_USER, DEMO_PASS) if (DEMO_USER and DEMO_PASS) else None
# ----------------------------
# Garment catalog folder in repo
# ----------------------------
GARMENT_DIR = "garments"
ALLOWED_EXTS = (".png", ".jpg", ".jpeg", ".webp")
def list_garments() -> List[str]:
try:
files = [
f for f in os.listdir(GARMENT_DIR)
if f.lower().endswith(ALLOWED_EXTS) and not f.startswith(".")
]
files.sort()
return files
except Exception:
return []
def garment_path(filename: str) -> str:
return os.path.join(GARMENT_DIR, filename)
def load_garment_pil(filename: str) -> Optional[Image.Image]:
if not filename:
return None
path = garment_path(filename)
if not os.path.exists(path):
return None
try:
return Image.open(path).convert("RGB")
except Exception:
return None
def build_gallery_items(files: List[str]):
return [(garment_path(f), "") for f in files]
# ----------------------------
# HF token (optional)
# ----------------------------
HF_TOKEN = os.getenv("HF_TOKEN", "")
print("HF_TOKEN set:", bool(HF_TOKEN), "len:", len(HF_TOKEN) if HF_TOKEN else 0)
if HF_TOKEN:
try:
login(token=HF_TOKEN, add_to_git_credential=False)
print("HF login: OK")
except Exception as e:
print("HF login: FAILED:", str(e)[:200])
else:
print("HF login: skipped (no token in env)")
# ----------------------------
# Helpers
# ----------------------------
def save_pil_temp(pil_img: Image.Image, suffix: str = ".png") -> str:
f = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
path = f.name
f.close()
pil_img.save(path, format="PNG")
return path
# ----------------------------
# Simple global rate limit (anti spam)
# NOTE: global across all users. Good enough for internal demo.
# ----------------------------
_last_call_ts = 0.0
def allow_call(min_interval_sec: float = 3.0) -> Tuple[bool, str]:
global _last_call_ts
now = time.time()
if now - _last_call_ts < min_interval_sec:
wait = max(0.0, min_interval_sec - (now - _last_call_ts))
return False, f"⏳ Слишком часто. Подождите {wait:.1f} сек."
_last_call_ts = now
return True, ""
def make_client_from_request(request: gr.Request) -> Client:
"""
IMPORTANT for ZeroGPU Spaces:
Forward X-IP-Token so the downstream ZeroGPU Space applies rate limits/quota
per user correctly (instead of treating calls as unauthenticated).
"""
headers: Dict[str, str] = {}
try:
# Gradio normalizes headers to lowercase keys
x_ip_token = request.headers.get("x-ip-token")
if x_ip_token:
headers["x-ip-token"] = x_ip_token
except Exception:
pass
# Some gradio_client versions accept headers=..., some may not. Fallback safely.
try:
return Client(SPACE, headers=headers) if headers else Client(SPACE)
except TypeError:
# older client: no headers kwarg
return Client(SPACE)
# ----------------------------
# Core inference (remote call)
# ----------------------------
def tryon_remote(person_pil, garment_filename, request: gr.Request):
ok, msg = allow_call(3.0)
if not ok:
return None, msg
if person_pil is None:
return None, "❌ Загрузите фото человека"
if not garment_filename:
return None, "❌ Выберите одежду (кликните на превью)"
garment_pil = load_garment_pil(garment_filename)
if garment_pil is None:
return None, "❌ Не удалось загрузить выбранную одежду (проверьте garments/)"
# Fixed params for simple demo
garment_desc = "a photo of a garment"
auto_mask = True
crop_center = True
denoise_steps = 25
seed = 42
p_path = save_pil_temp(person_pil)
g_path = save_pil_temp(garment_pil)
try:
last_err = None
for attempt in range(1, 7):
try:
client = make_client_from_request(request)
result = client.predict(
dict={"background": handle_file(p_path), "layers": [], "composite": None},
garm_img=handle_file(g_path),
garment_des=garment_desc,
is_checked=bool(auto_mask),
is_checked_crop=bool(crop_center),
denoise_steps=int(denoise_steps),
seed=int(seed),
api_name=API_NAME,
)
if isinstance(result, (list, tuple)):
result = result[0]
out = Image.open(result).convert("RGB")
return out, "✅ Готово"
except Exception as e:
last_err = e
msg_l = str(e).lower()
is_timeout = (
"write operation timed out" in msg_l
or "read operation timed out" in msg_l
or "timed out" in msg_l
)
is_busy = (
"too many requests" in msg_l
or "queue" in msg_l
or "too busy" in msg_l
or "overloaded" in msg_l
or "capacity" in msg_l
)
is_quota = "quota" in msg_l and "zerogpu" in msg_l
# Retry on transient issues; quota will likely not improve immediately
if is_timeout or is_busy:
time.sleep(4.0 * attempt)
continue
if is_quota:
return None, (
"⚠️ Лимит ZeroGPU на стороне модели исчерпан для текущего пользователя.\n"
"Попробуйте позже или используйте меньше попыток подряд."
)
time.sleep(1.2 * attempt)
tail = str(last_err)[:240] if last_err else "unknown error"
return None, f"❌ Ошибка Space после 6 попыток: {tail}"
finally:
for path in (p_path, g_path):
try:
os.remove(path)
except Exception:
pass
# ----------------------------
# UI helpers
# ----------------------------
def refresh_catalog():
files = list_garments()
items = build_gallery_items(files)
status = "✅ Каталог обновлён" if files else "⚠️ В папке garments/ пока нет изображений"
return items, files, None, status
def on_gallery_select(files: List[str], evt: gr.SelectData):
if not files:
return None, "⚠️ Каталог пуст"
try:
idx = int(evt.index) if evt.index is not None else 0
idx = max(0, min(idx, len(files) - 1))
return files[idx], f"👕 Выбрано: {files[idx]}"
except Exception:
return None, "⚠️ Не удалось выбрать одежду"
# ----------------------------
# UI
# ----------------------------
CUSTOM_CSS = """
footer {display:none !important;}
#api-info {display:none !important;}
div[class*="footer"] {display:none !important;}
button[aria-label="Settings"] {display:none !important;}
"""
initial_files = list_garments()
initial_items = build_gallery_items(initial_files)
with gr.Blocks(title="Virtual Try-On Rendez-vous", css=CUSTOM_CSS) as demo:
gr.Markdown("# Virtual Try-On Rendez-vous")
garment_files_state = gr.State(initial_files)
selected_garment_state = gr.State(None)
with gr.Row():
with gr.Column():
person = gr.Image(label="Фото человека", type="pil", height=420)
with gr.Row():
refresh_btn = gr.Button("🔄 Обновить каталог", variant="secondary")
selected_label = gr.Markdown("👕 Выберите одежду, кликнув по превью ниже")
garment_gallery = gr.Gallery(
label="Каталог одежды (кликните на превью)",
value=initial_items,
columns=4,
height=340,
allow_preview=True,
)
run = gr.Button("Примерить", variant="primary")
status = gr.Textbox(value="Ожидание...", interactive=False)
with gr.Column():
out = gr.Image(label="Результат", type="pil", height=760)
garment_gallery.select(
fn=on_gallery_select,
inputs=[garment_files_state],
outputs=[selected_garment_state, selected_label],
)
refresh_btn.click(
fn=refresh_catalog,
inputs=[],
outputs=[garment_gallery, garment_files_state, selected_garment_state, status],
)
run.click(
fn=tryon_remote,
inputs=[person, selected_garment_state],
outputs=[out, status],
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=False,
ssr_mode=False,
auth=APP_AUTH,
)