| import os | |
| import torch | |
| import gradio as gr | |
| import tempfile | |
| import gc | |
| from dotenv import load_dotenv | |
| from huggingface_hub import hf_hub_download, login | |
| from diffusers import AutoencoderKL, DDPMScheduler | |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection | |
| from promptdresser.models.unet import UNet2DConditionModel | |
| from promptdresser.models.cloth_encoder import ClothEncoder | |
| from promptdresser.pipelines.sdxl import PromptDresser | |
| from lib.caption import generate_caption | |
| from lib.mask import generate_clothing_mask | |
| from lib.pose import generate_openpose | |
| load_dotenv() | |
| TOKEN = os.getenv("HF_TOKEN") | |
| login(token=TOKEN) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| weight_dtype = torch.float16 if device == "cuda" else torch.float32 | |
| def load_models(): | |
| """Загружает все необходимые модели""" | |
| print("⚙️ Загрузка моделей...") | |
| try: | |
| noise_scheduler = DDPMScheduler.from_pretrained( | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| subfolder="scheduler" | |
| ) | |
| tokenizer = CLIPTokenizer.from_pretrained( | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| subfolder="tokenizer" | |
| ) | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| subfolder="text_encoder" | |
| ) | |
| tokenizer_2 = CLIPTokenizer.from_pretrained( | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| subfolder="tokenizer_2" | |
| ) | |
| text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| subfolder="text_encoder_2" | |
| ) | |
| vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix") | |
| unet = UNet2DConditionModel.from_pretrained( | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| subfolder="unet" | |
| ) | |
| unet_checkpoint_path = hf_hub_download( | |
| repo_id="Benrise/VITON-HD", | |
| filename="VITONHD/model/pytorch_model.bin", | |
| token=TOKEN | |
| ) | |
| unet.load_state_dict(torch.load(unet_checkpoint_path, map_location=device)) | |
| cloth_encoder = ClothEncoder.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="unet" | |
| ) | |
| models = { | |
| "unet": unet.to(device, dtype=weight_dtype), | |
| "vae": vae.to(device, dtype=weight_dtype), | |
| "text_encoder": text_encoder.to(device, dtype=weight_dtype), | |
| "text_encoder_2": text_encoder_2.to(device, dtype=weight_dtype), | |
| "cloth_encoder": cloth_encoder.to(device, dtype=weight_dtype), | |
| "noise_scheduler": noise_scheduler, | |
| "tokenizer": tokenizer, | |
| "tokenizer_2": tokenizer_2 | |
| } | |
| pipeline = PromptDresser( | |
| vae=models["vae"], | |
| text_encoder=models["text_encoder"], | |
| text_encoder_2=models["text_encoder_2"], | |
| tokenizer=models["tokenizer"], | |
| tokenizer_2=models["tokenizer_2"], | |
| unet=models["unet"], | |
| scheduler=models["noise_scheduler"], | |
| ).to(device, dtype=weight_dtype) | |
| print("✅ Модели успешно загружены") | |
| return {**models, "pipeline": pipeline} | |
| except Exception as e: | |
| print(f"❌ Ошибка загрузки моделей: {e}") | |
| raise | |
| def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt="", label=7): | |
| """Генерация виртуальной примерки с очисткой памяти""" | |
| try: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| person_path = os.path.join(tmp_dir, "person.png") | |
| cloth_path = os.path.join(tmp_dir, "cloth.png") | |
| person_image.save(person_path) | |
| cloth_image.save(cloth_path) | |
| mask_image = generate_clothing_mask(person_path, label=label) | |
| pose_image = generate_openpose(person_path) | |
| final_outfit_prompt = outfit_prompt or generate_caption(person_path, device) | |
| final_clothing_prompt = clothing_prompt or generate_caption(cloth_path, device) | |
| with torch.autocast(device): | |
| result = pipeline( | |
| image=person_image, | |
| mask_image=mask_image, | |
| pose_image=pose_image, | |
| cloth_encoder=models["cloth_encoder"], | |
| cloth_encoder_image=cloth_image, | |
| prompt=final_outfit_prompt, | |
| prompt_clothing=final_clothing_prompt, | |
| height=1024, | |
| width=768, | |
| guidance_scale=2.0, | |
| guidance_scale_img=4.5, | |
| guidance_scale_text=7.5, | |
| num_inference_steps=30, | |
| strength=1, | |
| interm_cloth_start_ratio=0.5, | |
| generator=None, | |
| ).images[0] | |
| return result | |
| except Exception as e: | |
| print(f"❌ Ошибка генерации: {e}") | |
| return None | |
| finally: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| print("🔍 Инициализация моделей...") | |
| models = load_models() | |
| pipeline = models["pipeline"] | |
| with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container") as demo: | |
| gr.Markdown("# 🧥 Virtual Try-On") | |
| gr.Markdown("Загрузите фото человека и одежды для виртуальной примерки") | |
| clothing_classes = [ | |
| "фон", "шляпа", "волосы", "очки", "верхняя одежда", "юбка", "брюки", "платье", | |
| "ремень", "левая обувь", "правая обувь", "лицо", "левая нога", "правая нога", | |
| "левая рука", "правая рука", "сумка", "шарф" | |
| ] | |
| with gr.Row(): | |
| with gr.Column(): | |
| person_input = gr.Image(label="Фото человека", type="pil", sources=["upload"]) | |
| cloth_input = gr.Image(label="Фото одежды", type="pil", sources=["upload"]) | |
| clothing_label = gr.Dropdown( | |
| choices=[(f"{i}: {desc}", i) for i, desc in enumerate(clothing_classes)], | |
| label="Класс одежды для маски", | |
| value=4 | |
| ) | |
| outfit_prompt = gr.Textbox(label="Описание образа (опционально)", placeholder="Например: man in casual outfit") | |
| clothing_prompt = gr.Textbox(label="Описание одежды (опционально)", placeholder="Например: red t-shirt with print") | |
| generate_btn = gr.Button("Сгенерировать примерку", variant="primary") | |
| gr.Examples( | |
| examples=[ | |
| ["./test/person2.png", "./test/00008_00.jpg", "man in skirt", "black longsleeve", 4] | |
| ], | |
| inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt, clothing_label], | |
| label="Примеры для быстрого тестирования" | |
| ) | |
| with gr.Column(): | |
| output_image = gr.Image(label="Результат примерки", interactive=False) | |
| generate_btn.click( | |
| fn=generate_vton, | |
| inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt, clothing_label], | |
| outputs=output_image | |
| ) | |
| gr.Markdown("### Инструкция:") | |
| gr.Markdown("1. Загрузите четкое фото человека в полный рост\n" | |
| "2. Загрузите фото одежды на белом фоне\n" | |
| "3. Выберите тип одежды из выпадающего списка\n" | |
| "4. При необходимости уточните описание образа или одежды\n" | |
| "5. Нажмите кнопку 'Сгенерировать примерку'") | |
| if __name__ == "__main__": | |
| demo.queue(max_size=1).launch( | |
| server_name="0.0.0.0" if os.getenv("SPACE_ID") else None, | |
| share=os.getenv("GRADIO_SHARE") == "True" | |
| ) |