Benrise commited on
Commit
2826a7d
·
1 Parent(s): 9b63413

Update model loading

Browse files
Files changed (4) hide show
  1. .env.example +1 -0
  2. .gitignore +2 -1
  3. app.py +150 -108
  4. requirements.txt +2 -1
.env.example ADDED
@@ -0,0 +1 @@
 
 
1
+ HF_TOKEN=...
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  __pycache__
2
- checkpoints
 
 
1
  __pycache__
2
+ checkpoints
3
+ .env
app.py CHANGED
@@ -2,7 +2,9 @@ import os
2
  import torch
3
  import gradio as gr
4
  import tempfile
5
- from huggingface_hub import hf_hub_download
 
 
6
  from diffusers import AutoencoderKL, DDPMScheduler
7
  from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
8
 
@@ -13,138 +15,178 @@ from lib.caption import generate_caption
13
  from lib.mask import generate_clothing_mask
14
  from lib.pose import generate_openpose
15
 
 
 
 
 
 
 
 
 
16
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  weight_dtype = torch.float16 if device == "cuda" else torch.float32
19
 
 
 
 
20
  def load_models():
 
21
  print("⚙️ Загрузка моделей...")
22
 
23
- noise_scheduler = DDPMScheduler.from_pretrained(
24
- "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
25
- subfolder="scheduler"
26
- )
27
- tokenizer = CLIPTokenizer.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="tokenizer")
28
- text_encoder = CLIPTextModel.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="text_encoder")
29
- tokenizer_2 = CLIPTokenizer.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="tokenizer_2")
30
- text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="text_encoder_2")
31
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix")
32
- unet = UNet2DConditionModel.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="unet")
33
- cloth_encoder = ClothEncoder.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet")
34
-
35
- unet_checkpoint_path = hf_hub_download(
36
- repo_id="Benrise/VITON-HD",
37
- filename="VITONHD/model/pytorch_model.bin",
38
- cache_dir="checkpoints"
39
- )
40
- unet.load_state_dict(torch.load(unet_checkpoint_path))
41
-
42
- models = {
43
- "unet": unet.to(device, dtype=weight_dtype),
44
- "vae": vae.to(device, dtype=weight_dtype),
45
- "text_encoder": text_encoder.to(device, dtype=weight_dtype),
46
- "text_encoder_2": text_encoder_2.to(device, dtype=weight_dtype),
47
- "cloth_encoder": cloth_encoder.to(device, dtype=weight_dtype),
48
- "noise_scheduler": noise_scheduler,
49
- "tokenizer": tokenizer,
50
- "tokenizer_2": tokenizer_2
51
- }
52
-
53
- pipeline = PromptDresser(
54
- vae=models["vae"],
55
- text_encoder=models["text_encoder"],
56
- text_encoder_2=models["text_encoder_2"],
57
- tokenizer=models["tokenizer"],
58
- tokenizer_2=models["tokenizer_2"],
59
- unet=models["unet"],
60
- scheduler=models["noise_scheduler"],
61
- ).to(device, dtype=weight_dtype)
62
-
63
- return {**models, "pipeline": pipeline}
64
-
65
- models = load_models()
66
- pipeline = models["pipeline"]
67
-
68
- def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt=""):
69
- with tempfile.TemporaryDirectory() as tmp_dir:
70
- person_path = os.path.join(tmp_dir, "person.png")
71
- cloth_path = os.path.join(tmp_dir, "cloth.png")
72
 
73
- person_image.save(person_path)
74
- cloth_image.save(cloth_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- mask_path = os.path.join(tmp_dir, "mask.png")
77
- pose_path = os.path.join(tmp_dir, "pose.png")
78
 
79
- mask_image = generate_clothing_mask(person_path, label=4, output_path=mask_path, show_result=False)
80
- pose_image = generate_openpose(person_path, output_image_path=pose_path, show_result=False)
 
 
81
 
82
- auto_outfit_prompt = generate_caption(person_path, device)
83
- auto_clothing_prompt = generate_caption(cloth_path, device)
 
 
 
 
 
 
 
 
84
 
85
- final_outfit_prompt = outfit_prompt or auto_outfit_prompt
86
- final_clothing_prompt = clothing_prompt or auto_clothing_prompt
 
 
 
 
 
 
 
87
 
88
- with torch.autocast(device):
89
- result = pipeline(
90
- image=person_image,
91
- mask_image=mask_image,
92
- pose_image=pose_image,
93
- cloth_encoder=models["cloth_encoder"],
94
- cloth_encoder_image=cloth_image,
95
- prompt=final_outfit_prompt,
96
- prompt_clothing=final_clothing_prompt,
97
- height=1024,
98
- width=768,
99
- guidance_scale=2.0,
100
- guidance_scale_img=4.5,
101
- guidance_scale_text=7.5,
102
- num_inference_steps=30,
103
- strength=1,
104
- interm_cloth_start_ratio=0.5,
105
- generator=None,
106
- ).images[0]
107
 
108
- return result
 
 
109
 
110
- with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  gr.Markdown("# 🧥 Virtual Try-On")
112
- gr.Markdown("Загрузите фото человека и одежды для виртуальной примерки")
113
 
114
  with gr.Row():
115
  with gr.Column():
116
- person_input = gr.Image(label="Фото человека", type="pil", sources=["upload"])
117
- cloth_input = gr.Image(label="Фото одежды", type="pil", sources=["upload"])
118
- outfit_prompt = gr.Textbox(label="Описание образа (опционально)", placeholder="Например: man in casual outfit")
119
- clothing_prompt = gr.Textbox(label="Описание одежды (опционально)", placeholder="Например: red t-shirt with print")
120
- generate_btn = gr.Button("Сгенерировать примерку", variant="primary")
121
-
122
- gr.Examples(
123
- examples=[
124
- ["./test/person2.png", "./test/00008_00.jpg", "man in skirt", "black longsleeve"]
125
- ],
126
- inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt],
127
- label="Примеры для быстрого тестирования"
128
- )
129
-
130
  with gr.Column():
