VITON-HD / app.py
Benrise's picture
Update model loading
2826a7d
raw
history blame
7.37 kB
import os
import torch
import gradio as gr
import tempfile
import gc
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download, login
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from promptdresser.models.unet import UNet2DConditionModel
from promptdresser.models.cloth_encoder import ClothEncoder
from promptdresser.pipelines.sdxl import PromptDresser
from lib.caption import generate_caption
from lib.mask import generate_clothing_mask
from lib.pose import generate_openpose
load_dotenv()
TOKEN = os.getenv("HF_TOKEN")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.set_grad_enabled(False)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = torch.float16 if device == "cuda" else torch.float32
CHECKPOINT_DIR = "./checkpoints/VITONHD/model"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
def load_models():
"""Загружает все необходимые модели"""
print("⚙️ Загрузка моделей...")
try:
noise_scheduler = DDPMScheduler.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="scheduler"
)
tokenizer = CLIPTokenizer.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="tokenizer"
)
text_encoder = CLIPTextModel.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="text_encoder"
)
tokenizer_2 = CLIPTokenizer.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="tokenizer_2"
)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="text_encoder_2"
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix")
unet = UNet2DConditionModel.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="unet"
)
checkpoint_path = os.path.join(CHECKPOINT_DIR, "pytorch_model.bin")
if not os.path.exists(checkpoint_path):
print("⏳ Загрузка чекпоинта модели...")
hf_hub_download(
repo_id="Benrise/VITON-HD",
filename="VITONHD/model/pytorch_model.bin",
token=TOKEN,
local_dir=CHECKPOINT_DIR,
force_filename="pytorch_model.bin"
)
unet.load_state_dict(torch.load(checkpoint_path))
cloth_encoder = ClothEncoder.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="unet"
)
models = {
"unet": unet.to(device, dtype=weight_dtype),
"vae": vae.to(device, dtype=weight_dtype),
"text_encoder": text_encoder.to(device, dtype=weight_dtype),
"text_encoder_2": text_encoder_2.to(device, dtype=weight_dtype),
"cloth_encoder": cloth_encoder.to(device, dtype=weight_dtype),
"noise_scheduler": noise_scheduler,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2
}
pipeline = PromptDresser(
vae=models["vae"],
text_encoder=models["text_encoder"],
text_encoder_2=models["text_encoder_2"],
tokenizer=models["tokenizer"],
tokenizer_2=models["tokenizer_2"],
unet=models["unet"],
scheduler=models["noise_scheduler"],
).to(device, dtype=weight_dtype)
print("✅ Модели успешно загружены")
return {**models, "pipeline": pipeline}
except Exception as e:
print(f"❌ Ошибка загрузки моделей: {e}")
raise
def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt=""):
"""Генерация виртуальной примерки с очисткой памяти"""
try:
torch.cuda.empty_cache()
gc.collect()
with tempfile.TemporaryDirectory() as tmp_dir:
person_path = os.path.join(tmp_dir, "person.png")
cloth_path = os.path.join(tmp_dir, "cloth.png")
person_image.save(person_path)
cloth_image.save(cloth_path)
mask_image = generate_clothing_mask(person_path)
pose_image = generate_openpose(person_path)
final_outfit_prompt = outfit_prompt or generate_caption(person_path, device)
final_clothing_prompt = clothing_prompt or generate_caption(cloth_path, device)
with torch.autocast(device):
result = pipeline(
image=person_image,
mask_image=mask_image,
pose_image=pose_image,
cloth_encoder=models["cloth_encoder"],
cloth_encoder_image=cloth_image,
prompt=final_outfit_prompt,
prompt_clothing=final_clothing_prompt,
height=1024,
width=768,
guidance_scale=2.0,
guidance_scale_img=4.5,
guidance_scale_text=7.5,
num_inference_steps=30,
strength=1,
interm_cloth_start_ratio=0.5,
generator=None,
).images[0]
return result
except Exception as e:
print(f"❌ Ошибка генерации: {e}")
return None
finally:
torch.cuda.empty_cache()
gc.collect()
print("🔍 Инициализация моделей...")
models = load_models()
pipeline = models["pipeline"]
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 900px}") as demo:
gr.Markdown("# 🧥 Virtual Try-On")
with gr.Row():
with gr.Column():
gr.Markdown("### Входные данные")
person_input = gr.Image(label="Фото человека", type="pil")
cloth_input = gr.Image(label="Фото одежды", type="pil")
outfit_prompt = gr.Textbox(label="Описание образа (необязательно)")
generate_btn = gr.Button("Сгенерировать", variant="primary")
with gr.Column():
gr.Markdown("### Результат")
output_image = gr.Image(label="Результат примерки")
gr.Markdown("Подождите 1-2 минуты для генерации")
generate_btn.click(
fn=generate_vton,
inputs=[person_input, cloth_input, outfit_prompt],
outputs=output_image
)
if __name__ == "__main__":
demo.queue(concurrency_count=1, max_size=2).launch(
server_name="0.0.0.0" if os.getenv("SPACE_ID") else None,
share=os.getenv("GRADIO_SHARE") == "True"
)