Fix
Browse files- app.py +26 -18
- lib/mask.py +1 -2
app.py
CHANGED
|
@@ -19,12 +19,6 @@ load_dotenv()
|
|
| 19 |
TOKEN = os.getenv("HF_TOKEN")
|
| 20 |
login(token=TOKEN)
|
| 21 |
|
| 22 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 23 |
-
torch.backends.cudnn.benchmark = True
|
| 24 |
-
torch.set_grad_enabled(False)
|
| 25 |
-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
| 26 |
-
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
|
| 27 |
-
|
| 28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
weight_dtype = torch.float16 if device == "cuda" else torch.float32
|
| 30 |
|
|
@@ -104,7 +98,7 @@ def load_models():
|
|
| 104 |
print(f"❌ Ошибка загрузки моделей: {e}")
|
| 105 |
raise
|
| 106 |
|
| 107 |
-
def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt=""):
|
| 108 |
"""Генерация виртуальной примерки с очисткой памяти"""
|
| 109 |
try:
|
| 110 |
torch.cuda.empty_cache()
|
|
@@ -116,7 +110,7 @@ def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt="
|
|
| 116 |
person_image.save(person_path)
|
| 117 |
cloth_image.save(cloth_path)
|
| 118 |
|
| 119 |
-
mask_image = generate_clothing_mask(person_path)
|
| 120 |
pose_image = generate_openpose(person_path)
|
| 121 |
|
| 122 |
final_outfit_prompt = outfit_prompt or generate_caption(person_path, device)
|
|
@@ -158,37 +152,51 @@ pipeline = models["pipeline"]
|
|
| 158 |
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container") as demo:
|
| 159 |
gr.Markdown("# 🧥 Virtual Try-On")
|
| 160 |
gr.Markdown("Загрузите фото человека и одежды для виртуальной примерки")
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
with gr.Row():
|
| 163 |
with gr.Column():
|
| 164 |
person_input = gr.Image(label="Фото человека", type="pil", sources=["upload"])
|
| 165 |
cloth_input = gr.Image(label="Фото одежды", type="pil", sources=["upload"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
outfit_prompt = gr.Textbox(label="Описание образа (опционально)", placeholder="Например: man in casual outfit")
|
| 167 |
clothing_prompt = gr.Textbox(label="Описание одежды (опционально)", placeholder="Например: red t-shirt with print")
|
|
|
|
| 168 |
generate_btn = gr.Button("Сгенерировать примерку", variant="primary")
|
| 169 |
-
|
| 170 |
gr.Examples(
|
| 171 |
examples=[
|
| 172 |
-
["./test/person2.png", "./test/00008_00.jpg", "man in skirt", "black longsleeve"]
|
| 173 |
],
|
| 174 |
-
inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt],
|
| 175 |
label="Примеры для быстрого тестирования"
|
| 176 |
)
|
| 177 |
-
|
| 178 |
with gr.Column():
|
| 179 |
output_image = gr.Image(label="Результат примерки", interactive=False)
|
| 180 |
-
|
| 181 |
generate_btn.click(
|
| 182 |
fn=generate_vton,
|
| 183 |
-
inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt],
|
| 184 |
outputs=output_image
|
| 185 |
)
|
| 186 |
-
|
| 187 |
gr.Markdown("### Инструкция:")
|
| 188 |
gr.Markdown("1. Загрузите четкое фото человека в полный рост\n"
|
| 189 |
"2. Загрузите фото одежды на белом фоне\n"
|
| 190 |
-
"3.
|
| 191 |
-
"4.
|
|
|
|
| 192 |
|
| 193 |
if __name__ == "__main__":
|
| 194 |
demo.queue(max_size=1).launch(
|
|
|
|
| 19 |
TOKEN = os.getenv("HF_TOKEN")
|
| 20 |
login(token=TOKEN)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 23 |
weight_dtype = torch.float16 if device == "cuda" else torch.float32
|
| 24 |
|
|
|
|
| 98 |
print(f"❌ Ошибка загрузки моделей: {e}")
|
| 99 |
raise
|
| 100 |
|
| 101 |
+
def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt="", label=7):
|
| 102 |
"""Генерация виртуальной примерки с очисткой памяти"""
|
| 103 |
try:
|
| 104 |
torch.cuda.empty_cache()
|
|
|
|
| 110 |
person_image.save(person_path)
|
| 111 |
cloth_image.save(cloth_path)
|
| 112 |
|
| 113 |
+
mask_image = generate_clothing_mask(person_path, label=label)
|
| 114 |
pose_image = generate_openpose(person_path)
|
| 115 |
|
| 116 |
final_outfit_prompt = outfit_prompt or generate_caption(person_path, device)
|
|
|
|
| 152 |
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container") as demo:
|
| 153 |
gr.Markdown("# 🧥 Virtual Try-On")
|
| 154 |
gr.Markdown("Загрузите фото человека и одежды для виртуальной примерки")
|
| 155 |
+
|
| 156 |
+
clothing_classes = [
|
| 157 |
+
"фон", "шляпа", "волосы", "сумка", "рюкзак", "верхняя одежда",
|
| 158 |
+
"рубашка", "футболка", "брюки", "короткие штаны", "платье",
|
| 159 |
+
"юбка", "носки", "обувь", "водолазка", "аксессуары", "руки", "ноги"
|
| 160 |
+
]
|
| 161 |
+
|
| 162 |
with gr.Row():
|
| 163 |
with gr.Column():
|
| 164 |
person_input = gr.Image(label="Фото человека", type="pil", sources=["upload"])
|
| 165 |
cloth_input = gr.Image(label="Фото одежды", type="pil", sources=["upload"])
|
| 166 |
+
clothing_label = gr.Dropdown(
|
| 167 |
+
choices=[(f"{i}: {desc}", i) for i, desc in enumerate(clothing_classes)],
|
| 168 |
+
label="Класс одежды для маски",
|
| 169 |
+
value=7
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
outfit_prompt = gr.Textbox(label="Описание образа (опционально)", placeholder="Например: man in casual outfit")
|
| 173 |
clothing_prompt = gr.Textbox(label="Описание одежды (опционально)", placeholder="Например: red t-shirt with print")
|
| 174 |
+
|
| 175 |
generate_btn = gr.Button("Сгенерировать примерку", variant="primary")
|
| 176 |
+
|
| 177 |
gr.Examples(
|
| 178 |
examples=[
|
| 179 |
+
["./test/person2.png", "./test/00008_00.jpg", "man in skirt", "black longsleeve", 7]
|
| 180 |
],
|
| 181 |
+
inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt, clothing_label],
|
| 182 |
label="Примеры для быстрого тестирования"
|
| 183 |
)
|
| 184 |
+
|
| 185 |
with gr.Column():
|
| 186 |
output_image = gr.Image(label="Результат примерки", interactive=False)
|
| 187 |
+
|
| 188 |
generate_btn.click(
|
| 189 |
fn=generate_vton,
|
| 190 |
+
inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt, clothing_label],
|
| 191 |
outputs=output_image
|
| 192 |
)
|
| 193 |
+
|
| 194 |
gr.Markdown("### Инструкция:")
|
| 195 |
gr.Markdown("1. Загрузите четкое фото человека в полный рост\n"
|
| 196 |
"2. Загрузите фото одежды на белом фоне\n"
|
| 197 |
+
"3. Выберите тип одежды из выпадающего списка\n"
|
| 198 |
+
"4. При необходимости уточните описание образа или одежды\n"
|
| 199 |
+
"5. Нажмите кнопку 'Сгенерировать примерку'")
|
| 200 |
|
| 201 |
if __name__ == "__main__":
|
| 202 |
demo.queue(max_size=1).launch(
|
lib/mask.py
CHANGED
|
@@ -9,9 +9,8 @@ import os
|
|
| 9 |
def generate_clothing_mask(
|
| 10 |
image_path: str,
|
| 11 |
label: int,
|
| 12 |
-
output_path: str = "./output_mask.png",
|
| 13 |
model_name: str = "mattmdjaga/segformer_b2_clothes",
|
| 14 |
-
show_result: bool = False
|
| 15 |
) -> Image.Image:
|
| 16 |
"""
|
| 17 |
Генерирует бинарную маску для указанного класса одежды и сохраняет её
|
|
|
|
| 9 |
def generate_clothing_mask(
|
| 10 |
image_path: str,
|
| 11 |
label: int,
|
| 12 |
+
output_path: str = "./test/output_mask.png",
|
| 13 |
model_name: str = "mattmdjaga/segformer_b2_clothes",
|
|
|
|
| 14 |
) -> Image.Image:
|
| 15 |
"""
|
| 16 |
Генерирует бинарную маску для указанного класса одежды и сохраняет её
|