Benrise commited on
Commit
2835d45
·
1 Parent(s): 2a52992
Files changed (2) hide show
  1. app.py +26 -18
  2. lib/mask.py +1 -2
app.py CHANGED
@@ -19,12 +19,6 @@ load_dotenv()
19
  TOKEN = os.getenv("HF_TOKEN")
20
  login(token=TOKEN)
21
 
22
- torch.backends.cuda.matmul.allow_tf32 = True
23
- torch.backends.cudnn.benchmark = True
24
- torch.set_grad_enabled(False)
25
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
26
- os.environ["CUDA_MODULE_LOADING"] = "LAZY"
27
-
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  weight_dtype = torch.float16 if device == "cuda" else torch.float32
30
 
@@ -104,7 +98,7 @@ def load_models():
104
  print(f"❌ Ошибка загрузки моделей: {e}")
105
  raise
106
 
107
- def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt=""):
108
  """Генерация виртуальной примерки с очисткой памяти"""
109
  try:
110
  torch.cuda.empty_cache()
@@ -116,7 +110,7 @@ def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt="
116
  person_image.save(person_path)
117
  cloth_image.save(cloth_path)
118
 
119
- mask_image = generate_clothing_mask(person_path)
120
  pose_image = generate_openpose(person_path)
121
 
122
  final_outfit_prompt = outfit_prompt or generate_caption(person_path, device)
@@ -158,37 +152,51 @@ pipeline = models["pipeline"]
158
  with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container") as demo:
159
  gr.Markdown("# 🧥 Virtual Try-On")
160
  gr.Markdown("Загрузите фото человека и одежды для виртуальной примерки")
161
-
 
 
 
 
 
 
162
  with gr.Row():
163
  with gr.Column():
164
  person_input = gr.Image(label="Фото человека", type="pil", sources=["upload"])
165
  cloth_input = gr.Image(label="Фото одежды", type="pil", sources=["upload"])
 
 
 
 
 
 
166
  outfit_prompt = gr.Textbox(label="Описание образа (опционально)", placeholder="Например: man in casual outfit")
167
  clothing_prompt = gr.Textbox(label="Описание одежды (опционально)", placeholder="Например: red t-shirt with print")
 
168
  generate_btn = gr.Button("Сгенерировать примерку", variant="primary")
169
-
170
  gr.Examples(
171
  examples=[
172
- ["./test/person2.png", "./test/00008_00.jpg", "man in skirt", "black longsleeve"]
173
  ],
174
- inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt],
175
  label="Примеры для быстрого тестирования"
176
  )
177
-
178
  with gr.Column():
179
  output_image = gr.Image(label="Результат примерки", interactive=False)
180
-
181
  generate_btn.click(
182
  fn=generate_vton,
183
- inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt],
184
  outputs=output_image
185
  )
186
-
187
  gr.Markdown("### Инструкция:")
188
  gr.Markdown("1. Загрузите четкое фото человека в полный рост\n"
189
  "2. Загрузите фото одежды на белом фоне\n"
190
- "3. При необходимости уточните описание образа или одежды\n"
191
- "4. Нажмите кнопку 'Сгенерировать примерку'")
 
192
 
193
  if __name__ == "__main__":
194
  demo.queue(max_size=1).launch(
 
19
  TOKEN = os.getenv("HF_TOKEN")
20
  login(token=TOKEN)
21
 
 
 
 
 
 
 
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  weight_dtype = torch.float16 if device == "cuda" else torch.float32
24
 
 
98
  print(f"❌ Ошибка загрузки моделей: {e}")
99
  raise
100
 
101
+ def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt="", label=7):
102
  """Генерация виртуальной примерки с очисткой памяти"""
103
  try:
104
  torch.cuda.empty_cache()
 
110
  person_image.save(person_path)
111
  cloth_image.save(cloth_path)
112
 
113
+ mask_image = generate_clothing_mask(person_path, label=label)
114
  pose_image = generate_openpose(person_path)
115
 
116
  final_outfit_prompt = outfit_prompt or generate_caption(person_path, device)
 
152
  with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container") as demo:
153
  gr.Markdown("# 🧥 Virtual Try-On")
154
  gr.Markdown("Загрузите фото человека и одежды для виртуальной примерки")
155
+
156
+ clothing_classes = [
157
+ "фон", "шляпа", "волосы", "сумка", "рюкзак", "верхняя одежда",
158
+ "рубашка", "футболка", "брюки", "короткие штаны", "платье",
159
+ "юбка", "носки", "обувь", "водолазка", "аксессуары", "руки", "ноги"
160
+ ]
161
+
162
  with gr.Row():
163
  with gr.Column():
164
  person_input = gr.Image(label="Фото человека", type="pil", sources=["upload"])
165
  cloth_input = gr.Image(label="Фото одежды", type="pil", sources=["upload"])
166
+ clothing_label = gr.Dropdown(
167
+ choices=[(f"{i}: {desc}", i) for i, desc in enumerate(clothing_classes)],
168
+ label="Класс одежды для маски",
169
+ value=7
170
+ )
171
+
172
  outfit_prompt = gr.Textbox(label="Описание образа (опционально)", placeholder="Например: man in casual outfit")
173
  clothing_prompt = gr.Textbox(label="Описание одежды (опционально)", placeholder="Например: red t-shirt with print")
174
+
175
  generate_btn = gr.Button("Сгенерировать примерку", variant="primary")
176
+
177
  gr.Examples(
178
  examples=[
179
+ ["./test/person2.png", "./test/00008_00.jpg", "man in skirt", "black longsleeve", 7]
180
  ],
181
+ inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt, clothing_label],
182
  label="Примеры для быстрого тестирования"
183
  )
184
+
185
  with gr.Column():
186
  output_image = gr.Image(label="Результат примерки", interactive=False)
187
+
188
  generate_btn.click(
189
  fn=generate_vton,
190
+ inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt, clothing_label],
191
  outputs=output_image
192
  )
193
+
194
  gr.Markdown("### Инструкция:")
195
  gr.Markdown("1. Загрузите четкое фото человека в полный рост\n"
196
  "2. Загрузите фото одежды на белом фоне\n"
197
+ "3. Выберите тип одежды из выпадающего списка\n"
198
+ "4. При необходимости уточните описание образа или одежды\n"
199
+ "5. Нажмите кнопку 'Сгенерировать примерку'")
200
 
201
  if __name__ == "__main__":
202
  demo.queue(max_size=1).launch(
lib/mask.py CHANGED
@@ -9,9 +9,8 @@ import os
9
  def generate_clothing_mask(
10
  image_path: str,
11
  label: int,
12
- output_path: str = "./output_mask.png",
13
  model_name: str = "mattmdjaga/segformer_b2_clothes",
14
- show_result: bool = False
15
  ) -> Image.Image:
16
  """
17
  Генерирует бинарную маску для указанного класса одежды и сохраняет её
 
9
  def generate_clothing_mask(
10
  image_path: str,
11
  label: int,
12
+ output_path: str = "./test/output_mask.png",
13
  model_name: str = "mattmdjaga/segformer_b2_clothes",
 
14
  ) -> Image.Image:
15
  """
16
  Генерирует бинарную маску для указанного класса одежды и сохраняет её