| import os | |
| import torch | |
| import gradio as gr | |
| import tempfile | |
| import gc | |
| from dotenv import load_dotenv | |
| from huggingface_hub import hf_hub_download, login | |
| from diffusers import AutoencoderKL, DDPMScheduler | |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection | |
| from promptdresser.models.unet import UNet2DConditionModel | |
| from promptdresser.models.cloth_encoder import ClothEncoder | |
| from promptdresser.pipelines.sdxl import PromptDresser | |
| from lib.caption import generate_caption | |
| from lib.mask import generate_clothing_mask | |
| from lib.pose import generate_openpose | |
| load_dotenv() | |
| TOKEN = os.getenv("HF_TOKEN") | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.benchmark = True | |
| torch.set_grad_enabled(False) | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" | |
| os.environ["CUDA_MODULE_LOADING"] = "LAZY" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| weight_dtype = torch.float16 if device == "cuda" else torch.float32 | |
| CHECKPOINT_DIR = "./checkpoints/VITONHD/model" | |
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) | |
| def load_models(): | |
| """Загружает все необходимые модели""" | |
| print("⚙️ Загрузка моделей...") | |
| try: | |
| noise_scheduler = DDPMScheduler.from_pretrained( | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| subfolder="scheduler" | |
| ) | |
| tokenizer = CLIPTokenizer.from_pretrained( | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| subfolder="tokenizer" | |
| ) | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| subfolder="text_encoder" | |
| ) | |
| tokenizer_2 = CLIPTokenizer.from_pretrained( | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| subfolder="tokenizer_2" | |
| ) | |
| text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| subfolder="text_encoder_2" | |
| ) | |
| vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix") | |
| unet = UNet2DConditionModel.from_pretrained( | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| subfolder="unet" | |
| ) | |
| checkpoint_path = os.path.join(CHECKPOINT_DIR, "pytorch_model.bin") | |
| if not os.path.exists(checkpoint_path): | |
| print("⏳ Загрузка чекпоинта модели...") | |
| hf_hub_download( | |
| repo_id="Benrise/VITON-HD", | |
| filename="VITONHD/model/pytorch_model.bin", | |
| token=TOKEN, | |
| local_dir=CHECKPOINT_DIR, | |
| force_filename="pytorch_model.bin" | |
| ) | |
| unet.load_state_dict(torch.load(checkpoint_path)) | |
| cloth_encoder = ClothEncoder.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="unet" | |
| ) | |
| models = { | |
| "unet": unet.to(device, dtype=weight_dtype), | |
| "vae": vae.to(device, dtype=weight_dtype), | |
| "text_encoder": text_encoder.to(device, dtype=weight_dtype), | |
| "text_encoder_2": text_encoder_2.to(device, dtype=weight_dtype), | |
| "cloth_encoder": cloth_encoder.to(device, dtype=weight_dtype), | |
| "noise_scheduler": noise_scheduler, | |
| "tokenizer": tokenizer, | |
| "tokenizer_2": tokenizer_2 | |
| } | |
| pipeline = PromptDresser( | |
| vae=models["vae"], | |
| text_encoder=models["text_encoder"], | |
| text_encoder_2=models["text_encoder_2"], | |
| tokenizer=models["tokenizer"], | |
| tokenizer_2=models["tokenizer_2"], | |
| unet=models["unet"], | |
| scheduler=models["noise_scheduler"], | |
| ).to(device, dtype=weight_dtype) | |
| print("✅ Модели успешно загружены") | |
| return {**models, "pipeline": pipeline} | |
| except Exception as e: | |
| print(f"❌ Ошибка загрузки моделей: {e}") | |
| raise | |
| def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt=""): | |
| """Генерация виртуальной примерки с очисткой памяти""" | |
| try: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| person_path = os.path.join(tmp_dir, "person.png") | |
| cloth_path = os.path.join(tmp_dir, "cloth.png") | |
| person_image.save(person_path) | |
| cloth_image.save(cloth_path) | |
| mask_image = generate_clothing_mask(person_path) | |
| pose_image = generate_openpose(person_path) | |
| final_outfit_prompt = outfit_prompt or generate_caption(person_path, device) | |
| final_clothing_prompt = clothing_prompt or generate_caption(cloth_path, device) | |
| with torch.autocast(device): | |
| result = pipeline( | |
| image=person_image, | |
| mask_image=mask_image, | |
| pose_image=pose_image, | |
| cloth_encoder=models["cloth_encoder"], | |
| cloth_encoder_image=cloth_image, | |
| prompt=final_outfit_prompt, | |
| prompt_clothing=final_clothing_prompt, | |
| height=1024, | |
| width=768, | |
| guidance_scale=2.0, | |
| guidance_scale_img=4.5, | |
| guidance_scale_text=7.5, | |
| num_inference_steps=30, | |
| strength=1, | |
| interm_cloth_start_ratio=0.5, | |
| generator=None, | |
| ).images[0] | |
| return result | |
| except Exception as e: | |
| print(f"❌ Ошибка генерации: {e}") | |
| return None | |
| finally: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| print("🔍 Инициализация моделей...") | |
| models = load_models() | |
| pipeline = models["pipeline"] | |
| with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 900px}") as demo: | |
| gr.Markdown("# 🧥 Virtual Try-On") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Входные данные") | |
| person_input = gr.Image(label="Фото человека", type="pil") | |
| cloth_input = gr.Image(label="Фото одежды", type="pil") | |
| outfit_prompt = gr.Textbox(label="Описание образа (необязательно)") | |
| generate_btn = gr.Button("Сгенерировать", variant="primary") | |
| with gr.Column(): | |
| gr.Markdown("### Результат") | |
| output_image = gr.Image(label="Результат примерки") | |
| gr.Markdown("Подождите 1-2 минуты для генерации") | |
| generate_btn.click( | |
| fn=generate_vton, | |
| inputs=[person_input, cloth_input, outfit_prompt], | |
| outputs=output_image | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(concurrency_count=1, max_size=2).launch( | |
| server_name="0.0.0.0" if os.getenv("SPACE_ID") else None, | |
| share=os.getenv("GRADIO_SHARE") == "True" | |
| ) |