ArmanRV's picture
Update app.py
602036b verified
# -*- coding: utf-8 -*-
import os
import time
import json
from datetime import datetime
from typing import List, Optional, Tuple
import spaces
import gradio as gr
from PIL import Image
# =========================
# FIX: gradio 4.24 / gradio_client crashes on boolean JSON Schemas in /api_info
# =========================
def _patch_gradio_client_bool_schema():
try:
import gradio_client.utils as gcu
patched_any = False
if hasattr(gcu, "get_type"):
_orig_get_type = gcu.get_type
def _get_type_patched(schema):
if isinstance(schema, bool):
return "any"
return _orig_get_type(schema)
gcu.get_type = _get_type_patched
patched_any = True
if hasattr(gcu, "get_desc"):
_orig_get_desc = gcu.get_desc
def _get_desc_patched(schema):
if isinstance(schema, bool):
return ""
return _orig_get_desc(schema)
gcu.get_desc = _get_desc_patched
patched_any = True
if hasattr(gcu, "_json_schema_to_python_type"):
_orig_json2py = gcu._json_schema_to_python_type
def _json_schema_to_python_type_patched(schema, defs=None):
if isinstance(schema, bool):
return "any"
return _orig_json2py(schema, defs)
gcu._json_schema_to_python_type = _json_schema_to_python_type_patched
patched_any = True
if patched_any:
print("Patched gradio_client.utils for boolean JSON Schemas (/api_info)", flush=True)
else:
print("gradio_client patch: nothing to patch (unexpected utils layout)", flush=True)
except Exception as e:
print("gradio_client patch failed:", repr(e), flush=True)
_patch_gradio_client_bool_schema()
import torch
from torchvision import transforms
from huggingface_hub import login, snapshot_download, HfApi, hf_hub_download
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
from src.unet_hacked_tryon import UNet2DConditionModel
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
CLIPTextModel,
CLIPTextModelWithProjection,
AutoTokenizer,
)
from diffusers import DDPMScheduler, AutoencoderKL
import apply_net
from utils_mask import get_mask_location
from preprocess.humanparsing.run_parsing import Parsing
from preprocess.openpose.run_openpose import OpenPose
from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
# =========================
# Garments catalog (only)
# =========================
GARMENT_DIR = "garments"
ALLOWED_EXTS = (".png", ".jpg", ".jpeg", ".webp")
GARMENTS_DATASET = os.getenv("GARMENTS_DATASET", "").strip()
HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
def ensure_garments_available() -> None:
"""
Если GARMENTS_DATASET не задан — используем локальную папку ./garments.
Если задан — скачиваем датасет HF в ./garments.
"""
os.makedirs(GARMENT_DIR, exist_ok=True)
if not GARMENTS_DATASET:
print("GARMENTS_DATASET not set. Using local ./garments (if any).", flush=True)
return
if HF_TOKEN:
try:
login(token=HF_TOKEN, add_to_git_credential=False)
print("HF login: OK", flush=True)
except Exception as e:
print("HF login: FAILED:", str(e)[:200], flush=True)
try:
snapshot_download(
repo_id=GARMENTS_DATASET,
repo_type="dataset",
local_dir=GARMENT_DIR,
local_dir_use_symlinks=False,
token=HF_TOKEN if HF_TOKEN else None,
)
print(f"Garments dataset downloaded: {GARMENTS_DATASET} -> {GARMENT_DIR}/", flush=True)
except Exception as e:
print("Garments download FAILED:", str(e)[:300], flush=True)
def _gender_subdir(gender: str) -> str:
return "male" if gender == "Мужская" else "female"
def list_garments(gender: Optional[str] = None) -> List[str]:
files: List[str] = []
if not os.path.isdir(GARMENT_DIR):
return files
if gender:
base = os.path.join(GARMENT_DIR, _gender_subdir(gender))
if os.path.isdir(base):
for root, _, fnames in os.walk(base):
for f in fnames:
if f.lower().endswith(ALLOWED_EXTS) and not f.startswith("."):
rel = os.path.relpath(os.path.join(root, f), GARMENT_DIR)
files.append(rel)
files.sort()
return files
for root, _, fnames in os.walk(GARMENT_DIR):
for f in fnames:
if f.lower().endswith(ALLOWED_EXTS) and not f.startswith("."):
rel = os.path.relpath(os.path.join(root, f), GARMENT_DIR)
files.append(rel)
files.sort()
return files
def garment_path(relpath: str) -> str:
return os.path.join(GARMENT_DIR, relpath)
def load_garment_pil(relpath: str) -> Optional[Image.Image]:
if not relpath:
return None
path = garment_path(relpath)
if not os.path.exists(path):
return None
try:
return Image.open(path).convert("RGB")
except Exception:
return None
def build_gallery_items(files: List[str]):
return [(garment_path(f), "") for f in files]
# =========================
# Rate limit
# =========================
_last_call_ts = 0.0
def allow_call(min_interval_sec: float = 2.5) -> Tuple[bool, str]:
global _last_call_ts
now = time.time()
if now - _last_call_ts < min_interval_sec:
wait = min_interval_sec - (now - _last_call_ts)
return False, f"⏳ Подождите {wait:.1f} сек."
_last_call_ts = now
return True, ""
# =========================
# Model init (BASELINE)
# =========================
base_path = "yisol/IDM-VTON"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
print("DEVICE:", DEVICE, "DTYPE:", DTYPE, flush=True)
tensor_transfrom = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=DTYPE)
unet.requires_grad_(False)
tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", revision=None, use_fast=False)
tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", revision=None, use_fast=False)
noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=DTYPE)
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=DTYPE)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=DTYPE)
vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=DTYPE)
UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=DTYPE)
UNet_Encoder.requires_grad_(False)
parsing_model = Parsing(0)
openpose_model = OpenPose(0)
for m in [UNet_Encoder, image_encoder, vae, unet, text_encoder_one, text_encoder_two]:
m.requires_grad_(False)
pipe = TryonPipeline.from_pretrained(
base_path,
unet=unet,
vae=vae,
feature_extractor=CLIPImageProcessor(),
text_encoder=text_encoder_one,
text_encoder_2=text_encoder_two,
tokenizer=tokenizer_one,
tokenizer_2=tokenizer_two,
scheduler=noise_scheduler,
image_encoder=image_encoder,
torch_dtype=DTYPE,
)
pipe.unet_encoder = UNet_Encoder
# =========================
# Inference (BASELINE params)
# =========================
@spaces.GPU
def start_tryon(
human_pil: Image.Image,
garm_img: Image.Image,
) -> Image.Image:
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
if device == "cuda":
openpose_model.preprocessor.body_estimation.model.to(device)
pipe.to(device)
pipe.unet_encoder.to(device)
garm_img = garm_img.convert("RGB").resize((768, 1024))
human_img_orig = human_pil.convert("RGB")
width, height = human_img_orig.size
target_width = int(min(width, height * (3 / 4)))
target_height = int(min(height, width * (4 / 3)))
left = (width - target_width) / 2
top = (height - target_height) / 2
right = (width + target_width) / 2
bottom = (height + target_height) / 2
cropped_img = human_img_orig.crop((left, top, right, bottom))
crop_size = cropped_img.size
human_img = cropped_img.resize((768, 1024))
keypoints = openpose_model(human_img.resize((384, 512)))
model_parse, _ = parsing_model(human_img.resize((384, 512)))
mask, _ = get_mask_location("hd", "upper_body", model_parse, keypoints)
mask = mask.resize((768, 1024))
human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
args = apply_net.create_argument_parser().parse_args(
(
"show",
"./configs/densepose_rcnn_R_50_FPN_s1x.yaml",
"./ckpt/densepose/model_final_162be9.pkl",
"dp_segm",
"-v",
"--opts",
"MODEL.DEVICE",
"cuda" if device == "cuda" else "cpu",
)
)
pose_img = args.func(args, human_img_arg)
pose_img = pose_img[:, :, ::-1]
pose_img = Image.fromarray(pose_img).resize((768, 1024))
garment_des = "a garment"
prompt_main = "model is wearing " + garment_des
prompt_cloth = "a photo of " + garment_des
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
denoise_steps = 30
guidance_scale = 2.0
strength = 1.0
seed = 42
with torch.no_grad():
if device == "cuda":
autocast_ctx = torch.cuda.amp.autocast()
else:
class _NoCtx:
def __enter__(self):
return None
def __exit__(self, *args):
return False
autocast_ctx = _NoCtx()
with autocast_ctx:
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(
prompt_main,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
(prompt_embeds_c, _, _, _) = pipe.encode_prompt(
[prompt_cloth],
num_images_per_prompt=1,
do_classifier_free_guidance=False,
negative_prompt=[negative_prompt],
)
pose_t = tensor_transfrom(pose_img).unsqueeze(0).to(device=device, dtype=dtype)
garm_t = tensor_transfrom(garm_img).unsqueeze(0).to(device=device, dtype=dtype)
generator = torch.Generator(device).manual_seed(seed)
images = pipe(
prompt_embeds=prompt_embeds.to(device=device, dtype=dtype),
negative_prompt_embeds=negative_prompt_embeds.to(device=device, dtype=dtype),
pooled_prompt_embeds=pooled_prompt_embeds.to(device=device, dtype=dtype),
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device=device, dtype=dtype),
num_inference_steps=denoise_steps,
generator=generator,
strength=strength,
pose_img=pose_t,
text_embeds_cloth=prompt_embeds_c.to(device=device, dtype=dtype),
cloth=garm_t,
mask_image=mask,
image=human_img,
height=1024,
width=768,
ip_adapter_image=garm_img.resize((768, 1024)),
guidance_scale=guidance_scale,
)[0]
out_img = images[0]
out_img_rs = out_img.resize(crop_size)
human_img_orig.paste(out_img_rs, (int(left), int(top)))
return human_img_orig
# =========================
# UI / CSS
# =========================
CUSTOM_CSS = """
footer {display:none !important;}
#api-info {display:none !important;}
div[class*="footer"] {display:none !important;}
button[aria-label="Settings"] {display:none !important;}
.feedback-box {
border: 1px solid #e5e7eb;
border-radius: 14px;
padding: 12px 14px;
background: #fafafa;
margin-top: 10px;
}
.feedback-ok {
border: 1px solid #b7ebc6;
background: #f0fff4;
color: #166534;
border-radius: 12px;
padding: 10px 12px;
font-size: 14px;
margin-top: 8px;
}
.feedback-idle {
border: 1px dashed #d1d5db;
background: #ffffff;
color: #6b7280;
border-radius: 12px;
padding: 10px 12px;
font-size: 14px;
margin-top: 8px;
}
"""
# =========================
# UX example image
# =========================
UX_EXAMPLE_IMG_PATH = "assets/photo_2026-02-26_14-56-24.jpg"
def _load_ux_example_pil() -> Optional[Image.Image]:
try:
if UX_EXAMPLE_IMG_PATH and os.path.exists(UX_EXAMPLE_IMG_PATH):
return Image.open(UX_EXAMPLE_IMG_PATH).convert("RGB")
except Exception as e:
print("UX example image load failed:", repr(e), flush=True)
return None
_UX_EXAMPLE_PIL = _load_ux_example_pil()
def refresh_catalog(gender: str):
ensure_garments_available()
files = list_garments(gender=gender)
items = build_gallery_items(files)
status = f"✅ Каталог: {gender} ({len(files)})" if files else f"⚠️ Каталог пуст: {gender}"
return items, files, None, status, "👕 Выберите одежду ниже"
def on_gallery_select(files_list: List[str], evt: gr.SelectData):
if not files_list:
return None, "⚠️ Каталог пуст"
idx = int(evt.index) if evt.index is not None else 0
idx = max(0, min(idx, len(files_list) - 1))
return files_list[idx], "👕 Одежда выбрана"
# =========================
# Feedback storage
# =========================
FEEDBACK_DIR = "./feedback"
FEEDBACK_PATH = os.path.join(FEEDBACK_DIR, "feedback.jsonl")
FEEDBACK_REPO_ID = os.getenv("FEEDBACK_REPO_ID", "").strip()
FEEDBACK_REPO_TYPE = "dataset"
FEEDBACK_REPO_FILEPATH = "feedback/feedback.jsonl"
def _read_local_feedback_text() -> str:
if not os.path.exists(FEEDBACK_PATH):
return ""
with open(FEEDBACK_PATH, "r", encoding="utf-8") as f:
return f.read()
def _write_local_feedback_text(text: str) -> None:
os.makedirs(FEEDBACK_DIR, exist_ok=True)
with open(FEEDBACK_PATH, "w", encoding="utf-8") as f:
f.write(text)
def _download_repo_feedback(api: HfApi, token: str) -> str:
if not FEEDBACK_REPO_ID:
return ""
try:
local_path = hf_hub_download(
repo_id=FEEDBACK_REPO_ID,
repo_type=FEEDBACK_REPO_TYPE,
filename=FEEDBACK_REPO_FILEPATH,
token=token,
)
with open(local_path, "r", encoding="utf-8") as f:
return f.read()
except Exception:
return ""
def _upload_repo_feedback(api: HfApi, token: str, text: str) -> None:
if not FEEDBACK_REPO_ID:
raise RuntimeError("FEEDBACK_REPO_ID not set")
os.makedirs(FEEDBACK_DIR, exist_ok=True)
tmp_path = os.path.join(FEEDBACK_DIR, "_feedback_upload.jsonl")
with open(tmp_path, "w", encoding="utf-8") as f:
f.write(text)
api.upload_file(
path_or_fileobj=tmp_path,
path_in_repo=FEEDBACK_REPO_FILEPATH,
repo_id=FEEDBACK_REPO_ID,
repo_type=FEEDBACK_REPO_TYPE,
token=token,
commit_message="Add try-on feedback",
)
def _append_feedback_record(record: dict) -> None:
line = json.dumps(record, ensure_ascii=False) + "\n"
try:
local_text = _read_local_feedback_text()
local_text += line
_write_local_feedback_text(local_text)
except Exception as e:
print("Feedback local write failed:", repr(e), flush=True)
try:
token = (os.getenv("HF_TOKEN", "").strip() or os.getenv("HUGGINGFACEHUB_API_TOKEN", "").strip())
if not token:
print("Feedback repo sync skipped: HF_TOKEN not set", flush=True)
return
if not FEEDBACK_REPO_ID:
print("Feedback repo sync skipped: FEEDBACK_REPO_ID not set", flush=True)
return
api = HfApi()
repo_text = _download_repo_feedback(api, token)
repo_text += line
_upload_repo_feedback(api, token, repo_text)
except Exception as e:
print("Feedback repo sync failed:", repr(e), flush=True)
def save_rating_feedback(is_like: bool, garment_name: str) -> None:
record = {
"timestamp": datetime.utcnow().isoformat(),
"event": "rating",
"like": bool(is_like),
"garment": garment_name or "",
}
_append_feedback_record(record)
def save_comment_feedback(garment_name: str, comment: str) -> None:
clean_comment = (comment or "").strip()
if len(clean_comment) > 1500:
clean_comment = clean_comment[:1500]
record = {
"timestamp": datetime.utcnow().isoformat(),
"event": "comment",
"garment": garment_name or "",
"comment": clean_comment,
}
_append_feedback_record(record)
# =========================
# Feedback UI helpers
# =========================
def _rating_notice_idle_html():
return """
<div class="feedback-idle">
Оцените результат: нажмите лайк или дизлайк.
</div>
"""
def _comment_notice_idle_html():
return """
<div class="feedback-idle">
При желании напишите комментарий и нажмите кнопку отправки.
</div>
"""
def _rating_notice_ok_html(action_text: str):
return f"""
<div class="feedback-ok">
✅ Ваша оценка <b style="color:#166534;">{action_text}</b> сохранена.
</div>
"""
def _comment_notice_ok_html():
return """
<div class="feedback-ok">
✅ Комментарий отправлен. Спасибо за обратную связь.
</div>
"""
# =========================
# Feedback actions
# =========================
def submit_like_feedback(garment_name: str):
if not garment_name:
return "⚠️ Сначала выполните примерку и выберите одежду", _rating_notice_idle_html()
try:
save_rating_feedback(True, garment_name)
return "✅ Оценка сохранена: «Нравится»", _rating_notice_ok_html("«Нравится»")
except Exception as e:
return (
f"❌ Ошибка сохранения оценки: {type(e).__name__}: {str(e)[:200]}",
_rating_notice_idle_html(),
)
def submit_dislike_feedback(garment_name: str):
if not garment_name:
return "⚠️ Сначала выполните примерку и выберите одежду", _rating_notice_idle_html()
try:
save_rating_feedback(False, garment_name)
return "✅ Оценка сохранена: «Не нравится»", _rating_notice_ok_html("«Не нравится»")
except Exception as e:
return (
f"❌ Ошибка сохранения оценки: {type(e).__name__}: {str(e)[:200]}",
_rating_notice_idle_html(),
)
def submit_comment_feedback(garment_name: str, comment: str):
clean_comment = (comment or "").strip()
if not garment_name:
return (
"⚠️ Сначала выполните примерку и выберите одежду",
_comment_notice_idle_html(),
gr.update(value=comment),
)
if not clean_comment:
return (
"⚠️ Напишите комментарий перед отправкой",
_comment_notice_idle_html(),
gr.update(value=comment),
)
try:
save_comment_feedback(garment_name, clean_comment)
return (
"✅ Комментарий отправлен",
_comment_notice_ok_html(),
gr.update(value=""),
)
except Exception as e:
return (
f"❌ Ошибка отправки комментария: {type(e).__name__}: {str(e)[:200]}",
_comment_notice_idle_html(),
gr.update(value=comment),
)
# =========================
# Try-on UI
# =========================
def tryon_ui(person_pil, selected_filename):
for msg in [
"🧵 Анализируем посадку ткани…",
"📏 Подбираем пропорции одежды…",
"🧍 Определяем позу и положение тела…",
"📐 Выравниваем геометрию одежды…",
"🎨 Сохраняем цвет и фактуру ткани…",
"🪡 Прорисовываем швы и детали…",
"🧶 Адаптируем складки ткани…",
"🧥 Корректируем посадку на фигуре…",
"✨ Улучшаем освещение и тени…",
"🔬 Уточняем текстуру материала…"
"🔍 Проверяем мелкие детали…",
"🎯 Финальная корректировка результата…",
"🌟 Последние штрихи…",
"🪄 Почти готово…",
]:
yield (
None,
msg,
gr.update(visible=False),
gr.update(value=""),
gr.update(value=_rating_notice_idle_html()),
gr.update(value=_comment_notice_idle_html()),
)
time.sleep(2.3)
ok, msg = allow_call(2.5)
if not ok:
yield (
None,
msg,
gr.update(visible=False),
gr.update(value=""),
gr.update(value=_rating_notice_idle_html()),
gr.update(value=_comment_notice_idle_html()),
)
return
if person_pil is None:
yield (
None,
"❌ Загрузите фото человека",
gr.update(visible=False),
gr.update(value=""),
gr.update(value=_rating_notice_idle_html()),
gr.update(value=_comment_notice_idle_html()),
)
return
if not selected_filename:
yield (
None,
"❌ Выберите одежду (клик по превью)",
gr.update(visible=False),
gr.update(value=""),
gr.update(value=_rating_notice_idle_html()),
gr.update(value=_comment_notice_idle_html()),
)
return
garm = load_garment_pil(selected_filename)
if garm is None:
yield (
None,
"❌ Не удалось загрузить выбранную одежду",
gr.update(visible=False),
gr.update(value=""),
gr.update(value=_rating_notice_idle_html()),
gr.update(value=_comment_notice_idle_html()),
)
return
try:
out_img = start_tryon(human_pil=person_pil, garm_img=garm)
yield (
out_img,
"✅ Готово — оцените результат и при желании оставьте комментарий",
gr.update(visible=True),
gr.update(value=""),
gr.update(value=_rating_notice_idle_html()),
gr.update(value=_comment_notice_idle_html()),
)
except Exception as e:
yield (
None,
f"❌ Ошибка: {type(e).__name__}: {str(e)[:220]}",
gr.update(visible=False),
gr.update(value=""),
gr.update(value=_rating_notice_idle_html()),
gr.update(value=_comment_notice_idle_html()),
)
# =========================
# Boot
# =========================
ensure_garments_available()
_default_gender = "Женская"
_initial_files = list_garments(gender=_default_gender)
_initial_items = build_gallery_items(_initial_files)
with gr.Blocks(title="Virtual Try-On Rendez-vous", css=CUSTOM_CSS) as demo:
gr.Markdown("# Virtual Try-On Rendez-vous")
garment_files_state = gr.State(_initial_files)
selected_garment_state = gr.State(None)
with gr.Row():
with gr.Column():
person = gr.Image(label="Фото человека", type="pil", height=420)
gr.Markdown("""
### Какое фото подойдёт
✔ В полный рост или по пояс
✔ Одежда по фигуре
✔ Стоите прямо, смотрите в камеру
✔ Руки и предметы не закрывают тело
✔ Хороший свет, без резких теней
✔ В кадре только вы
""")
gr.Image(
value=_UX_EXAMPLE_PIL,
label="",
show_label=False,
interactive=False,
height=340,
visible=bool(_UX_EXAMPLE_PIL),
)
gender = gr.Radio(
choices=["Женская", "Мужская"],
value=_default_gender,
label="Раздел каталога",
)
selected_label = gr.Markdown("👕 Выберите одежду ниже")
garment_gallery = gr.Gallery(
label="Одежда для примерки",
value=_initial_items,
columns=4,
height=340,
allow_preview=True,
)
run = gr.Button("Примерить", variant="primary")
status = gr.Textbox(value="Ожидание...", interactive=False, show_label=False)
with gr.Column():
out = gr.Image(label="Результат", type="pil", height=760)
with gr.Column(visible=False) as feedback_box:
gr.HTML('<div class="feedback-box">')
gr.Markdown("### Оценка результата")
with gr.Row():
like_btn = gr.Button("👍 Нравится")
dislike_btn = gr.Button("👎 Не нравится")
rating_notice = gr.HTML(_rating_notice_idle_html())
gr.Markdown("### Комментарий")
feedback_comment = gr.Textbox(
label="",
placeholder="Напишите, что понравилось или что стоит улучшить...",
lines=3,
max_lines=6,
show_label=False,
)
submit_comment_btn = gr.Button("Отправить комментарий")
comment_notice = gr.HTML(_comment_notice_idle_html())
gr.HTML("</div>")
gender.change(
fn=refresh_catalog,
inputs=[gender],
outputs=[garment_gallery, garment_files_state, selected_garment_state, status, selected_label],
)
garment_gallery.select(
fn=on_gallery_select,
inputs=[garment_files_state],
outputs=[selected_garment_state, selected_label],
)
run.click(
fn=tryon_ui,
inputs=[person, selected_garment_state],
outputs=[out, status, feedback_box, feedback_comment, rating_notice, comment_notice],
concurrency_limit=1,
)
like_btn.click(
fn=submit_like_feedback,
inputs=[selected_garment_state],
outputs=[status, rating_notice],
concurrency_limit=1,
)
dislike_btn.click(
fn=submit_dislike_feedback,
inputs=[selected_garment_state],
outputs=[status, rating_notice],
concurrency_limit=1,
)
submit_comment_btn.click(
fn=submit_comment_feedback,
inputs=[selected_garment_state, feedback_comment],
outputs=[status, comment_notice, feedback_comment],
concurrency_limit=1,
)
demo.queue(max_size=20)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
max_threads=4,
show_error=True,
show_api=False,
)