131
- output_image = gr.Image(label="Результат примерки", interactive=False)
132
-
 
 
133
  generate_btn.click(
134
  fn=generate_vton,
135
- inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt],
136
  outputs=output_image
137
  )
138
-
139
- gr.Markdown("### Инструкция:")
140
- gr.Markdown("1. Загрузите четкое фото человека в полный рост\n"
141
- "2. Загрузите фото одежды на белом фоне\n"
142
- "3. При необходимости уточните описание образа или одежды\n"
143
- "4. Нажмите кнопку 'Сгенерировать примерку'")
144
 
145
  if __name__ == "__main__":
146
- demo.queue(max_size=3).launch(
147
  server_name="0.0.0.0" if os.getenv("SPACE_ID") else None,
148
- share=os.getenv("GRADIO_SHARE") == "True",
149
- debug=True
150
  )
 
2
  import torch
3
  import gradio as gr
4
  import tempfile
5
+ import gc
6
+ from dotenv import load_dotenv
7
+ from huggingface_hub import hf_hub_download, login
8
  from diffusers import AutoencoderKL, DDPMScheduler
9
  from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
10
 
 
15
  from lib.mask import generate_clothing_mask
16
  from lib.pose import generate_openpose
17
 
18
+ load_dotenv()
19
+ TOKEN = os.getenv("HF_TOKEN")
20
+
21
+ torch.backends.cuda.matmul.allow_tf32 = True
22
+ torch.backends.cudnn.benchmark = True
23
+ torch.set_grad_enabled(False)
24
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
25
+ os.environ["CUDA_MODULE_LOADING"] = "LAZY"
26
 
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  weight_dtype = torch.float16 if device == "cuda" else torch.float32
29
 
30
+ CHECKPOINT_DIR = "./checkpoints/VITONHD/model"
31
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
32
+
33
  def load_models():
34
+ """Загружает все необходимые модели"""
35
  print("⚙️ Загрузка моделей...")
36
 
37
+ try:
38
+ noise_scheduler = DDPMScheduler.from_pretrained(
39
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
40
+ subfolder="scheduler"
41
+ )
42
+
43
+ tokenizer = CLIPTokenizer.from_pretrained(
44
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
45
+ subfolder="tokenizer"
46
+ )
47
+
48
+ text_encoder = CLIPTextModel.from_pretrained(
49
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
50
+ subfolder="text_encoder"
51
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ tokenizer_2 = CLIPTokenizer.from_pretrained(
54
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
55
+ subfolder="tokenizer_2"
56
+ )
57
+
58
+ text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
59
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
60
+ subfolder="text_encoder_2"
61
+ )
62
+
63
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix")
64
+ unet = UNet2DConditionModel.from_pretrained(
65
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
66
+ subfolder="unet"
67
+ )
68
+
69
+ checkpoint_path = os.path.join(CHECKPOINT_DIR, "pytorch_model.bin")
70
+ if not os.path.exists(checkpoint_path):
71
+ print("⏳ Загрузка чекпоинта модели...")
72
+ hf_hub_download(
73
+ repo_id="Benrise/VITON-HD",
74
+ filename="VITONHD/model/pytorch_model.bin",
75
+ token=TOKEN,
76
+ local_dir=CHECKPOINT_DIR,
77
+ force_filename="pytorch_model.bin"
78
+ )
79
 
80
+ unet.load_state_dict(torch.load(checkpoint_path))
 
81
 
82
+ cloth_encoder = ClothEncoder.from_pretrained(
83
+ "stabilityai/stable-diffusion-xl-base-1.0",
84
+ subfolder="unet"
85
+ )
86
 
