# -*- coding: utf-8 -*- import os import time import json from datetime import datetime from typing import List, Optional, Tuple import spaces import gradio as gr from PIL import Image # ========================= # FIX: gradio 4.24 / gradio_client crashes on boolean JSON Schemas in /api_info # ========================= def _patch_gradio_client_bool_schema(): try: import gradio_client.utils as gcu patched_any = False if hasattr(gcu, "get_type"): _orig_get_type = gcu.get_type def _get_type_patched(schema): if isinstance(schema, bool): return "any" return _orig_get_type(schema) gcu.get_type = _get_type_patched patched_any = True if hasattr(gcu, "get_desc"): _orig_get_desc = gcu.get_desc def _get_desc_patched(schema): if isinstance(schema, bool): return "" return _orig_get_desc(schema) gcu.get_desc = _get_desc_patched patched_any = True if hasattr(gcu, "_json_schema_to_python_type"): _orig_json2py = gcu._json_schema_to_python_type def _json_schema_to_python_type_patched(schema, defs=None): if isinstance(schema, bool): return "any" return _orig_json2py(schema, defs) gcu._json_schema_to_python_type = _json_schema_to_python_type_patched patched_any = True if patched_any: print("Patched gradio_client.utils for boolean JSON Schemas (/api_info)", flush=True) else: print("gradio_client patch: nothing to patch (unexpected utils layout)", flush=True) except Exception as e: print("gradio_client patch failed:", repr(e), flush=True) _patch_gradio_client_bool_schema() import torch from torchvision import transforms from huggingface_hub import login, snapshot_download, HfApi, hf_hub_download from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref from src.unet_hacked_tryon import UNet2DConditionModel from transformers import ( CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPTextModel, CLIPTextModelWithProjection, AutoTokenizer, ) from diffusers import DDPMScheduler, AutoencoderKL import apply_net from utils_mask import get_mask_location from preprocess.humanparsing.run_parsing import Parsing from preprocess.openpose.run_openpose import OpenPose from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation # ========================= # Garments catalog (only) # ========================= GARMENT_DIR = "garments" ALLOWED_EXTS = (".png", ".jpg", ".jpeg", ".webp") GARMENTS_DATASET = os.getenv("GARMENTS_DATASET", "").strip() HF_TOKEN = os.getenv("HF_TOKEN", "").strip() def ensure_garments_available() -> None: """ Если GARMENTS_DATASET не задан — используем локальную папку ./garments. Если задан — скачиваем датасет HF в ./garments. """ os.makedirs(GARMENT_DIR, exist_ok=True) if not GARMENTS_DATASET: print("GARMENTS_DATASET not set. Using local ./garments (if any).", flush=True) return if HF_TOKEN: try: login(token=HF_TOKEN, add_to_git_credential=False) print("HF login: OK", flush=True) except Exception as e: print("HF login: FAILED:", str(e)[:200], flush=True) try: snapshot_download( repo_id=GARMENTS_DATASET, repo_type="dataset", local_dir=GARMENT_DIR, local_dir_use_symlinks=False, token=HF_TOKEN if HF_TOKEN else None, ) print(f"Garments dataset downloaded: {GARMENTS_DATASET} -> {GARMENT_DIR}/", flush=True) except Exception as e: print("Garments download FAILED:", str(e)[:300], flush=True) def _gender_subdir(gender: str) -> str: return "male" if gender == "Мужская" else "female" def list_garments(gender: Optional[str] = None) -> List[str]: files: List[str] = [] if not os.path.isdir(GARMENT_DIR): return files if gender: base = os.path.join(GARMENT_DIR, _gender_subdir(gender)) if os.path.isdir(base): for root, _, fnames in os.walk(base): for f in fnames: if f.lower().endswith(ALLOWED_EXTS) and not f.startswith("."): rel = os.path.relpath(os.path.join(root, f), GARMENT_DIR) files.append(rel) files.sort() return files for root, _, fnames in os.walk(GARMENT_DIR): for f in fnames: if f.lower().endswith(ALLOWED_EXTS) and not f.startswith("."): rel = os.path.relpath(os.path.join(root, f), GARMENT_DIR) files.append(rel) files.sort() return files def garment_path(relpath: str) -> str: return os.path.join(GARMENT_DIR, relpath) def load_garment_pil(relpath: str) -> Optional[Image.Image]: if not relpath: return None path = garment_path(relpath) if not os.path.exists(path): return None try: return Image.open(path).convert("RGB") except Exception: return None def build_gallery_items(files: List[str]): return [(garment_path(f), "") for f in files] # ========================= # Rate limit # ========================= _last_call_ts = 0.0 def allow_call(min_interval_sec: float = 2.5) -> Tuple[bool, str]: global _last_call_ts now = time.time() if now - _last_call_ts < min_interval_sec: wait = min_interval_sec - (now - _last_call_ts) return False, f"⏳ Подождите {wait:.1f} сек." _last_call_ts = now return True, "" # ========================= # Model init (BASELINE) # ========================= base_path = "yisol/IDM-VTON" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 print("DEVICE:", DEVICE, "DTYPE:", DTYPE, flush=True) tensor_transfrom = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=DTYPE) unet.requires_grad_(False) tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", revision=None, use_fast=False) tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", revision=None, use_fast=False) noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler") text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=DTYPE) text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=DTYPE) image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=DTYPE) vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=DTYPE) UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=DTYPE) UNet_Encoder.requires_grad_(False) parsing_model = Parsing(0) openpose_model = OpenPose(0) for m in [UNet_Encoder, image_encoder, vae, unet, text_encoder_one, text_encoder_two]: m.requires_grad_(False) pipe = TryonPipeline.from_pretrained( base_path, unet=unet, vae=vae, feature_extractor=CLIPImageProcessor(), text_encoder=text_encoder_one, text_encoder_2=text_encoder_two, tokenizer=tokenizer_one, tokenizer_2=tokenizer_two, scheduler=noise_scheduler, image_encoder=image_encoder, torch_dtype=DTYPE, ) pipe.unet_encoder = UNet_Encoder # ========================= # Inference (BASELINE params) # ========================= @spaces.GPU def start_tryon( human_pil: Image.Image, garm_img: Image.Image, ) -> Image.Image: device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 if device == "cuda": openpose_model.preprocessor.body_estimation.model.to(device) pipe.to(device) pipe.unet_encoder.to(device) garm_img = garm_img.convert("RGB").resize((768, 1024)) human_img_orig = human_pil.convert("RGB") width, height = human_img_orig.size target_width = int(min(width, height * (3 / 4))) target_height = int(min(height, width * (4 / 3))) left = (width - target_width) / 2 top = (height - target_height) / 2 right = (width + target_width) / 2 bottom = (height + target_height) / 2 cropped_img = human_img_orig.crop((left, top, right, bottom)) crop_size = cropped_img.size human_img = cropped_img.resize((768, 1024)) keypoints = openpose_model(human_img.resize((384, 512))) model_parse, _ = parsing_model(human_img.resize((384, 512))) mask, _ = get_mask_location("hd", "upper_body", model_parse, keypoints) mask = mask.resize((768, 1024)) human_img_arg = _apply_exif_orientation(human_img.resize((384, 512))) human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR") args = apply_net.create_argument_parser().parse_args( ( "show", "./configs/densepose_rcnn_R_50_FPN_s1x.yaml", "./ckpt/densepose/model_final_162be9.pkl", "dp_segm", "-v", "--opts", "MODEL.DEVICE", "cuda" if device == "cuda" else "cpu", ) ) pose_img = args.func(args, human_img_arg) pose_img = pose_img[:, :, ::-1] pose_img = Image.fromarray(pose_img).resize((768, 1024)) garment_des = "a garment" prompt_main = "model is wearing " + garment_des prompt_cloth = "a photo of " + garment_des negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" denoise_steps = 30 guidance_scale = 2.0 strength = 1.0 seed = 42 with torch.no_grad(): if device == "cuda": autocast_ctx = torch.cuda.amp.autocast() else: class _NoCtx: def __enter__(self): return None def __exit__(self, *args): return False autocast_ctx = _NoCtx() with autocast_ctx: ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = pipe.encode_prompt( prompt_main, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) (prompt_embeds_c, _, _, _) = pipe.encode_prompt( [prompt_cloth], num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=[negative_prompt], ) pose_t = tensor_transfrom(pose_img).unsqueeze(0).to(device=device, dtype=dtype) garm_t = tensor_transfrom(garm_img).unsqueeze(0).to(device=device, dtype=dtype) generator = torch.Generator(device).manual_seed(seed) images = pipe( prompt_embeds=prompt_embeds.to(device=device, dtype=dtype), negative_prompt_embeds=negative_prompt_embeds.to(device=device, dtype=dtype), pooled_prompt_embeds=pooled_prompt_embeds.to(device=device, dtype=dtype), negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device=device, dtype=dtype), num_inference_steps=denoise_steps, generator=generator, strength=strength, pose_img=pose_t, text_embeds_cloth=prompt_embeds_c.to(device=device, dtype=dtype), cloth=garm_t, mask_image=mask, image=human_img, height=1024, width=768, ip_adapter_image=garm_img.resize((768, 1024)), guidance_scale=guidance_scale, )[0] out_img = images[0] out_img_rs = out_img.resize(crop_size) human_img_orig.paste(out_img_rs, (int(left), int(top))) return human_img_orig # ========================= # UI / CSS # ========================= CUSTOM_CSS = """ footer {display:none !important;} #api-info {display:none !important;} div[class*="footer"] {display:none !important;} button[aria-label="Settings"] {display:none !important;} .feedback-box { border: 1px solid #e5e7eb; border-radius: 14px; padding: 12px 14px; background: #fafafa; margin-top: 10px; } .feedback-ok { border: 1px solid #b7ebc6; background: #f0fff4; color: #166534; border-radius: 12px; padding: 10px 12px; font-size: 14px; margin-top: 8px; } .feedback-idle { border: 1px dashed #d1d5db; background: #ffffff; color: #6b7280; border-radius: 12px; padding: 10px 12px; font-size: 14px; margin-top: 8px; } """ # ========================= # UX example image # ========================= UX_EXAMPLE_IMG_PATH = "assets/photo_2026-02-26_14-56-24.jpg" def _load_ux_example_pil() -> Optional[Image.Image]: try: if UX_EXAMPLE_IMG_PATH and os.path.exists(UX_EXAMPLE_IMG_PATH): return Image.open(UX_EXAMPLE_IMG_PATH).convert("RGB") except Exception as e: print("UX example image load failed:", repr(e), flush=True) return None _UX_EXAMPLE_PIL = _load_ux_example_pil() def refresh_catalog(gender: str): ensure_garments_available() files = list_garments(gender=gender) items = build_gallery_items(files) status = f"✅ Каталог: {gender} ({len(files)})" if files else f"⚠️ Каталог пуст: {gender}" return items, files, None, status, "👕 Выберите одежду ниже" def on_gallery_select(files_list: List[str], evt: gr.SelectData): if not files_list: return None, "⚠️ Каталог пуст" idx = int(evt.index) if evt.index is not None else 0 idx = max(0, min(idx, len(files_list) - 1)) return files_list[idx], "👕 Одежда выбрана" # ========================= # Feedback storage # ========================= FEEDBACK_DIR = "./feedback" FEEDBACK_PATH = os.path.join(FEEDBACK_DIR, "feedback.jsonl") FEEDBACK_REPO_ID = os.getenv("FEEDBACK_REPO_ID", "").strip() FEEDBACK_REPO_TYPE = "dataset" FEEDBACK_REPO_FILEPATH = "feedback/feedback.jsonl" def _read_local_feedback_text() -> str: if not os.path.exists(FEEDBACK_PATH): return "" with open(FEEDBACK_PATH, "r", encoding="utf-8") as f: return f.read() def _write_local_feedback_text(text: str) -> None: os.makedirs(FEEDBACK_DIR, exist_ok=True) with open(FEEDBACK_PATH, "w", encoding="utf-8") as f: f.write(text) def _download_repo_feedback(api: HfApi, token: str) -> str: if not FEEDBACK_REPO_ID: return "" try: local_path = hf_hub_download( repo_id=FEEDBACK_REPO_ID, repo_type=FEEDBACK_REPO_TYPE, filename=FEEDBACK_REPO_FILEPATH, token=token, ) with open(local_path, "r", encoding="utf-8") as f: return f.read() except Exception: return "" def _upload_repo_feedback(api: HfApi, token: str, text: str) -> None: if not FEEDBACK_REPO_ID: raise RuntimeError("FEEDBACK_REPO_ID not set") os.makedirs(FEEDBACK_DIR, exist_ok=True) tmp_path = os.path.join(FEEDBACK_DIR, "_feedback_upload.jsonl") with open(tmp_path, "w", encoding="utf-8") as f: f.write(text) api.upload_file( path_or_fileobj=tmp_path, path_in_repo=FEEDBACK_REPO_FILEPATH, repo_id=FEEDBACK_REPO_ID, repo_type=FEEDBACK_REPO_TYPE, token=token, commit_message="Add try-on feedback", ) def _append_feedback_record(record: dict) -> None: line = json.dumps(record, ensure_ascii=False) + "\n" try: local_text = _read_local_feedback_text() local_text += line _write_local_feedback_text(local_text) except Exception as e: print("Feedback local write failed:", repr(e), flush=True) try: token = (os.getenv("HF_TOKEN", "").strip() or os.getenv("HUGGINGFACEHUB_API_TOKEN", "").strip()) if not token: print("Feedback repo sync skipped: HF_TOKEN not set", flush=True) return if not FEEDBACK_REPO_ID: print("Feedback repo sync skipped: FEEDBACK_REPO_ID not set", flush=True) return api = HfApi() repo_text = _download_repo_feedback(api, token) repo_text += line _upload_repo_feedback(api, token, repo_text) except Exception as e: print("Feedback repo sync failed:", repr(e), flush=True) def save_rating_feedback(is_like: bool, garment_name: str) -> None: record = { "timestamp": datetime.utcnow().isoformat(), "event": "rating", "like": bool(is_like), "garment": garment_name or "", } _append_feedback_record(record) def save_comment_feedback(garment_name: str, comment: str) -> None: clean_comment = (comment or "").strip() if len(clean_comment) > 1500: clean_comment = clean_comment[:1500] record = { "timestamp": datetime.utcnow().isoformat(), "event": "comment", "garment": garment_name or "", "comment": clean_comment, } _append_feedback_record(record) # ========================= # Feedback UI helpers # ========================= def _rating_notice_idle_html(): return """
Оцените результат: нажмите лайк или дизлайк.
""" def _comment_notice_idle_html(): return """
При желании напишите комментарий и нажмите кнопку отправки.
""" def _rating_notice_ok_html(action_text: str): return f"""
✅ Ваша оценка {action_text} сохранена.
""" def _comment_notice_ok_html(): return """
✅ Комментарий отправлен. Спасибо за обратную связь.
""" # ========================= # Feedback actions # ========================= def submit_like_feedback(garment_name: str): if not garment_name: return "⚠️ Сначала выполните примерку и выберите одежду", _rating_notice_idle_html() try: save_rating_feedback(True, garment_name) return "✅ Оценка сохранена: «Нравится»", _rating_notice_ok_html("«Нравится»") except Exception as e: return ( f"❌ Ошибка сохранения оценки: {type(e).__name__}: {str(e)[:200]}", _rating_notice_idle_html(), ) def submit_dislike_feedback(garment_name: str): if not garment_name: return "⚠️ Сначала выполните примерку и выберите одежду", _rating_notice_idle_html() try: save_rating_feedback(False, garment_name) return "✅ Оценка сохранена: «Не нравится»", _rating_notice_ok_html("«Не нравится»") except Exception as e: return ( f"❌ Ошибка сохранения оценки: {type(e).__name__}: {str(e)[:200]}", _rating_notice_idle_html(), ) def submit_comment_feedback(garment_name: str, comment: str): clean_comment = (comment or "").strip() if not garment_name: return ( "⚠️ Сначала выполните примерку и выберите одежду", _comment_notice_idle_html(), gr.update(value=comment), ) if not clean_comment: return ( "⚠️ Напишите комментарий перед отправкой", _comment_notice_idle_html(), gr.update(value=comment), ) try: save_comment_feedback(garment_name, clean_comment) return ( "✅ Комментарий отправлен", _comment_notice_ok_html(), gr.update(value=""), ) except Exception as e: return ( f"❌ Ошибка отправки комментария: {type(e).__name__}: {str(e)[:200]}", _comment_notice_idle_html(), gr.update(value=comment), ) # ========================= # Try-on UI # ========================= def tryon_ui(person_pil, selected_filename): for msg in [ "🧵 Анализируем посадку ткани…", "📏 Подбираем пропорции одежды…", "🧍 Определяем позу и положение тела…", "📐 Выравниваем геометрию одежды…", "🎨 Сохраняем цвет и фактуру ткани…", "🪡 Прорисовываем швы и детали…", "🧶 Адаптируем складки ткани…", "🧥 Корректируем посадку на фигуре…", "✨ Улучшаем освещение и тени…", "🔬 Уточняем текстуру материала…" "🔍 Проверяем мелкие детали…", "🎯 Финальная корректировка результата…", "🌟 Последние штрихи…", "🪄 Почти готово…", ]: yield ( None, msg, gr.update(visible=False), gr.update(value=""), gr.update(value=_rating_notice_idle_html()), gr.update(value=_comment_notice_idle_html()), ) time.sleep(2.3) ok, msg = allow_call(2.5) if not ok: yield ( None, msg, gr.update(visible=False), gr.update(value=""), gr.update(value=_rating_notice_idle_html()), gr.update(value=_comment_notice_idle_html()), ) return if person_pil is None: yield ( None, "❌ Загрузите фото человека", gr.update(visible=False), gr.update(value=""), gr.update(value=_rating_notice_idle_html()), gr.update(value=_comment_notice_idle_html()), ) return if not selected_filename: yield ( None, "❌ Выберите одежду (клик по превью)", gr.update(visible=False), gr.update(value=""), gr.update(value=_rating_notice_idle_html()), gr.update(value=_comment_notice_idle_html()), ) return garm = load_garment_pil(selected_filename) if garm is None: yield ( None, "❌ Не удалось загрузить выбранную одежду", gr.update(visible=False), gr.update(value=""), gr.update(value=_rating_notice_idle_html()), gr.update(value=_comment_notice_idle_html()), ) return try: out_img = start_tryon(human_pil=person_pil, garm_img=garm) yield ( out_img, "✅ Готово — оцените результат и при желании оставьте комментарий", gr.update(visible=True), gr.update(value=""), gr.update(value=_rating_notice_idle_html()), gr.update(value=_comment_notice_idle_html()), ) except Exception as e: yield ( None, f"❌ Ошибка: {type(e).__name__}: {str(e)[:220]}", gr.update(visible=False), gr.update(value=""), gr.update(value=_rating_notice_idle_html()), gr.update(value=_comment_notice_idle_html()), ) # ========================= # Boot # ========================= ensure_garments_available() _default_gender = "Женская" _initial_files = list_garments(gender=_default_gender) _initial_items = build_gallery_items(_initial_files) with gr.Blocks(title="Virtual Try-On Rendez-vous", css=CUSTOM_CSS) as demo: gr.Markdown("# Virtual Try-On Rendez-vous") garment_files_state = gr.State(_initial_files) selected_garment_state = gr.State(None) with gr.Row(): with gr.Column(): person = gr.Image(label="Фото человека", type="pil", height=420) gr.Markdown(""" ### Какое фото подойдёт ✔ В полный рост или по пояс ✔ Одежда по фигуре ✔ Стоите прямо, смотрите в камеру ✔ Руки и предметы не закрывают тело ✔ Хороший свет, без резких теней ✔ В кадре только вы """) gr.Image( value=_UX_EXAMPLE_PIL, label="", show_label=False, interactive=False, height=340, visible=bool(_UX_EXAMPLE_PIL), ) gender = gr.Radio( choices=["Женская", "Мужская"], value=_default_gender, label="Раздел каталога", ) selected_label = gr.Markdown("👕 Выберите одежду ниже") garment_gallery = gr.Gallery( label="Одежда для примерки", value=_initial_items, columns=4, height=340, allow_preview=True, ) run = gr.Button("Примерить", variant="primary") status = gr.Textbox(value="Ожидание...", interactive=False, show_label=False) with gr.Column(): out = gr.Image(label="Результат", type="pil", height=760) with gr.Column(visible=False) as feedback_box: gr.HTML('
') gr.Markdown("### Оценка результата") with gr.Row(): like_btn = gr.Button("👍 Нравится") dislike_btn = gr.Button("👎 Не нравится") rating_notice = gr.HTML(_rating_notice_idle_html()) gr.Markdown("### Комментарий") feedback_comment = gr.Textbox( label="", placeholder="Напишите, что понравилось или что стоит улучшить...", lines=3, max_lines=6, show_label=False, ) submit_comment_btn = gr.Button("Отправить комментарий") comment_notice = gr.HTML(_comment_notice_idle_html()) gr.HTML("
") gender.change( fn=refresh_catalog, inputs=[gender], outputs=[garment_gallery, garment_files_state, selected_garment_state, status, selected_label], ) garment_gallery.select( fn=on_gallery_select, inputs=[garment_files_state], outputs=[selected_garment_state, selected_label], ) run.click( fn=tryon_ui, inputs=[person, selected_garment_state], outputs=[out, status, feedback_box, feedback_comment, rating_notice, comment_notice], concurrency_limit=1, ) like_btn.click( fn=submit_like_feedback, inputs=[selected_garment_state], outputs=[status, rating_notice], concurrency_limit=1, ) dislike_btn.click( fn=submit_dislike_feedback, inputs=[selected_garment_state], outputs=[status, rating_notice], concurrency_limit=1, ) submit_comment_btn.click( fn=submit_comment_feedback, inputs=[selected_garment_state, feedback_comment], outputs=[status, comment_notice, feedback_comment], concurrency_limit=1, ) demo.queue(max_size=20) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False, max_threads=4, show_error=True, show_api=False, )