Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- CatVTON/.gitattributes +36 -0
- CatVTON/.gitignore +2 -0
- CatVTON/README.md +13 -0
- CatVTON/__pycache__/utils.cpython-39.pyc +0 -0
- CatVTON/app.py +778 -0
- CatVTON/densepose/__init__.py +22 -0
- CatVTON/densepose/__pycache__/__init__.cpython-39.pyc +0 -0
- CatVTON/densepose/__pycache__/config.cpython-39.pyc +0 -0
- CatVTON/densepose/config.py +277 -0
- CatVTON/densepose/converters/__init__.py +17 -0
- CatVTON/densepose/converters/__pycache__/__init__.cpython-39.pyc +0 -0
- CatVTON/densepose/converters/__pycache__/base.cpython-39.pyc +0 -0
- CatVTON/densepose/converters/__pycache__/builtin.cpython-39.pyc +0 -0
- CatVTON/densepose/converters/__pycache__/chart_output_hflip.cpython-39.pyc +0 -0
- CatVTON/densepose/converters/__pycache__/chart_output_to_chart_result.cpython-39.pyc +0 -0
- CatVTON/densepose/converters/__pycache__/hflip.cpython-39.pyc +0 -0
- CatVTON/densepose/converters/__pycache__/segm_to_mask.cpython-39.pyc +0 -0
- CatVTON/densepose/converters/__pycache__/to_chart_result.cpython-39.pyc +0 -0
- CatVTON/densepose/converters/__pycache__/to_mask.cpython-39.pyc +0 -0
- CatVTON/densepose/converters/base.py +95 -0
- CatVTON/densepose/converters/builtin.py +33 -0
- CatVTON/densepose/converters/chart_output_hflip.py +73 -0
- CatVTON/densepose/converters/chart_output_to_chart_result.py +190 -0
- CatVTON/densepose/converters/hflip.py +36 -0
- CatVTON/densepose/converters/segm_to_mask.py +152 -0
- CatVTON/densepose/converters/to_chart_result.py +72 -0
- CatVTON/densepose/converters/to_mask.py +51 -0
- CatVTON/densepose/engine/__init__.py +5 -0
- CatVTON/densepose/engine/trainer.py +260 -0
- CatVTON/densepose/modeling/__init__.py +15 -0
- CatVTON/densepose/modeling/build.py +89 -0
- CatVTON/densepose/modeling/confidence.py +75 -0
- CatVTON/densepose/modeling/densepose_checkpoint.py +37 -0
- CatVTON/densepose/modeling/filter.py +96 -0
- CatVTON/densepose/modeling/hrfpn.py +184 -0
- CatVTON/densepose/modeling/hrnet.py +476 -0
- CatVTON/densepose/modeling/inference.py +46 -0
- CatVTON/densepose/modeling/test_time_augmentation.py +209 -0
- CatVTON/densepose/modeling/utils.py +13 -0
- CatVTON/densepose/utils/__init__.py +0 -0
- CatVTON/densepose/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- CatVTON/densepose/utils/__pycache__/transform.cpython-39.pyc +0 -0
- CatVTON/densepose/utils/dbhelper.py +149 -0
- CatVTON/densepose/utils/logger.py +15 -0
- CatVTON/densepose/utils/transform.py +17 -0
- CatVTON/model/DensePose/__init__.py +158 -0
- CatVTON/model/DensePose/__pycache__/__init__.cpython-310.pyc +0 -0
- CatVTON/model/DensePose/__pycache__/__init__.cpython-312.pyc +0 -0
- CatVTON/model/DensePose/__pycache__/__init__.cpython-39.pyc +0 -0
- CatVTON/model/SCHP/__init__.py +179 -0
CatVTON/.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
detectron2/_C.cpython-39-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
CatVTON/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
playground.py
|
| 2 |
+
__pycache__
|
CatVTON/README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: CatVTON
|
| 3 |
+
emoji: 🐈
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.40.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: cc-by-nc-sa-4.0
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
CatVTON/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (20.3 kB). View file
|
|
|
CatVTON/app.py
ADDED
|
@@ -0,0 +1,778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
os.environ['CUDA_HOME'] = '/usr/local/cuda'
|
| 4 |
+
os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin'
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import spaces
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 12 |
+
from huggingface_hub import snapshot_download
|
| 13 |
+
from PIL import Image
|
| 14 |
+
torch.jit.script = lambda f: f
|
| 15 |
+
from model.cloth_masker import AutoMasker, vis_mask
|
| 16 |
+
from model.pipeline import CatVTONPipeline, CatVTONPix2PixPipeline
|
| 17 |
+
from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
|
| 18 |
+
from utils import init_weight_dtype, resize_and_crop, resize_and_padding
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def parse_args():
|
| 22 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 23 |
+
parser.add_argument(
|
| 24 |
+
"--base_model_path",
|
| 25 |
+
type=str,
|
| 26 |
+
default="booksforcharlie/stable-diffusion-inpainting",
|
| 27 |
+
help=(
|
| 28 |
+
"The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
|
| 29 |
+
),
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--p2p_base_model_path",
|
| 33 |
+
type=str,
|
| 34 |
+
default="timbrooks/instruct-pix2pix",
|
| 35 |
+
help=(
|
| 36 |
+
"The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
|
| 37 |
+
),
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--resume_path",
|
| 41 |
+
type=str,
|
| 42 |
+
default="zhengchong/CatVTON",
|
| 43 |
+
help=(
|
| 44 |
+
"The Path to the checkpoint of trained tryon model."
|
| 45 |
+
),
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--output_dir",
|
| 49 |
+
type=str,
|
| 50 |
+
default="resource/demo/output",
|
| 51 |
+
help="The output directory where the model predictions will be written.",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--width",
|
| 56 |
+
type=int,
|
| 57 |
+
default=768,
|
| 58 |
+
help=(
|
| 59 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 60 |
+
" resolution"
|
| 61 |
+
),
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--height",
|
| 65 |
+
type=int,
|
| 66 |
+
default=1024,
|
| 67 |
+
help=(
|
| 68 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 69 |
+
" resolution"
|
| 70 |
+
),
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--repaint",
|
| 74 |
+
action="store_true",
|
| 75 |
+
help="Whether to repaint the result image with the original background."
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--allow_tf32",
|
| 79 |
+
action="store_true",
|
| 80 |
+
default=True,
|
| 81 |
+
help=(
|
| 82 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 83 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 84 |
+
),
|
| 85 |
+
)
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--mixed_precision",
|
| 88 |
+
type=str,
|
| 89 |
+
default="bf16",
|
| 90 |
+
choices=["no", "fp16", "bf16"],
|
| 91 |
+
help=(
|
| 92 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 93 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 94 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 95 |
+
),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
args = parser.parse_args()
|
| 99 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 100 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 101 |
+
args.local_rank = env_local_rank
|
| 102 |
+
|
| 103 |
+
return args
|
| 104 |
+
|
| 105 |
+
def image_grid(imgs, rows, cols):
|
| 106 |
+
assert len(imgs) == rows * cols
|
| 107 |
+
|
| 108 |
+
w, h = imgs[0].size
|
| 109 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
| 110 |
+
|
| 111 |
+
for i, img in enumerate(imgs):
|
| 112 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
| 113 |
+
return grid
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
args = parse_args()
|
| 117 |
+
|
| 118 |
+
# Mask-based CatVTON
|
| 119 |
+
catvton_repo = "zhengchong/CatVTON"
|
| 120 |
+
repo_path = snapshot_download(repo_id=catvton_repo)
|
| 121 |
+
# Pipeline
|
| 122 |
+
pipeline = CatVTONPipeline(
|
| 123 |
+
base_ckpt=args.base_model_path,
|
| 124 |
+
attn_ckpt=repo_path,
|
| 125 |
+
attn_ckpt_version="mix",
|
| 126 |
+
weight_dtype=init_weight_dtype(args.mixed_precision),
|
| 127 |
+
use_tf32=args.allow_tf32,
|
| 128 |
+
device='cuda'
|
| 129 |
+
)
|
| 130 |
+
# AutoMasker
|
| 131 |
+
mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
|
| 132 |
+
automasker = AutoMasker(
|
| 133 |
+
densepose_ckpt=os.path.join(repo_path, "DensePose"),
|
| 134 |
+
schp_ckpt=os.path.join(repo_path, "SCHP"),
|
| 135 |
+
device='cuda',
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# Flux-based CatVTON
|
| 140 |
+
access_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
|
| 141 |
+
flux_repo = "black-forest-labs/FLUX.1-Fill-dev"
|
| 142 |
+
pipeline_flux = FluxTryOnPipeline.from_pretrained(flux_repo, use_auth_token=access_token)
|
| 143 |
+
pipeline_flux.load_lora_weights(
|
| 144 |
+
os.path.join(repo_path, "flux-lora"),
|
| 145 |
+
weight_name='pytorch_lora_weights.safetensors'
|
| 146 |
+
)
|
| 147 |
+
pipeline_flux.to("cuda", init_weight_dtype(args.mixed_precision))
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# Mask-free CatVTON
|
| 151 |
+
catvton_mf_repo = "zhengchong/CatVTON-MaskFree"
|
| 152 |
+
repo_path_mf = snapshot_download(repo_id=catvton_mf_repo, use_auth_token=access_token)
|
| 153 |
+
pipeline_p2p = CatVTONPix2PixPipeline(
|
| 154 |
+
base_ckpt=args.p2p_base_model_path,
|
| 155 |
+
attn_ckpt=repo_path_mf,
|
| 156 |
+
attn_ckpt_version="mix-48k-1024",
|
| 157 |
+
weight_dtype=init_weight_dtype(args.mixed_precision),
|
| 158 |
+
use_tf32=args.allow_tf32,
|
| 159 |
+
device='cuda'
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@spaces.GPU(duration=120)
|
| 164 |
+
def submit_function(
|
| 165 |
+
person_image,
|
| 166 |
+
cloth_image,
|
| 167 |
+
cloth_type,
|
| 168 |
+
num_inference_steps,
|
| 169 |
+
guidance_scale,
|
| 170 |
+
seed,
|
| 171 |
+
show_type
|
| 172 |
+
):
|
| 173 |
+
person_image, mask = person_image["background"], person_image["layers"][0]
|
| 174 |
+
mask = Image.open(mask).convert("L")
|
| 175 |
+
if len(np.unique(np.array(mask))) == 1:
|
| 176 |
+
mask = None
|
| 177 |
+
else:
|
| 178 |
+
mask = np.array(mask)
|
| 179 |
+
mask[mask > 0] = 255
|
| 180 |
+
mask = Image.fromarray(mask)
|
| 181 |
+
|
| 182 |
+
tmp_folder = args.output_dir
|
| 183 |
+
date_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
| 184 |
+
result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
|
| 185 |
+
if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
|
| 186 |
+
os.makedirs(os.path.join(tmp_folder, date_str[:8]))
|
| 187 |
+
|
| 188 |
+
generator = None
|
| 189 |
+
if seed != -1:
|
| 190 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
| 191 |
+
|
| 192 |
+
person_image = Image.open(person_image).convert("RGB")
|
| 193 |
+
cloth_image = Image.open(cloth_image).convert("RGB")
|
| 194 |
+
person_image = resize_and_crop(person_image, (args.width, args.height))
|
| 195 |
+
cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
|
| 196 |
+
|
| 197 |
+
# Process mask
|
| 198 |
+
if mask is not None:
|
| 199 |
+
mask = resize_and_crop(mask, (args.width, args.height))
|
| 200 |
+
else:
|
| 201 |
+
mask = automasker(
|
| 202 |
+
person_image,
|
| 203 |
+
cloth_type
|
| 204 |
+
)['mask']
|
| 205 |
+
mask = mask_processor.blur(mask, blur_factor=9)
|
| 206 |
+
|
| 207 |
+
# Inference
|
| 208 |
+
# try:
|
| 209 |
+
result_image = pipeline(
|
| 210 |
+
image=person_image,
|
| 211 |
+
condition_image=cloth_image,
|
| 212 |
+
mask=mask,
|
| 213 |
+
num_inference_steps=num_inference_steps,
|
| 214 |
+
guidance_scale=guidance_scale,
|
| 215 |
+
generator=generator
|
| 216 |
+
)[0]
|
| 217 |
+
# except Exception as e:
|
| 218 |
+
# raise gr.Error(
|
| 219 |
+
# "An error occurred. Please try again later: {}".format(e)
|
| 220 |
+
# )
|
| 221 |
+
|
| 222 |
+
# Post-process
|
| 223 |
+
masked_person = vis_mask(person_image, mask)
|
| 224 |
+
save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
|
| 225 |
+
save_result_image.save(result_save_path)
|
| 226 |
+
if show_type == "result only":
|
| 227 |
+
return result_image
|
| 228 |
+
else:
|
| 229 |
+
width, height = person_image.size
|
| 230 |
+
if show_type == "input & result":
|
| 231 |
+
condition_width = width // 2
|
| 232 |
+
conditions = image_grid([person_image, cloth_image], 2, 1)
|
| 233 |
+
else:
|
| 234 |
+
condition_width = width // 3
|
| 235 |
+
conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
|
| 236 |
+
conditions = conditions.resize((condition_width, height), Image.NEAREST)
|
| 237 |
+
new_result_image = Image.new("RGB", (width + condition_width + 5, height))
|
| 238 |
+
new_result_image.paste(conditions, (0, 0))
|
| 239 |
+
new_result_image.paste(result_image, (condition_width + 5, 0))
|
| 240 |
+
return new_result_image
|
| 241 |
+
|
| 242 |
+
@spaces.GPU(duration=120)
|
| 243 |
+
def submit_function_p2p(
|
| 244 |
+
person_image,
|
| 245 |
+
cloth_image,
|
| 246 |
+
num_inference_steps,
|
| 247 |
+
guidance_scale,
|
| 248 |
+
seed):
|
| 249 |
+
person_image= person_image["background"]
|
| 250 |
+
|
| 251 |
+
tmp_folder = args.output_dir
|
| 252 |
+
date_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
| 253 |
+
result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
|
| 254 |
+
if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
|
| 255 |
+
os.makedirs(os.path.join(tmp_folder, date_str[:8]))
|
| 256 |
+
|
| 257 |
+
generator = None
|
| 258 |
+
if seed != -1:
|
| 259 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
| 260 |
+
|
| 261 |
+
person_image = Image.open(person_image).convert("RGB")
|
| 262 |
+
cloth_image = Image.open(cloth_image).convert("RGB")
|
| 263 |
+
person_image = resize_and_crop(person_image, (args.width, args.height))
|
| 264 |
+
cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
|
| 265 |
+
|
| 266 |
+
# Inference
|
| 267 |
+
try:
|
| 268 |
+
result_image = pipeline_p2p(
|
| 269 |
+
image=person_image,
|
| 270 |
+
condition_image=cloth_image,
|
| 271 |
+
num_inference_steps=num_inference_steps,
|
| 272 |
+
guidance_scale=guidance_scale,
|
| 273 |
+
generator=generator
|
| 274 |
+
)[0]
|
| 275 |
+
except Exception as e:
|
| 276 |
+
raise gr.Error(
|
| 277 |
+
"An error occurred. Please try again later: {}".format(e)
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Post-process
|
| 281 |
+
save_result_image = image_grid([person_image, cloth_image, result_image], 1, 3)
|
| 282 |
+
save_result_image.save(result_save_path)
|
| 283 |
+
return result_image
|
| 284 |
+
|
| 285 |
+
@spaces.GPU(duration=120)
|
| 286 |
+
def submit_function_flux(
|
| 287 |
+
person_image,
|
| 288 |
+
cloth_image,
|
| 289 |
+
cloth_type,
|
| 290 |
+
num_inference_steps,
|
| 291 |
+
guidance_scale,
|
| 292 |
+
seed,
|
| 293 |
+
show_type
|
| 294 |
+
):
|
| 295 |
+
|
| 296 |
+
# Process image editor input
|
| 297 |
+
person_image, mask = person_image["background"], person_image["layers"][0]
|
| 298 |
+
mask = Image.open(mask).convert("L")
|
| 299 |
+
if len(np.unique(np.array(mask))) == 1:
|
| 300 |
+
mask = None
|
| 301 |
+
else:
|
| 302 |
+
mask = np.array(mask)
|
| 303 |
+
mask[mask > 0] = 255
|
| 304 |
+
mask = Image.fromarray(mask)
|
| 305 |
+
|
| 306 |
+
# Set random seed
|
| 307 |
+
generator = None
|
| 308 |
+
if seed != -1:
|
| 309 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
| 310 |
+
|
| 311 |
+
# Process input images
|
| 312 |
+
person_image = Image.open(person_image).convert("RGB")
|
| 313 |
+
cloth_image = Image.open(cloth_image).convert("RGB")
|
| 314 |
+
|
| 315 |
+
# Adjust image sizes
|
| 316 |
+
person_image = resize_and_crop(person_image, (args.width, args.height))
|
| 317 |
+
cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
|
| 318 |
+
|
| 319 |
+
# Process mask
|
| 320 |
+
if mask is not None:
|
| 321 |
+
mask = resize_and_crop(mask, (args.width, args.height))
|
| 322 |
+
else:
|
| 323 |
+
mask = automasker(
|
| 324 |
+
person_image,
|
| 325 |
+
cloth_type
|
| 326 |
+
)['mask']
|
| 327 |
+
mask = mask_processor.blur(mask, blur_factor=9)
|
| 328 |
+
|
| 329 |
+
# Inference
|
| 330 |
+
result_image = pipeline_flux(
|
| 331 |
+
image=person_image,
|
| 332 |
+
condition_image=cloth_image,
|
| 333 |
+
mask_image=mask,
|
| 334 |
+
width=args.width,
|
| 335 |
+
height=args.height,
|
| 336 |
+
num_inference_steps=num_inference_steps,
|
| 337 |
+
guidance_scale=guidance_scale,
|
| 338 |
+
generator=generator
|
| 339 |
+
).images[0]
|
| 340 |
+
|
| 341 |
+
# Post-processing
|
| 342 |
+
masked_person = vis_mask(person_image, mask)
|
| 343 |
+
|
| 344 |
+
# Return result based on show type
|
| 345 |
+
if show_type == "result only":
|
| 346 |
+
return result_image
|
| 347 |
+
else:
|
| 348 |
+
width, height = person_image.size
|
| 349 |
+
if show_type == "input & result":
|
| 350 |
+
condition_width = width // 2
|
| 351 |
+
conditions = image_grid([person_image, cloth_image], 2, 1)
|
| 352 |
+
else:
|
| 353 |
+
condition_width = width // 3
|
| 354 |
+
conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
|
| 355 |
+
|
| 356 |
+
conditions = conditions.resize((condition_width, height), Image.NEAREST)
|
| 357 |
+
new_result_image = Image.new("RGB", (width + condition_width + 5, height))
|
| 358 |
+
new_result_image.paste(conditions, (0, 0))
|
| 359 |
+
new_result_image.paste(result_image, (condition_width + 5, 0))
|
| 360 |
+
return new_result_image
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def person_example_fn(image_path):
|
| 364 |
+
return image_path
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
HEADER = """
|
| 368 |
+
<h1 style="text-align: center;"> 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
|
| 369 |
+
<div style="display: flex; justify-content: center; align-items: center;">
|
| 370 |
+
<a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
|
| 371 |
+
<img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
|
| 372 |
+
</a>
|
| 373 |
+
<a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
|
| 374 |
+
<img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
|
| 375 |
+
</a>
|
| 376 |
+
<a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
|
| 377 |
+
<img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
|
| 378 |
+
</a>
|
| 379 |
+
<a href="http://120.76.142.206:8888" style="margin: 0 2px;">
|
| 380 |
+
<img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
|
| 381 |
+
</a>
|
| 382 |
+
<a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
|
| 383 |
+
<img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
|
| 384 |
+
</a>
|
| 385 |
+
<a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
|
| 386 |
+
<img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
|
| 387 |
+
</a>
|
| 388 |
+
<a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
|
| 389 |
+
<img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
|
| 390 |
+
</a>
|
| 391 |
+
</div>
|
| 392 |
+
<br>
|
| 393 |
+
· This demo and our weights are only for Non-commercial Use. <br>
|
| 394 |
+
· Thanks to <a href="https://huggingface.co/zero-gpu-explorers">ZeroGPU</a> for providing A100 for our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a>. <br>
|
| 395 |
+
· SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
|
| 396 |
+
"""
|
| 397 |
+
|
| 398 |
+
def app_gradio():
|
| 399 |
+
with gr.Blocks(title="CatVTON") as demo:
|
| 400 |
+
gr.Markdown(HEADER)
|
| 401 |
+
with gr.Tab("Mask-based & SD1.5"):
|
| 402 |
+
with gr.Row():
|
| 403 |
+
with gr.Column(scale=1, min_width=350):
|
| 404 |
+
with gr.Row():
|
| 405 |
+
image_path = gr.Image(
|
| 406 |
+
type="filepath",
|
| 407 |
+
interactive=True,
|
| 408 |
+
visible=False,
|
| 409 |
+
)
|
| 410 |
+
person_image = gr.ImageEditor(
|
| 411 |
+
interactive=True, label="Person Image", type="filepath"
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
with gr.Row():
|
| 415 |
+
with gr.Column(scale=1, min_width=230):
|
| 416 |
+
cloth_image = gr.Image(
|
| 417 |
+
interactive=True, label="Condition Image", type="filepath"
|
| 418 |
+
)
|
| 419 |
+
with gr.Column(scale=1, min_width=120):
|
| 420 |
+
gr.Markdown(
|
| 421 |
+
'<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
|
| 422 |
+
)
|
| 423 |
+
cloth_type = gr.Radio(
|
| 424 |
+
label="Try-On Cloth Type",
|
| 425 |
+
choices=["upper", "lower", "overall"],
|
| 426 |
+
value="upper",
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
submit = gr.Button("Submit")
|
| 431 |
+
gr.Markdown(
|
| 432 |
+
'<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
gr.Markdown(
|
| 436 |
+
'<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
|
| 437 |
+
)
|
| 438 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 439 |
+
num_inference_steps = gr.Slider(
|
| 440 |
+
label="Inference Step", minimum=10, maximum=100, step=5, value=50
|
| 441 |
+
)
|
| 442 |
+
# Guidence Scale
|
| 443 |
+
guidance_scale = gr.Slider(
|
| 444 |
+
label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
|
| 445 |
+
)
|
| 446 |
+
# Random Seed
|
| 447 |
+
seed = gr.Slider(
|
| 448 |
+
label="Seed", minimum=-1, maximum=10000, step=1, value=42
|
| 449 |
+
)
|
| 450 |
+
show_type = gr.Radio(
|
| 451 |
+
label="Show Type",
|
| 452 |
+
choices=["result only", "input & result", "input & mask & result"],
|
| 453 |
+
value="input & mask & result",
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
with gr.Column(scale=2, min_width=500):
|
| 457 |
+
result_image = gr.Image(interactive=False, label="Result")
|
| 458 |
+
with gr.Row():
|
| 459 |
+
# Photo Examples
|
| 460 |
+
root_path = "resource/demo/example"
|
| 461 |
+
with gr.Column():
|
| 462 |
+
men_exm = gr.Examples(
|
| 463 |
+
examples=[
|
| 464 |
+
os.path.join(root_path, "person", "men", _)
|
| 465 |
+
for _ in os.listdir(os.path.join(root_path, "person", "men"))
|
| 466 |
+
],
|
| 467 |
+
examples_per_page=4,
|
| 468 |
+
inputs=image_path,
|
| 469 |
+
label="Person Examples ①",
|
| 470 |
+
)
|
| 471 |
+
women_exm = gr.Examples(
|
| 472 |
+
examples=[
|
| 473 |
+
os.path.join(root_path, "person", "women", _)
|
| 474 |
+
for _ in os.listdir(os.path.join(root_path, "person", "women"))
|
| 475 |
+
],
|
| 476 |
+
examples_per_page=4,
|
| 477 |
+
inputs=image_path,
|
| 478 |
+
label="Person Examples ②",
|
| 479 |
+
)
|
| 480 |
+
gr.Markdown(
|
| 481 |
+
'<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
|
| 482 |
+
)
|
| 483 |
+
with gr.Column():
|
| 484 |
+
condition_upper_exm = gr.Examples(
|
| 485 |
+
examples=[
|
| 486 |
+
os.path.join(root_path, "condition", "upper", _)
|
| 487 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
|
| 488 |
+
],
|
| 489 |
+
examples_per_page=4,
|
| 490 |
+
inputs=cloth_image,
|
| 491 |
+
label="Condition Upper Examples",
|
| 492 |
+
)
|
| 493 |
+
condition_overall_exm = gr.Examples(
|
| 494 |
+
examples=[
|
| 495 |
+
os.path.join(root_path, "condition", "overall", _)
|
| 496 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
|
| 497 |
+
],
|
| 498 |
+
examples_per_page=4,
|
| 499 |
+
inputs=cloth_image,
|
| 500 |
+
label="Condition Overall Examples",
|
| 501 |
+
)
|
| 502 |
+
condition_person_exm = gr.Examples(
|
| 503 |
+
examples=[
|
| 504 |
+
os.path.join(root_path, "condition", "person", _)
|
| 505 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "person"))
|
| 506 |
+
],
|
| 507 |
+
examples_per_page=4,
|
| 508 |
+
inputs=cloth_image,
|
| 509 |
+
label="Condition Reference Person Examples",
|
| 510 |
+
)
|
| 511 |
+
gr.Markdown(
|
| 512 |
+
'<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
image_path.change(
|
| 516 |
+
person_example_fn, inputs=image_path, outputs=person_image
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
submit.click(
|
| 520 |
+
submit_function,
|
| 521 |
+
[
|
| 522 |
+
person_image,
|
| 523 |
+
cloth_image,
|
| 524 |
+
cloth_type,
|
| 525 |
+
num_inference_steps,
|
| 526 |
+
guidance_scale,
|
| 527 |
+
seed,
|
| 528 |
+
show_type,
|
| 529 |
+
],
|
| 530 |
+
result_image,
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
with gr.Tab("Mask-based & Flux.1 Fill Dev"):
|
| 534 |
+
with gr.Row():
|
| 535 |
+
with gr.Column(scale=1, min_width=350):
|
| 536 |
+
with gr.Row():
|
| 537 |
+
image_path_flux = gr.Image(
|
| 538 |
+
type="filepath",
|
| 539 |
+
interactive=True,
|
| 540 |
+
visible=False,
|
| 541 |
+
)
|
| 542 |
+
person_image_flux = gr.ImageEditor(
|
| 543 |
+
interactive=True, label="Person Image", type="filepath"
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
with gr.Row():
|
| 547 |
+
with gr.Column(scale=1, min_width=230):
|
| 548 |
+
cloth_image_flux = gr.Image(
|
| 549 |
+
interactive=True, label="Condition Image", type="filepath"
|
| 550 |
+
)
|
| 551 |
+
with gr.Column(scale=1, min_width=120):
|
| 552 |
+
gr.Markdown(
|
| 553 |
+
'<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
|
| 554 |
+
)
|
| 555 |
+
cloth_type = gr.Radio(
|
| 556 |
+
label="Try-On Cloth Type",
|
| 557 |
+
choices=["upper", "lower", "overall"],
|
| 558 |
+
value="upper",
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
submit_flux = gr.Button("Submit")
|
| 562 |
+
gr.Markdown(
|
| 563 |
+
'<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 567 |
+
num_inference_steps_flux = gr.Slider(
|
| 568 |
+
label="Inference Step", minimum=10, maximum=100, step=5, value=50
|
| 569 |
+
)
|
| 570 |
+
# Guidence Scale
|
| 571 |
+
guidance_scale_flux = gr.Slider(
|
| 572 |
+
label="CFG Strenth", minimum=0.0, maximum=50, step=0.5, value=30
|
| 573 |
+
)
|
| 574 |
+
# Random Seed
|
| 575 |
+
seed_flux = gr.Slider(
|
| 576 |
+
label="Seed", minimum=-1, maximum=10000, step=1, value=42
|
| 577 |
+
)
|
| 578 |
+
show_type = gr.Radio(
|
| 579 |
+
label="Show Type",
|
| 580 |
+
choices=["result only", "input & result", "input & mask & result"],
|
| 581 |
+
value="input & mask & result",
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
with gr.Column(scale=2, min_width=500):
|
| 585 |
+
result_image_flux = gr.Image(interactive=False, label="Result")
|
| 586 |
+
with gr.Row():
|
| 587 |
+
# Photo Examples
|
| 588 |
+
root_path = "resource/demo/example"
|
| 589 |
+
with gr.Column():
|
| 590 |
+
gr.Examples(
|
| 591 |
+
examples=[
|
| 592 |
+
os.path.join(root_path, "person", "men", _)
|
| 593 |
+
for _ in os.listdir(os.path.join(root_path, "person", "men"))
|
| 594 |
+
],
|
| 595 |
+
examples_per_page=4,
|
| 596 |
+
inputs=image_path_flux,
|
| 597 |
+
label="Person Examples ①",
|
| 598 |
+
)
|
| 599 |
+
gr.Examples(
|
| 600 |
+
examples=[
|
| 601 |
+
os.path.join(root_path, "person", "women", _)
|
| 602 |
+
for _ in os.listdir(os.path.join(root_path, "person", "women"))
|
| 603 |
+
],
|
| 604 |
+
examples_per_page=4,
|
| 605 |
+
inputs=image_path_flux,
|
| 606 |
+
label="Person Examples ②",
|
| 607 |
+
)
|
| 608 |
+
gr.Markdown(
|
| 609 |
+
'<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
|
| 610 |
+
)
|
| 611 |
+
with gr.Column():
|
| 612 |
+
gr.Examples(
|
| 613 |
+
examples=[
|
| 614 |
+
os.path.join(root_path, "condition", "upper", _)
|
| 615 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
|
| 616 |
+
],
|
| 617 |
+
examples_per_page=4,
|
| 618 |
+
inputs=cloth_image_flux,
|
| 619 |
+
label="Condition Upper Examples",
|
| 620 |
+
)
|
| 621 |
+
gr.Examples(
|
| 622 |
+
examples=[
|
| 623 |
+
os.path.join(root_path, "condition", "overall", _)
|
| 624 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
|
| 625 |
+
],
|
| 626 |
+
examples_per_page=4,
|
| 627 |
+
inputs=cloth_image_flux,
|
| 628 |
+
label="Condition Overall Examples",
|
| 629 |
+
)
|
| 630 |
+
condition_person_exm = gr.Examples(
|
| 631 |
+
examples=[
|
| 632 |
+
os.path.join(root_path, "condition", "person", _)
|
| 633 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "person"))
|
| 634 |
+
],
|
| 635 |
+
examples_per_page=4,
|
| 636 |
+
inputs=cloth_image_flux,
|
| 637 |
+
label="Condition Reference Person Examples",
|
| 638 |
+
)
|
| 639 |
+
gr.Markdown(
|
| 640 |
+
'<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
image_path_flux.change(
|
| 645 |
+
person_example_fn, inputs=image_path_flux, outputs=person_image_flux
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
submit_flux.click(
|
| 649 |
+
submit_function_flux,
|
| 650 |
+
[person_image_flux, cloth_image_flux, cloth_type, num_inference_steps_flux, guidance_scale_flux, seed_flux, show_type],
|
| 651 |
+
result_image_flux,
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
with gr.Tab("Mask-free & SD1.5"):
|
| 656 |
+
with gr.Row():
|
| 657 |
+
with gr.Column(scale=1, min_width=350):
|
| 658 |
+
with gr.Row():
|
| 659 |
+
image_path_p2p = gr.Image(
|
| 660 |
+
type="filepath",
|
| 661 |
+
interactive=True,
|
| 662 |
+
visible=False,
|
| 663 |
+
)
|
| 664 |
+
person_image_p2p = gr.ImageEditor(
|
| 665 |
+
interactive=True, label="Person Image", type="filepath"
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
with gr.Row():
|
| 669 |
+
with gr.Column(scale=1, min_width=230):
|
| 670 |
+
cloth_image_p2p = gr.Image(
|
| 671 |
+
interactive=True, label="Condition Image", type="filepath"
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
submit_p2p = gr.Button("Submit")
|
| 675 |
+
gr.Markdown(
|
| 676 |
+
'<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
gr.Markdown(
|
| 680 |
+
'<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
|
| 681 |
+
)
|
| 682 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 683 |
+
num_inference_steps_p2p = gr.Slider(
|
| 684 |
+
label="Inference Step", minimum=10, maximum=100, step=5, value=50
|
| 685 |
+
)
|
| 686 |
+
# Guidence Scale
|
| 687 |
+
guidance_scale_p2p = gr.Slider(
|
| 688 |
+
label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
|
| 689 |
+
)
|
| 690 |
+
# Random Seed
|
| 691 |
+
seed_p2p = gr.Slider(
|
| 692 |
+
label="Seed", minimum=-1, maximum=10000, step=1, value=42
|
| 693 |
+
)
|
| 694 |
+
# show_type = gr.Radio(
|
| 695 |
+
# label="Show Type",
|
| 696 |
+
# choices=["result only", "input & result", "input & mask & result"],
|
| 697 |
+
# value="input & mask & result",
|
| 698 |
+
# )
|
| 699 |
+
|
| 700 |
+
with gr.Column(scale=2, min_width=500):
|
| 701 |
+
result_image_p2p = gr.Image(interactive=False, label="Result")
|
| 702 |
+
with gr.Row():
|
| 703 |
+
# Photo Examples
|
| 704 |
+
root_path = "resource/demo/example"
|
| 705 |
+
with gr.Column():
|
| 706 |
+
gr.Examples(
|
| 707 |
+
examples=[
|
| 708 |
+
os.path.join(root_path, "person", "men", _)
|
| 709 |
+
for _ in os.listdir(os.path.join(root_path, "person", "men"))
|
| 710 |
+
],
|
| 711 |
+
examples_per_page=4,
|
| 712 |
+
inputs=image_path_p2p,
|
| 713 |
+
label="Person Examples ①",
|
| 714 |
+
)
|
| 715 |
+
gr.Examples(
|
| 716 |
+
examples=[
|
| 717 |
+
os.path.join(root_path, "person", "women", _)
|
| 718 |
+
for _ in os.listdir(os.path.join(root_path, "person", "women"))
|
| 719 |
+
],
|
| 720 |
+
examples_per_page=4,
|
| 721 |
+
inputs=image_path_p2p,
|
| 722 |
+
label="Person Examples ②",
|
| 723 |
+
)
|
| 724 |
+
gr.Markdown(
|
| 725 |
+
'<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
|
| 726 |
+
)
|
| 727 |
+
with gr.Column():
|
| 728 |
+
gr.Examples(
|
| 729 |
+
examples=[
|
| 730 |
+
os.path.join(root_path, "condition", "upper", _)
|
| 731 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
|
| 732 |
+
],
|
| 733 |
+
examples_per_page=4,
|
| 734 |
+
inputs=cloth_image_p2p,
|
| 735 |
+
label="Condition Upper Examples",
|
| 736 |
+
)
|
| 737 |
+
gr.Examples(
|
| 738 |
+
examples=[
|
| 739 |
+
os.path.join(root_path, "condition", "overall", _)
|
| 740 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
|
| 741 |
+
],
|
| 742 |
+
examples_per_page=4,
|
| 743 |
+
inputs=cloth_image_p2p,
|
| 744 |
+
label="Condition Overall Examples",
|
| 745 |
+
)
|
| 746 |
+
condition_person_exm = gr.Examples(
|
| 747 |
+
examples=[
|
| 748 |
+
os.path.join(root_path, "condition", "person", _)
|
| 749 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "person"))
|
| 750 |
+
],
|
| 751 |
+
examples_per_page=4,
|
| 752 |
+
inputs=cloth_image_p2p,
|
| 753 |
+
label="Condition Reference Person Examples",
|
| 754 |
+
)
|
| 755 |
+
gr.Markdown(
|
| 756 |
+
'<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
image_path_p2p.change(
|
| 760 |
+
person_example_fn, inputs=image_path_p2p, outputs=person_image_p2p
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
submit_p2p.click(
|
| 764 |
+
submit_function_p2p,
|
| 765 |
+
[
|
| 766 |
+
person_image_p2p,
|
| 767 |
+
cloth_image_p2p,
|
| 768 |
+
num_inference_steps_p2p,
|
| 769 |
+
guidance_scale_p2p,
|
| 770 |
+
seed_p2p],
|
| 771 |
+
result_image_p2p,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
demo.queue().launch(share=True, show_error=True)
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
if __name__ == "__main__":
|
| 778 |
+
app_gradio()
|
CatVTON/densepose/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
from .data.datasets import builtin # just to register data
|
| 5 |
+
from .converters import builtin as builtin_converters # register converters
|
| 6 |
+
from .config import (
|
| 7 |
+
add_densepose_config,
|
| 8 |
+
add_densepose_head_config,
|
| 9 |
+
add_hrnet_config,
|
| 10 |
+
add_dataset_category_config,
|
| 11 |
+
add_bootstrap_config,
|
| 12 |
+
load_bootstrap_config,
|
| 13 |
+
)
|
| 14 |
+
from .structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData
|
| 15 |
+
from .evaluation import DensePoseCOCOEvaluator
|
| 16 |
+
from .modeling.roi_heads import DensePoseROIHeads
|
| 17 |
+
from .modeling.test_time_augmentation import (
|
| 18 |
+
DensePoseGeneralizedRCNNWithTTA,
|
| 19 |
+
DensePoseDatasetMapperTTA,
|
| 20 |
+
)
|
| 21 |
+
from .utils.transform import load_from_cfg
|
| 22 |
+
from .modeling.hrfpn import build_hrfpn_backbone
|
CatVTON/densepose/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (925 Bytes). View file
|
|
|
CatVTON/densepose/__pycache__/config.cpython-39.pyc
ADDED
|
Binary file (5.82 kB). View file
|
|
|
CatVTON/densepose/config.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding = utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
# pyre-ignore-all-errors
|
| 4 |
+
|
| 5 |
+
from detectron2.config import CfgNode as CN
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def add_dataset_category_config(cfg: CN) -> None:
|
| 9 |
+
"""
|
| 10 |
+
Add config for additional category-related dataset options
|
| 11 |
+
- category whitelisting
|
| 12 |
+
- category mapping
|
| 13 |
+
"""
|
| 14 |
+
_C = cfg
|
| 15 |
+
_C.DATASETS.CATEGORY_MAPS = CN(new_allowed=True)
|
| 16 |
+
_C.DATASETS.WHITELISTED_CATEGORIES = CN(new_allowed=True)
|
| 17 |
+
# class to mesh mapping
|
| 18 |
+
_C.DATASETS.CLASS_TO_MESH_NAME_MAPPING = CN(new_allowed=True)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def add_evaluation_config(cfg: CN) -> None:
|
| 22 |
+
_C = cfg
|
| 23 |
+
_C.DENSEPOSE_EVALUATION = CN()
|
| 24 |
+
# evaluator type, possible values:
|
| 25 |
+
# - "iou": evaluator for models that produce iou data
|
| 26 |
+
# - "cse": evaluator for models that produce cse data
|
| 27 |
+
_C.DENSEPOSE_EVALUATION.TYPE = "iou"
|
| 28 |
+
# storage for DensePose results, possible values:
|
| 29 |
+
# - "none": no explicit storage, all the results are stored in the
|
| 30 |
+
# dictionary with predictions, memory intensive;
|
| 31 |
+
# historically the default storage type
|
| 32 |
+
# - "ram": RAM storage, uses per-process RAM storage, which is
|
| 33 |
+
# reduced to a single process storage on later stages,
|
| 34 |
+
# less memory intensive
|
| 35 |
+
# - "file": file storage, uses per-process file-based storage,
|
| 36 |
+
# the least memory intensive, but may create bottlenecks
|
| 37 |
+
# on file system accesses
|
| 38 |
+
_C.DENSEPOSE_EVALUATION.STORAGE = "none"
|
| 39 |
+
# minimum threshold for IOU values: the lower its values is,
|
| 40 |
+
# the more matches are produced (and the higher the AP score)
|
| 41 |
+
_C.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD = 0.5
|
| 42 |
+
# Non-distributed inference is slower (at inference time) but can avoid RAM OOM
|
| 43 |
+
_C.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE = True
|
| 44 |
+
# evaluate mesh alignment based on vertex embeddings, only makes sense in CSE context
|
| 45 |
+
_C.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT = False
|
| 46 |
+
# meshes to compute mesh alignment for
|
| 47 |
+
_C.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES = []
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def add_bootstrap_config(cfg: CN) -> None:
|
| 51 |
+
""" """
|
| 52 |
+
_C = cfg
|
| 53 |
+
_C.BOOTSTRAP_DATASETS = []
|
| 54 |
+
_C.BOOTSTRAP_MODEL = CN()
|
| 55 |
+
_C.BOOTSTRAP_MODEL.WEIGHTS = ""
|
| 56 |
+
_C.BOOTSTRAP_MODEL.DEVICE = "cuda"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_bootstrap_dataset_config() -> CN:
|
| 60 |
+
_C = CN()
|
| 61 |
+
_C.DATASET = ""
|
| 62 |
+
# ratio used to mix data loaders
|
| 63 |
+
_C.RATIO = 0.1
|
| 64 |
+
# image loader
|
| 65 |
+
_C.IMAGE_LOADER = CN(new_allowed=True)
|
| 66 |
+
_C.IMAGE_LOADER.TYPE = ""
|
| 67 |
+
_C.IMAGE_LOADER.BATCH_SIZE = 4
|
| 68 |
+
_C.IMAGE_LOADER.NUM_WORKERS = 4
|
| 69 |
+
_C.IMAGE_LOADER.CATEGORIES = []
|
| 70 |
+
_C.IMAGE_LOADER.MAX_COUNT_PER_CATEGORY = 1_000_000
|
| 71 |
+
_C.IMAGE_LOADER.CATEGORY_TO_CLASS_MAPPING = CN(new_allowed=True)
|
| 72 |
+
# inference
|
| 73 |
+
_C.INFERENCE = CN()
|
| 74 |
+
# batch size for model inputs
|
| 75 |
+
_C.INFERENCE.INPUT_BATCH_SIZE = 4
|
| 76 |
+
# batch size to group model outputs
|
| 77 |
+
_C.INFERENCE.OUTPUT_BATCH_SIZE = 2
|
| 78 |
+
# sampled data
|
| 79 |
+
_C.DATA_SAMPLER = CN(new_allowed=True)
|
| 80 |
+
_C.DATA_SAMPLER.TYPE = ""
|
| 81 |
+
_C.DATA_SAMPLER.USE_GROUND_TRUTH_CATEGORIES = False
|
| 82 |
+
# filter
|
| 83 |
+
_C.FILTER = CN(new_allowed=True)
|
| 84 |
+
_C.FILTER.TYPE = ""
|
| 85 |
+
return _C
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def load_bootstrap_config(cfg: CN) -> None:
|
| 89 |
+
"""
|
| 90 |
+
Bootstrap datasets are given as a list of `dict` that are not automatically
|
| 91 |
+
converted into CfgNode. This method processes all bootstrap dataset entries
|
| 92 |
+
and ensures that they are in CfgNode format and comply with the specification
|
| 93 |
+
"""
|
| 94 |
+
if not cfg.BOOTSTRAP_DATASETS:
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
bootstrap_datasets_cfgnodes = []
|
| 98 |
+
for dataset_cfg in cfg.BOOTSTRAP_DATASETS:
|
| 99 |
+
_C = get_bootstrap_dataset_config().clone()
|
| 100 |
+
_C.merge_from_other_cfg(CN(dataset_cfg))
|
| 101 |
+
bootstrap_datasets_cfgnodes.append(_C)
|
| 102 |
+
cfg.BOOTSTRAP_DATASETS = bootstrap_datasets_cfgnodes
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def add_densepose_head_cse_config(cfg: CN) -> None:
|
| 106 |
+
"""
|
| 107 |
+
Add configuration options for Continuous Surface Embeddings (CSE)
|
| 108 |
+
"""
|
| 109 |
+
_C = cfg
|
| 110 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE = CN()
|
| 111 |
+
# Dimensionality D of the embedding space
|
| 112 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE = 16
|
| 113 |
+
# Embedder specifications for various mesh IDs
|
| 114 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS = CN(new_allowed=True)
|
| 115 |
+
# normalization coefficient for embedding distances
|
| 116 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA = 0.01
|
| 117 |
+
# normalization coefficient for geodesic distances
|
| 118 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.GEODESIC_DIST_GAUSS_SIGMA = 0.01
|
| 119 |
+
# embedding loss weight
|
| 120 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT = 0.6
|
| 121 |
+
# embedding loss name, currently the following options are supported:
|
| 122 |
+
# - EmbeddingLoss: cross-entropy on vertex labels
|
| 123 |
+
# - SoftEmbeddingLoss: cross-entropy on vertex label combined with
|
| 124 |
+
# Gaussian penalty on distance between vertices
|
| 125 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME = "EmbeddingLoss"
|
| 126 |
+
# optimizer hyperparameters
|
| 127 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR = 1.0
|
| 128 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR = 1.0
|
| 129 |
+
# Shape to shape cycle consistency loss parameters:
|
| 130 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
|
| 131 |
+
# shape to shape cycle consistency loss weight
|
| 132 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.025
|
| 133 |
+
# norm type used for loss computation
|
| 134 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
|
| 135 |
+
# normalization term for embedding similarity matrices
|
| 136 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.TEMPERATURE = 0.05
|
| 137 |
+
# maximum number of vertices to include into shape to shape cycle loss
|
| 138 |
+
# if negative or zero, all vertices are considered
|
| 139 |
+
# if positive, random subset of vertices of given size is considered
|
| 140 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.MAX_NUM_VERTICES = 4936
|
| 141 |
+
# Pixel to shape cycle consistency loss parameters:
|
| 142 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
|
| 143 |
+
# pixel to shape cycle consistency loss weight
|
| 144 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.0001
|
| 145 |
+
# norm type used for loss computation
|
| 146 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
|
| 147 |
+
# map images to all meshes and back (if false, use only gt meshes from the batch)
|
| 148 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.USE_ALL_MESHES_NOT_GT_ONLY = False
|
| 149 |
+
# Randomly select at most this number of pixels from every instance
|
| 150 |
+
# if negative or zero, all vertices are considered
|
| 151 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NUM_PIXELS_TO_SAMPLE = 100
|
| 152 |
+
# normalization factor for pixel to pixel distances (higher value = smoother distribution)
|
| 153 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.PIXEL_SIGMA = 5.0
|
| 154 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_PIXEL_TO_VERTEX = 0.05
|
| 155 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_VERTEX_TO_PIXEL = 0.05
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def add_densepose_head_config(cfg: CN) -> None:
|
| 159 |
+
"""
|
| 160 |
+
Add config for densepose head.
|
| 161 |
+
"""
|
| 162 |
+
_C = cfg
|
| 163 |
+
|
| 164 |
+
_C.MODEL.DENSEPOSE_ON = True
|
| 165 |
+
|
| 166 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD = CN()
|
| 167 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.NAME = ""
|
| 168 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS = 8
|
| 169 |
+
# Number of parts used for point labels
|
| 170 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES = 24
|
| 171 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL = 4
|
| 172 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM = 512
|
| 173 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL = 3
|
| 174 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE = 2
|
| 175 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE = 112
|
| 176 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_TYPE = "ROIAlignV2"
|
| 177 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_RESOLUTION = 28
|
| 178 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_SAMPLING_RATIO = 2
|
| 179 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS = 2 # 15 or 2
|
| 180 |
+
# Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD)
|
| 181 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD = 0.7
|
| 182 |
+
# Loss weights for annotation masks.(14 Parts)
|
| 183 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS = 5.0
|
| 184 |
+
# Loss weights for surface parts. (24 Parts)
|
| 185 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.PART_WEIGHTS = 1.0
|
| 186 |
+
# Loss weights for UV regression.
|
| 187 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.POINT_REGRESSION_WEIGHTS = 0.01
|
| 188 |
+
# Coarse segmentation is trained using instance segmentation task data
|
| 189 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS = False
|
| 190 |
+
# For Decoder
|
| 191 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_ON = True
|
| 192 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NUM_CLASSES = 256
|
| 193 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_CONV_DIMS = 256
|
| 194 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NORM = ""
|
| 195 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_COMMON_STRIDE = 4
|
| 196 |
+
# For DeepLab head
|
| 197 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB = CN()
|
| 198 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM = "GN"
|
| 199 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON = 0
|
| 200 |
+
# Predictor class name, must be registered in DENSEPOSE_PREDICTOR_REGISTRY
|
| 201 |
+
# Some registered predictors:
|
| 202 |
+
# "DensePoseChartPredictor": predicts segmentation and UV coordinates for predefined charts
|
| 203 |
+
# "DensePoseChartWithConfidencePredictor": predicts segmentation, UV coordinates
|
| 204 |
+
# and associated confidences for predefined charts (default)
|
| 205 |
+
# "DensePoseEmbeddingWithConfidencePredictor": predicts segmentation, embeddings
|
| 206 |
+
# and associated confidences for CSE
|
| 207 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME = "DensePoseChartWithConfidencePredictor"
|
| 208 |
+
# Loss class name, must be registered in DENSEPOSE_LOSS_REGISTRY
|
| 209 |
+
# Some registered losses:
|
| 210 |
+
# "DensePoseChartLoss": loss for chart-based models that estimate
|
| 211 |
+
# segmentation and UV coordinates
|
| 212 |
+
# "DensePoseChartWithConfidenceLoss": loss for chart-based models that estimate
|
| 213 |
+
# segmentation, UV coordinates and the corresponding confidences (default)
|
| 214 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME = "DensePoseChartWithConfidenceLoss"
|
| 215 |
+
# Confidences
|
| 216 |
+
# Enable learning UV confidences (variances) along with the actual values
|
| 217 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE = CN({"ENABLED": False})
|
| 218 |
+
# UV confidence lower bound
|
| 219 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.EPSILON = 0.01
|
| 220 |
+
# Enable learning segmentation confidences (variances) along with the actual values
|
| 221 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE = CN({"ENABLED": False})
|
| 222 |
+
# Segmentation confidence lower bound
|
| 223 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE.EPSILON = 0.01
|
| 224 |
+
# Statistical model type for confidence learning, possible values:
|
| 225 |
+
# - "iid_iso": statistically independent identically distributed residuals
|
| 226 |
+
# with isotropic covariance
|
| 227 |
+
# - "indep_aniso": statistically independent residuals with anisotropic
|
| 228 |
+
# covariances
|
| 229 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.TYPE = "iid_iso"
|
| 230 |
+
# List of angles for rotation in data augmentation during training
|
| 231 |
+
_C.INPUT.ROTATION_ANGLES = [0]
|
| 232 |
+
_C.TEST.AUG.ROTATION_ANGLES = () # Rotation TTA
|
| 233 |
+
|
| 234 |
+
add_densepose_head_cse_config(cfg)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def add_hrnet_config(cfg: CN) -> None:
|
| 238 |
+
"""
|
| 239 |
+
Add config for HRNet backbone.
|
| 240 |
+
"""
|
| 241 |
+
_C = cfg
|
| 242 |
+
|
| 243 |
+
# For HigherHRNet w32
|
| 244 |
+
_C.MODEL.HRNET = CN()
|
| 245 |
+
_C.MODEL.HRNET.STEM_INPLANES = 64
|
| 246 |
+
_C.MODEL.HRNET.STAGE2 = CN()
|
| 247 |
+
_C.MODEL.HRNET.STAGE2.NUM_MODULES = 1
|
| 248 |
+
_C.MODEL.HRNET.STAGE2.NUM_BRANCHES = 2
|
| 249 |
+
_C.MODEL.HRNET.STAGE2.BLOCK = "BASIC"
|
| 250 |
+
_C.MODEL.HRNET.STAGE2.NUM_BLOCKS = [4, 4]
|
| 251 |
+
_C.MODEL.HRNET.STAGE2.NUM_CHANNELS = [32, 64]
|
| 252 |
+
_C.MODEL.HRNET.STAGE2.FUSE_METHOD = "SUM"
|
| 253 |
+
_C.MODEL.HRNET.STAGE3 = CN()
|
| 254 |
+
_C.MODEL.HRNET.STAGE3.NUM_MODULES = 4
|
| 255 |
+
_C.MODEL.HRNET.STAGE3.NUM_BRANCHES = 3
|
| 256 |
+
_C.MODEL.HRNET.STAGE3.BLOCK = "BASIC"
|
| 257 |
+
_C.MODEL.HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4]
|
| 258 |
+
_C.MODEL.HRNET.STAGE3.NUM_CHANNELS = [32, 64, 128]
|
| 259 |
+
_C.MODEL.HRNET.STAGE3.FUSE_METHOD = "SUM"
|
| 260 |
+
_C.MODEL.HRNET.STAGE4 = CN()
|
| 261 |
+
_C.MODEL.HRNET.STAGE4.NUM_MODULES = 3
|
| 262 |
+
_C.MODEL.HRNET.STAGE4.NUM_BRANCHES = 4
|
| 263 |
+
_C.MODEL.HRNET.STAGE4.BLOCK = "BASIC"
|
| 264 |
+
_C.MODEL.HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
|
| 265 |
+
_C.MODEL.HRNET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
|
| 266 |
+
_C.MODEL.HRNET.STAGE4.FUSE_METHOD = "SUM"
|
| 267 |
+
|
| 268 |
+
_C.MODEL.HRNET.HRFPN = CN()
|
| 269 |
+
_C.MODEL.HRNET.HRFPN.OUT_CHANNELS = 256
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def add_densepose_config(cfg: CN) -> None:
|
| 273 |
+
add_densepose_head_config(cfg)
|
| 274 |
+
add_hrnet_config(cfg)
|
| 275 |
+
add_bootstrap_config(cfg)
|
| 276 |
+
add_dataset_category_config(cfg)
|
| 277 |
+
add_evaluation_config(cfg)
|
CatVTON/densepose/converters/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .hflip import HFlipConverter
|
| 6 |
+
from .to_mask import ToMaskConverter
|
| 7 |
+
from .to_chart_result import ToChartResultConverter, ToChartResultConverterWithConfidences
|
| 8 |
+
from .segm_to_mask import (
|
| 9 |
+
predictor_output_with_fine_and_coarse_segm_to_mask,
|
| 10 |
+
predictor_output_with_coarse_segm_to_mask,
|
| 11 |
+
resample_fine_and_coarse_segm_to_bbox,
|
| 12 |
+
)
|
| 13 |
+
from .chart_output_to_chart_result import (
|
| 14 |
+
densepose_chart_predictor_output_to_result,
|
| 15 |
+
densepose_chart_predictor_output_to_result_with_confidences,
|
| 16 |
+
)
|
| 17 |
+
from .chart_output_hflip import densepose_chart_predictor_output_hflip
|
CatVTON/densepose/converters/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (799 Bytes). View file
|
|
|
CatVTON/densepose/converters/__pycache__/base.cpython-39.pyc
ADDED
|
Binary file (3.68 kB). View file
|
|
|
CatVTON/densepose/converters/__pycache__/builtin.cpython-39.pyc
ADDED
|
Binary file (804 Bytes). View file
|
|
|
CatVTON/densepose/converters/__pycache__/chart_output_hflip.cpython-39.pyc
ADDED
|
Binary file (1.95 kB). View file
|
|
|
CatVTON/densepose/converters/__pycache__/chart_output_to_chart_result.cpython-39.pyc
ADDED
|
Binary file (6.03 kB). View file
|
|
|
CatVTON/densepose/converters/__pycache__/hflip.cpython-39.pyc
ADDED
|
Binary file (1.35 kB). View file
|
|
|
CatVTON/densepose/converters/__pycache__/segm_to_mask.cpython-39.pyc
ADDED
|
Binary file (5.75 kB). View file
|
|
|
CatVTON/densepose/converters/__pycache__/to_chart_result.cpython-39.pyc
ADDED
|
Binary file (2.74 kB). View file
|
|
|
CatVTON/densepose/converters/__pycache__/to_mask.cpython-39.pyc
ADDED
|
Binary file (1.76 kB). View file
|
|
|
CatVTON/densepose/converters/base.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any, Tuple, Type
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BaseConverter:
|
| 10 |
+
"""
|
| 11 |
+
Converter base class to be reused by various converters.
|
| 12 |
+
Converter allows one to convert data from various source types to a particular
|
| 13 |
+
destination type. Each source type needs to register its converter. The
|
| 14 |
+
registration for each source type is valid for all descendants of that type.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
@classmethod
|
| 18 |
+
def register(cls, from_type: Type, converter: Any = None):
|
| 19 |
+
"""
|
| 20 |
+
Registers a converter for the specified type.
|
| 21 |
+
Can be used as a decorator (if converter is None), or called as a method.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
from_type (type): type to register the converter for;
|
| 25 |
+
all instances of this type will use the same converter
|
| 26 |
+
converter (callable): converter to be registered for the given
|
| 27 |
+
type; if None, this method is assumed to be a decorator for the converter
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
if converter is not None:
|
| 31 |
+
cls._do_register(from_type, converter)
|
| 32 |
+
|
| 33 |
+
def wrapper(converter: Any) -> Any:
|
| 34 |
+
cls._do_register(from_type, converter)
|
| 35 |
+
return converter
|
| 36 |
+
|
| 37 |
+
return wrapper
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def _do_register(cls, from_type: Type, converter: Any):
|
| 41 |
+
cls.registry[from_type] = converter # pyre-ignore[16]
|
| 42 |
+
|
| 43 |
+
@classmethod
|
| 44 |
+
def _lookup_converter(cls, from_type: Type) -> Any:
|
| 45 |
+
"""
|
| 46 |
+
Perform recursive lookup for the given type
|
| 47 |
+
to find registered converter. If a converter was found for some base
|
| 48 |
+
class, it gets registered for this class to save on further lookups.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
from_type: type for which to find a converter
|
| 52 |
+
Return:
|
| 53 |
+
callable or None - registered converter or None
|
| 54 |
+
if no suitable entry was found in the registry
|
| 55 |
+
"""
|
| 56 |
+
if from_type in cls.registry: # pyre-ignore[16]
|
| 57 |
+
return cls.registry[from_type]
|
| 58 |
+
for base in from_type.__bases__:
|
| 59 |
+
converter = cls._lookup_converter(base)
|
| 60 |
+
if converter is not None:
|
| 61 |
+
cls._do_register(from_type, converter)
|
| 62 |
+
return converter
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def convert(cls, instance: Any, *args, **kwargs):
|
| 67 |
+
"""
|
| 68 |
+
Convert an instance to the destination type using some registered
|
| 69 |
+
converter. Does recursive lookup for base classes, so there's no need
|
| 70 |
+
for explicit registration for derived classes.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
instance: source instance to convert to the destination type
|
| 74 |
+
Return:
|
| 75 |
+
An instance of the destination type obtained from the source instance
|
| 76 |
+
Raises KeyError, if no suitable converter found
|
| 77 |
+
"""
|
| 78 |
+
instance_type = type(instance)
|
| 79 |
+
converter = cls._lookup_converter(instance_type)
|
| 80 |
+
if converter is None:
|
| 81 |
+
if cls.dst_type is None: # pyre-ignore[16]
|
| 82 |
+
output_type_str = "itself"
|
| 83 |
+
else:
|
| 84 |
+
output_type_str = cls.dst_type
|
| 85 |
+
raise KeyError(f"Could not find converter from {instance_type} to {output_type_str}")
|
| 86 |
+
return converter(instance, *args, **kwargs)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
IntTupleBox = Tuple[int, int, int, int]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def make_int_box(box: torch.Tensor) -> IntTupleBox:
|
| 93 |
+
int_box = [0, 0, 0, 0]
|
| 94 |
+
int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist())
|
| 95 |
+
return int_box[0], int_box[1], int_box[2], int_box[3]
|
CatVTON/densepose/converters/builtin.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from ..structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
|
| 6 |
+
from . import (
|
| 7 |
+
HFlipConverter,
|
| 8 |
+
ToChartResultConverter,
|
| 9 |
+
ToChartResultConverterWithConfidences,
|
| 10 |
+
ToMaskConverter,
|
| 11 |
+
densepose_chart_predictor_output_hflip,
|
| 12 |
+
densepose_chart_predictor_output_to_result,
|
| 13 |
+
densepose_chart_predictor_output_to_result_with_confidences,
|
| 14 |
+
predictor_output_with_coarse_segm_to_mask,
|
| 15 |
+
predictor_output_with_fine_and_coarse_segm_to_mask,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
ToMaskConverter.register(
|
| 19 |
+
DensePoseChartPredictorOutput, predictor_output_with_fine_and_coarse_segm_to_mask
|
| 20 |
+
)
|
| 21 |
+
ToMaskConverter.register(
|
| 22 |
+
DensePoseEmbeddingPredictorOutput, predictor_output_with_coarse_segm_to_mask
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
ToChartResultConverter.register(
|
| 26 |
+
DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
ToChartResultConverterWithConfidences.register(
|
| 30 |
+
DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result_with_confidences
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
HFlipConverter.register(DensePoseChartPredictorOutput, densepose_chart_predictor_output_hflip)
|
CatVTON/densepose/converters/chart_output_hflip.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
from dataclasses import fields
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from densepose.structures import DensePoseChartPredictorOutput, DensePoseTransformData
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def densepose_chart_predictor_output_hflip(
|
| 11 |
+
densepose_predictor_output: DensePoseChartPredictorOutput,
|
| 12 |
+
transform_data: DensePoseTransformData,
|
| 13 |
+
) -> DensePoseChartPredictorOutput:
|
| 14 |
+
"""
|
| 15 |
+
Change to take into account a Horizontal flip.
|
| 16 |
+
"""
|
| 17 |
+
if len(densepose_predictor_output) > 0:
|
| 18 |
+
|
| 19 |
+
PredictorOutput = type(densepose_predictor_output)
|
| 20 |
+
output_dict = {}
|
| 21 |
+
|
| 22 |
+
for field in fields(densepose_predictor_output):
|
| 23 |
+
field_value = getattr(densepose_predictor_output, field.name)
|
| 24 |
+
# flip tensors
|
| 25 |
+
if isinstance(field_value, torch.Tensor):
|
| 26 |
+
setattr(densepose_predictor_output, field.name, torch.flip(field_value, [3]))
|
| 27 |
+
|
| 28 |
+
densepose_predictor_output = _flip_iuv_semantics_tensor(
|
| 29 |
+
densepose_predictor_output, transform_data
|
| 30 |
+
)
|
| 31 |
+
densepose_predictor_output = _flip_segm_semantics_tensor(
|
| 32 |
+
densepose_predictor_output, transform_data
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
for field in fields(densepose_predictor_output):
|
| 36 |
+
output_dict[field.name] = getattr(densepose_predictor_output, field.name)
|
| 37 |
+
|
| 38 |
+
return PredictorOutput(**output_dict)
|
| 39 |
+
else:
|
| 40 |
+
return densepose_predictor_output
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _flip_iuv_semantics_tensor(
|
| 44 |
+
densepose_predictor_output: DensePoseChartPredictorOutput,
|
| 45 |
+
dp_transform_data: DensePoseTransformData,
|
| 46 |
+
) -> DensePoseChartPredictorOutput:
|
| 47 |
+
point_label_symmetries = dp_transform_data.point_label_symmetries
|
| 48 |
+
uv_symmetries = dp_transform_data.uv_symmetries
|
| 49 |
+
|
| 50 |
+
N, C, H, W = densepose_predictor_output.u.shape
|
| 51 |
+
u_loc = (densepose_predictor_output.u[:, 1:, :, :].clamp(0, 1) * 255).long()
|
| 52 |
+
v_loc = (densepose_predictor_output.v[:, 1:, :, :].clamp(0, 1) * 255).long()
|
| 53 |
+
Iindex = torch.arange(C - 1, device=densepose_predictor_output.u.device)[
|
| 54 |
+
None, :, None, None
|
| 55 |
+
].expand(N, C - 1, H, W)
|
| 56 |
+
densepose_predictor_output.u[:, 1:, :, :] = uv_symmetries["U_transforms"][Iindex, v_loc, u_loc]
|
| 57 |
+
densepose_predictor_output.v[:, 1:, :, :] = uv_symmetries["V_transforms"][Iindex, v_loc, u_loc]
|
| 58 |
+
|
| 59 |
+
for el in ["fine_segm", "u", "v"]:
|
| 60 |
+
densepose_predictor_output.__dict__[el] = densepose_predictor_output.__dict__[el][
|
| 61 |
+
:, point_label_symmetries, :, :
|
| 62 |
+
]
|
| 63 |
+
return densepose_predictor_output
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _flip_segm_semantics_tensor(
|
| 67 |
+
densepose_predictor_output: DensePoseChartPredictorOutput, dp_transform_data
|
| 68 |
+
):
|
| 69 |
+
if densepose_predictor_output.coarse_segm.shape[1] > 2:
|
| 70 |
+
densepose_predictor_output.coarse_segm = densepose_predictor_output.coarse_segm[
|
| 71 |
+
:, dp_transform_data.mask_label_symmetries, :, :
|
| 72 |
+
]
|
| 73 |
+
return densepose_predictor_output
|
CatVTON/densepose/converters/chart_output_to_chart_result.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Dict
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
from detectron2.structures.boxes import Boxes, BoxMode
|
| 10 |
+
|
| 11 |
+
from ..structures import (
|
| 12 |
+
DensePoseChartPredictorOutput,
|
| 13 |
+
DensePoseChartResult,
|
| 14 |
+
DensePoseChartResultWithConfidences,
|
| 15 |
+
)
|
| 16 |
+
from . import resample_fine_and_coarse_segm_to_bbox
|
| 17 |
+
from .base import IntTupleBox, make_int_box
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def resample_uv_tensors_to_bbox(
|
| 21 |
+
u: torch.Tensor,
|
| 22 |
+
v: torch.Tensor,
|
| 23 |
+
labels: torch.Tensor,
|
| 24 |
+
box_xywh_abs: IntTupleBox,
|
| 25 |
+
) -> torch.Tensor:
|
| 26 |
+
"""
|
| 27 |
+
Resamples U and V coordinate estimates for the given bounding box
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
u (tensor [1, C, H, W] of float): U coordinates
|
| 31 |
+
v (tensor [1, C, H, W] of float): V coordinates
|
| 32 |
+
labels (tensor [H, W] of long): labels obtained by resampling segmentation
|
| 33 |
+
outputs for the given bounding box
|
| 34 |
+
box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
|
| 35 |
+
Return:
|
| 36 |
+
Resampled U and V coordinates - a tensor [2, H, W] of float
|
| 37 |
+
"""
|
| 38 |
+
x, y, w, h = box_xywh_abs
|
| 39 |
+
w = max(int(w), 1)
|
| 40 |
+
h = max(int(h), 1)
|
| 41 |
+
u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False)
|
| 42 |
+
v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False)
|
| 43 |
+
uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device)
|
| 44 |
+
for part_id in range(1, u_bbox.size(1)):
|
| 45 |
+
uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id]
|
| 46 |
+
uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id]
|
| 47 |
+
return uv
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def resample_uv_to_bbox(
|
| 51 |
+
predictor_output: DensePoseChartPredictorOutput,
|
| 52 |
+
labels: torch.Tensor,
|
| 53 |
+
box_xywh_abs: IntTupleBox,
|
| 54 |
+
) -> torch.Tensor:
|
| 55 |
+
"""
|
| 56 |
+
Resamples U and V coordinate estimates for the given bounding box
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
predictor_output (DensePoseChartPredictorOutput): DensePose predictor
|
| 60 |
+
output to be resampled
|
| 61 |
+
labels (tensor [H, W] of long): labels obtained by resampling segmentation
|
| 62 |
+
outputs for the given bounding box
|
| 63 |
+
box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
|
| 64 |
+
Return:
|
| 65 |
+
Resampled U and V coordinates - a tensor [2, H, W] of float
|
| 66 |
+
"""
|
| 67 |
+
return resample_uv_tensors_to_bbox(
|
| 68 |
+
predictor_output.u,
|
| 69 |
+
predictor_output.v,
|
| 70 |
+
labels,
|
| 71 |
+
box_xywh_abs,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def densepose_chart_predictor_output_to_result(
|
| 76 |
+
predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
|
| 77 |
+
) -> DensePoseChartResult:
|
| 78 |
+
"""
|
| 79 |
+
Convert densepose chart predictor outputs to results
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
predictor_output (DensePoseChartPredictorOutput): DensePose predictor
|
| 83 |
+
output to be converted to results, must contain only 1 output
|
| 84 |
+
boxes (Boxes): bounding box that corresponds to the predictor output,
|
| 85 |
+
must contain only 1 bounding box
|
| 86 |
+
Return:
|
| 87 |
+
DensePose chart-based result (DensePoseChartResult)
|
| 88 |
+
"""
|
| 89 |
+
assert len(predictor_output) == 1 and len(boxes) == 1, (
|
| 90 |
+
f"Predictor output to result conversion can operate only single outputs"
|
| 91 |
+
f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
boxes_xyxy_abs = boxes.tensor.clone()
|
| 95 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| 96 |
+
box_xywh = make_int_box(boxes_xywh_abs[0])
|
| 97 |
+
|
| 98 |
+
labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
|
| 99 |
+
uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
|
| 100 |
+
return DensePoseChartResult(labels=labels, uv=uv)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def resample_confidences_to_bbox(
|
| 104 |
+
predictor_output: DensePoseChartPredictorOutput,
|
| 105 |
+
labels: torch.Tensor,
|
| 106 |
+
box_xywh_abs: IntTupleBox,
|
| 107 |
+
) -> Dict[str, torch.Tensor]:
|
| 108 |
+
"""
|
| 109 |
+
Resamples confidences for the given bounding box
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
predictor_output (DensePoseChartPredictorOutput): DensePose predictor
|
| 113 |
+
output to be resampled
|
| 114 |
+
labels (tensor [H, W] of long): labels obtained by resampling segmentation
|
| 115 |
+
outputs for the given bounding box
|
| 116 |
+
box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
|
| 117 |
+
Return:
|
| 118 |
+
Resampled confidences - a dict of [H, W] tensors of float
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
x, y, w, h = box_xywh_abs
|
| 122 |
+
w = max(int(w), 1)
|
| 123 |
+
h = max(int(h), 1)
|
| 124 |
+
|
| 125 |
+
confidence_names = [
|
| 126 |
+
"sigma_1",
|
| 127 |
+
"sigma_2",
|
| 128 |
+
"kappa_u",
|
| 129 |
+
"kappa_v",
|
| 130 |
+
"fine_segm_confidence",
|
| 131 |
+
"coarse_segm_confidence",
|
| 132 |
+
]
|
| 133 |
+
confidence_results = {key: None for key in confidence_names}
|
| 134 |
+
confidence_names = [
|
| 135 |
+
key for key in confidence_names if getattr(predictor_output, key) is not None
|
| 136 |
+
]
|
| 137 |
+
confidence_base = torch.zeros([h, w], dtype=torch.float32, device=predictor_output.u.device)
|
| 138 |
+
|
| 139 |
+
# assign data from channels that correspond to the labels
|
| 140 |
+
for key in confidence_names:
|
| 141 |
+
resampled_confidence = F.interpolate(
|
| 142 |
+
getattr(predictor_output, key),
|
| 143 |
+
(h, w),
|
| 144 |
+
mode="bilinear",
|
| 145 |
+
align_corners=False,
|
| 146 |
+
)
|
| 147 |
+
result = confidence_base.clone()
|
| 148 |
+
for part_id in range(1, predictor_output.u.size(1)):
|
| 149 |
+
if resampled_confidence.size(1) != predictor_output.u.size(1):
|
| 150 |
+
# confidence is not part-based, don't try to fill it part by part
|
| 151 |
+
continue
|
| 152 |
+
result[labels == part_id] = resampled_confidence[0, part_id][labels == part_id]
|
| 153 |
+
|
| 154 |
+
if resampled_confidence.size(1) != predictor_output.u.size(1):
|
| 155 |
+
# confidence is not part-based, fill the data with the first channel
|
| 156 |
+
# (targeted for segmentation confidences that have only 1 channel)
|
| 157 |
+
result = resampled_confidence[0, 0]
|
| 158 |
+
|
| 159 |
+
confidence_results[key] = result
|
| 160 |
+
|
| 161 |
+
return confidence_results # pyre-ignore[7]
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def densepose_chart_predictor_output_to_result_with_confidences(
|
| 165 |
+
predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
|
| 166 |
+
) -> DensePoseChartResultWithConfidences:
|
| 167 |
+
"""
|
| 168 |
+
Convert densepose chart predictor outputs to results
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
predictor_output (DensePoseChartPredictorOutput): DensePose predictor
|
| 172 |
+
output with confidences to be converted to results, must contain only 1 output
|
| 173 |
+
boxes (Boxes): bounding box that corresponds to the predictor output,
|
| 174 |
+
must contain only 1 bounding box
|
| 175 |
+
Return:
|
| 176 |
+
DensePose chart-based result with confidences (DensePoseChartResultWithConfidences)
|
| 177 |
+
"""
|
| 178 |
+
assert len(predictor_output) == 1 and len(boxes) == 1, (
|
| 179 |
+
f"Predictor output to result conversion can operate only single outputs"
|
| 180 |
+
f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
boxes_xyxy_abs = boxes.tensor.clone()
|
| 184 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| 185 |
+
box_xywh = make_int_box(boxes_xywh_abs[0])
|
| 186 |
+
|
| 187 |
+
labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
|
| 188 |
+
uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
|
| 189 |
+
confidences = resample_confidences_to_bbox(predictor_output, labels, box_xywh)
|
| 190 |
+
return DensePoseChartResultWithConfidences(labels=labels, uv=uv, **confidences)
|
CatVTON/densepose/converters/hflip.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from .base import BaseConverter
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HFlipConverter(BaseConverter):
|
| 11 |
+
"""
|
| 12 |
+
Converts various DensePose predictor outputs to DensePose results.
|
| 13 |
+
Each DensePose predictor output type has to register its convertion strategy.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
registry = {}
|
| 17 |
+
dst_type = None
|
| 18 |
+
|
| 19 |
+
@classmethod
|
| 20 |
+
# pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
|
| 21 |
+
# inconsistently.
|
| 22 |
+
def convert(cls, predictor_outputs: Any, transform_data: Any, *args, **kwargs):
|
| 23 |
+
"""
|
| 24 |
+
Performs an horizontal flip on DensePose predictor outputs.
|
| 25 |
+
Does recursive lookup for base classes, so there's no need
|
| 26 |
+
for explicit registration for derived classes.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
predictor_outputs: DensePose predictor output to be converted to BitMasks
|
| 30 |
+
transform_data: Anything useful for the flip
|
| 31 |
+
Return:
|
| 32 |
+
An instance of the same type as predictor_outputs
|
| 33 |
+
"""
|
| 34 |
+
return super(HFlipConverter, cls).convert(
|
| 35 |
+
predictor_outputs, transform_data, *args, **kwargs
|
| 36 |
+
)
|
CatVTON/densepose/converters/segm_to_mask.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
from detectron2.structures import BitMasks, Boxes, BoxMode
|
| 10 |
+
|
| 11 |
+
from .base import IntTupleBox, make_int_box
|
| 12 |
+
from .to_mask import ImageSizeType
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def resample_coarse_segm_tensor_to_bbox(coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox):
|
| 16 |
+
"""
|
| 17 |
+
Resample coarse segmentation tensor to the given
|
| 18 |
+
bounding box and derive labels for each pixel of the bounding box
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
coarse_segm: float tensor of shape [1, K, Hout, Wout]
|
| 22 |
+
box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
|
| 23 |
+
corner coordinates, width (W) and height (H)
|
| 24 |
+
Return:
|
| 25 |
+
Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
|
| 26 |
+
"""
|
| 27 |
+
x, y, w, h = box_xywh_abs
|
| 28 |
+
w = max(int(w), 1)
|
| 29 |
+
h = max(int(h), 1)
|
| 30 |
+
labels = F.interpolate(coarse_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
|
| 31 |
+
return labels
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def resample_fine_and_coarse_segm_tensors_to_bbox(
|
| 35 |
+
fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Resample fine and coarse segmentation tensors to the given
|
| 39 |
+
bounding box and derive labels for each pixel of the bounding box
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
fine_segm: float tensor of shape [1, C, Hout, Wout]
|
| 43 |
+
coarse_segm: float tensor of shape [1, K, Hout, Wout]
|
| 44 |
+
box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
|
| 45 |
+
corner coordinates, width (W) and height (H)
|
| 46 |
+
Return:
|
| 47 |
+
Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
|
| 48 |
+
"""
|
| 49 |
+
x, y, w, h = box_xywh_abs
|
| 50 |
+
w = max(int(w), 1)
|
| 51 |
+
h = max(int(h), 1)
|
| 52 |
+
# coarse segmentation
|
| 53 |
+
coarse_segm_bbox = F.interpolate(
|
| 54 |
+
coarse_segm,
|
| 55 |
+
(h, w),
|
| 56 |
+
mode="bilinear",
|
| 57 |
+
align_corners=False,
|
| 58 |
+
).argmax(dim=1)
|
| 59 |
+
# combined coarse and fine segmentation
|
| 60 |
+
labels = (
|
| 61 |
+
F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
|
| 62 |
+
* (coarse_segm_bbox > 0).long()
|
| 63 |
+
)
|
| 64 |
+
return labels
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def resample_fine_and_coarse_segm_to_bbox(predictor_output: Any, box_xywh_abs: IntTupleBox):
|
| 68 |
+
"""
|
| 69 |
+
Resample fine and coarse segmentation outputs from a predictor to the given
|
| 70 |
+
bounding box and derive labels for each pixel of the bounding box
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
predictor_output: DensePose predictor output that contains segmentation
|
| 74 |
+
results to be resampled
|
| 75 |
+
box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
|
| 76 |
+
corner coordinates, width (W) and height (H)
|
| 77 |
+
Return:
|
| 78 |
+
Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
|
| 79 |
+
"""
|
| 80 |
+
return resample_fine_and_coarse_segm_tensors_to_bbox(
|
| 81 |
+
predictor_output.fine_segm,
|
| 82 |
+
predictor_output.coarse_segm,
|
| 83 |
+
box_xywh_abs,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def predictor_output_with_coarse_segm_to_mask(
|
| 88 |
+
predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
|
| 89 |
+
) -> BitMasks:
|
| 90 |
+
"""
|
| 91 |
+
Convert predictor output with coarse and fine segmentation to a mask.
|
| 92 |
+
Assumes that predictor output has the following attributes:
|
| 93 |
+
- coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
|
| 94 |
+
unnormalized scores for N instances; D is the number of coarse
|
| 95 |
+
segmentation labels, H and W is the resolution of the estimate
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
predictor_output: DensePose predictor output to be converted to mask
|
| 99 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
| 100 |
+
predictor outputs
|
| 101 |
+
image_size_hw (tuple [int, int]): image height Himg and width Wimg
|
| 102 |
+
Return:
|
| 103 |
+
BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
|
| 104 |
+
a mask of the size of the image for each instance
|
| 105 |
+
"""
|
| 106 |
+
H, W = image_size_hw
|
| 107 |
+
boxes_xyxy_abs = boxes.tensor.clone()
|
| 108 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| 109 |
+
N = len(boxes_xywh_abs)
|
| 110 |
+
masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
|
| 111 |
+
for i in range(len(boxes_xywh_abs)):
|
| 112 |
+
box_xywh = make_int_box(boxes_xywh_abs[i])
|
| 113 |
+
box_mask = resample_coarse_segm_tensor_to_bbox(predictor_output[i].coarse_segm, box_xywh)
|
| 114 |
+
x, y, w, h = box_xywh
|
| 115 |
+
masks[i, y : y + h, x : x + w] = box_mask
|
| 116 |
+
|
| 117 |
+
return BitMasks(masks)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def predictor_output_with_fine_and_coarse_segm_to_mask(
|
| 121 |
+
predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
|
| 122 |
+
) -> BitMasks:
|
| 123 |
+
"""
|
| 124 |
+
Convert predictor output with coarse and fine segmentation to a mask.
|
| 125 |
+
Assumes that predictor output has the following attributes:
|
| 126 |
+
- coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
|
| 127 |
+
unnormalized scores for N instances; D is the number of coarse
|
| 128 |
+
segmentation labels, H and W is the resolution of the estimate
|
| 129 |
+
- fine_segm (tensor of size [N, C, H, W]): fine segmentation
|
| 130 |
+
unnormalized scores for N instances; C is the number of fine
|
| 131 |
+
segmentation labels, H and W is the resolution of the estimate
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
predictor_output: DensePose predictor output to be converted to mask
|
| 135 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
| 136 |
+
predictor outputs
|
| 137 |
+
image_size_hw (tuple [int, int]): image height Himg and width Wimg
|
| 138 |
+
Return:
|
| 139 |
+
BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
|
| 140 |
+
a mask of the size of the image for each instance
|
| 141 |
+
"""
|
| 142 |
+
H, W = image_size_hw
|
| 143 |
+
boxes_xyxy_abs = boxes.tensor.clone()
|
| 144 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| 145 |
+
N = len(boxes_xywh_abs)
|
| 146 |
+
masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
|
| 147 |
+
for i in range(len(boxes_xywh_abs)):
|
| 148 |
+
box_xywh = make_int_box(boxes_xywh_abs[i])
|
| 149 |
+
labels_i = resample_fine_and_coarse_segm_to_bbox(predictor_output[i], box_xywh)
|
| 150 |
+
x, y, w, h = box_xywh
|
| 151 |
+
masks[i, y : y + h, x : x + w] = labels_i > 0
|
| 152 |
+
return BitMasks(masks)
|
CatVTON/densepose/converters/to_chart_result.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from detectron2.structures import Boxes
|
| 8 |
+
|
| 9 |
+
from ..structures import DensePoseChartResult, DensePoseChartResultWithConfidences
|
| 10 |
+
from .base import BaseConverter
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ToChartResultConverter(BaseConverter):
|
| 14 |
+
"""
|
| 15 |
+
Converts various DensePose predictor outputs to DensePose results.
|
| 16 |
+
Each DensePose predictor output type has to register its convertion strategy.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
registry = {}
|
| 20 |
+
dst_type = DensePoseChartResult
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
# pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
|
| 24 |
+
# inconsistently.
|
| 25 |
+
def convert(cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs) -> DensePoseChartResult:
|
| 26 |
+
"""
|
| 27 |
+
Convert DensePose predictor outputs to DensePoseResult using some registered
|
| 28 |
+
converter. Does recursive lookup for base classes, so there's no need
|
| 29 |
+
for explicit registration for derived classes.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
densepose_predictor_outputs: DensePose predictor output to be
|
| 33 |
+
converted to BitMasks
|
| 34 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
| 35 |
+
predictor outputs
|
| 36 |
+
Return:
|
| 37 |
+
An instance of DensePoseResult. If no suitable converter was found, raises KeyError
|
| 38 |
+
"""
|
| 39 |
+
return super(ToChartResultConverter, cls).convert(predictor_outputs, boxes, *args, **kwargs)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ToChartResultConverterWithConfidences(BaseConverter):
|
| 43 |
+
"""
|
| 44 |
+
Converts various DensePose predictor outputs to DensePose results.
|
| 45 |
+
Each DensePose predictor output type has to register its convertion strategy.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
registry = {}
|
| 49 |
+
dst_type = DensePoseChartResultWithConfidences
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
# pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
|
| 53 |
+
# inconsistently.
|
| 54 |
+
def convert(
|
| 55 |
+
cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs
|
| 56 |
+
) -> DensePoseChartResultWithConfidences:
|
| 57 |
+
"""
|
| 58 |
+
Convert DensePose predictor outputs to DensePoseResult with confidences
|
| 59 |
+
using some registered converter. Does recursive lookup for base classes,
|
| 60 |
+
so there's no need for explicit registration for derived classes.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
densepose_predictor_outputs: DensePose predictor output with confidences
|
| 64 |
+
to be converted to BitMasks
|
| 65 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
| 66 |
+
predictor outputs
|
| 67 |
+
Return:
|
| 68 |
+
An instance of DensePoseResult. If no suitable converter was found, raises KeyError
|
| 69 |
+
"""
|
| 70 |
+
return super(ToChartResultConverterWithConfidences, cls).convert(
|
| 71 |
+
predictor_outputs, boxes, *args, **kwargs
|
| 72 |
+
)
|
CatVTON/densepose/converters/to_mask.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any, Tuple
|
| 6 |
+
|
| 7 |
+
from detectron2.structures import BitMasks, Boxes
|
| 8 |
+
|
| 9 |
+
from .base import BaseConverter
|
| 10 |
+
|
| 11 |
+
ImageSizeType = Tuple[int, int]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ToMaskConverter(BaseConverter):
|
| 15 |
+
"""
|
| 16 |
+
Converts various DensePose predictor outputs to masks
|
| 17 |
+
in bit mask format (see `BitMasks`). Each DensePose predictor output type
|
| 18 |
+
has to register its convertion strategy.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
registry = {}
|
| 22 |
+
dst_type = BitMasks
|
| 23 |
+
|
| 24 |
+
@classmethod
|
| 25 |
+
# pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
|
| 26 |
+
# inconsistently.
|
| 27 |
+
def convert(
|
| 28 |
+
cls,
|
| 29 |
+
densepose_predictor_outputs: Any,
|
| 30 |
+
boxes: Boxes,
|
| 31 |
+
image_size_hw: ImageSizeType,
|
| 32 |
+
*args,
|
| 33 |
+
**kwargs
|
| 34 |
+
) -> BitMasks:
|
| 35 |
+
"""
|
| 36 |
+
Convert DensePose predictor outputs to BitMasks using some registered
|
| 37 |
+
converter. Does recursive lookup for base classes, so there's no need
|
| 38 |
+
for explicit registration for derived classes.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
densepose_predictor_outputs: DensePose predictor output to be
|
| 42 |
+
converted to BitMasks
|
| 43 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
| 44 |
+
predictor outputs
|
| 45 |
+
image_size_hw (tuple [int, int]): image height and width
|
| 46 |
+
Return:
|
| 47 |
+
An instance of `BitMasks`. If no suitable converter was found, raises KeyError
|
| 48 |
+
"""
|
| 49 |
+
return super(ToMaskConverter, cls).convert(
|
| 50 |
+
densepose_predictor_outputs, boxes, image_size_hw, *args, **kwargs
|
| 51 |
+
)
|
CatVTON/densepose/engine/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .trainer import Trainer
|
CatVTON/densepose/engine/trainer.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
from typing import List, Optional, Union
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
| 13 |
+
from detectron2.config import CfgNode
|
| 14 |
+
from detectron2.engine import DefaultTrainer
|
| 15 |
+
from detectron2.evaluation import (
|
| 16 |
+
DatasetEvaluator,
|
| 17 |
+
DatasetEvaluators,
|
| 18 |
+
inference_on_dataset,
|
| 19 |
+
print_csv_format,
|
| 20 |
+
)
|
| 21 |
+
from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping
|
| 22 |
+
from detectron2.utils import comm
|
| 23 |
+
from detectron2.utils.events import EventWriter, get_event_storage
|
| 24 |
+
|
| 25 |
+
from densepose import DensePoseDatasetMapperTTA, DensePoseGeneralizedRCNNWithTTA, load_from_cfg
|
| 26 |
+
from densepose.data import (
|
| 27 |
+
DatasetMapper,
|
| 28 |
+
build_combined_loader,
|
| 29 |
+
build_detection_test_loader,
|
| 30 |
+
build_detection_train_loader,
|
| 31 |
+
build_inference_based_loaders,
|
| 32 |
+
has_inference_based_loaders,
|
| 33 |
+
)
|
| 34 |
+
from densepose.evaluation.d2_evaluator_adapter import Detectron2COCOEvaluatorAdapter
|
| 35 |
+
from densepose.evaluation.evaluator import DensePoseCOCOEvaluator, build_densepose_evaluator_storage
|
| 36 |
+
from densepose.modeling.cse import Embedder
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SampleCountingLoader:
|
| 40 |
+
def __init__(self, loader):
|
| 41 |
+
self.loader = loader
|
| 42 |
+
|
| 43 |
+
def __iter__(self):
|
| 44 |
+
it = iter(self.loader)
|
| 45 |
+
storage = get_event_storage()
|
| 46 |
+
while True:
|
| 47 |
+
try:
|
| 48 |
+
batch = next(it)
|
| 49 |
+
num_inst_per_dataset = {}
|
| 50 |
+
for data in batch:
|
| 51 |
+
dataset_name = data["dataset"]
|
| 52 |
+
if dataset_name not in num_inst_per_dataset:
|
| 53 |
+
num_inst_per_dataset[dataset_name] = 0
|
| 54 |
+
num_inst = len(data["instances"])
|
| 55 |
+
num_inst_per_dataset[dataset_name] += num_inst
|
| 56 |
+
for dataset_name in num_inst_per_dataset:
|
| 57 |
+
storage.put_scalar(f"batch/{dataset_name}", num_inst_per_dataset[dataset_name])
|
| 58 |
+
yield batch
|
| 59 |
+
except StopIteration:
|
| 60 |
+
break
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class SampleCountMetricPrinter(EventWriter):
|
| 64 |
+
def __init__(self):
|
| 65 |
+
self.logger = logging.getLogger(__name__)
|
| 66 |
+
|
| 67 |
+
def write(self):
|
| 68 |
+
storage = get_event_storage()
|
| 69 |
+
batch_stats_strs = []
|
| 70 |
+
for key, buf in storage.histories().items():
|
| 71 |
+
if key.startswith("batch/"):
|
| 72 |
+
batch_stats_strs.append(f"{key} {buf.avg(20)}")
|
| 73 |
+
self.logger.info(", ".join(batch_stats_strs))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class Trainer(DefaultTrainer):
|
| 77 |
+
@classmethod
|
| 78 |
+
def extract_embedder_from_model(cls, model: nn.Module) -> Optional[Embedder]:
|
| 79 |
+
if isinstance(model, nn.parallel.DistributedDataParallel):
|
| 80 |
+
model = model.module
|
| 81 |
+
if hasattr(model, "roi_heads") and hasattr(model.roi_heads, "embedder"):
|
| 82 |
+
return model.roi_heads.embedder
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
# TODO: the only reason to copy the base class code here is to pass the embedder from
|
| 86 |
+
# the model to the evaluator; that should be refactored to avoid unnecessary copy-pasting
|
| 87 |
+
@classmethod
|
| 88 |
+
def test(
|
| 89 |
+
cls,
|
| 90 |
+
cfg: CfgNode,
|
| 91 |
+
model: nn.Module,
|
| 92 |
+
evaluators: Optional[Union[DatasetEvaluator, List[DatasetEvaluator]]] = None,
|
| 93 |
+
):
|
| 94 |
+
"""
|
| 95 |
+
Args:
|
| 96 |
+
cfg (CfgNode):
|
| 97 |
+
model (nn.Module):
|
| 98 |
+
evaluators (DatasetEvaluator, list[DatasetEvaluator] or None): if None, will call
|
| 99 |
+
:meth:`build_evaluator`. Otherwise, must have the same length as
|
| 100 |
+
``cfg.DATASETS.TEST``.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
dict: a dict of result metrics
|
| 104 |
+
"""
|
| 105 |
+
logger = logging.getLogger(__name__)
|
| 106 |
+
if isinstance(evaluators, DatasetEvaluator):
|
| 107 |
+
evaluators = [evaluators]
|
| 108 |
+
if evaluators is not None:
|
| 109 |
+
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
|
| 110 |
+
len(cfg.DATASETS.TEST), len(evaluators)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
results = OrderedDict()
|
| 114 |
+
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
|
| 115 |
+
data_loader = cls.build_test_loader(cfg, dataset_name)
|
| 116 |
+
# When evaluators are passed in as arguments,
|
| 117 |
+
# implicitly assume that evaluators can be created before data_loader.
|
| 118 |
+
if evaluators is not None:
|
| 119 |
+
evaluator = evaluators[idx]
|
| 120 |
+
else:
|
| 121 |
+
try:
|
| 122 |
+
embedder = cls.extract_embedder_from_model(model)
|
| 123 |
+
evaluator = cls.build_evaluator(cfg, dataset_name, embedder=embedder)
|
| 124 |
+
except NotImplementedError:
|
| 125 |
+
logger.warn(
|
| 126 |
+
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
|
| 127 |
+
"or implement its `build_evaluator` method."
|
| 128 |
+
)
|
| 129 |
+
results[dataset_name] = {}
|
| 130 |
+
continue
|
| 131 |
+
if cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE or comm.is_main_process():
|
| 132 |
+
results_i = inference_on_dataset(model, data_loader, evaluator)
|
| 133 |
+
else:
|
| 134 |
+
results_i = {}
|
| 135 |
+
results[dataset_name] = results_i
|
| 136 |
+
if comm.is_main_process():
|
| 137 |
+
assert isinstance(
|
| 138 |
+
results_i, dict
|
| 139 |
+
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
| 140 |
+
results_i
|
| 141 |
+
)
|
| 142 |
+
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
|
| 143 |
+
print_csv_format(results_i)
|
| 144 |
+
|
| 145 |
+
if len(results) == 1:
|
| 146 |
+
results = list(results.values())[0]
|
| 147 |
+
return results
|
| 148 |
+
|
| 149 |
+
@classmethod
|
| 150 |
+
def build_evaluator(
|
| 151 |
+
cls,
|
| 152 |
+
cfg: CfgNode,
|
| 153 |
+
dataset_name: str,
|
| 154 |
+
output_folder: Optional[str] = None,
|
| 155 |
+
embedder: Optional[Embedder] = None,
|
| 156 |
+
) -> DatasetEvaluators:
|
| 157 |
+
if output_folder is None:
|
| 158 |
+
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
| 159 |
+
evaluators = []
|
| 160 |
+
distributed = cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE
|
| 161 |
+
# Note: we currently use COCO evaluator for both COCO and LVIS datasets
|
| 162 |
+
# to have compatible metrics. LVIS bbox evaluator could also be used
|
| 163 |
+
# with an adapter to properly handle filtered / mapped categories
|
| 164 |
+
# evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
|
| 165 |
+
# if evaluator_type == "coco":
|
| 166 |
+
# evaluators.append(COCOEvaluator(dataset_name, output_dir=output_folder))
|
| 167 |
+
# elif evaluator_type == "lvis":
|
| 168 |
+
# evaluators.append(LVISEvaluator(dataset_name, output_dir=output_folder))
|
| 169 |
+
evaluators.append(
|
| 170 |
+
Detectron2COCOEvaluatorAdapter(
|
| 171 |
+
dataset_name, output_dir=output_folder, distributed=distributed
|
| 172 |
+
)
|
| 173 |
+
)
|
| 174 |
+
if cfg.MODEL.DENSEPOSE_ON:
|
| 175 |
+
storage = build_densepose_evaluator_storage(cfg, output_folder)
|
| 176 |
+
evaluators.append(
|
| 177 |
+
DensePoseCOCOEvaluator(
|
| 178 |
+
dataset_name,
|
| 179 |
+
distributed,
|
| 180 |
+
output_folder,
|
| 181 |
+
evaluator_type=cfg.DENSEPOSE_EVALUATION.TYPE,
|
| 182 |
+
min_iou_threshold=cfg.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD,
|
| 183 |
+
storage=storage,
|
| 184 |
+
embedder=embedder,
|
| 185 |
+
should_evaluate_mesh_alignment=cfg.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT,
|
| 186 |
+
mesh_alignment_mesh_names=cfg.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES,
|
| 187 |
+
)
|
| 188 |
+
)
|
| 189 |
+
return DatasetEvaluators(evaluators)
|
| 190 |
+
|
| 191 |
+
@classmethod
|
| 192 |
+
def build_optimizer(cls, cfg: CfgNode, model: nn.Module):
|
| 193 |
+
params = get_default_optimizer_params(
|
| 194 |
+
model,
|
| 195 |
+
base_lr=cfg.SOLVER.BASE_LR,
|
| 196 |
+
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
|
| 197 |
+
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
|
| 198 |
+
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
|
| 199 |
+
overrides={
|
| 200 |
+
"features": {
|
| 201 |
+
"lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR,
|
| 202 |
+
},
|
| 203 |
+
"embeddings": {
|
| 204 |
+
"lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR,
|
| 205 |
+
},
|
| 206 |
+
},
|
| 207 |
+
)
|
| 208 |
+
optimizer = torch.optim.SGD(
|
| 209 |
+
params,
|
| 210 |
+
cfg.SOLVER.BASE_LR,
|
| 211 |
+
momentum=cfg.SOLVER.MOMENTUM,
|
| 212 |
+
nesterov=cfg.SOLVER.NESTEROV,
|
| 213 |
+
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
|
| 214 |
+
)
|
| 215 |
+
# pyre-fixme[6]: For 2nd param expected `Type[Optimizer]` but got `SGD`.
|
| 216 |
+
return maybe_add_gradient_clipping(cfg, optimizer)
|
| 217 |
+
|
| 218 |
+
@classmethod
|
| 219 |
+
def build_test_loader(cls, cfg: CfgNode, dataset_name):
|
| 220 |
+
return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False))
|
| 221 |
+
|
| 222 |
+
@classmethod
|
| 223 |
+
def build_train_loader(cls, cfg: CfgNode):
|
| 224 |
+
data_loader = build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True))
|
| 225 |
+
if not has_inference_based_loaders(cfg):
|
| 226 |
+
return data_loader
|
| 227 |
+
model = cls.build_model(cfg)
|
| 228 |
+
model.to(cfg.BOOTSTRAP_MODEL.DEVICE)
|
| 229 |
+
DetectionCheckpointer(model).resume_or_load(cfg.BOOTSTRAP_MODEL.WEIGHTS, resume=False)
|
| 230 |
+
inference_based_loaders, ratios = build_inference_based_loaders(cfg, model)
|
| 231 |
+
loaders = [data_loader] + inference_based_loaders
|
| 232 |
+
ratios = [1.0] + ratios
|
| 233 |
+
combined_data_loader = build_combined_loader(cfg, loaders, ratios)
|
| 234 |
+
sample_counting_loader = SampleCountingLoader(combined_data_loader)
|
| 235 |
+
return sample_counting_loader
|
| 236 |
+
|
| 237 |
+
def build_writers(self):
|
| 238 |
+
writers = super().build_writers()
|
| 239 |
+
writers.append(SampleCountMetricPrinter())
|
| 240 |
+
return writers
|
| 241 |
+
|
| 242 |
+
@classmethod
|
| 243 |
+
def test_with_TTA(cls, cfg: CfgNode, model):
|
| 244 |
+
logger = logging.getLogger("detectron2.trainer")
|
| 245 |
+
# In the end of training, run an evaluation with TTA
|
| 246 |
+
# Only support some R-CNN models.
|
| 247 |
+
logger.info("Running inference with test-time augmentation ...")
|
| 248 |
+
transform_data = load_from_cfg(cfg)
|
| 249 |
+
model = DensePoseGeneralizedRCNNWithTTA(
|
| 250 |
+
cfg, model, transform_data, DensePoseDatasetMapperTTA(cfg)
|
| 251 |
+
)
|
| 252 |
+
evaluators = [
|
| 253 |
+
cls.build_evaluator(
|
| 254 |
+
cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
|
| 255 |
+
)
|
| 256 |
+
for name in cfg.DATASETS.TEST
|
| 257 |
+
]
|
| 258 |
+
res = cls.test(cfg, model, evaluators) # pyre-ignore[6]
|
| 259 |
+
res = OrderedDict({k + "_TTA": v for k, v in res.items()})
|
| 260 |
+
return res
|
CatVTON/densepose/modeling/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .confidence import DensePoseConfidenceModelConfig, DensePoseUVConfidenceType
|
| 6 |
+
from .filter import DensePoseDataFilter
|
| 7 |
+
from .inference import densepose_inference
|
| 8 |
+
from .utils import initialize_module_params
|
| 9 |
+
from .build import (
|
| 10 |
+
build_densepose_data_filter,
|
| 11 |
+
build_densepose_embedder,
|
| 12 |
+
build_densepose_head,
|
| 13 |
+
build_densepose_losses,
|
| 14 |
+
build_densepose_predictor,
|
| 15 |
+
)
|
CatVTON/densepose/modeling/build.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from detectron2.config import CfgNode
|
| 9 |
+
|
| 10 |
+
from .cse.embedder import Embedder
|
| 11 |
+
from .filter import DensePoseDataFilter
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def build_densepose_predictor(cfg: CfgNode, input_channels: int):
|
| 15 |
+
"""
|
| 16 |
+
Create an instance of DensePose predictor based on configuration options.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
cfg (CfgNode): configuration options
|
| 20 |
+
input_channels (int): input tensor size along the channel dimension
|
| 21 |
+
Return:
|
| 22 |
+
An instance of DensePose predictor
|
| 23 |
+
"""
|
| 24 |
+
from .predictors import DENSEPOSE_PREDICTOR_REGISTRY
|
| 25 |
+
|
| 26 |
+
predictor_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME
|
| 27 |
+
return DENSEPOSE_PREDICTOR_REGISTRY.get(predictor_name)(cfg, input_channels)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def build_densepose_data_filter(cfg: CfgNode):
|
| 31 |
+
"""
|
| 32 |
+
Build DensePose data filter which selects data for training
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
cfg (CfgNode): configuration options
|
| 36 |
+
|
| 37 |
+
Return:
|
| 38 |
+
Callable: list(Tensor), list(Instances) -> list(Tensor), list(Instances)
|
| 39 |
+
An instance of DensePose filter, which takes feature tensors and proposals
|
| 40 |
+
as an input and returns filtered features and proposals
|
| 41 |
+
"""
|
| 42 |
+
dp_filter = DensePoseDataFilter(cfg)
|
| 43 |
+
return dp_filter
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def build_densepose_head(cfg: CfgNode, input_channels: int):
|
| 47 |
+
"""
|
| 48 |
+
Build DensePose head based on configurations options
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
cfg (CfgNode): configuration options
|
| 52 |
+
input_channels (int): input tensor size along the channel dimension
|
| 53 |
+
Return:
|
| 54 |
+
An instance of DensePose head
|
| 55 |
+
"""
|
| 56 |
+
from .roi_heads.registry import ROI_DENSEPOSE_HEAD_REGISTRY
|
| 57 |
+
|
| 58 |
+
head_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.NAME
|
| 59 |
+
return ROI_DENSEPOSE_HEAD_REGISTRY.get(head_name)(cfg, input_channels)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def build_densepose_losses(cfg: CfgNode):
|
| 63 |
+
"""
|
| 64 |
+
Build DensePose loss based on configurations options
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
cfg (CfgNode): configuration options
|
| 68 |
+
Return:
|
| 69 |
+
An instance of DensePose loss
|
| 70 |
+
"""
|
| 71 |
+
from .losses import DENSEPOSE_LOSS_REGISTRY
|
| 72 |
+
|
| 73 |
+
loss_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME
|
| 74 |
+
return DENSEPOSE_LOSS_REGISTRY.get(loss_name)(cfg)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def build_densepose_embedder(cfg: CfgNode) -> Optional[nn.Module]:
|
| 78 |
+
"""
|
| 79 |
+
Build embedder used to embed mesh vertices into an embedding space.
|
| 80 |
+
Embedder contains sub-embedders, one for each mesh ID.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
cfg (cfgNode): configuration options
|
| 84 |
+
Return:
|
| 85 |
+
Embedding module
|
| 86 |
+
"""
|
| 87 |
+
if cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS:
|
| 88 |
+
return Embedder(cfg)
|
| 89 |
+
return None
|
CatVTON/densepose/modeling/confidence.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from enum import Enum
|
| 7 |
+
|
| 8 |
+
from detectron2.config import CfgNode
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DensePoseUVConfidenceType(Enum):
|
| 12 |
+
"""
|
| 13 |
+
Statistical model type for confidence learning, possible values:
|
| 14 |
+
- "iid_iso": statistically independent identically distributed residuals
|
| 15 |
+
with anisotropic covariance
|
| 16 |
+
- "indep_aniso": statistically independent residuals with anisotropic
|
| 17 |
+
covariances
|
| 18 |
+
For details, see:
|
| 19 |
+
N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning
|
| 20 |
+
Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
# fmt: off
|
| 24 |
+
IID_ISO = "iid_iso"
|
| 25 |
+
INDEP_ANISO = "indep_aniso"
|
| 26 |
+
# fmt: on
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class DensePoseUVConfidenceConfig:
|
| 31 |
+
"""
|
| 32 |
+
Configuration options for confidence on UV data
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
enabled: bool = False
|
| 36 |
+
# lower bound on UV confidences
|
| 37 |
+
epsilon: float = 0.01
|
| 38 |
+
type: DensePoseUVConfidenceType = DensePoseUVConfidenceType.IID_ISO
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class DensePoseSegmConfidenceConfig:
|
| 43 |
+
"""
|
| 44 |
+
Configuration options for confidence on segmentation
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
enabled: bool = False
|
| 48 |
+
# lower bound on confidence values
|
| 49 |
+
epsilon: float = 0.01
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class DensePoseConfidenceModelConfig:
|
| 54 |
+
"""
|
| 55 |
+
Configuration options for confidence models
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
# confidence for U and V values
|
| 59 |
+
uv_confidence: DensePoseUVConfidenceConfig
|
| 60 |
+
# segmentation confidence
|
| 61 |
+
segm_confidence: DensePoseSegmConfidenceConfig
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def from_cfg(cfg: CfgNode) -> "DensePoseConfidenceModelConfig":
|
| 65 |
+
return DensePoseConfidenceModelConfig(
|
| 66 |
+
uv_confidence=DensePoseUVConfidenceConfig(
|
| 67 |
+
enabled=cfg.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.ENABLED,
|
| 68 |
+
epsilon=cfg.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.EPSILON,
|
| 69 |
+
type=DensePoseUVConfidenceType(cfg.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.TYPE),
|
| 70 |
+
),
|
| 71 |
+
segm_confidence=DensePoseSegmConfidenceConfig(
|
| 72 |
+
enabled=cfg.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE.ENABLED,
|
| 73 |
+
epsilon=cfg.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE.EPSILON,
|
| 74 |
+
),
|
| 75 |
+
)
|
CatVTON/densepose/modeling/densepose_checkpoint.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
|
| 6 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _rename_HRNet_weights(weights):
|
| 10 |
+
# We detect and rename HRNet weights for DensePose. 1956 and 1716 are values that are
|
| 11 |
+
# common to all HRNet pretrained weights, and should be enough to accurately identify them
|
| 12 |
+
if (
|
| 13 |
+
len(weights["model"].keys()) == 1956
|
| 14 |
+
and len([k for k in weights["model"].keys() if k.startswith("stage")]) == 1716
|
| 15 |
+
):
|
| 16 |
+
hrnet_weights = OrderedDict()
|
| 17 |
+
for k in weights["model"].keys():
|
| 18 |
+
hrnet_weights["backbone.bottom_up." + str(k)] = weights["model"][k]
|
| 19 |
+
return {"model": hrnet_weights}
|
| 20 |
+
else:
|
| 21 |
+
return weights
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DensePoseCheckpointer(DetectionCheckpointer):
|
| 25 |
+
"""
|
| 26 |
+
Same as :class:`DetectionCheckpointer`, but is able to handle HRNet weights
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables):
|
| 30 |
+
super().__init__(model, save_dir, save_to_disk=save_to_disk, **checkpointables)
|
| 31 |
+
|
| 32 |
+
def _load_file(self, filename: str) -> object:
|
| 33 |
+
"""
|
| 34 |
+
Adding hrnet support
|
| 35 |
+
"""
|
| 36 |
+
weights = super()._load_file(filename)
|
| 37 |
+
return _rename_HRNet_weights(weights)
|
CatVTON/densepose/modeling/filter.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import List
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from detectron2.config import CfgNode
|
| 9 |
+
from detectron2.structures import Instances
|
| 10 |
+
from detectron2.structures.boxes import matched_pairwise_iou
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DensePoseDataFilter:
|
| 14 |
+
def __init__(self, cfg: CfgNode):
|
| 15 |
+
self.iou_threshold = cfg.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD
|
| 16 |
+
self.keep_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
|
| 17 |
+
|
| 18 |
+
@torch.no_grad()
|
| 19 |
+
def __call__(self, features: List[torch.Tensor], proposals_with_targets: List[Instances]):
|
| 20 |
+
"""
|
| 21 |
+
Filters proposals with targets to keep only the ones relevant for
|
| 22 |
+
DensePose training
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
features (list[Tensor]): input data as a list of features,
|
| 26 |
+
each feature is a tensor. Axis 0 represents the number of
|
| 27 |
+
images `N` in the input data; axes 1-3 are channels,
|
| 28 |
+
height, and width, which may vary between features
|
| 29 |
+
(e.g., if a feature pyramid is used).
|
| 30 |
+
proposals_with_targets (list[Instances]): length `N` list of
|
| 31 |
+
`Instances`. The i-th `Instances` contains instances
|
| 32 |
+
(proposals, GT) for the i-th input image,
|
| 33 |
+
Returns:
|
| 34 |
+
list[Tensor]: filtered features
|
| 35 |
+
list[Instances]: filtered proposals
|
| 36 |
+
"""
|
| 37 |
+
proposals_filtered = []
|
| 38 |
+
# TODO: the commented out code was supposed to correctly deal with situations
|
| 39 |
+
# where no valid DensePose GT is available for certain images. The corresponding
|
| 40 |
+
# image features were sliced and proposals were filtered. This led to performance
|
| 41 |
+
# deterioration, both in terms of runtime and in terms of evaluation results.
|
| 42 |
+
#
|
| 43 |
+
# feature_mask = torch.ones(
|
| 44 |
+
# len(proposals_with_targets),
|
| 45 |
+
# dtype=torch.bool,
|
| 46 |
+
# device=features[0].device if len(features) > 0 else torch.device("cpu"),
|
| 47 |
+
# )
|
| 48 |
+
for i, proposals_per_image in enumerate(proposals_with_targets):
|
| 49 |
+
if not proposals_per_image.has("gt_densepose") and (
|
| 50 |
+
not proposals_per_image.has("gt_masks") or not self.keep_masks
|
| 51 |
+
):
|
| 52 |
+
# feature_mask[i] = 0
|
| 53 |
+
continue
|
| 54 |
+
gt_boxes = proposals_per_image.gt_boxes
|
| 55 |
+
est_boxes = proposals_per_image.proposal_boxes
|
| 56 |
+
# apply match threshold for densepose head
|
| 57 |
+
iou = matched_pairwise_iou(gt_boxes, est_boxes)
|
| 58 |
+
iou_select = iou > self.iou_threshold
|
| 59 |
+
proposals_per_image = proposals_per_image[iou_select] # pyre-ignore[6]
|
| 60 |
+
|
| 61 |
+
N_gt_boxes = len(proposals_per_image.gt_boxes)
|
| 62 |
+
assert N_gt_boxes == len(proposals_per_image.proposal_boxes), (
|
| 63 |
+
f"The number of GT boxes {N_gt_boxes} is different from the "
|
| 64 |
+
f"number of proposal boxes {len(proposals_per_image.proposal_boxes)}"
|
| 65 |
+
)
|
| 66 |
+
# filter out any target without suitable annotation
|
| 67 |
+
if self.keep_masks:
|
| 68 |
+
gt_masks = (
|
| 69 |
+
proposals_per_image.gt_masks
|
| 70 |
+
if hasattr(proposals_per_image, "gt_masks")
|
| 71 |
+
else [None] * N_gt_boxes
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
gt_masks = [None] * N_gt_boxes
|
| 75 |
+
gt_densepose = (
|
| 76 |
+
proposals_per_image.gt_densepose
|
| 77 |
+
if hasattr(proposals_per_image, "gt_densepose")
|
| 78 |
+
else [None] * N_gt_boxes
|
| 79 |
+
)
|
| 80 |
+
assert len(gt_masks) == N_gt_boxes
|
| 81 |
+
assert len(gt_densepose) == N_gt_boxes
|
| 82 |
+
selected_indices = [
|
| 83 |
+
i
|
| 84 |
+
for i, (dp_target, mask_target) in enumerate(zip(gt_densepose, gt_masks))
|
| 85 |
+
if (dp_target is not None) or (mask_target is not None)
|
| 86 |
+
]
|
| 87 |
+
# if not len(selected_indices):
|
| 88 |
+
# feature_mask[i] = 0
|
| 89 |
+
# continue
|
| 90 |
+
if len(selected_indices) != N_gt_boxes:
|
| 91 |
+
proposals_per_image = proposals_per_image[selected_indices] # pyre-ignore[6]
|
| 92 |
+
assert len(proposals_per_image.gt_boxes) == len(proposals_per_image.proposal_boxes)
|
| 93 |
+
proposals_filtered.append(proposals_per_image)
|
| 94 |
+
# features_filtered = [feature[feature_mask] for feature in features]
|
| 95 |
+
# return features_filtered, proposals_filtered
|
| 96 |
+
return features, proposals_filtered
|
CatVTON/densepose/modeling/hrfpn.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
"""
|
| 5 |
+
MIT License
|
| 6 |
+
Copyright (c) 2019 Microsoft
|
| 7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 8 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 9 |
+
in the Software without restriction, including without limitation the rights
|
| 10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 11 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 12 |
+
furnished to do so, subject to the following conditions:
|
| 13 |
+
The above copyright notice and this permission notice shall be included in all
|
| 14 |
+
copies or substantial portions of the Software.
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
|
| 28 |
+
from detectron2.layers import ShapeSpec
|
| 29 |
+
from detectron2.modeling.backbone import BACKBONE_REGISTRY
|
| 30 |
+
from detectron2.modeling.backbone.backbone import Backbone
|
| 31 |
+
|
| 32 |
+
from .hrnet import build_pose_hrnet_backbone
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class HRFPN(Backbone):
|
| 36 |
+
"""HRFPN (High Resolution Feature Pyramids)
|
| 37 |
+
Transforms outputs of HRNet backbone so they are suitable for the ROI_heads
|
| 38 |
+
arXiv: https://arxiv.org/abs/1904.04514
|
| 39 |
+
Adapted from https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/necks/hrfpn.py
|
| 40 |
+
Args:
|
| 41 |
+
bottom_up: (list) output of HRNet
|
| 42 |
+
in_features (list): names of the input features (output of HRNet)
|
| 43 |
+
in_channels (list): number of channels for each branch
|
| 44 |
+
out_channels (int): output channels of feature pyramids
|
| 45 |
+
n_out_features (int): number of output stages
|
| 46 |
+
pooling (str): pooling for generating feature pyramids (from {MAX, AVG})
|
| 47 |
+
share_conv (bool): Have one conv per output, or share one with all the outputs
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
bottom_up,
|
| 53 |
+
in_features,
|
| 54 |
+
n_out_features,
|
| 55 |
+
in_channels,
|
| 56 |
+
out_channels,
|
| 57 |
+
pooling="AVG",
|
| 58 |
+
share_conv=False,
|
| 59 |
+
):
|
| 60 |
+
super(HRFPN, self).__init__()
|
| 61 |
+
assert isinstance(in_channels, list)
|
| 62 |
+
self.bottom_up = bottom_up
|
| 63 |
+
self.in_features = in_features
|
| 64 |
+
self.n_out_features = n_out_features
|
| 65 |
+
self.in_channels = in_channels
|
| 66 |
+
self.out_channels = out_channels
|
| 67 |
+
self.num_ins = len(in_channels)
|
| 68 |
+
self.share_conv = share_conv
|
| 69 |
+
|
| 70 |
+
if self.share_conv:
|
| 71 |
+
self.fpn_conv = nn.Conv2d(
|
| 72 |
+
in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
self.fpn_conv = nn.ModuleList()
|
| 76 |
+
for _ in range(self.n_out_features):
|
| 77 |
+
self.fpn_conv.append(
|
| 78 |
+
nn.Conv2d(
|
| 79 |
+
in_channels=out_channels,
|
| 80 |
+
out_channels=out_channels,
|
| 81 |
+
kernel_size=3,
|
| 82 |
+
padding=1,
|
| 83 |
+
)
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Custom change: Replaces a simple bilinear interpolation
|
| 87 |
+
self.interp_conv = nn.ModuleList()
|
| 88 |
+
for i in range(len(self.in_features)):
|
| 89 |
+
self.interp_conv.append(
|
| 90 |
+
nn.Sequential(
|
| 91 |
+
nn.ConvTranspose2d(
|
| 92 |
+
in_channels=in_channels[i],
|
| 93 |
+
out_channels=in_channels[i],
|
| 94 |
+
kernel_size=4,
|
| 95 |
+
stride=2**i,
|
| 96 |
+
padding=0,
|
| 97 |
+
output_padding=0,
|
| 98 |
+
bias=False,
|
| 99 |
+
),
|
| 100 |
+
nn.BatchNorm2d(in_channels[i], momentum=0.1),
|
| 101 |
+
nn.ReLU(inplace=True),
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Custom change: Replaces a couple (reduction conv + pooling) by one conv
|
| 106 |
+
self.reduction_pooling_conv = nn.ModuleList()
|
| 107 |
+
for i in range(self.n_out_features):
|
| 108 |
+
self.reduction_pooling_conv.append(
|
| 109 |
+
nn.Sequential(
|
| 110 |
+
nn.Conv2d(sum(in_channels), out_channels, kernel_size=2**i, stride=2**i),
|
| 111 |
+
nn.BatchNorm2d(out_channels, momentum=0.1),
|
| 112 |
+
nn.ReLU(inplace=True),
|
| 113 |
+
)
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if pooling == "MAX":
|
| 117 |
+
self.pooling = F.max_pool2d
|
| 118 |
+
else:
|
| 119 |
+
self.pooling = F.avg_pool2d
|
| 120 |
+
|
| 121 |
+
self._out_features = []
|
| 122 |
+
self._out_feature_channels = {}
|
| 123 |
+
self._out_feature_strides = {}
|
| 124 |
+
|
| 125 |
+
for i in range(self.n_out_features):
|
| 126 |
+
self._out_features.append("p%d" % (i + 1))
|
| 127 |
+
self._out_feature_channels.update({self._out_features[-1]: self.out_channels})
|
| 128 |
+
self._out_feature_strides.update({self._out_features[-1]: 2 ** (i + 2)})
|
| 129 |
+
|
| 130 |
+
# default init_weights for conv(msra) and norm in ConvModule
|
| 131 |
+
def init_weights(self):
|
| 132 |
+
for m in self.modules():
|
| 133 |
+
if isinstance(m, nn.Conv2d):
|
| 134 |
+
nn.init.kaiming_normal_(m.weight, a=1)
|
| 135 |
+
nn.init.constant_(m.bias, 0)
|
| 136 |
+
|
| 137 |
+
def forward(self, inputs):
|
| 138 |
+
bottom_up_features = self.bottom_up(inputs)
|
| 139 |
+
assert len(bottom_up_features) == len(self.in_features)
|
| 140 |
+
inputs = [bottom_up_features[f] for f in self.in_features]
|
| 141 |
+
|
| 142 |
+
outs = []
|
| 143 |
+
for i in range(len(inputs)):
|
| 144 |
+
outs.append(self.interp_conv[i](inputs[i]))
|
| 145 |
+
shape_2 = min(o.shape[2] for o in outs)
|
| 146 |
+
shape_3 = min(o.shape[3] for o in outs)
|
| 147 |
+
out = torch.cat([o[:, :, :shape_2, :shape_3] for o in outs], dim=1)
|
| 148 |
+
outs = []
|
| 149 |
+
for i in range(self.n_out_features):
|
| 150 |
+
outs.append(self.reduction_pooling_conv[i](out))
|
| 151 |
+
for i in range(len(outs)): # Make shapes consistent
|
| 152 |
+
outs[-1 - i] = outs[-1 - i][
|
| 153 |
+
:, :, : outs[-1].shape[2] * 2**i, : outs[-1].shape[3] * 2**i
|
| 154 |
+
]
|
| 155 |
+
outputs = []
|
| 156 |
+
for i in range(len(outs)):
|
| 157 |
+
if self.share_conv:
|
| 158 |
+
outputs.append(self.fpn_conv(outs[i]))
|
| 159 |
+
else:
|
| 160 |
+
outputs.append(self.fpn_conv[i](outs[i]))
|
| 161 |
+
|
| 162 |
+
assert len(self._out_features) == len(outputs)
|
| 163 |
+
return dict(zip(self._out_features, outputs))
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@BACKBONE_REGISTRY.register()
|
| 167 |
+
def build_hrfpn_backbone(cfg, input_shape: ShapeSpec) -> HRFPN:
|
| 168 |
+
|
| 169 |
+
in_channels = cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS
|
| 170 |
+
in_features = ["p%d" % (i + 1) for i in range(cfg.MODEL.HRNET.STAGE4.NUM_BRANCHES)]
|
| 171 |
+
n_out_features = len(cfg.MODEL.ROI_HEADS.IN_FEATURES)
|
| 172 |
+
out_channels = cfg.MODEL.HRNET.HRFPN.OUT_CHANNELS
|
| 173 |
+
hrnet = build_pose_hrnet_backbone(cfg, input_shape)
|
| 174 |
+
hrfpn = HRFPN(
|
| 175 |
+
hrnet,
|
| 176 |
+
in_features,
|
| 177 |
+
n_out_features,
|
| 178 |
+
in_channels,
|
| 179 |
+
out_channels,
|
| 180 |
+
pooling="AVG",
|
| 181 |
+
share_conv=False,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
return hrfpn
|
CatVTON/densepose/modeling/hrnet.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# ------------------------------------------------------------------------------
|
| 3 |
+
# Copyright (c) Microsoft
|
| 4 |
+
# Licensed under the MIT License.
|
| 5 |
+
# Written by Bin Xiao (leoxiaobin@gmail.com)
|
| 6 |
+
# Modified by Bowen Cheng (bcheng9@illinois.edu)
|
| 7 |
+
# Adapted from https://github.com/HRNet/Higher-HRNet-Human-Pose-Estimation/blob/master/lib/models/pose_higher_hrnet.py # noqa
|
| 8 |
+
# ------------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
# pyre-unsafe
|
| 11 |
+
|
| 12 |
+
from __future__ import absolute_import, division, print_function
|
| 13 |
+
import logging
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
from detectron2.layers import ShapeSpec
|
| 17 |
+
from detectron2.modeling.backbone import BACKBONE_REGISTRY
|
| 18 |
+
from detectron2.modeling.backbone.backbone import Backbone
|
| 19 |
+
|
| 20 |
+
BN_MOMENTUM = 0.1
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
__all__ = ["build_pose_hrnet_backbone", "PoseHigherResolutionNet"]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 27 |
+
"""3x3 convolution with padding"""
|
| 28 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BasicBlock(nn.Module):
|
| 32 |
+
expansion = 1
|
| 33 |
+
|
| 34 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 35 |
+
super(BasicBlock, self).__init__()
|
| 36 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 37 |
+
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
| 38 |
+
self.relu = nn.ReLU(inplace=True)
|
| 39 |
+
self.conv2 = conv3x3(planes, planes)
|
| 40 |
+
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
| 41 |
+
self.downsample = downsample
|
| 42 |
+
self.stride = stride
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
residual = x
|
| 46 |
+
|
| 47 |
+
out = self.conv1(x)
|
| 48 |
+
out = self.bn1(out)
|
| 49 |
+
out = self.relu(out)
|
| 50 |
+
|
| 51 |
+
out = self.conv2(out)
|
| 52 |
+
out = self.bn2(out)
|
| 53 |
+
|
| 54 |
+
if self.downsample is not None:
|
| 55 |
+
residual = self.downsample(x)
|
| 56 |
+
|
| 57 |
+
out += residual
|
| 58 |
+
out = self.relu(out)
|
| 59 |
+
|
| 60 |
+
return out
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class Bottleneck(nn.Module):
|
| 64 |
+
expansion = 4
|
| 65 |
+
|
| 66 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 67 |
+
super(Bottleneck, self).__init__()
|
| 68 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 69 |
+
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
| 70 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 71 |
+
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
| 72 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
| 73 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
|
| 74 |
+
self.relu = nn.ReLU(inplace=True)
|
| 75 |
+
self.downsample = downsample
|
| 76 |
+
self.stride = stride
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
residual = x
|
| 80 |
+
|
| 81 |
+
out = self.conv1(x)
|
| 82 |
+
out = self.bn1(out)
|
| 83 |
+
out = self.relu(out)
|
| 84 |
+
|
| 85 |
+
out = self.conv2(out)
|
| 86 |
+
out = self.bn2(out)
|
| 87 |
+
out = self.relu(out)
|
| 88 |
+
|
| 89 |
+
out = self.conv3(out)
|
| 90 |
+
out = self.bn3(out)
|
| 91 |
+
|
| 92 |
+
if self.downsample is not None:
|
| 93 |
+
residual = self.downsample(x)
|
| 94 |
+
|
| 95 |
+
out += residual
|
| 96 |
+
out = self.relu(out)
|
| 97 |
+
|
| 98 |
+
return out
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class HighResolutionModule(nn.Module):
|
| 102 |
+
"""HighResolutionModule
|
| 103 |
+
Building block of the PoseHigherResolutionNet (see lower)
|
| 104 |
+
arXiv: https://arxiv.org/abs/1908.10357
|
| 105 |
+
Args:
|
| 106 |
+
num_branches (int): number of branches of the modyle
|
| 107 |
+
blocks (str): type of block of the module
|
| 108 |
+
num_blocks (int): number of blocks of the module
|
| 109 |
+
num_inchannels (int): number of input channels of the module
|
| 110 |
+
num_channels (list): number of channels of each branch
|
| 111 |
+
multi_scale_output (bool): only used by the last module of PoseHigherResolutionNet
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
num_branches,
|
| 117 |
+
blocks,
|
| 118 |
+
num_blocks,
|
| 119 |
+
num_inchannels,
|
| 120 |
+
num_channels,
|
| 121 |
+
multi_scale_output=True,
|
| 122 |
+
):
|
| 123 |
+
super(HighResolutionModule, self).__init__()
|
| 124 |
+
self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels)
|
| 125 |
+
|
| 126 |
+
self.num_inchannels = num_inchannels
|
| 127 |
+
self.num_branches = num_branches
|
| 128 |
+
|
| 129 |
+
self.multi_scale_output = multi_scale_output
|
| 130 |
+
|
| 131 |
+
self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels)
|
| 132 |
+
self.fuse_layers = self._make_fuse_layers()
|
| 133 |
+
self.relu = nn.ReLU(True)
|
| 134 |
+
|
| 135 |
+
def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels):
|
| 136 |
+
if num_branches != len(num_blocks):
|
| 137 |
+
error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(num_branches, len(num_blocks))
|
| 138 |
+
logger.error(error_msg)
|
| 139 |
+
raise ValueError(error_msg)
|
| 140 |
+
|
| 141 |
+
if num_branches != len(num_channels):
|
| 142 |
+
error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
|
| 143 |
+
num_branches, len(num_channels)
|
| 144 |
+
)
|
| 145 |
+
logger.error(error_msg)
|
| 146 |
+
raise ValueError(error_msg)
|
| 147 |
+
|
| 148 |
+
if num_branches != len(num_inchannels):
|
| 149 |
+
error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(
|
| 150 |
+
num_branches, len(num_inchannels)
|
| 151 |
+
)
|
| 152 |
+
logger.error(error_msg)
|
| 153 |
+
raise ValueError(error_msg)
|
| 154 |
+
|
| 155 |
+
def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
|
| 156 |
+
downsample = None
|
| 157 |
+
if (
|
| 158 |
+
stride != 1
|
| 159 |
+
or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion
|
| 160 |
+
):
|
| 161 |
+
downsample = nn.Sequential(
|
| 162 |
+
nn.Conv2d(
|
| 163 |
+
self.num_inchannels[branch_index],
|
| 164 |
+
num_channels[branch_index] * block.expansion,
|
| 165 |
+
kernel_size=1,
|
| 166 |
+
stride=stride,
|
| 167 |
+
bias=False,
|
| 168 |
+
),
|
| 169 |
+
nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM),
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
layers = []
|
| 173 |
+
layers.append(
|
| 174 |
+
block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)
|
| 175 |
+
)
|
| 176 |
+
self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
|
| 177 |
+
for _ in range(1, num_blocks[branch_index]):
|
| 178 |
+
layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
|
| 179 |
+
|
| 180 |
+
return nn.Sequential(*layers)
|
| 181 |
+
|
| 182 |
+
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
| 183 |
+
branches = []
|
| 184 |
+
|
| 185 |
+
for i in range(num_branches):
|
| 186 |
+
branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
|
| 187 |
+
|
| 188 |
+
return nn.ModuleList(branches)
|
| 189 |
+
|
| 190 |
+
def _make_fuse_layers(self):
|
| 191 |
+
if self.num_branches == 1:
|
| 192 |
+
return None
|
| 193 |
+
|
| 194 |
+
num_branches = self.num_branches
|
| 195 |
+
num_inchannels = self.num_inchannels
|
| 196 |
+
fuse_layers = []
|
| 197 |
+
for i in range(num_branches if self.multi_scale_output else 1):
|
| 198 |
+
fuse_layer = []
|
| 199 |
+
for j in range(num_branches):
|
| 200 |
+
if j > i:
|
| 201 |
+
fuse_layer.append(
|
| 202 |
+
nn.Sequential(
|
| 203 |
+
nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False),
|
| 204 |
+
nn.BatchNorm2d(num_inchannels[i]),
|
| 205 |
+
nn.Upsample(scale_factor=2 ** (j - i), mode="nearest"),
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
elif j == i:
|
| 209 |
+
fuse_layer.append(None)
|
| 210 |
+
else:
|
| 211 |
+
conv3x3s = []
|
| 212 |
+
for k in range(i - j):
|
| 213 |
+
if k == i - j - 1:
|
| 214 |
+
num_outchannels_conv3x3 = num_inchannels[i]
|
| 215 |
+
conv3x3s.append(
|
| 216 |
+
nn.Sequential(
|
| 217 |
+
nn.Conv2d(
|
| 218 |
+
num_inchannels[j],
|
| 219 |
+
num_outchannels_conv3x3,
|
| 220 |
+
3,
|
| 221 |
+
2,
|
| 222 |
+
1,
|
| 223 |
+
bias=False,
|
| 224 |
+
),
|
| 225 |
+
nn.BatchNorm2d(num_outchannels_conv3x3),
|
| 226 |
+
)
|
| 227 |
+
)
|
| 228 |
+
else:
|
| 229 |
+
num_outchannels_conv3x3 = num_inchannels[j]
|
| 230 |
+
conv3x3s.append(
|
| 231 |
+
nn.Sequential(
|
| 232 |
+
nn.Conv2d(
|
| 233 |
+
num_inchannels[j],
|
| 234 |
+
num_outchannels_conv3x3,
|
| 235 |
+
3,
|
| 236 |
+
2,
|
| 237 |
+
1,
|
| 238 |
+
bias=False,
|
| 239 |
+
),
|
| 240 |
+
nn.BatchNorm2d(num_outchannels_conv3x3),
|
| 241 |
+
nn.ReLU(True),
|
| 242 |
+
)
|
| 243 |
+
)
|
| 244 |
+
fuse_layer.append(nn.Sequential(*conv3x3s))
|
| 245 |
+
fuse_layers.append(nn.ModuleList(fuse_layer))
|
| 246 |
+
|
| 247 |
+
return nn.ModuleList(fuse_layers)
|
| 248 |
+
|
| 249 |
+
def get_num_inchannels(self):
|
| 250 |
+
return self.num_inchannels
|
| 251 |
+
|
| 252 |
+
def forward(self, x):
|
| 253 |
+
if self.num_branches == 1:
|
| 254 |
+
return [self.branches[0](x[0])]
|
| 255 |
+
|
| 256 |
+
for i in range(self.num_branches):
|
| 257 |
+
x[i] = self.branches[i](x[i])
|
| 258 |
+
|
| 259 |
+
x_fuse = []
|
| 260 |
+
|
| 261 |
+
for i in range(len(self.fuse_layers)):
|
| 262 |
+
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
| 263 |
+
for j in range(1, self.num_branches):
|
| 264 |
+
if i == j:
|
| 265 |
+
y = y + x[j]
|
| 266 |
+
else:
|
| 267 |
+
z = self.fuse_layers[i][j](x[j])[:, :, : y.shape[2], : y.shape[3]]
|
| 268 |
+
y = y + z
|
| 269 |
+
x_fuse.append(self.relu(y))
|
| 270 |
+
|
| 271 |
+
return x_fuse
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck}
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class PoseHigherResolutionNet(Backbone):
|
| 278 |
+
"""PoseHigherResolutionNet
|
| 279 |
+
Composed of several HighResolutionModule tied together with ConvNets
|
| 280 |
+
Adapted from the GitHub version to fit with HRFPN and the Detectron2 infrastructure
|
| 281 |
+
arXiv: https://arxiv.org/abs/1908.10357
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
def __init__(self, cfg, **kwargs):
|
| 285 |
+
self.inplanes = cfg.MODEL.HRNET.STEM_INPLANES
|
| 286 |
+
super(PoseHigherResolutionNet, self).__init__()
|
| 287 |
+
|
| 288 |
+
# stem net
|
| 289 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
| 290 |
+
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
| 291 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
| 292 |
+
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
| 293 |
+
self.relu = nn.ReLU(inplace=True)
|
| 294 |
+
self.layer1 = self._make_layer(Bottleneck, 64, 4)
|
| 295 |
+
|
| 296 |
+
self.stage2_cfg = cfg.MODEL.HRNET.STAGE2
|
| 297 |
+
num_channels = self.stage2_cfg.NUM_CHANNELS
|
| 298 |
+
block = blocks_dict[self.stage2_cfg.BLOCK]
|
| 299 |
+
num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
|
| 300 |
+
self.transition1 = self._make_transition_layer([256], num_channels)
|
| 301 |
+
self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels)
|
| 302 |
+
|
| 303 |
+
self.stage3_cfg = cfg.MODEL.HRNET.STAGE3
|
| 304 |
+
num_channels = self.stage3_cfg.NUM_CHANNELS
|
| 305 |
+
block = blocks_dict[self.stage3_cfg.BLOCK]
|
| 306 |
+
num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
|
| 307 |
+
self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
|
| 308 |
+
self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels)
|
| 309 |
+
|
| 310 |
+
self.stage4_cfg = cfg.MODEL.HRNET.STAGE4
|
| 311 |
+
num_channels = self.stage4_cfg.NUM_CHANNELS
|
| 312 |
+
block = blocks_dict[self.stage4_cfg.BLOCK]
|
| 313 |
+
num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
|
| 314 |
+
self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
|
| 315 |
+
self.stage4, pre_stage_channels = self._make_stage(
|
| 316 |
+
self.stage4_cfg, num_channels, multi_scale_output=True
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
self._out_features = []
|
| 320 |
+
self._out_feature_channels = {}
|
| 321 |
+
self._out_feature_strides = {}
|
| 322 |
+
|
| 323 |
+
for i in range(cfg.MODEL.HRNET.STAGE4.NUM_BRANCHES):
|
| 324 |
+
self._out_features.append("p%d" % (i + 1))
|
| 325 |
+
self._out_feature_channels.update(
|
| 326 |
+
{self._out_features[-1]: cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS[i]}
|
| 327 |
+
)
|
| 328 |
+
self._out_feature_strides.update({self._out_features[-1]: 1})
|
| 329 |
+
|
| 330 |
+
def _get_deconv_cfg(self, deconv_kernel):
|
| 331 |
+
if deconv_kernel == 4:
|
| 332 |
+
padding = 1
|
| 333 |
+
output_padding = 0
|
| 334 |
+
elif deconv_kernel == 3:
|
| 335 |
+
padding = 1
|
| 336 |
+
output_padding = 1
|
| 337 |
+
elif deconv_kernel == 2:
|
| 338 |
+
padding = 0
|
| 339 |
+
output_padding = 0
|
| 340 |
+
|
| 341 |
+
return deconv_kernel, padding, output_padding
|
| 342 |
+
|
| 343 |
+
def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
|
| 344 |
+
num_branches_cur = len(num_channels_cur_layer)
|
| 345 |
+
num_branches_pre = len(num_channels_pre_layer)
|
| 346 |
+
|
| 347 |
+
transition_layers = []
|
| 348 |
+
for i in range(num_branches_cur):
|
| 349 |
+
if i < num_branches_pre:
|
| 350 |
+
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
| 351 |
+
transition_layers.append(
|
| 352 |
+
nn.Sequential(
|
| 353 |
+
nn.Conv2d(
|
| 354 |
+
num_channels_pre_layer[i],
|
| 355 |
+
num_channels_cur_layer[i],
|
| 356 |
+
3,
|
| 357 |
+
1,
|
| 358 |
+
1,
|
| 359 |
+
bias=False,
|
| 360 |
+
),
|
| 361 |
+
nn.BatchNorm2d(num_channels_cur_layer[i]),
|
| 362 |
+
nn.ReLU(inplace=True),
|
| 363 |
+
)
|
| 364 |
+
)
|
| 365 |
+
else:
|
| 366 |
+
transition_layers.append(None)
|
| 367 |
+
else:
|
| 368 |
+
conv3x3s = []
|
| 369 |
+
for j in range(i + 1 - num_branches_pre):
|
| 370 |
+
inchannels = num_channels_pre_layer[-1]
|
| 371 |
+
outchannels = (
|
| 372 |
+
num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels
|
| 373 |
+
)
|
| 374 |
+
conv3x3s.append(
|
| 375 |
+
nn.Sequential(
|
| 376 |
+
nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
|
| 377 |
+
nn.BatchNorm2d(outchannels),
|
| 378 |
+
nn.ReLU(inplace=True),
|
| 379 |
+
)
|
| 380 |
+
)
|
| 381 |
+
transition_layers.append(nn.Sequential(*conv3x3s))
|
| 382 |
+
|
| 383 |
+
return nn.ModuleList(transition_layers)
|
| 384 |
+
|
| 385 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 386 |
+
downsample = None
|
| 387 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 388 |
+
downsample = nn.Sequential(
|
| 389 |
+
nn.Conv2d(
|
| 390 |
+
self.inplanes,
|
| 391 |
+
planes * block.expansion,
|
| 392 |
+
kernel_size=1,
|
| 393 |
+
stride=stride,
|
| 394 |
+
bias=False,
|
| 395 |
+
),
|
| 396 |
+
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
layers = []
|
| 400 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 401 |
+
self.inplanes = planes * block.expansion
|
| 402 |
+
for _ in range(1, blocks):
|
| 403 |
+
layers.append(block(self.inplanes, planes))
|
| 404 |
+
|
| 405 |
+
return nn.Sequential(*layers)
|
| 406 |
+
|
| 407 |
+
def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True):
|
| 408 |
+
num_modules = layer_config["NUM_MODULES"]
|
| 409 |
+
num_branches = layer_config["NUM_BRANCHES"]
|
| 410 |
+
num_blocks = layer_config["NUM_BLOCKS"]
|
| 411 |
+
num_channels = layer_config["NUM_CHANNELS"]
|
| 412 |
+
block = blocks_dict[layer_config["BLOCK"]]
|
| 413 |
+
|
| 414 |
+
modules = []
|
| 415 |
+
for i in range(num_modules):
|
| 416 |
+
# multi_scale_output is only used last module
|
| 417 |
+
if not multi_scale_output and i == num_modules - 1:
|
| 418 |
+
reset_multi_scale_output = False
|
| 419 |
+
else:
|
| 420 |
+
reset_multi_scale_output = True
|
| 421 |
+
|
| 422 |
+
modules.append(
|
| 423 |
+
HighResolutionModule(
|
| 424 |
+
num_branches,
|
| 425 |
+
block,
|
| 426 |
+
num_blocks,
|
| 427 |
+
num_inchannels,
|
| 428 |
+
num_channels,
|
| 429 |
+
reset_multi_scale_output,
|
| 430 |
+
)
|
| 431 |
+
)
|
| 432 |
+
num_inchannels = modules[-1].get_num_inchannels()
|
| 433 |
+
|
| 434 |
+
return nn.Sequential(*modules), num_inchannels
|
| 435 |
+
|
| 436 |
+
def forward(self, x):
|
| 437 |
+
x = self.conv1(x)
|
| 438 |
+
x = self.bn1(x)
|
| 439 |
+
x = self.relu(x)
|
| 440 |
+
x = self.conv2(x)
|
| 441 |
+
x = self.bn2(x)
|
| 442 |
+
x = self.relu(x)
|
| 443 |
+
x = self.layer1(x)
|
| 444 |
+
|
| 445 |
+
x_list = []
|
| 446 |
+
for i in range(self.stage2_cfg.NUM_BRANCHES):
|
| 447 |
+
if self.transition1[i] is not None:
|
| 448 |
+
x_list.append(self.transition1[i](x))
|
| 449 |
+
else:
|
| 450 |
+
x_list.append(x)
|
| 451 |
+
y_list = self.stage2(x_list)
|
| 452 |
+
|
| 453 |
+
x_list = []
|
| 454 |
+
for i in range(self.stage3_cfg.NUM_BRANCHES):
|
| 455 |
+
if self.transition2[i] is not None:
|
| 456 |
+
x_list.append(self.transition2[i](y_list[-1]))
|
| 457 |
+
else:
|
| 458 |
+
x_list.append(y_list[i])
|
| 459 |
+
y_list = self.stage3(x_list)
|
| 460 |
+
|
| 461 |
+
x_list = []
|
| 462 |
+
for i in range(self.stage4_cfg.NUM_BRANCHES):
|
| 463 |
+
if self.transition3[i] is not None:
|
| 464 |
+
x_list.append(self.transition3[i](y_list[-1]))
|
| 465 |
+
else:
|
| 466 |
+
x_list.append(y_list[i])
|
| 467 |
+
y_list = self.stage4(x_list)
|
| 468 |
+
|
| 469 |
+
assert len(self._out_features) == len(y_list)
|
| 470 |
+
return dict(zip(self._out_features, y_list)) # final_outputs
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
@BACKBONE_REGISTRY.register()
|
| 474 |
+
def build_pose_hrnet_backbone(cfg, input_shape: ShapeSpec):
|
| 475 |
+
model = PoseHigherResolutionNet(cfg)
|
| 476 |
+
return model
|
CatVTON/densepose/modeling/inference.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
from dataclasses import fields
|
| 5 |
+
from typing import Any, List
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from detectron2.structures import Instances
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def densepose_inference(densepose_predictor_output: Any, detections: List[Instances]) -> None:
|
| 12 |
+
"""
|
| 13 |
+
Splits DensePose predictor outputs into chunks, each chunk corresponds to
|
| 14 |
+
detections on one image. Predictor output chunks are stored in `pred_densepose`
|
| 15 |
+
attribute of the corresponding `Instances` object.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
densepose_predictor_output: a dataclass instance (can be of different types,
|
| 19 |
+
depending on predictor used for inference). Each field can be `None`
|
| 20 |
+
(if the corresponding output was not inferred) or a tensor of size
|
| 21 |
+
[N, ...], where N = N_1 + N_2 + .. + N_k is a total number of
|
| 22 |
+
detections on all images, N_1 is the number of detections on image 1,
|
| 23 |
+
N_2 is the number of detections on image 2, etc.
|
| 24 |
+
detections: a list of objects of type `Instance`, k-th object corresponds
|
| 25 |
+
to detections on k-th image.
|
| 26 |
+
"""
|
| 27 |
+
k = 0
|
| 28 |
+
for detection_i in detections:
|
| 29 |
+
if densepose_predictor_output is None:
|
| 30 |
+
# don't add `pred_densepose` attribute
|
| 31 |
+
continue
|
| 32 |
+
n_i = detection_i.__len__()
|
| 33 |
+
|
| 34 |
+
PredictorOutput = type(densepose_predictor_output)
|
| 35 |
+
output_i_dict = {}
|
| 36 |
+
# we assume here that `densepose_predictor_output` is a dataclass object
|
| 37 |
+
for field in fields(densepose_predictor_output):
|
| 38 |
+
field_value = getattr(densepose_predictor_output, field.name)
|
| 39 |
+
# slice tensors
|
| 40 |
+
if isinstance(field_value, torch.Tensor):
|
| 41 |
+
output_i_dict[field.name] = field_value[k : k + n_i]
|
| 42 |
+
# leave others as is
|
| 43 |
+
else:
|
| 44 |
+
output_i_dict[field.name] = field_value
|
| 45 |
+
detection_i.pred_densepose = PredictorOutput(**output_i_dict)
|
| 46 |
+
k += n_i
|
CatVTON/densepose/modeling/test_time_augmentation.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
import copy
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from fvcore.transforms import HFlipTransform, TransformList
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
|
| 10 |
+
from detectron2.data.transforms import RandomRotation, RotationTransform, apply_transform_gens
|
| 11 |
+
from detectron2.modeling.postprocessing import detector_postprocess
|
| 12 |
+
from detectron2.modeling.test_time_augmentation import DatasetMapperTTA, GeneralizedRCNNWithTTA
|
| 13 |
+
|
| 14 |
+
from ..converters import HFlipConverter
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DensePoseDatasetMapperTTA(DatasetMapperTTA):
|
| 18 |
+
def __init__(self, cfg):
|
| 19 |
+
super().__init__(cfg=cfg)
|
| 20 |
+
self.angles = cfg.TEST.AUG.ROTATION_ANGLES
|
| 21 |
+
|
| 22 |
+
def __call__(self, dataset_dict):
|
| 23 |
+
ret = super().__call__(dataset_dict=dataset_dict)
|
| 24 |
+
numpy_image = dataset_dict["image"].permute(1, 2, 0).numpy()
|
| 25 |
+
for angle in self.angles:
|
| 26 |
+
rotate = RandomRotation(angle=angle, expand=True)
|
| 27 |
+
new_numpy_image, tfms = apply_transform_gens([rotate], np.copy(numpy_image))
|
| 28 |
+
torch_image = torch.from_numpy(np.ascontiguousarray(new_numpy_image.transpose(2, 0, 1)))
|
| 29 |
+
dic = copy.deepcopy(dataset_dict)
|
| 30 |
+
# In DatasetMapperTTA, there is a pre_tfm transform (resize or no-op) that is
|
| 31 |
+
# added at the beginning of each TransformList. That's '.transforms[0]'.
|
| 32 |
+
dic["transforms"] = TransformList(
|
| 33 |
+
[ret[-1]["transforms"].transforms[0]] + tfms.transforms
|
| 34 |
+
)
|
| 35 |
+
dic["image"] = torch_image
|
| 36 |
+
ret.append(dic)
|
| 37 |
+
return ret
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class DensePoseGeneralizedRCNNWithTTA(GeneralizedRCNNWithTTA):
|
| 41 |
+
def __init__(self, cfg, model, transform_data, tta_mapper=None, batch_size=1):
|
| 42 |
+
"""
|
| 43 |
+
Args:
|
| 44 |
+
cfg (CfgNode):
|
| 45 |
+
model (GeneralizedRCNN): a GeneralizedRCNN to apply TTA on.
|
| 46 |
+
transform_data (DensePoseTransformData): contains symmetry label
|
| 47 |
+
transforms used for horizontal flip
|
| 48 |
+
tta_mapper (callable): takes a dataset dict and returns a list of
|
| 49 |
+
augmented versions of the dataset dict. Defaults to
|
| 50 |
+
`DatasetMapperTTA(cfg)`.
|
| 51 |
+
batch_size (int): batch the augmented images into this batch size for inference.
|
| 52 |
+
"""
|
| 53 |
+
self._transform_data = transform_data.to(model.device)
|
| 54 |
+
super().__init__(cfg=cfg, model=model, tta_mapper=tta_mapper, batch_size=batch_size)
|
| 55 |
+
|
| 56 |
+
# the implementation follows closely the one from detectron2/modeling
|
| 57 |
+
def _inference_one_image(self, input):
|
| 58 |
+
"""
|
| 59 |
+
Args:
|
| 60 |
+
input (dict): one dataset dict with "image" field being a CHW tensor
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
dict: one output dict
|
| 64 |
+
"""
|
| 65 |
+
orig_shape = (input["height"], input["width"])
|
| 66 |
+
# For some reason, resize with uint8 slightly increases box AP but decreases densepose AP
|
| 67 |
+
input["image"] = input["image"].to(torch.uint8)
|
| 68 |
+
augmented_inputs, tfms = self._get_augmented_inputs(input)
|
| 69 |
+
# Detect boxes from all augmented versions
|
| 70 |
+
with self._turn_off_roi_heads(["mask_on", "keypoint_on", "densepose_on"]):
|
| 71 |
+
# temporarily disable roi heads
|
| 72 |
+
all_boxes, all_scores, all_classes = self._get_augmented_boxes(augmented_inputs, tfms)
|
| 73 |
+
merged_instances = self._merge_detections(all_boxes, all_scores, all_classes, orig_shape)
|
| 74 |
+
|
| 75 |
+
if self.cfg.MODEL.MASK_ON or self.cfg.MODEL.DENSEPOSE_ON:
|
| 76 |
+
# Use the detected boxes to obtain new fields
|
| 77 |
+
augmented_instances = self._rescale_detected_boxes(
|
| 78 |
+
augmented_inputs, merged_instances, tfms
|
| 79 |
+
)
|
| 80 |
+
# run forward on the detected boxes
|
| 81 |
+
outputs = self._batch_inference(augmented_inputs, augmented_instances)
|
| 82 |
+
# Delete now useless variables to avoid being out of memory
|
| 83 |
+
del augmented_inputs, augmented_instances
|
| 84 |
+
# average the predictions
|
| 85 |
+
if self.cfg.MODEL.MASK_ON:
|
| 86 |
+
merged_instances.pred_masks = self._reduce_pred_masks(outputs, tfms)
|
| 87 |
+
if self.cfg.MODEL.DENSEPOSE_ON:
|
| 88 |
+
merged_instances.pred_densepose = self._reduce_pred_densepose(outputs, tfms)
|
| 89 |
+
# postprocess
|
| 90 |
+
merged_instances = detector_postprocess(merged_instances, *orig_shape)
|
| 91 |
+
return {"instances": merged_instances}
|
| 92 |
+
else:
|
| 93 |
+
return {"instances": merged_instances}
|
| 94 |
+
|
| 95 |
+
def _get_augmented_boxes(self, augmented_inputs, tfms):
|
| 96 |
+
# Heavily based on detectron2/modeling/test_time_augmentation.py
|
| 97 |
+
# Only difference is that RotationTransform is excluded from bbox computation
|
| 98 |
+
# 1: forward with all augmented images
|
| 99 |
+
outputs = self._batch_inference(augmented_inputs)
|
| 100 |
+
# 2: union the results
|
| 101 |
+
all_boxes = []
|
| 102 |
+
all_scores = []
|
| 103 |
+
all_classes = []
|
| 104 |
+
for output, tfm in zip(outputs, tfms):
|
| 105 |
+
# Need to inverse the transforms on boxes, to obtain results on original image
|
| 106 |
+
if not any(isinstance(t, RotationTransform) for t in tfm.transforms):
|
| 107 |
+
# Some transforms can't compute bbox correctly
|
| 108 |
+
pred_boxes = output.pred_boxes.tensor
|
| 109 |
+
original_pred_boxes = tfm.inverse().apply_box(pred_boxes.cpu().numpy())
|
| 110 |
+
all_boxes.append(torch.from_numpy(original_pred_boxes).to(pred_boxes.device))
|
| 111 |
+
all_scores.extend(output.scores)
|
| 112 |
+
all_classes.extend(output.pred_classes)
|
| 113 |
+
all_boxes = torch.cat(all_boxes, dim=0)
|
| 114 |
+
return all_boxes, all_scores, all_classes
|
| 115 |
+
|
| 116 |
+
def _reduce_pred_densepose(self, outputs, tfms):
|
| 117 |
+
# Should apply inverse transforms on densepose preds.
|
| 118 |
+
# We assume only rotation, resize & flip are used. pred_masks is a scale-invariant
|
| 119 |
+
# representation, so we handle the other ones specially
|
| 120 |
+
for idx, (output, tfm) in enumerate(zip(outputs, tfms)):
|
| 121 |
+
for t in tfm.transforms:
|
| 122 |
+
for attr in ["coarse_segm", "fine_segm", "u", "v"]:
|
| 123 |
+
setattr(
|
| 124 |
+
output.pred_densepose,
|
| 125 |
+
attr,
|
| 126 |
+
_inverse_rotation(
|
| 127 |
+
getattr(output.pred_densepose, attr), output.pred_boxes.tensor, t
|
| 128 |
+
),
|
| 129 |
+
)
|
| 130 |
+
if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
|
| 131 |
+
output.pred_densepose = HFlipConverter.convert(
|
| 132 |
+
output.pred_densepose, self._transform_data
|
| 133 |
+
)
|
| 134 |
+
self._incremental_avg_dp(outputs[0].pred_densepose, output.pred_densepose, idx)
|
| 135 |
+
return outputs[0].pred_densepose
|
| 136 |
+
|
| 137 |
+
# incrementally computed average: u_(n + 1) = u_n + (x_(n+1) - u_n) / (n + 1).
|
| 138 |
+
def _incremental_avg_dp(self, avg, new_el, idx):
|
| 139 |
+
for attr in ["coarse_segm", "fine_segm", "u", "v"]:
|
| 140 |
+
setattr(avg, attr, (getattr(avg, attr) * idx + getattr(new_el, attr)) / (idx + 1))
|
| 141 |
+
if idx:
|
| 142 |
+
# Deletion of the > 0 index intermediary values to prevent GPU OOM
|
| 143 |
+
setattr(new_el, attr, None)
|
| 144 |
+
return avg
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _inverse_rotation(densepose_attrs, boxes, transform):
|
| 148 |
+
# resample outputs to image size and rotate back the densepose preds
|
| 149 |
+
# on the rotated images to the space of the original image
|
| 150 |
+
if len(boxes) == 0 or not isinstance(transform, RotationTransform):
|
| 151 |
+
return densepose_attrs
|
| 152 |
+
boxes = boxes.int().cpu().numpy()
|
| 153 |
+
wh_boxes = boxes[:, 2:] - boxes[:, :2] # bboxes in the rotated space
|
| 154 |
+
inv_boxes = rotate_box_inverse(transform, boxes).astype(int) # bboxes in original image
|
| 155 |
+
wh_diff = (inv_boxes[:, 2:] - inv_boxes[:, :2] - wh_boxes) // 2 # diff between new/old bboxes
|
| 156 |
+
rotation_matrix = torch.tensor([transform.rm_image]).to(device=densepose_attrs.device).float()
|
| 157 |
+
rotation_matrix[:, :, -1] = 0
|
| 158 |
+
# To apply grid_sample for rotation, we need to have enough space to fit the original and
|
| 159 |
+
# rotated bboxes. l_bds and r_bds are the left/right bounds that will be used to
|
| 160 |
+
# crop the difference once the rotation is done
|
| 161 |
+
l_bds = np.maximum(0, -wh_diff)
|
| 162 |
+
for i in range(len(densepose_attrs)):
|
| 163 |
+
if min(wh_boxes[i]) <= 0:
|
| 164 |
+
continue
|
| 165 |
+
densepose_attr = densepose_attrs[[i]].clone()
|
| 166 |
+
# 1. Interpolate densepose attribute to size of the rotated bbox
|
| 167 |
+
densepose_attr = F.interpolate(densepose_attr, wh_boxes[i].tolist()[::-1], mode="bilinear")
|
| 168 |
+
# 2. Pad the interpolated attribute so it has room for the original + rotated bbox
|
| 169 |
+
densepose_attr = F.pad(densepose_attr, tuple(np.repeat(np.maximum(0, wh_diff[i]), 2)))
|
| 170 |
+
# 3. Compute rotation grid and transform
|
| 171 |
+
grid = F.affine_grid(rotation_matrix, size=densepose_attr.shape)
|
| 172 |
+
densepose_attr = F.grid_sample(densepose_attr, grid)
|
| 173 |
+
# 4. Compute right bounds and crop the densepose_attr to the size of the original bbox
|
| 174 |
+
r_bds = densepose_attr.shape[2:][::-1] - l_bds[i]
|
| 175 |
+
densepose_attr = densepose_attr[:, :, l_bds[i][1] : r_bds[1], l_bds[i][0] : r_bds[0]]
|
| 176 |
+
if min(densepose_attr.shape) > 0:
|
| 177 |
+
# Interpolate back to the original size of the densepose attribute
|
| 178 |
+
densepose_attr = F.interpolate(
|
| 179 |
+
densepose_attr, densepose_attrs.shape[-2:], mode="bilinear"
|
| 180 |
+
)
|
| 181 |
+
# Adding a very small probability to the background class to fill padded zones
|
| 182 |
+
densepose_attr[:, 0] += 1e-10
|
| 183 |
+
densepose_attrs[i] = densepose_attr
|
| 184 |
+
return densepose_attrs
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def rotate_box_inverse(rot_tfm, rotated_box):
|
| 188 |
+
"""
|
| 189 |
+
rotated_box is a N * 4 array of [x0, y0, x1, y1] boxes
|
| 190 |
+
When a bbox is rotated, it gets bigger, because we need to surround the tilted bbox
|
| 191 |
+
So when a bbox is rotated then inverse-rotated, it is much bigger than the original
|
| 192 |
+
This function aims to invert the rotation on the box, but also resize it to its original size
|
| 193 |
+
"""
|
| 194 |
+
# 1. Compute the inverse rotation of the rotated bboxes (bigger than it )
|
| 195 |
+
invrot_box = rot_tfm.inverse().apply_box(rotated_box)
|
| 196 |
+
h, w = rotated_box[:, 3] - rotated_box[:, 1], rotated_box[:, 2] - rotated_box[:, 0]
|
| 197 |
+
ih, iw = invrot_box[:, 3] - invrot_box[:, 1], invrot_box[:, 2] - invrot_box[:, 0]
|
| 198 |
+
assert 2 * rot_tfm.abs_sin**2 != 1, "45 degrees angle can't be inverted"
|
| 199 |
+
# 2. Inverse the corresponding computation in the rotation transform
|
| 200 |
+
# to get the original height/width of the rotated boxes
|
| 201 |
+
orig_h = (h * rot_tfm.abs_cos - w * rot_tfm.abs_sin) / (1 - 2 * rot_tfm.abs_sin**2)
|
| 202 |
+
orig_w = (w * rot_tfm.abs_cos - h * rot_tfm.abs_sin) / (1 - 2 * rot_tfm.abs_sin**2)
|
| 203 |
+
# 3. Resize the inverse-rotated bboxes to their original size
|
| 204 |
+
invrot_box[:, 0] += (iw - orig_w) / 2
|
| 205 |
+
invrot_box[:, 1] += (ih - orig_h) / 2
|
| 206 |
+
invrot_box[:, 2] -= (iw - orig_w) / 2
|
| 207 |
+
invrot_box[:, 3] -= (ih - orig_h) / 2
|
| 208 |
+
|
| 209 |
+
return invrot_box
|
CatVTON/densepose/modeling/utils.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def initialize_module_params(module: nn.Module) -> None:
|
| 9 |
+
for name, param in module.named_parameters():
|
| 10 |
+
if "bias" in name:
|
| 11 |
+
nn.init.constant_(param, 0)
|
| 12 |
+
elif "weight" in name:
|
| 13 |
+
nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
|
CatVTON/densepose/utils/__init__.py
ADDED
|
File without changes
|
CatVTON/densepose/utils/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (160 Bytes). View file
|
|
|
CatVTON/densepose/utils/__pycache__/transform.cpython-39.pyc
ADDED
|
Binary file (733 Bytes). View file
|
|
|
CatVTON/densepose/utils/dbhelper.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
from typing import Any, Dict, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class EntrySelector:
|
| 8 |
+
"""
|
| 9 |
+
Base class for entry selectors
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
@staticmethod
|
| 13 |
+
def from_string(spec: str) -> "EntrySelector":
|
| 14 |
+
if spec == "*":
|
| 15 |
+
return AllEntrySelector()
|
| 16 |
+
return FieldEntrySelector(spec)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AllEntrySelector(EntrySelector):
|
| 20 |
+
"""
|
| 21 |
+
Selector that accepts all entries
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
SPECIFIER = "*"
|
| 25 |
+
|
| 26 |
+
def __call__(self, entry):
|
| 27 |
+
return True
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class FieldEntrySelector(EntrySelector):
|
| 31 |
+
"""
|
| 32 |
+
Selector that accepts only entries that match provided field
|
| 33 |
+
specifier(s). Only a limited set of specifiers is supported for now:
|
| 34 |
+
<specifiers>::=<specifier>[<comma><specifiers>]
|
| 35 |
+
<specifier>::=<field_name>[<type_delim><type>]<equal><value_or_range>
|
| 36 |
+
<field_name> is a valid identifier
|
| 37 |
+
<type> ::= "int" | "str"
|
| 38 |
+
<equal> ::= "="
|
| 39 |
+
<comma> ::= ","
|
| 40 |
+
<type_delim> ::= ":"
|
| 41 |
+
<value_or_range> ::= <value> | <range>
|
| 42 |
+
<range> ::= <value><range_delim><value>
|
| 43 |
+
<range_delim> ::= "-"
|
| 44 |
+
<value> is a string without spaces and special symbols
|
| 45 |
+
(e.g. <comma>, <equal>, <type_delim>, <range_delim>)
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
_SPEC_DELIM = ","
|
| 49 |
+
_TYPE_DELIM = ":"
|
| 50 |
+
_RANGE_DELIM = "-"
|
| 51 |
+
_EQUAL = "="
|
| 52 |
+
_ERROR_PREFIX = "Invalid field selector specifier"
|
| 53 |
+
|
| 54 |
+
class _FieldEntryValuePredicate:
|
| 55 |
+
"""
|
| 56 |
+
Predicate that checks strict equality for the specified entry field
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, name: str, typespec: Optional[str], value: str):
|
| 60 |
+
import builtins
|
| 61 |
+
|
| 62 |
+
self.name = name
|
| 63 |
+
self.type = getattr(builtins, typespec) if typespec is not None else str
|
| 64 |
+
self.value = value
|
| 65 |
+
|
| 66 |
+
def __call__(self, entry):
|
| 67 |
+
return entry[self.name] == self.type(self.value)
|
| 68 |
+
|
| 69 |
+
class _FieldEntryRangePredicate:
|
| 70 |
+
"""
|
| 71 |
+
Predicate that checks whether an entry field falls into the specified range
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, name: str, typespec: Optional[str], vmin: str, vmax: str):
|
| 75 |
+
import builtins
|
| 76 |
+
|
| 77 |
+
self.name = name
|
| 78 |
+
self.type = getattr(builtins, typespec) if typespec is not None else str
|
| 79 |
+
self.vmin = vmin
|
| 80 |
+
self.vmax = vmax
|
| 81 |
+
|
| 82 |
+
def __call__(self, entry):
|
| 83 |
+
return (entry[self.name] >= self.type(self.vmin)) and (
|
| 84 |
+
entry[self.name] <= self.type(self.vmax)
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def __init__(self, spec: str):
|
| 88 |
+
self._predicates = self._parse_specifier_into_predicates(spec)
|
| 89 |
+
|
| 90 |
+
def __call__(self, entry: Dict[str, Any]):
|
| 91 |
+
for predicate in self._predicates:
|
| 92 |
+
if not predicate(entry):
|
| 93 |
+
return False
|
| 94 |
+
return True
|
| 95 |
+
|
| 96 |
+
def _parse_specifier_into_predicates(self, spec: str):
|
| 97 |
+
predicates = []
|
| 98 |
+
specs = spec.split(self._SPEC_DELIM)
|
| 99 |
+
for subspec in specs:
|
| 100 |
+
eq_idx = subspec.find(self._EQUAL)
|
| 101 |
+
if eq_idx > 0:
|
| 102 |
+
field_name_with_type = subspec[:eq_idx]
|
| 103 |
+
field_name, field_type = self._parse_field_name_type(field_name_with_type)
|
| 104 |
+
field_value_or_range = subspec[eq_idx + 1 :]
|
| 105 |
+
if self._is_range_spec(field_value_or_range):
|
| 106 |
+
vmin, vmax = self._get_range_spec(field_value_or_range)
|
| 107 |
+
predicate = FieldEntrySelector._FieldEntryRangePredicate(
|
| 108 |
+
field_name, field_type, vmin, vmax
|
| 109 |
+
)
|
| 110 |
+
else:
|
| 111 |
+
predicate = FieldEntrySelector._FieldEntryValuePredicate(
|
| 112 |
+
field_name, field_type, field_value_or_range
|
| 113 |
+
)
|
| 114 |
+
predicates.append(predicate)
|
| 115 |
+
elif eq_idx == 0:
|
| 116 |
+
self._parse_error(f'"{subspec}", field name is empty!')
|
| 117 |
+
else:
|
| 118 |
+
self._parse_error(f'"{subspec}", should have format ' "<field>=<value_or_range>!")
|
| 119 |
+
return predicates
|
| 120 |
+
|
| 121 |
+
def _parse_field_name_type(self, field_name_with_type: str) -> Tuple[str, Optional[str]]:
|
| 122 |
+
type_delim_idx = field_name_with_type.find(self._TYPE_DELIM)
|
| 123 |
+
if type_delim_idx > 0:
|
| 124 |
+
field_name = field_name_with_type[:type_delim_idx]
|
| 125 |
+
field_type = field_name_with_type[type_delim_idx + 1 :]
|
| 126 |
+
elif type_delim_idx == 0:
|
| 127 |
+
self._parse_error(f'"{field_name_with_type}", field name is empty!')
|
| 128 |
+
else:
|
| 129 |
+
field_name = field_name_with_type
|
| 130 |
+
field_type = None
|
| 131 |
+
# pyre-fixme[61]: `field_name` may not be initialized here.
|
| 132 |
+
# pyre-fixme[61]: `field_type` may not be initialized here.
|
| 133 |
+
return field_name, field_type
|
| 134 |
+
|
| 135 |
+
def _is_range_spec(self, field_value_or_range):
|
| 136 |
+
delim_idx = field_value_or_range.find(self._RANGE_DELIM)
|
| 137 |
+
return delim_idx > 0
|
| 138 |
+
|
| 139 |
+
def _get_range_spec(self, field_value_or_range):
|
| 140 |
+
if self._is_range_spec(field_value_or_range):
|
| 141 |
+
delim_idx = field_value_or_range.find(self._RANGE_DELIM)
|
| 142 |
+
vmin = field_value_or_range[:delim_idx]
|
| 143 |
+
vmax = field_value_or_range[delim_idx + 1 :]
|
| 144 |
+
return vmin, vmax
|
| 145 |
+
else:
|
| 146 |
+
self._parse_error('"field_value_or_range", range of values expected!')
|
| 147 |
+
|
| 148 |
+
def _parse_error(self, msg):
|
| 149 |
+
raise ValueError(f"{self._ERROR_PREFIX}: {msg}")
|
CatVTON/densepose/utils/logger.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def verbosity_to_level(verbosity) -> int:
|
| 8 |
+
if verbosity is not None:
|
| 9 |
+
if verbosity == 0:
|
| 10 |
+
return logging.WARNING
|
| 11 |
+
elif verbosity == 1:
|
| 12 |
+
return logging.INFO
|
| 13 |
+
elif verbosity >= 2:
|
| 14 |
+
return logging.DEBUG
|
| 15 |
+
return logging.WARNING
|
CatVTON/densepose/utils/transform.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
from detectron2.data import MetadataCatalog
|
| 5 |
+
from detectron2.utils.file_io import PathManager
|
| 6 |
+
|
| 7 |
+
from densepose import DensePoseTransformData
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load_for_dataset(dataset_name):
|
| 11 |
+
path = MetadataCatalog.get(dataset_name).densepose_transform_src
|
| 12 |
+
densepose_transform_data_fpath = PathManager.get_local_path(path)
|
| 13 |
+
return DensePoseTransformData.load(densepose_transform_data_fpath)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_from_cfg(cfg):
|
| 17 |
+
return load_for_dataset(cfg.DATASETS.TEST[0])
|
CatVTON/model/DensePose/__init__.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import glob
|
| 3 |
+
import os
|
| 4 |
+
from random import randint
|
| 5 |
+
import shutil
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from densepose import add_densepose_config
|
| 13 |
+
from densepose.vis.base import CompoundVisualizer
|
| 14 |
+
from densepose.vis.densepose_results import DensePoseResultsFineSegmentationVisualizer
|
| 15 |
+
from densepose.vis.extractor import create_extractor, CompoundExtractor
|
| 16 |
+
from detectron2.config import get_cfg
|
| 17 |
+
from detectron2.data.detection_utils import read_image
|
| 18 |
+
from detectron2.engine.defaults import DefaultPredictor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DensePose:
|
| 22 |
+
"""
|
| 23 |
+
DensePose used in this project is from Detectron2 (https://github.com/facebookresearch/detectron2).
|
| 24 |
+
These codes are modified from https://github.com/facebookresearch/detectron2/tree/main/projects/DensePose.
|
| 25 |
+
The checkpoint is downloaded from https://github.com/facebookresearch/detectron2/blob/main/projects/DensePose/doc/DENSEPOSE_IUV.md#ModelZoo.
|
| 26 |
+
|
| 27 |
+
We use the model R_50_FPN_s1x with id 165712039, but other models should also work.
|
| 28 |
+
The config file is downloaded from https://github.com/facebookresearch/detectron2/tree/main/projects/DensePose/configs.
|
| 29 |
+
Noted that the config file should match the model checkpoint and Base-DensePose-RCNN-FPN.yaml is also needed.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, model_path="./checkpoints/densepose_", device="cuda"):
|
| 33 |
+
self.device = device
|
| 34 |
+
self.config_path = os.path.join(model_path, 'densepose_rcnn_R_50_FPN_s1x.yaml')
|
| 35 |
+
self.model_path = os.path.join(model_path, 'model_final_162be9.pkl')
|
| 36 |
+
self.visualizations = ["dp_segm"]
|
| 37 |
+
self.VISUALIZERS = {"dp_segm": DensePoseResultsFineSegmentationVisualizer}
|
| 38 |
+
self.min_score = 0.8
|
| 39 |
+
|
| 40 |
+
self.cfg = self.setup_config()
|
| 41 |
+
self.predictor = DefaultPredictor(self.cfg)
|
| 42 |
+
self.predictor.model.to(self.device)
|
| 43 |
+
|
| 44 |
+
def setup_config(self):
|
| 45 |
+
opts = ["MODEL.ROI_HEADS.SCORE_THRESH_TEST", str(self.min_score)]
|
| 46 |
+
cfg = get_cfg()
|
| 47 |
+
add_densepose_config(cfg)
|
| 48 |
+
cfg.merge_from_file(self.config_path)
|
| 49 |
+
cfg.merge_from_list(opts)
|
| 50 |
+
cfg.MODEL.WEIGHTS = self.model_path
|
| 51 |
+
cfg.freeze()
|
| 52 |
+
return cfg
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def _get_input_file_list(input_spec: str):
|
| 56 |
+
if os.path.isdir(input_spec):
|
| 57 |
+
file_list = [os.path.join(input_spec, fname) for fname in os.listdir(input_spec)
|
| 58 |
+
if os.path.isfile(os.path.join(input_spec, fname))]
|
| 59 |
+
elif os.path.isfile(input_spec):
|
| 60 |
+
file_list = [input_spec]
|
| 61 |
+
else:
|
| 62 |
+
file_list = glob.glob(input_spec)
|
| 63 |
+
return file_list
|
| 64 |
+
|
| 65 |
+
def create_context(self, cfg, output_path):
|
| 66 |
+
vis_specs = self.visualizations
|
| 67 |
+
visualizers = []
|
| 68 |
+
extractors = []
|
| 69 |
+
for vis_spec in vis_specs:
|
| 70 |
+
texture_atlas = texture_atlases_dict = None
|
| 71 |
+
vis = self.VISUALIZERS[vis_spec](
|
| 72 |
+
cfg=cfg,
|
| 73 |
+
texture_atlas=texture_atlas,
|
| 74 |
+
texture_atlases_dict=texture_atlases_dict,
|
| 75 |
+
alpha=1.0
|
| 76 |
+
)
|
| 77 |
+
visualizers.append(vis)
|
| 78 |
+
extractor = create_extractor(vis)
|
| 79 |
+
extractors.append(extractor)
|
| 80 |
+
visualizer = CompoundVisualizer(visualizers)
|
| 81 |
+
extractor = CompoundExtractor(extractors)
|
| 82 |
+
context = {
|
| 83 |
+
"extractor": extractor,
|
| 84 |
+
"visualizer": visualizer,
|
| 85 |
+
"out_fname": output_path,
|
| 86 |
+
"entry_idx": 0,
|
| 87 |
+
}
|
| 88 |
+
return context
|
| 89 |
+
|
| 90 |
+
def execute_on_outputs(self, context, entry, outputs):
|
| 91 |
+
extractor = context["extractor"]
|
| 92 |
+
|
| 93 |
+
data = extractor(outputs)
|
| 94 |
+
|
| 95 |
+
H, W, _ = entry["image"].shape
|
| 96 |
+
result = np.zeros((H, W), dtype=np.uint8)
|
| 97 |
+
|
| 98 |
+
data, box = data[0]
|
| 99 |
+
x, y, w, h = [int(_) for _ in box[0].cpu().numpy()]
|
| 100 |
+
i_array = data[0].labels[None].cpu().numpy()[0]
|
| 101 |
+
result[y:y + h, x:x + w] = i_array
|
| 102 |
+
result = Image.fromarray(result)
|
| 103 |
+
result.save(context["out_fname"])
|
| 104 |
+
|
| 105 |
+
def __call__(self, image_or_path, resize=512) -> Image.Image:
|
| 106 |
+
"""
|
| 107 |
+
:param image_or_path: Path of the input image.
|
| 108 |
+
:param resize: Resize the input image if its max size is larger than this value.
|
| 109 |
+
:return: Dense pose image.
|
| 110 |
+
"""
|
| 111 |
+
# random tmp path with timestamp
|
| 112 |
+
tmp_path = f"./densepose_/tmp/"
|
| 113 |
+
if not os.path.exists(tmp_path):
|
| 114 |
+
os.makedirs(tmp_path)
|
| 115 |
+
|
| 116 |
+
image_path = os.path.join(tmp_path, f"{int(time.time())}-{self.device}-{randint(0, 100000)}.png")
|
| 117 |
+
if isinstance(image_or_path, str):
|
| 118 |
+
assert image_or_path.split(".")[-1] in ["jpg", "png"], "Only support jpg and png images."
|
| 119 |
+
shutil.copy(image_or_path, image_path)
|
| 120 |
+
elif isinstance(image_or_path, Image.Image):
|
| 121 |
+
image_or_path.save(image_path)
|
| 122 |
+
else:
|
| 123 |
+
shutil.rmtree(tmp_path)
|
| 124 |
+
raise TypeError("image_path must be str or PIL.Image.Image")
|
| 125 |
+
|
| 126 |
+
output_path = image_path.replace(".png", "_dense.png").replace(".jpg", "_dense.png")
|
| 127 |
+
w, h = Image.open(image_path).size
|
| 128 |
+
|
| 129 |
+
file_list = self._get_input_file_list(image_path)
|
| 130 |
+
assert len(file_list), "No input images found!"
|
| 131 |
+
context = self.create_context(self.cfg, output_path)
|
| 132 |
+
for file_name in file_list:
|
| 133 |
+
img = read_image(file_name, format="BGR") # predictor expects BGR image.
|
| 134 |
+
# resize
|
| 135 |
+
if (_ := max(img.shape)) > resize:
|
| 136 |
+
scale = resize / _
|
| 137 |
+
img = cv2.resize(img, (int(img.shape[1] * scale), int(img.shape[0] * scale)))
|
| 138 |
+
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
outputs = self.predictor(img)["instances"]
|
| 141 |
+
try:
|
| 142 |
+
self.execute_on_outputs(context, {"file_name": file_name, "image": img}, outputs)
|
| 143 |
+
except Exception as e:
|
| 144 |
+
null_gray = Image.new('L', (1, 1))
|
| 145 |
+
null_gray.save(output_path)
|
| 146 |
+
|
| 147 |
+
dense_gray = Image.open(output_path).convert("L")
|
| 148 |
+
dense_gray = dense_gray.resize((w, h), Image.NEAREST)
|
| 149 |
+
# remove image_path and output_path
|
| 150 |
+
os.remove(image_path)
|
| 151 |
+
os.remove(output_path)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
return dense_gray
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
if __name__ == '__main__':
|
| 158 |
+
pass
|
CatVTON/model/DensePose/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (5.85 kB). View file
|
|
|
CatVTON/model/DensePose/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (8.91 kB). View file
|
|
|
CatVTON/model/DensePose/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (5.83 kB). View file
|
|
|
CatVTON/model/SCHP/__init__.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from model.SCHP import networks
|
| 2 |
+
from model.SCHP.utils.transforms import get_affine_transform, transform_logits
|
| 3 |
+
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
|
| 11 |
+
def get_palette(num_cls):
|
| 12 |
+
""" Returns the color map for visualizing the segmentation mask.
|
| 13 |
+
Args:
|
| 14 |
+
num_cls: Number of classes
|
| 15 |
+
Returns:
|
| 16 |
+
The color map
|
| 17 |
+
"""
|
| 18 |
+
n = num_cls
|
| 19 |
+
palette = [0] * (n * 3)
|
| 20 |
+
for j in range(0, n):
|
| 21 |
+
lab = j
|
| 22 |
+
palette[j * 3 + 0] = 0
|
| 23 |
+
palette[j * 3 + 1] = 0
|
| 24 |
+
palette[j * 3 + 2] = 0
|
| 25 |
+
i = 0
|
| 26 |
+
while lab:
|
| 27 |
+
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
|
| 28 |
+
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
|
| 29 |
+
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
|
| 30 |
+
i += 1
|
| 31 |
+
lab >>= 3
|
| 32 |
+
return palette
|
| 33 |
+
|
| 34 |
+
dataset_settings = {
|
| 35 |
+
'lip': {
|
| 36 |
+
'input_size': [473, 473],
|
| 37 |
+
'num_classes': 20,
|
| 38 |
+
'label': ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat',
|
| 39 |
+
'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm',
|
| 40 |
+
'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe']
|
| 41 |
+
},
|
| 42 |
+
'atr': {
|
| 43 |
+
'input_size': [512, 512],
|
| 44 |
+
'num_classes': 18,
|
| 45 |
+
'label': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt',
|
| 46 |
+
'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf']
|
| 47 |
+
},
|
| 48 |
+
'pascal': {
|
| 49 |
+
'input_size': [512, 512],
|
| 50 |
+
'num_classes': 7,
|
| 51 |
+
'label': ['Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs'],
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
class SCHP:
|
| 56 |
+
def __init__(self, ckpt_path, device):
|
| 57 |
+
dataset_type = None
|
| 58 |
+
if 'lip' in ckpt_path:
|
| 59 |
+
dataset_type = 'lip'
|
| 60 |
+
elif 'atr' in ckpt_path:
|
| 61 |
+
dataset_type = 'atr'
|
| 62 |
+
elif 'pascal' in ckpt_path:
|
| 63 |
+
dataset_type = 'pascal'
|
| 64 |
+
assert dataset_type is not None, 'Dataset type not found in checkpoint path'
|
| 65 |
+
self.device = device
|
| 66 |
+
self.num_classes = dataset_settings[dataset_type]['num_classes']
|
| 67 |
+
self.input_size = dataset_settings[dataset_type]['input_size']
|
| 68 |
+
self.aspect_ratio = self.input_size[1] * 1.0 / self.input_size[0]
|
| 69 |
+
self.palette = get_palette(self.num_classes)
|
| 70 |
+
|
| 71 |
+
self.label = dataset_settings[dataset_type]['label']
|
| 72 |
+
self.model = networks.init_model('resnet101', num_classes=self.num_classes, pretrained=None).to(device)
|
| 73 |
+
self.load_ckpt(ckpt_path)
|
| 74 |
+
self.model.eval()
|
| 75 |
+
|
| 76 |
+
self.transform = transforms.Compose([
|
| 77 |
+
transforms.ToTensor(),
|
| 78 |
+
transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
|
| 79 |
+
])
|
| 80 |
+
self.upsample = torch.nn.Upsample(size=self.input_size, mode='bilinear', align_corners=True)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def load_ckpt(self, ckpt_path):
|
| 84 |
+
rename_map = {
|
| 85 |
+
"decoder.conv3.2.weight": "decoder.conv3.3.weight",
|
| 86 |
+
"decoder.conv3.3.weight": "decoder.conv3.4.weight",
|
| 87 |
+
"decoder.conv3.3.bias": "decoder.conv3.4.bias",
|
| 88 |
+
"decoder.conv3.3.running_mean": "decoder.conv3.4.running_mean",
|
| 89 |
+
"decoder.conv3.3.running_var": "decoder.conv3.4.running_var",
|
| 90 |
+
"fushion.3.weight": "fushion.4.weight",
|
| 91 |
+
"fushion.3.bias": "fushion.4.bias",
|
| 92 |
+
}
|
| 93 |
+
state_dict = torch.load(ckpt_path, map_location='cpu')['state_dict']
|
| 94 |
+
new_state_dict = OrderedDict()
|
| 95 |
+
for k, v in state_dict.items():
|
| 96 |
+
name = k[7:] # remove `module.`
|
| 97 |
+
new_state_dict[name] = v
|
| 98 |
+
new_state_dict_ = OrderedDict()
|
| 99 |
+
for k, v in list(new_state_dict.items()):
|
| 100 |
+
if k in rename_map:
|
| 101 |
+
new_state_dict_[rename_map[k]] = v
|
| 102 |
+
else:
|
| 103 |
+
new_state_dict_[k] = v
|
| 104 |
+
self.model.load_state_dict(new_state_dict_, strict=False)
|
| 105 |
+
|
| 106 |
+
def _box2cs(self, box):
|
| 107 |
+
x, y, w, h = box[:4]
|
| 108 |
+
return self._xywh2cs(x, y, w, h)
|
| 109 |
+
|
| 110 |
+
def _xywh2cs(self, x, y, w, h):
|
| 111 |
+
center = np.zeros((2), dtype=np.float32)
|
| 112 |
+
center[0] = x + w * 0.5
|
| 113 |
+
center[1] = y + h * 0.5
|
| 114 |
+
if w > self.aspect_ratio * h:
|
| 115 |
+
h = w * 1.0 / self.aspect_ratio
|
| 116 |
+
elif w < self.aspect_ratio * h:
|
| 117 |
+
w = h * self.aspect_ratio
|
| 118 |
+
scale = np.array([w, h], dtype=np.float32)
|
| 119 |
+
return center, scale
|
| 120 |
+
|
| 121 |
+
def preprocess(self, image):
|
| 122 |
+
if isinstance(image, str):
|
| 123 |
+
img = cv2.imread(image, cv2.IMREAD_COLOR)
|
| 124 |
+
elif isinstance(image, Image.Image):
|
| 125 |
+
# to cv2 format
|
| 126 |
+
img = np.array(image)
|
| 127 |
+
|
| 128 |
+
h, w, _ = img.shape
|
| 129 |
+
# Get person center and scale
|
| 130 |
+
person_center, s = self._box2cs([0, 0, w - 1, h - 1])
|
| 131 |
+
r = 0
|
| 132 |
+
trans = get_affine_transform(person_center, s, r, self.input_size)
|
| 133 |
+
input = cv2.warpAffine(
|
| 134 |
+
img,
|
| 135 |
+
trans,
|
| 136 |
+
(int(self.input_size[1]), int(self.input_size[0])),
|
| 137 |
+
flags=cv2.INTER_LINEAR,
|
| 138 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 139 |
+
borderValue=(0, 0, 0))
|
| 140 |
+
|
| 141 |
+
input = self.transform(input).to(self.device).unsqueeze(0)
|
| 142 |
+
meta = {
|
| 143 |
+
'center': person_center,
|
| 144 |
+
'height': h,
|
| 145 |
+
'width': w,
|
| 146 |
+
'scale': s,
|
| 147 |
+
'rotation': r
|
| 148 |
+
}
|
| 149 |
+
return input, meta
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def __call__(self, image_or_path):
|
| 153 |
+
if isinstance(image_or_path, list):
|
| 154 |
+
image_list = []
|
| 155 |
+
meta_list = []
|
| 156 |
+
for image in image_or_path:
|
| 157 |
+
image, meta = self.preprocess(image)
|
| 158 |
+
image_list.append(image)
|
| 159 |
+
meta_list.append(meta)
|
| 160 |
+
image = torch.cat(image_list, dim=0)
|
| 161 |
+
else:
|
| 162 |
+
image, meta = self.preprocess(image_or_path)
|
| 163 |
+
meta_list = [meta]
|
| 164 |
+
|
| 165 |
+
output = self.model(image)
|
| 166 |
+
# upsample_outputs = self.upsample(output[0][-1])
|
| 167 |
+
upsample_outputs = self.upsample(output)
|
| 168 |
+
upsample_outputs = upsample_outputs.permute(0, 2, 3, 1) # BCHW -> BHWC
|
| 169 |
+
|
| 170 |
+
output_img_list = []
|
| 171 |
+
for upsample_output, meta in zip(upsample_outputs, meta_list):
|
| 172 |
+
c, s, w, h = meta['center'], meta['scale'], meta['width'], meta['height']
|
| 173 |
+
logits_result = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, input_size=self.input_size)
|
| 174 |
+
parsing_result = np.argmax(logits_result, axis=2)
|
| 175 |
+
output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
|
| 176 |
+
output_img.putpalette(self.palette)
|
| 177 |
+
output_img_list.append(output_img)
|
| 178 |
+
|
| 179 |
+
return output_img_list[0] if len(output_img_list) == 1 else output_img_list
|