87
+ models = {
88
+ "unet": unet.to(device, dtype=weight_dtype),
89
+ "vae": vae.to(device, dtype=weight_dtype),
90
+ "text_encoder": text_encoder.to(device, dtype=weight_dtype),
91
+ "text_encoder_2": text_encoder_2.to(device, dtype=weight_dtype),
92
+ "cloth_encoder": cloth_encoder.to(device, dtype=weight_dtype),
93
+ "noise_scheduler": noise_scheduler,
94
+ "tokenizer": tokenizer,
95
+ "tokenizer_2": tokenizer_2
96
+ }
97
 
98
+ pipeline = PromptDresser(
99
+ vae=models["vae"],
100
+ text_encoder=models["text_encoder"],
101
+ text_encoder_2=models["text_encoder_2"],
102
+ tokenizer=models["tokenizer"],
103
+ tokenizer_2=models["tokenizer_2"],
104
+ unet=models["unet"],
105
+ scheduler=models["noise_scheduler"],
106
+ ).to(device, dtype=weight_dtype)
107
 
108
+ print("✅ Модели успешно загружены")
109
+ return {**models, "pipeline": pipeline}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ except Exception as e:
112
+ print(f"❌ Ошибка загрузки моделей: {e}")
113
+ raise
114
 
115
+ def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt=""):
116
+ """Генерация виртуальной примерки с очисткой памяти"""
117
+ try:
118
+ torch.cuda.empty_cache()
119
+ gc.collect()
120
+
121
+ with tempfile.TemporaryDirectory() as tmp_dir:
122
+ person_path = os.path.join(tmp_dir, "person.png")
123
+ cloth_path = os.path.join(tmp_dir, "cloth.png")
124
+ person_image.save(person_path)
125
+ cloth_image.save(cloth_path)
126
+
127
+ mask_image = generate_clothing_mask(person_path)
128
+ pose_image = generate_openpose(person_path)
129
+
130
+ final_outfit_prompt = outfit_prompt or generate_caption(person_path, device)
131
+ final_clothing_prompt = clothing_prompt or generate_caption(cloth_path, device)
132
+
133
+ with torch.autocast(device):
134
+ result = pipeline(
135
+ image=person_image,
136
+ mask_image=mask_image,
137
+ pose_image=pose_image,
138
+ cloth_encoder=models["cloth_encoder"],
139
+ cloth_encoder_image=cloth_image,
140
+ prompt=final_outfit_prompt,
141
+ prompt_clothing=final_clothing_prompt,
142
+ height=1024,
143
+ width=768,
144
+ guidance_scale=2.0,
145
+ guidance_scale_img=4.5,
146
+ guidance_scale_text=7.5,
147
+ num_inference_steps=30,
148
+ strength=1,
149
+ interm_cloth_start_ratio=0.5,
150
+ generator=None,
151
+ ).images[0]
152
+
153
+ return result
154
+
155
+ except Exception as e:
156
+ print(f"❌ Ошибка генерации: {e}")
157
+ return None
158
+ finally:
159
+ torch.cuda.empty_cache()
160
+ gc.collect()
161
+
162
+ print("🔍 Инициализация моделей...")
163
+ models = load_models()
164
+ pipeline = models["pipeline"]
165
+
166
+ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 900px}") as demo:
167
  gr.Markdown("# 🧥 Virtual Try-On")
 
168
 
169
  with gr.Row():
170
  with gr.Column():
171
+ gr.Markdown("### Входные данные")
172
+ person_input = gr.Image(label="Фото человека", type="pil")
173
+ cloth_input = gr.Image(label="Фото одежды", type="pil")
174
+ outfit_prompt = gr.Textbox(label="Описание образа (необязательно)")
175
+ generate_btn = gr.Button("Сгенерировать", variant="primary")
176
+
 
 
 
 
 
 
 
 
177
  with gr.Column():
178
+ gr.Markdown("### Результат")
179
+ output_image = gr.Image(label="Результат примерки")
180
+ gr.Markdown("Подождите 1-2 минуты для генерации")
181
+
182
  generate_btn.click(
183
  fn=generate_vton,
184
+ inputs=[person_input, cloth_input, outfit_prompt],
185
  outputs=output_image
186
  )
 
 
 
 
 
 
187
 
188
  if __name__ == "__main__":
189
+ demo.queue(concurrency_count=1, max_size=2).launch(
190
  server_name="0.0.0.0" if os.getenv("SPACE_ID") else None,
191
+ share=os.getenv("GRADIO_SHARE") == "True"
 
192
  )
requirements.txt CHANGED
@@ -15,4 +15,5 @@ controlnet-aux==0.0.10
15
  accelerate==1.8.1
16
  mediapipe==0.10.21
17
  gradio==5.34.2
18
- huggingface-hub==0.33.0
 
 
15
  accelerate==1.8.1
16
  mediapipe==0.10.21
17
  gradio==5.34.2
18
+ huggingface-hub==0.33.0
19
+ python-dotenv==1.1.0