Spaces:
Runtime error
Runtime error
Commit ·
d91c189
1
Parent(s): a48a9ba
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +163 -0
- CONTRIBUTING.md +20 -0
- README.md +38 -14
- caption_images.py +52 -0
- demo/__init__.py +3 -0
- demo/extract_garment/README.md +14 -0
- demo/extract_garment/__init__.py +1 -0
- demo/extract_garment/app.py +76 -0
- demo/extract_garment/requirements.txt +3 -0
- demo/model_swap/.gitignore +1 -0
- demo/model_swap/README.md +14 -0
- demo/model_swap/__init__.py +1 -0
- demo/model_swap/app.py +321 -0
- demo/model_swap/requirements.txt +2 -0
- demo/outfit_generator/README.md +86 -0
- demo/outfit_generator/__init__.py +1 -0
- demo/outfit_generator/app.py +164 -0
- demo/outfit_generator/images/sample1.jpeg +0 -0
- demo/outfit_generator/images/sample2.jpeg +0 -0
- demo/outfit_generator/images/sample3.jpeg +0 -0
- demo/outfit_generator/images/sample4.jpeg +0 -0
- demo/outfit_generator/requirements.txt +10 -0
- environment.yml +179 -0
- main.py +44 -0
- requirements.txt +15 -0
- run_demo.py +18 -0
- run_ootd.py +37 -0
- scripts/install_conda.sh +10 -0
- scripts/install_sam2.sh +11 -0
- setup.py +31 -0
- tryon/README.md +34 -0
- tryon/__init__.py +0 -0
- tryon/models/__init__.py +0 -0
- tryon/models/ootdiffusion/setup.sh +30 -0
- tryon/preprocessing/__init__.py +3 -0
- tryon/preprocessing/captioning/__init__.py +2 -0
- tryon/preprocessing/captioning/generate_caption.py +108 -0
- tryon/preprocessing/extract_garment_new.py +91 -0
- tryon/preprocessing/preprocess_garment.py +107 -0
- tryon/preprocessing/preprocess_human.py +86 -0
- tryon/preprocessing/sam2/__init__.py +23 -0
- tryon/preprocessing/u2net/__init__.py +3 -0
- tryon/preprocessing/u2net/data_loader.py +277 -0
- tryon/preprocessing/u2net/load_u2net.py +47 -0
- tryon/preprocessing/u2net/u2net_cloth_segm.py +550 -0
- tryon/preprocessing/u2net/u2net_human_segm.py +520 -0
- tryon/preprocessing/u2net/utils.py +10 -0
- tryon/preprocessing/utils.py +91 -0
- tryondiffusion/__init__.py +0 -0
- tryondiffusion/diffusion.py +275 -0
.gitignore
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
env/
|
| 126 |
+
venv/
|
| 127 |
+
ENV/
|
| 128 |
+
env.bak/
|
| 129 |
+
venv.bak/
|
| 130 |
+
|
| 131 |
+
# Spyder project settings
|
| 132 |
+
.spyderproject
|
| 133 |
+
.spyproject
|
| 134 |
+
|
| 135 |
+
# Rope project settings
|
| 136 |
+
.ropeproject
|
| 137 |
+
|
| 138 |
+
# mkdocs documentation
|
| 139 |
+
/site
|
| 140 |
+
|
| 141 |
+
# mypy
|
| 142 |
+
.mypy_cache/
|
| 143 |
+
.dmypy.json
|
| 144 |
+
dmypy.json
|
| 145 |
+
|
| 146 |
+
# Pyre type checker
|
| 147 |
+
.pyre/
|
| 148 |
+
|
| 149 |
+
# pytype static type analyzer
|
| 150 |
+
.pytype/
|
| 151 |
+
|
| 152 |
+
# Cython debug symbols
|
| 153 |
+
cython_debug/
|
| 154 |
+
|
| 155 |
+
# PyCharm
|
| 156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
+
#.idea/
|
| 161 |
+
|
| 162 |
+
u2net_cloth_segm.pth
|
| 163 |
+
u2net_segm.pth
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## How to contribute to tryondiffusion
|
| 2 |
+
|
| 3 |
+
### 1. Open an issue
|
| 4 |
+
We recommend opening an issue (if one doesn't already exist) and discussing your intended changes before making any changes.
|
| 5 |
+
We'll be able to provide you feedback and confirm the planned modifications this way.
|
| 6 |
+
|
| 7 |
+
### 2. Make changes in the code
|
| 8 |
+
Start with forking the repository, set up the environment, install the dependencies, and make changes in the code appropriately.
|
| 9 |
+
|
| 10 |
+
### 3. Create pull request
|
| 11 |
+
Create a pull request to the main branch from your fork's branch.
|
| 12 |
+
|
| 13 |
+
### 4. Merging pull request
|
| 14 |
+
Once the pull request is created, we will review the code changes and merge the pull request as soon as possible.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
### Writing documentation
|
| 18 |
+
|
| 19 |
+
If you are interested in writing the documentation, you can add it to README.md and create a pull request.
|
| 20 |
+
For now, we are maintaining our documentation in a single file and we will add more files as it grows.
|
README.md
CHANGED
|
@@ -1,14 +1,38 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Try On Diffusion: A Tale of Two UNets Implementation
|
| 2 |
+
### [Paper Link](https://arxiv.org/abs/2306.08276)
|
| 3 |
+
|
| 4 |
+
### [Click here](https://discord.gg/T5mPpZHxkY) to join our discord channel
|
| 5 |
+
|
| 6 |
+
## Roadmap
|
| 7 |
+
|
| 8 |
+
1. ~~Prepare initial implementation~~
|
| 9 |
+
1. Test initial implementation with small dataset (VITON-HD)
|
| 10 |
+
1. Gather sufficient data and compute resources
|
| 11 |
+
1. Prepare and train final implementation
|
| 12 |
+
1. Publicly release parameters
|
| 13 |
+
|
| 14 |
+
## How to contribute to tryondiffusion
|
| 15 |
+
|
| 16 |
+
### 1. Open an issue
|
| 17 |
+
We recommend opening an issue (if one doesn't already exist) and discussing your intended changes before making any changes.
|
| 18 |
+
We'll be able to provide you feedback and confirm the planned modifications this way.
|
| 19 |
+
|
| 20 |
+
### 2. Make changes in the code
|
| 21 |
+
Start with forking the repository, set up the environment, install the dependencies, and make changes in the code appropriately.
|
| 22 |
+
|
| 23 |
+
### 3. Create pull request
|
| 24 |
+
Create a pull request to the main branch from your fork's branch.
|
| 25 |
+
|
| 26 |
+
### 4. Merging pull request
|
| 27 |
+
Once the pull request is created, we will review the code changes and merge the pull request as soon as possible.
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
### Writing documentation
|
| 31 |
+
|
| 32 |
+
If you are interested in writing the documentation, you can add it to README.md and create a pull request.
|
| 33 |
+
For now, we are maintaining our documentation in a single file and we will add more files as it grows.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
## License
|
| 37 |
+
|
| 38 |
+
All material is made available under [Creative Commons BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/). You can **use** the material for **non-commercial purposes**, as long as you give appropriate credit by **citing our original [github repo](https://github.com/kailashahirwar/tryondiffusion)** and **indicate any changes** that you've made to the code.
|
caption_images.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
from tryon.preprocessing.captioning import caption_image, create_llava_next_pipeline
|
| 8 |
+
|
| 9 |
+
INPUT_IMAGES_DIR = os.path.join("fashion_datatset", "*")
|
| 10 |
+
OUTPUT_CAPTIONS_DIR = "fashion_datatset_captions"
|
| 11 |
+
|
| 12 |
+
os.makedirs(OUTPUT_CAPTIONS_DIR, exist_ok=True)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def change_extension(filename, new_extension):
|
| 16 |
+
base_name, _ = os.path.splitext(filename)
|
| 17 |
+
return f"{base_name}.{new_extension}"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if __name__ == '__main__':
|
| 21 |
+
model, processor = create_llava_next_pipeline()
|
| 22 |
+
|
| 23 |
+
images_path = sorted(glob.glob(INPUT_IMAGES_DIR))
|
| 24 |
+
|
| 25 |
+
for index, image_path in enumerate(images_path):
|
| 26 |
+
print(f"index: {index}, total images: {len(images_path)}, {image_path}")
|
| 27 |
+
image = Image.open(image_path)
|
| 28 |
+
|
| 29 |
+
prompt = """
|
| 30 |
+
You're a fashion expert. The list of clothing properties includes [color, pattern, style, fit, type, hemline,
|
| 31 |
+
material, sleeve-length, fabric-elasticity, neckline, waistline]. Please provide the following information in
|
| 32 |
+
JSON format for the outfit shown in the image. Question: What are the color, pattern, style, fit, type,
|
| 33 |
+
hemline, material, sleeve length, fabric elasticity, neckline, and waistline of the outfit in the image?
|
| 34 |
+
Answer:
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
json_file_path = os.path.join(OUTPUT_CAPTIONS_DIR,
|
| 38 |
+
change_extension(os.path.basename(image_path), "json"))
|
| 39 |
+
caption_file_path = os.path.join(OUTPUT_CAPTIONS_DIR,
|
| 40 |
+
change_extension(os.path.basename(image_path), "txt"))
|
| 41 |
+
|
| 42 |
+
if os.path.exists(caption_file_path) and os.path.exists(json_file_path):
|
| 43 |
+
print(f"caption already exists for {image_path}")
|
| 44 |
+
continue
|
| 45 |
+
|
| 46 |
+
json_data, generated_caption = caption_image(image, prompt, model, processor, json_only=False)
|
| 47 |
+
|
| 48 |
+
with open(json_file_path, "w") as f:
|
| 49 |
+
json.dump(json_data, f)
|
| 50 |
+
|
| 51 |
+
with open(caption_file_path, "w") as f:
|
| 52 |
+
f.write(generated_caption)
|
demo/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .extract_garment import demo as extract_garment_demo
|
| 2 |
+
from .model_swap import demo as model_swap_demo
|
| 3 |
+
from .outfit_generator import demo as outfit_generator_demo
|
demo/extract_garment/README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Extract Garment AI
|
| 3 |
+
emoji: 📊
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
short_description: Gradio Demo of Extract Garment AI by TryOn Labs.
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
demo/extract_garment/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .app import demo
|
demo/extract_garment/app.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from tryon import preprocessing
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def extract_garment(input_img, cls):
|
| 10 |
+
print(input_img, type(input_img), cls)
|
| 11 |
+
|
| 12 |
+
input_dir = "input_image"
|
| 13 |
+
output_dir = "output_image"
|
| 14 |
+
|
| 15 |
+
os.makedirs(input_dir, exist_ok=True)
|
| 16 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 17 |
+
|
| 18 |
+
for f in glob.glob(input_dir + "/*.*"):
|
| 19 |
+
os.remove(f)
|
| 20 |
+
|
| 21 |
+
for f in glob.glob(output_dir + "/*.*"):
|
| 22 |
+
os.remove(f)
|
| 23 |
+
|
| 24 |
+
for f in glob.glob("cloth-mask/*.*"):
|
| 25 |
+
os.remove(f)
|
| 26 |
+
|
| 27 |
+
input_img.save(os.path.join(input_dir, "img.jpg"))
|
| 28 |
+
|
| 29 |
+
preprocessing.extract_garment(inputs_dir=input_dir, outputs_dir=output_dir, cls=cls)
|
| 30 |
+
|
| 31 |
+
return Image.open(glob.glob(output_dir + "/*.*")[0])
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
css = """
|
| 35 |
+
#col-container {
|
| 36 |
+
margin: 0 auto;
|
| 37 |
+
max-width: 720px;
|
| 38 |
+
}
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
with gr.Blocks(css=css) as demo:
|
| 42 |
+
with gr.Column(elem_id="col-container"):
|
| 43 |
+
gr.Markdown(f"""
|
| 44 |
+
# Clothes Extraction using U2Net
|
| 45 |
+
Pull out clothes like tops, bottoms, and dresses from a photo. This implementation is based on the [U2Net](https://github.com/xuebinqin/U-2-Net) model.
|
| 46 |
+
""")
|
| 47 |
+
|
| 48 |
+
with gr.Row():
|
| 49 |
+
with gr.Column():
|
| 50 |
+
input_image = gr.Image(label="Input Image", type='pil', height="400px", show_label=True)
|
| 51 |
+
dropdown = gr.Dropdown(["upper", "lower", "dress"], value="upper", label="Extract garment",
|
| 52 |
+
info="Select the garment type you wish to extract!")
|
| 53 |
+
|
| 54 |
+
output_image = gr.Image(label="Extracted garment", type='pil', height="400px", show_label=True,
|
| 55 |
+
show_download_button=True)
|
| 56 |
+
|
| 57 |
+
with gr.Row():
|
| 58 |
+
submit_button = gr.Button("Submit", variant='primary', scale=1)
|
| 59 |
+
reset_button = gr.ClearButton(value="Reset", scale=1)
|
| 60 |
+
|
| 61 |
+
gr.on(
|
| 62 |
+
triggers=[submit_button.click],
|
| 63 |
+
fn=extract_garment,
|
| 64 |
+
inputs=[input_image, dropdown],
|
| 65 |
+
outputs=[output_image]
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
reset_button.click(
|
| 69 |
+
fn=lambda: (None, "upper", None),
|
| 70 |
+
inputs=[],
|
| 71 |
+
outputs=[input_image, dropdown, output_image],
|
| 72 |
+
concurrency_limit=1,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if __name__ == '__main__':
|
| 76 |
+
demo.launch()
|
demo/extract_garment/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.44.1
|
| 2 |
+
pillow
|
| 3 |
+
tryondiffusion
|
demo/model_swap/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.token
|
demo/model_swap/README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Model Swap AI
|
| 3 |
+
emoji: 📊
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
short_description: Gradio Demo of Model Swap AI by TryOn Labs.
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
demo/model_swap/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .app import demo
|
demo/model_swap/app.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import json
|
| 5 |
+
import requests
|
| 6 |
+
import time
|
| 7 |
+
from gradio_modal import Modal
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
|
| 10 |
+
TRYON_SERVER_HOST = "https://prod.server.tryonlabs.ai"
|
| 11 |
+
TRYON_SERVER_PORT = "80"
|
| 12 |
+
if TRYON_SERVER_PORT == "80":
|
| 13 |
+
TRYON_SERVER_URL = f"{TRYON_SERVER_HOST}"
|
| 14 |
+
else:
|
| 15 |
+
TRYON_SERVER_URL = f"{TRYON_SERVER_HOST}:{TRYON_SERVER_PORT}"
|
| 16 |
+
|
| 17 |
+
TRYON_SERVER_API_URL = f"{TRYON_SERVER_URL}/api/v1/"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def start_model_swap(input_image, prompt, cls, seed, guidance_scale, num_results, strength, inference_steps):
|
| 21 |
+
# make a request to TryOn Server
|
| 22 |
+
# 1. create an experiment image
|
| 23 |
+
print("inputs:", input_image, prompt, cls, seed, guidance_scale, num_results, strength, inference_steps)
|
| 24 |
+
|
| 25 |
+
if input_image is None:
|
| 26 |
+
raise gr.Error("Select an image!")
|
| 27 |
+
|
| 28 |
+
if prompt is None or prompt == "":
|
| 29 |
+
raise gr.Error("Enter a prompt!")
|
| 30 |
+
|
| 31 |
+
token = load_token()
|
| 32 |
+
if token is None or token == "":
|
| 33 |
+
raise gr.Error("You need to login first!")
|
| 34 |
+
else:
|
| 35 |
+
login(token)
|
| 36 |
+
|
| 37 |
+
byte_io = BytesIO()
|
| 38 |
+
input_image.save(byte_io, 'png')
|
| 39 |
+
byte_io.seek(0)
|
| 40 |
+
|
| 41 |
+
r = requests.post(f"{TRYON_SERVER_API_URL}experiment_image/",
|
| 42 |
+
files={"image": (
|
| 43 |
+
'ei_image.png',
|
| 44 |
+
byte_io,
|
| 45 |
+
'image/png'
|
| 46 |
+
)},
|
| 47 |
+
data={
|
| 48 |
+
"type": "model",
|
| 49 |
+
"preprocess": "false"},
|
| 50 |
+
headers={
|
| 51 |
+
"Authorization": f"Bearer {token}"
|
| 52 |
+
})
|
| 53 |
+
# print(r.json())
|
| 54 |
+
if r.status_code == 200 or r.status_code == 201:
|
| 55 |
+
print("Experiment image created successfully", r.json())
|
| 56 |
+
res = r.json()
|
| 57 |
+
# 2 create an experiment
|
| 58 |
+
r2 = requests.post(f"{TRYON_SERVER_API_URL}experiment/",
|
| 59 |
+
data={
|
| 60 |
+
"model_id": res['id'],
|
| 61 |
+
"action": "model_swap",
|
| 62 |
+
"params": json.dumps({"prompt": prompt,
|
| 63 |
+
"guidance_scale": guidance_scale,
|
| 64 |
+
"strength": strength,
|
| 65 |
+
"num_inference_steps": inference_steps,
|
| 66 |
+
"seed": seed,
|
| 67 |
+
"garment_class": f"{cls} garment",
|
| 68 |
+
"negative_prompt": "(hands:1.15), disfigured, ugly, bad, immature"
|
| 69 |
+
", cartoon, anime, 3d, painting, b&w, (ugly),"
|
| 70 |
+
" (pixelated), watermark, glossy, smooth, "
|
| 71 |
+
"earrings, necklace",
|
| 72 |
+
"num_results": num_results})
|
| 73 |
+
},
|
| 74 |
+
headers={
|
| 75 |
+
"Authorization": f"Bearer {token}"
|
| 76 |
+
})
|
| 77 |
+
if r2.status_code == 200 or r2.status_code == 201:
|
| 78 |
+
# 3. keep checking the status of the experiment
|
| 79 |
+
res2 = r2.json()
|
| 80 |
+
print("Experiment created successfully", res2)
|
| 81 |
+
time.sleep(10)
|
| 82 |
+
|
| 83 |
+
experiment = res2['experiment']
|
| 84 |
+
status = fetch_experiment_status(experiment_id=experiment['id'], token=token)
|
| 85 |
+
status_status = status['status']
|
| 86 |
+
while status_status == "running":
|
| 87 |
+
time.sleep(10)
|
| 88 |
+
status = fetch_experiment_status(experiment_id=experiment['id'], token=token)
|
| 89 |
+
status_status = status['status']
|
| 90 |
+
print(f"Current status: {status_status}")
|
| 91 |
+
|
| 92 |
+
if status['status'] == "success":
|
| 93 |
+
print("Experiment successful")
|
| 94 |
+
print(f"Results:{status['result_images']}")
|
| 95 |
+
return status['result_images']
|
| 96 |
+
elif status['status'] == "failed":
|
| 97 |
+
print("Experiment failed")
|
| 98 |
+
raise gr.Error("Experiment failed")
|
| 99 |
+
else:
|
| 100 |
+
print(f"Error: {r2.text}")
|
| 101 |
+
raise gr.Error(f"Failure: {r2.text}")
|
| 102 |
+
else:
|
| 103 |
+
print(f"Error: {r.text}")
|
| 104 |
+
raise gr.Error(f"Failure: {r.text}")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def fetch_experiment_status(experiment_id, token):
|
| 108 |
+
print(f"experiment id:{experiment_id}")
|
| 109 |
+
|
| 110 |
+
r3 = requests.get(f"{TRYON_SERVER_API_URL}experiment/{experiment_id}/",
|
| 111 |
+
headers={
|
| 112 |
+
"Authorization": f"Bearer {token}"
|
| 113 |
+
})
|
| 114 |
+
if r3.status_code == 200:
|
| 115 |
+
res = r3.json()
|
| 116 |
+
if res['status'] == "running":
|
| 117 |
+
return {"status": "running"}
|
| 118 |
+
elif res['status'] == "success":
|
| 119 |
+
experiment = r3.json()['experiment']
|
| 120 |
+
result_images = [f"{TRYON_SERVER_URL}/{experiment['result']['image_url']}"]
|
| 121 |
+
if len(experiment['results']) > 0:
|
| 122 |
+
for result in experiment['results']:
|
| 123 |
+
result_images.append(f"{TRYON_SERVER_URL}/{result['image_url']}")
|
| 124 |
+
return {"status": "success", "result_images": result_images}
|
| 125 |
+
elif res['status'] == "failed":
|
| 126 |
+
return {"status": "failed"}
|
| 127 |
+
else:
|
| 128 |
+
print(f"Error: {r3.text}")
|
| 129 |
+
return {"status": "failed"}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def get_user_credits(token):
|
| 133 |
+
if token == "":
|
| 134 |
+
return None
|
| 135 |
+
|
| 136 |
+
r = requests.get(f"{TRYON_SERVER_API_URL}user/get/", headers={
|
| 137 |
+
"Authorization": f"Bearer {token}"
|
| 138 |
+
})
|
| 139 |
+
if r.status_code == 200:
|
| 140 |
+
res = r.json()
|
| 141 |
+
return res['credits']
|
| 142 |
+
else:
|
| 143 |
+
print(f"Error: {r.text}")
|
| 144 |
+
return None
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def load_token():
|
| 148 |
+
if os.path.exists(".token"):
|
| 149 |
+
with open(".token", "r") as f:
|
| 150 |
+
return json.load(f)['token']
|
| 151 |
+
else:
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def save_token(access_token):
|
| 156 |
+
if access_token != "":
|
| 157 |
+
with open(".token", "w") as f:
|
| 158 |
+
json.dump({"token": access_token}, f)
|
| 159 |
+
else:
|
| 160 |
+
raise gr.Error("No token provided!")
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def is_logged_in():
|
| 164 |
+
loaded_token = load_token()
|
| 165 |
+
if loaded_token is None or loaded_token == "":
|
| 166 |
+
return False
|
| 167 |
+
else:
|
| 168 |
+
return True
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def login(token):
|
| 172 |
+
print("logging in...")
|
| 173 |
+
# validate token
|
| 174 |
+
r = requests.post(f"{TRYON_SERVER_URL}/api/token/verify/", data={"token": token})
|
| 175 |
+
if r.status_code == 200:
|
| 176 |
+
save_token(token)
|
| 177 |
+
return True
|
| 178 |
+
else:
|
| 179 |
+
raise gr.Error("Login failed")
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def logout():
|
| 183 |
+
print("logged out")
|
| 184 |
+
with open(".token", "w") as f:
|
| 185 |
+
json.dump({"token": ""}, f)
|
| 186 |
+
return [False, ""]
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
css = """
|
| 190 |
+
#col-container {
|
| 191 |
+
margin: 0 auto;
|
| 192 |
+
max-width: 1024px;
|
| 193 |
+
}
|
| 194 |
+
#credits-col-container{
|
| 195 |
+
display:flex;
|
| 196 |
+
justify-content: right;
|
| 197 |
+
align-items: center;
|
| 198 |
+
font-size: 24px;
|
| 199 |
+
margin-right: 1rem;
|
| 200 |
+
}
|
| 201 |
+
#login-modal{
|
| 202 |
+
max-width: 728px;
|
| 203 |
+
margin: 0 auto;
|
| 204 |
+
margin-top: 1rem;
|
| 205 |
+
margin-bottom: 1rem;
|
| 206 |
+
}
|
| 207 |
+
#login-logout-btn{
|
| 208 |
+
display:inline;
|
| 209 |
+
max-width: 124px;
|
| 210 |
+
}
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
with gr.Blocks(css=css, theme=gr.themes.Default()) as demo:
|
| 214 |
+
print("is logged in:", is_logged_in())
|
| 215 |
+
logged_in = gr.State(is_logged_in())
|
| 216 |
+
if os.path.exists(".token"):
|
| 217 |
+
with open(".token", "r") as f:
|
| 218 |
+
user_token = gr.State(json.load(f)["token"])
|
| 219 |
+
else:
|
| 220 |
+
user_token = gr.State("")
|
| 221 |
+
|
| 222 |
+
with Modal(visible=False) as modal:
|
| 223 |
+
@gr.render(inputs=user_token)
|
| 224 |
+
def rerender1(user_token1):
|
| 225 |
+
with gr.Column(elem_id="login-modal"):
|
| 226 |
+
access_token = gr.Textbox(
|
| 227 |
+
label="Token",
|
| 228 |
+
lines=1,
|
| 229 |
+
value=user_token1,
|
| 230 |
+
type="password",
|
| 231 |
+
placeholder="Enter your access token here!",
|
| 232 |
+
info="Visit https://playground.tryonlabs.ai to retrieve your access token."
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
login_submit_btn = gr.Button("Login", scale=1, variant='primary')
|
| 236 |
+
login_submit_btn.click(
|
| 237 |
+
fn=lambda access_token: (login(access_token), Modal(visible=False), access_token),
|
| 238 |
+
inputs=[access_token], outputs=[logged_in, modal, user_token],
|
| 239 |
+
concurrency_limit=1)
|
| 240 |
+
|
| 241 |
+
with gr.Row(elem_id="col-container"):
|
| 242 |
+
with gr.Column():
|
| 243 |
+
gr.Markdown(f"""
|
| 244 |
+
# Model Swap AI
|
| 245 |
+
## by TryOn Labs (https://www.tryonlabs.ai)
|
| 246 |
+
Swap a human model with a artificial model generated by Artificial Model while keeping the garment intact.
|
| 247 |
+
""")
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@gr.render(inputs=logged_in)
|
| 251 |
+
def rerender(is_logged_in):
|
| 252 |
+
with gr.Column():
|
| 253 |
+
if not is_logged_in:
|
| 254 |
+
with gr.Row(elem_id="credits-col-container"):
|
| 255 |
+
login_btn = gr.Button(value="Login", variant='primary', elem_id="login-logout-btn", size="sm")
|
| 256 |
+
login_btn.click(lambda: Modal(visible=True), None, modal)
|
| 257 |
+
else:
|
| 258 |
+
user_credits = get_user_credits(load_token())
|
| 259 |
+
print("user_credits", user_credits)
|
| 260 |
+
gr.HTML(f"""<div><p id="credits-col-container">Your Credits:
|
| 261 |
+
{user_credits if user_credits is not None else "0"}</p>
|
| 262 |
+
<p style="text-align: right;">Visit <a href="https://playground.tryonlabs.ai">
|
| 263 |
+
TryOn AI Playground</a> to acquire more credits</p></div>""")
|
| 264 |
+
with gr.Row(elem_id="credits-col-container"):
|
| 265 |
+
logout_btn = gr.Button(value="Logout", scale=1, variant='primary', size="sm",
|
| 266 |
+
elem_id="login-logout-btn")
|
| 267 |
+
logout_btn.click(fn=logout, inputs=None, outputs=[logged_in, user_token], concurrency_limit=1)
|
| 268 |
+
|
| 269 |
+
with gr.Column(elem_id="col-container"):
|
| 270 |
+
with gr.Row():
|
| 271 |
+
with gr.Column():
|
| 272 |
+
input_image = gr.Image(label="Original image", type='pil', height="400px", show_label=True)
|
| 273 |
+
prompt = gr.Textbox(
|
| 274 |
+
label="Prompt",
|
| 275 |
+
lines=3,
|
| 276 |
+
placeholder="Enter your prompt here!",
|
| 277 |
+
)
|
| 278 |
+
dropdown = gr.Dropdown(["upper", "lower", "dress"], value="upper", label="Retain garment",
|
| 279 |
+
info="Select the garment type you want to retain in the generated image!")
|
| 280 |
+
|
| 281 |
+
gallery = gr.Gallery(
|
| 282 |
+
label="Generated images", show_label=True, elem_id="gallery"
|
| 283 |
+
, columns=[3], rows=[1], object_fit="contain", height="auto")
|
| 284 |
+
|
| 285 |
+
# output_image = gr.Image(label="Swapped model", type='pil', height="400px", show_label=True,
|
| 286 |
+
# show_download_button=True)
|
| 287 |
+
|
| 288 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 289 |
+
with gr.Row():
|
| 290 |
+
seed = gr.Number(label="Seed", value=-1, interactive=True, minimum=-1)
|
| 291 |
+
guidance_scale = gr.Number(label="Guidance Scale", value=7.5, interactive=True, minimum=0.0,
|
| 292 |
+
maximum=10.0,
|
| 293 |
+
step=0.1)
|
| 294 |
+
num_results = gr.Number(label="Number of results", value=2, minimum=1, maximum=5)
|
| 295 |
+
|
| 296 |
+
with gr.Row():
|
| 297 |
+
strength = gr.Slider(0.00, 1.00, value=0.99, label="Strength",
|
| 298 |
+
info="Choose between 0.00 and 1.00", step=0.01, interactive=True)
|
| 299 |
+
inference_steps = gr.Number(label="Inference Steps", value=20, interactive=True, minimum=1, step=1)
|
| 300 |
+
|
| 301 |
+
with gr.Row():
|
| 302 |
+
submit_button = gr.Button("Submit", variant='primary', scale=1)
|
| 303 |
+
reset_button = gr.ClearButton(value="Reset", scale=1)
|
| 304 |
+
|
| 305 |
+
gr.on(
|
| 306 |
+
triggers=[submit_button.click],
|
| 307 |
+
fn=start_model_swap,
|
| 308 |
+
inputs=[input_image, prompt, dropdown, seed, guidance_scale, num_results, strength, inference_steps],
|
| 309 |
+
outputs=[gallery]
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
reset_button.click(
|
| 313 |
+
fn=lambda: (None, None, "upper", None, -1, 7.5, 2, 0.99, 20),
|
| 314 |
+
inputs=[],
|
| 315 |
+
outputs=[input_image, prompt, dropdown, gallery, seed, guidance_scale,
|
| 316 |
+
num_results, strength, inference_steps],
|
| 317 |
+
concurrency_limit=1,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
if __name__ == '__main__':
|
| 321 |
+
demo.launch()
|
demo/model_swap/requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.44.1
|
| 2 |
+
gradio_modal==0.0.3
|
demo/outfit_generator/README.md
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FLUX.1-dev LoRA Outfit Generator Gradio Demo
|
| 2 |
+
## by TryOn Labs (https://www.tryonlabs.ai)
|
| 3 |
+
Generate an outfit by describing the color, pattern, fit, style, material, type, etc.
|
| 4 |
+
|
| 5 |
+
## Model description
|
| 6 |
+
|
| 7 |
+
FLUX.1-dev LoRA Outfit Generator can create an outfit by detailing the color, pattern, fit, style, material, and type.
|
| 8 |
+
|
| 9 |
+
## Inference
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
import random
|
| 13 |
+
|
| 14 |
+
from diffusers import FluxPipeline
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
seed=42
|
| 18 |
+
prompt = "denim dark blue 5-pocket ankle-length jeans in washed stretch denim slightly looser fit with a wide waist panel for best fit over the tummy and tapered legs with raw-edge frayed hems"
|
| 19 |
+
PRE_TRAINED_MODEL = "black-forest-labs/FLUX.1-dev"
|
| 20 |
+
FINE_TUNED_MODEL = "tryonlabs/FLUX.1-dev-LoRA-Outfit-Generator"
|
| 21 |
+
|
| 22 |
+
# Load Flux
|
| 23 |
+
pipe = FluxPipeline.from_pretrained(PRE_TRAINED_MODEL, torch_dtype=torch.float16).to("cuda")
|
| 24 |
+
|
| 25 |
+
# Load fine-tuned model
|
| 26 |
+
pipe.load_lora_weights(FINE_TUNED_MODEL, adapter_name="default", weight_name="outfit-generator.safetensors")
|
| 27 |
+
|
| 28 |
+
seed = random.randint(0, MAX_SEED)
|
| 29 |
+
|
| 30 |
+
generator = torch.Generator().manual_seed(seed)
|
| 31 |
+
|
| 32 |
+
image = pipe(prompt, height=1024, width=1024, num_images_per_prompt=1, generator=generator,
|
| 33 |
+
guidance_scale=4.5, num_inference_steps=40).images[0]
|
| 34 |
+
|
| 35 |
+
image.save("gen_image.jpg")
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Dataset used
|
| 39 |
+
|
| 40 |
+
H&M Fashion Captions Dataset - 20.5k samples
|
| 41 |
+
https://huggingface.co/datasets/tomytjandra/h-and-m-fashion-caption
|
| 42 |
+
|
| 43 |
+
## Repository used
|
| 44 |
+
|
| 45 |
+
AI Toolkit by Ostris
|
| 46 |
+
https://github.com/ostris/ai-toolkit
|
| 47 |
+
|
| 48 |
+
## Download model
|
| 49 |
+
|
| 50 |
+
Weights for this model are available in Safetensors format.
|
| 51 |
+
|
| 52 |
+
[Download](https://huggingface.co/tryonlabs/FLUX.1-dev-LoRA-Outfit-Generator/tree/main) them in the Files & versions tab.
|
| 53 |
+
|
| 54 |
+
## Install dependencies
|
| 55 |
+
|
| 56 |
+
```
|
| 57 |
+
git clone https://github.com/tryonlabs/FLUX.1-dev-LoRA-Outfit-Generator.git
|
| 58 |
+
cd FLUX.1-dev-LoRA-Outfit-Generator
|
| 59 |
+
conda create -n demo python=3.12
|
| 60 |
+
pip install -r requirements.txt
|
| 61 |
+
conda install pytorch pytorch-cuda=12.4 -c pytorch -c nvidia
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
## Run demo
|
| 65 |
+
|
| 66 |
+
```
|
| 67 |
+
gradio app.py
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## Generated images
|
| 71 |
+
|
| 72 |
+

|
| 73 |
+
#### A dress with Color: Black, Department: Dresses, Detail: High Low,Fabric-Elasticity: No Sretch, Fit: Fitted, Hemline: Slit, Material: Gabardine, Neckline: Collared, Pattern: Solid, Sleeve-Length: Sleeveless, Style: Casual, Type: Tunic, Waistline: Regular
|
| 74 |
+
***
|
| 75 |
+

|
| 76 |
+
#### A dress with Color: Red, Department: Dresses, Detail: Belted, Fabric-Elasticity: High Stretch, Fit: Fitted, Hemline: Flared, Material: Gabardine, Neckline: Off The Shoulder, Pattern: Floral, Sleeve-Length: Sleeveless, Style: Elegant, Type: Fit and Flare, Waistline: High
|
| 77 |
+
***
|
| 78 |
+

|
| 79 |
+
#### A dress with Color: Multicolored, Department: Dresses, Detail: Split, Fabric-Elasticity: High Stretch, Fit: Fitted, Hemline: Slit, Material: Gabardine, Neckline: V Neck, Pattern: Leopard, Sleeve-Length: Sleeveless, Style: Casual, Type: T Shirt, Waistline: Regular
|
| 80 |
+
***
|
| 81 |
+

|
| 82 |
+
#### A dress with Color: Brown, Department: Dresses, Detail: Zipper, Fabric-Elasticity: No Sretch, Fit: Fitted, Hemline: Asymmetrical, Material: Satin, Neckline: Spaghetti Straps, Pattern: Floral, Sleeve-Length: Sleeveless, Style: Boho, Type: Cami Top, Waistline: High
|
| 83 |
+
***
|
| 84 |
+
|
| 85 |
+
## License
|
| 86 |
+
MIT [License](LICENSE)
|
demo/outfit_generator/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .app import demo
|
demo/outfit_generator/app.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os.path
|
| 3 |
+
import random
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import spaces
|
| 9 |
+
import torch
|
| 10 |
+
from diffusers import FluxPipeline
|
| 11 |
+
|
| 12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
+
PRE_TRAINED_MODEL = "black-forest-labs/FLUX.1-dev"
|
| 14 |
+
FINE_TUNED_MODEL = "tryonlabs/FLUX.1-dev-LoRA-Outfit-Generator"
|
| 15 |
+
RESULTS_DIR = "~/results"
|
| 16 |
+
os.makedirs(RESULTS_DIR, exist_ok=True)
|
| 17 |
+
|
| 18 |
+
if torch.cuda.is_available():
|
| 19 |
+
torch_dtype = torch.bfloat16
|
| 20 |
+
else:
|
| 21 |
+
torch_dtype = torch.float32
|
| 22 |
+
|
| 23 |
+
# Load Flux
|
| 24 |
+
pipe = FluxPipeline.from_pretrained(PRE_TRAINED_MODEL, torch_dtype=torch.float16).to("cuda")
|
| 25 |
+
|
| 26 |
+
# Load your fine-tuned model
|
| 27 |
+
pipe.load_lora_weights(FINE_TUNED_MODEL, adapter_name="default", weight_name="outfit-generator.safetensors")
|
| 28 |
+
|
| 29 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 30 |
+
MAX_IMAGE_SIZE = 1024
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@spaces.GPU(duration=65)
|
| 34 |
+
def infer(
|
| 35 |
+
prompt,
|
| 36 |
+
seed=42,
|
| 37 |
+
randomize_seed=False,
|
| 38 |
+
width=1024,
|
| 39 |
+
height=1024,
|
| 40 |
+
guidance_scale=4.5,
|
| 41 |
+
num_inference_steps=40,
|
| 42 |
+
progress=gr.Progress(track_tqdm=True),
|
| 43 |
+
):
|
| 44 |
+
if randomize_seed:
|
| 45 |
+
seed = random.randint(0, MAX_SEED)
|
| 46 |
+
|
| 47 |
+
generator = torch.Generator().manual_seed(seed)
|
| 48 |
+
|
| 49 |
+
image = pipe(prompt, height=width, width=height, num_images_per_prompt=1, generator=generator,
|
| 50 |
+
guidance_scale=guidance_scale,
|
| 51 |
+
num_inference_steps=num_inference_steps).images[0]
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
# save image
|
| 55 |
+
current_time = int(time.time() * 1000)
|
| 56 |
+
image.save(os.path.join(RESULTS_DIR, f"gen_img_{current_time}.png"))
|
| 57 |
+
with open(os.path.join(RESULTS_DIR, f"gen_img_{current_time}.json"), "w") as f:
|
| 58 |
+
json.dump({"prompt": prompt, "height": height, "width": width, "guidance_scale": guidance_scale,
|
| 59 |
+
"num_inference_steps": num_inference_steps, "seed": seed}, f)
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(str(e))
|
| 62 |
+
|
| 63 |
+
return image, seed
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
examples = [
|
| 67 |
+
"stripe red striped jersey top in a soft cotton and modal blend with short sleeves a chest pocket and rounded hem",
|
| 68 |
+
"A dress with Color: Orange, Department: Dresses, Detail: Split Thigh, Fabric-Elasticity: No Sretch, Fit: Fitted, Hemline: Slit, Material: Gabardine, Neckline: Gathered, Pattern: Tropical, Sleeve-Length: Sleeveless, Style: Boho, Type: A Line Skirt, Waistline: High",
|
| 69 |
+
"treatment dark pink knee-length skirt in crocodile-patterned imitation leather high waist with belt loops and press-studs a zip fly diagonal side pockets and a slit at the front the polyester content of the skirt is partly recycled",
|
| 70 |
+
"A dress with Color: Maroon, Department: Dresses, Detail: Ruched Bust, Fabric-Elasticity: Slight Stretch, Fit: Fitted, Hemline: Slit, Material: Gabardine, Neckline: Spaghetti Straps, Pattern: Floral, Sleeve-Length: Sleeveless, Style: Boho, Type: Cami Top, Waistline: Regular",
|
| 71 |
+
"denim dark blue 5-pocket ankle-length jeans in washed stretch denim slightly looser fit with a wide waist panel for best fit over the tummy and tapered legs with raw-edge frayed hems"
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
css = """
|
| 75 |
+
#col-container {
|
| 76 |
+
margin: 0 auto;
|
| 77 |
+
max-width: 768px;
|
| 78 |
+
}
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
with gr.Blocks(css=css) as demo:
|
| 82 |
+
with gr.Column(elem_id="col-container"):
|
| 83 |
+
gr.Markdown(f"""
|
| 84 |
+
# FLUX.1-dev LoRA Outfit Generator
|
| 85 |
+
## by TryOn Labs (https://www.tryonlabs.ai)
|
| 86 |
+
Generate an outfit by describing the color, pattern, fit, style, material, type, etc.
|
| 87 |
+
""")
|
| 88 |
+
with gr.Row():
|
| 89 |
+
prompt = gr.Text(
|
| 90 |
+
label="Prompt",
|
| 91 |
+
show_label=False,
|
| 92 |
+
max_lines=1,
|
| 93 |
+
placeholder="Enter your prompt",
|
| 94 |
+
container=False,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
run_button = gr.Button("Run", scale=0, variant="primary")
|
| 98 |
+
|
| 99 |
+
result = gr.Image(label="Result", show_label=False)
|
| 100 |
+
|
| 101 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 102 |
+
seed = gr.Slider(
|
| 103 |
+
label="Seed",
|
| 104 |
+
minimum=0,
|
| 105 |
+
maximum=MAX_SEED,
|
| 106 |
+
step=1,
|
| 107 |
+
value=0,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 111 |
+
|
| 112 |
+
with gr.Row():
|
| 113 |
+
width = gr.Slider(
|
| 114 |
+
label="Width",
|
| 115 |
+
minimum=512,
|
| 116 |
+
maximum=MAX_IMAGE_SIZE,
|
| 117 |
+
step=32,
|
| 118 |
+
value=1024,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
height = gr.Slider(
|
| 122 |
+
label="Height",
|
| 123 |
+
minimum=512,
|
| 124 |
+
maximum=MAX_IMAGE_SIZE,
|
| 125 |
+
step=32,
|
| 126 |
+
value=1024,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
with gr.Row():
|
| 130 |
+
guidance_scale = gr.Slider(
|
| 131 |
+
label="Guidance scale",
|
| 132 |
+
minimum=0.0,
|
| 133 |
+
maximum=7.5,
|
| 134 |
+
step=0.1,
|
| 135 |
+
value=4.5,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
num_inference_steps = gr.Slider(
|
| 139 |
+
label="Number of inference steps",
|
| 140 |
+
minimum=1,
|
| 141 |
+
maximum=50,
|
| 142 |
+
step=1,
|
| 143 |
+
value=40,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True,
|
| 147 |
+
cache_mode="lazy")
|
| 148 |
+
gr.on(
|
| 149 |
+
triggers=[run_button.click, prompt.submit],
|
| 150 |
+
fn=infer,
|
| 151 |
+
inputs=[
|
| 152 |
+
prompt,
|
| 153 |
+
seed,
|
| 154 |
+
randomize_seed,
|
| 155 |
+
width,
|
| 156 |
+
height,
|
| 157 |
+
guidance_scale,
|
| 158 |
+
num_inference_steps,
|
| 159 |
+
],
|
| 160 |
+
outputs=[result, seed],
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if __name__ == "__main__":
|
| 164 |
+
demo.launch(share=True)
|
demo/outfit_generator/images/sample1.jpeg
ADDED
|
demo/outfit_generator/images/sample2.jpeg
ADDED
|
demo/outfit_generator/images/sample3.jpeg
ADDED
|
demo/outfit_generator/images/sample4.jpeg
ADDED
|
demo/outfit_generator/requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spaces
|
| 2 |
+
gradio
|
| 3 |
+
diffusers
|
| 4 |
+
torch
|
| 5 |
+
numpy
|
| 6 |
+
transformers
|
| 7 |
+
accelerate
|
| 8 |
+
protobuf
|
| 9 |
+
sentencepiece
|
| 10 |
+
peft==0.13.2
|
environment.yml
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: tryondiffusion
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
dependencies:
|
| 5 |
+
- blas=1.0=mkl
|
| 6 |
+
- bottleneck=1.3.5=py310h4e76f89_0
|
| 7 |
+
- bzip2=1.0.8=h1de35cc_0
|
| 8 |
+
- ca-certificates=2023.08.22=hecd8cb5_0
|
| 9 |
+
- cffi=1.15.1=py310h6c40b1e_3
|
| 10 |
+
- gmp=6.2.1=he9d5cce_3
|
| 11 |
+
- gmpy2=2.1.2=py310hd5de756_0
|
| 12 |
+
- intel-openmp=2023.1.0=ha357a0b_43547
|
| 13 |
+
- jinja2=3.1.2=py310hecd8cb5_0
|
| 14 |
+
- libcxx=14.0.6=h9765a3e_0
|
| 15 |
+
- libffi=3.4.4=hecd8cb5_0
|
| 16 |
+
- libprotobuf=3.20.3=hfff2838_0
|
| 17 |
+
- libuv=1.44.2=h6c40b1e_0
|
| 18 |
+
- mkl=2023.1.0=h8e150cf_43559
|
| 19 |
+
- mkl-service=2.4.0=py310h6c40b1e_1
|
| 20 |
+
- mkl_fft=1.3.8=py310h6c40b1e_0
|
| 21 |
+
- mkl_random=1.2.4=py310ha357a0b_0
|
| 22 |
+
- mpc=1.1.0=h6ef4df4_1
|
| 23 |
+
- mpfr=4.0.2=h9066e36_1
|
| 24 |
+
- mpmath=1.3.0=py310hecd8cb5_0
|
| 25 |
+
- ncurses=6.4=hcec6c5f_0
|
| 26 |
+
- networkx=3.1=py310hecd8cb5_0
|
| 27 |
+
- ninja=1.10.2=hecd8cb5_5
|
| 28 |
+
- ninja-base=1.10.2=haf03e11_5
|
| 29 |
+
- numexpr=2.8.7=py310h827a554_0
|
| 30 |
+
- openssl=3.0.11=hca72f7f_2
|
| 31 |
+
- pandas=2.1.1=py310h3ea8b11_0
|
| 32 |
+
- pip=23.2.1=py310hecd8cb5_0
|
| 33 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
| 34 |
+
- python=3.10.13=h5ee71fb_0
|
| 35 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
| 36 |
+
- python-tzdata=2023.3=pyhd3eb1b0_0
|
| 37 |
+
- pytz=2023.3.post1=py310hecd8cb5_0
|
| 38 |
+
- readline=8.2=hca72f7f_0
|
| 39 |
+
- six=1.16.0=pyhd3eb1b0_1
|
| 40 |
+
- sqlite=3.41.2=h6c40b1e_0
|
| 41 |
+
- tbb=2021.8.0=ha357a0b_0
|
| 42 |
+
- tk=8.6.12=h5d9f67b_0
|
| 43 |
+
- tzdata=2023c=h04d1e81_0
|
| 44 |
+
- wheel=0.38.4=py310hecd8cb5_0
|
| 45 |
+
- xz=5.4.2=h6c40b1e_0
|
| 46 |
+
- zlib=1.2.13=h4dc903c_0
|
| 47 |
+
- pip:
|
| 48 |
+
- absl-py==2.0.0
|
| 49 |
+
- aiofiles==23.2.1
|
| 50 |
+
- annotated-types==0.6.0
|
| 51 |
+
- anyio==4.3.0
|
| 52 |
+
- appnope==0.1.3
|
| 53 |
+
- asttokens==2.4.0
|
| 54 |
+
- astunparse==1.6.3
|
| 55 |
+
- backcall==0.2.0
|
| 56 |
+
- cachetools==5.3.1
|
| 57 |
+
- carvekit==4.1.1
|
| 58 |
+
- certifi==2023.7.22
|
| 59 |
+
- charset-normalizer==3.2.0
|
| 60 |
+
- click==8.1.7
|
| 61 |
+
- comm==0.1.4
|
| 62 |
+
- contourpy==1.1.1
|
| 63 |
+
- cycler==0.11.0
|
| 64 |
+
- debugpy==1.8.0
|
| 65 |
+
- decorator==5.1.1
|
| 66 |
+
- diffusers==0.29.2
|
| 67 |
+
- einops==0.7.0
|
| 68 |
+
- exceptiongroup==1.1.3
|
| 69 |
+
- executing==1.2.0
|
| 70 |
+
- fastapi==0.108.0
|
| 71 |
+
- ffmpy==0.3.3
|
| 72 |
+
- filelock==3.12.4
|
| 73 |
+
- flatbuffers==23.5.26
|
| 74 |
+
- fonttools==4.42.1
|
| 75 |
+
- fsspec==2024.3.1
|
| 76 |
+
- gast==0.5.4
|
| 77 |
+
- google-auth==2.23.3
|
| 78 |
+
- google-auth-oauthlib==1.0.0
|
| 79 |
+
- google-pasta==0.2.0
|
| 80 |
+
- gradio==4.39.0
|
| 81 |
+
- gradio-client==1.1.1
|
| 82 |
+
- gradio-modal==0.0.3
|
| 83 |
+
- grpcio==1.59.0
|
| 84 |
+
- h11==0.14.0
|
| 85 |
+
- h5py==3.10.0
|
| 86 |
+
- httpcore==1.0.5
|
| 87 |
+
- httpx==0.27.0
|
| 88 |
+
- huggingface-hub==0.23.4
|
| 89 |
+
- idna==3.4
|
| 90 |
+
- imageio==2.34.0
|
| 91 |
+
- importlib-metadata==8.0.0
|
| 92 |
+
- importlib-resources==6.4.0
|
| 93 |
+
- ipykernel==6.25.2
|
| 94 |
+
- ipython==8.15.0
|
| 95 |
+
- jedi==0.19.0
|
| 96 |
+
- jupyter-client==8.3.1
|
| 97 |
+
- jupyter-core==5.3.1
|
| 98 |
+
- keras==2.14.0
|
| 99 |
+
- kiwisolver==1.4.5
|
| 100 |
+
- lazy-loader==0.3
|
| 101 |
+
- libclang==16.0.6
|
| 102 |
+
- loguru==0.7.2
|
| 103 |
+
- markdown==3.5
|
| 104 |
+
- markdown-it-py==3.0.0
|
| 105 |
+
- markupsafe==2.1.3
|
| 106 |
+
- matplotlib==3.8.0
|
| 107 |
+
- matplotlib-inline==0.1.6
|
| 108 |
+
- mdurl==0.1.2
|
| 109 |
+
- ml-dtypes==0.2.0
|
| 110 |
+
- nest-asyncio==1.5.8
|
| 111 |
+
- numpy==1.26.4
|
| 112 |
+
- oauthlib==3.2.2
|
| 113 |
+
- opencv-python==4.8.1.78
|
| 114 |
+
- opt-einsum==3.3.0
|
| 115 |
+
- orjson==3.10.6
|
| 116 |
+
- packaging==23.1
|
| 117 |
+
- parso==0.8.3
|
| 118 |
+
- pexpect==4.8.0
|
| 119 |
+
- pickleshare==0.7.5
|
| 120 |
+
- pillow==10.1.0
|
| 121 |
+
- platformdirs==3.10.0
|
| 122 |
+
- prompt-toolkit==3.0.39
|
| 123 |
+
- protobuf==4.24.4
|
| 124 |
+
- psutil==5.9.5
|
| 125 |
+
- ptyprocess==0.7.0
|
| 126 |
+
- pure-eval==0.2.2
|
| 127 |
+
- pyasn1==0.5.0
|
| 128 |
+
- pyasn1-modules==0.3.0
|
| 129 |
+
- pydantic==2.5.3
|
| 130 |
+
- pydantic-core==2.14.6
|
| 131 |
+
- pydub==0.25.1
|
| 132 |
+
- pygments==2.16.1
|
| 133 |
+
- pyparsing==3.1.1
|
| 134 |
+
- python-dotenv==1.0.1
|
| 135 |
+
- python-multipart==0.0.9
|
| 136 |
+
- pyyaml==6.0.1
|
| 137 |
+
- pyzmq==25.1.1
|
| 138 |
+
- regex==2024.5.15
|
| 139 |
+
- requests==2.31.0
|
| 140 |
+
- requests-oauthlib==1.3.1
|
| 141 |
+
- rich==13.7.1
|
| 142 |
+
- rsa==4.9
|
| 143 |
+
- ruff==0.5.5
|
| 144 |
+
- safetensors==0.4.3
|
| 145 |
+
- scikit-image==0.22.0
|
| 146 |
+
- scipy==1.11.4
|
| 147 |
+
- semantic-version==2.10.0
|
| 148 |
+
- setuptools==69.0.3
|
| 149 |
+
- shellingham==1.5.4
|
| 150 |
+
- sniffio==1.3.1
|
| 151 |
+
- stack-data==0.6.2
|
| 152 |
+
- starlette==0.32.0.post1
|
| 153 |
+
- sympy==1.12
|
| 154 |
+
- tensorboard==2.14.1
|
| 155 |
+
- tensorboard-data-server==0.7.1
|
| 156 |
+
- tensorflow==2.14.0
|
| 157 |
+
- tensorflow-estimator==2.14.0
|
| 158 |
+
- tensorflow-io-gcs-filesystem==0.34.0
|
| 159 |
+
- termcolor==2.3.0
|
| 160 |
+
- tifffile==2024.2.12
|
| 161 |
+
- tokenizers==0.19.1
|
| 162 |
+
- tomlkit==0.12.0
|
| 163 |
+
- torch==2.1.2
|
| 164 |
+
- torchvision==0.16.2
|
| 165 |
+
- tornado==6.3.3
|
| 166 |
+
- tqdm==4.66.1
|
| 167 |
+
- traitlets==5.10.0
|
| 168 |
+
- transformers==4.42.4
|
| 169 |
+
- typer==0.12.3
|
| 170 |
+
- typing==3.7.4.3
|
| 171 |
+
- typing-extensions==4.8.0
|
| 172 |
+
- urllib3==2.0.5
|
| 173 |
+
- uvicorn==0.25.0
|
| 174 |
+
- wcwidth==0.2.6
|
| 175 |
+
- websockets==11.0.3
|
| 176 |
+
- werkzeug==3.0.0
|
| 177 |
+
- wrapt==1.14.1
|
| 178 |
+
- zipp==3.19.2
|
| 179 |
+
prefix: /Users/apple/miniconda3/envs/tryondiffusion
|
main.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
|
| 3 |
+
load_dotenv()
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
import os
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
from tryon.preprocessing import segment_human, segment_garment, extract_garment
|
| 10 |
+
|
| 11 |
+
if __name__ == '__main__':
|
| 12 |
+
argp = argparse.ArgumentParser(description="Tryon preprocessing")
|
| 13 |
+
argp.add_argument('-d',
|
| 14 |
+
'--dataset',
|
| 15 |
+
type=str, default="data", help='Path of the dataset dir')
|
| 16 |
+
argp.add_argument('-a',
|
| 17 |
+
'--action',
|
| 18 |
+
type=str, default="", help='Ex. segment_garment, extract_garment, segment_human')
|
| 19 |
+
argp.add_argument('-c',
|
| 20 |
+
'--cls',
|
| 21 |
+
type=str, default="upper", help='Ex. upper, lower, all')
|
| 22 |
+
args = argp.parse_args()
|
| 23 |
+
|
| 24 |
+
if args.action == "segment_garment":
|
| 25 |
+
# 1. segment garment
|
| 26 |
+
print('Start time:', int(time.time()))
|
| 27 |
+
segment_garment(inputs_dir=os.path.join(args.dataset, "original_cloth"),
|
| 28 |
+
outputs_dir=os.path.join(args.dataset, "garment_segmented"), cls=args.cls)
|
| 29 |
+
print("End time:", int(time.time()))
|
| 30 |
+
|
| 31 |
+
elif args.action == "extract_garment":
|
| 32 |
+
# 2. extract garment
|
| 33 |
+
print('Start time:', int(time.time()))
|
| 34 |
+
extract_garment(inputs_dir=os.path.join(args.dataset, "original_cloth"),
|
| 35 |
+
outputs_dir=os.path.join(args.dataset, "cloth"), cls=args.cls, resize_to_width=400)
|
| 36 |
+
print("End time:", int(time.time()))
|
| 37 |
+
|
| 38 |
+
elif args.action == "segment_human":
|
| 39 |
+
# 2. segment human
|
| 40 |
+
print('Start time:', int(time.time()))
|
| 41 |
+
image_path = os.path.join(args.dataset, "original_human", "model.jpg")
|
| 42 |
+
output_dir = os.path.join(args.dataset, "human_segmented")
|
| 43 |
+
segment_human(image_path=image_path, output_dir=output_dir)
|
| 44 |
+
print("End time:", int(time.time()))
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
numpy
|
| 3 |
+
opencv-python
|
| 4 |
+
pillow
|
| 5 |
+
matplotlib
|
| 6 |
+
tqdm
|
| 7 |
+
torchvision
|
| 8 |
+
einops
|
| 9 |
+
python-dotenv
|
| 10 |
+
scikit-image
|
| 11 |
+
diffusers
|
| 12 |
+
transformers
|
| 13 |
+
gradio==4.44.1
|
| 14 |
+
gradio_modal==0.0.3
|
| 15 |
+
python-dotenv
|
run_demo.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
if __name__ == '__main__':
|
| 4 |
+
argp = argparse.ArgumentParser(description="Gradio demo")
|
| 5 |
+
argp.add_argument('-n',
|
| 6 |
+
'--name',
|
| 7 |
+
type=str, default="data", help='Name of the gradio demo to launch')
|
| 8 |
+
args = argp.parse_args()
|
| 9 |
+
|
| 10 |
+
if args.name == "extract_garment":
|
| 11 |
+
from demo import extract_garment_demo as demo
|
| 12 |
+
demo.launch()
|
| 13 |
+
elif args.name == "model_swap":
|
| 14 |
+
from demo import model_swap_demo as demo
|
| 15 |
+
demo.launch()
|
| 16 |
+
elif args.name == "outfit_generator":
|
| 17 |
+
from demo import outfit_generator_demo as demo
|
| 18 |
+
demo.launch()
|
run_ootd.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import subprocess
|
| 3 |
+
import os
|
| 4 |
+
import pathlib
|
| 5 |
+
|
| 6 |
+
parser = argparse.ArgumentParser(description='run ootd')
|
| 7 |
+
parser.add_argument('--gpu_id', '-g', type=int, default=0, required=False)
|
| 8 |
+
parser.add_argument('--model_path', type=str, default="", required=True)
|
| 9 |
+
parser.add_argument('--cloth_path', type=str, default="", required=True)
|
| 10 |
+
parser.add_argument('--output_path', type=str, default="", required=True)
|
| 11 |
+
parser.add_argument('--model_type', type=str, default="hd", required=False)
|
| 12 |
+
parser.add_argument('--category', '-c', type=int, default=0, required=False)
|
| 13 |
+
parser.add_argument('--scale', type=float, default=2.0, required=False)
|
| 14 |
+
parser.add_argument('--step', type=int, default=20, required=False)
|
| 15 |
+
parser.add_argument('--sample', type=int, default=4, required=False)
|
| 16 |
+
parser.add_argument('--seed', type=int, default=-1, required=False)
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
|
| 19 |
+
print(args)
|
| 20 |
+
|
| 21 |
+
if __name__ == '__main__':
|
| 22 |
+
ootdiffusion_dir = "/home/ubuntu/ootdiffusion"
|
| 23 |
+
|
| 24 |
+
command = (f"{os.path.join(str(pathlib.Path.home()), 'miniconda3/envs/ootdiffusion/bin/python')} "
|
| 25 |
+
f"run.py --model_path {args.model_path} --cloth_path {args.cloth_path} "
|
| 26 |
+
f"--output_path {args.output_path} --model_type {args.model_type} --category {args.category} "
|
| 27 |
+
f"--image_scale {args.scale} --gpu_id {args.gpu_id} --n_samples {args.sample} --seed {args.seed} "
|
| 28 |
+
f"--n_steps {args.step}")
|
| 29 |
+
|
| 30 |
+
print("command:", command, command.split(" "))
|
| 31 |
+
|
| 32 |
+
p = subprocess.Popen(command.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
| 33 |
+
cwd=ootdiffusion_dir)
|
| 34 |
+
out, err = p.communicate()
|
| 35 |
+
print(out, err)
|
| 36 |
+
|
| 37 |
+
|
scripts/install_conda.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Setup Ubuntu
|
| 2 |
+
sudo apt update --yes
|
| 3 |
+
sudo apt upgrade --yes
|
| 4 |
+
|
| 5 |
+
# Get Miniconda and make it the main Python interpreter
|
| 6 |
+
wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
|
| 7 |
+
bash ~/miniconda.sh -b -p ~/miniconda
|
| 8 |
+
rm ~/miniconda.sh
|
| 9 |
+
|
| 10 |
+
export PATH=~/miniconda/bin:$PATH
|
scripts/install_sam2.sh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ENV_NAME="sam2"
|
| 2 |
+
sudo apt-get -y update && sudo apt-get install -y --no-install-recommends ffmpeg libavutil-dev libavcodec-dev libavformat-dev libswscale-dev pkg-config build-essential libffi-dev
|
| 3 |
+
git clone https://github.com/facebookresearch/sam2.git ~/$ENV_NAME
|
| 4 |
+
conda create -n $ENV_NAME python=3.10
|
| 5 |
+
conda install -y -n $ENV_NAME pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
|
| 6 |
+
~/miniconda3/envs/$ENV_NAME/bin/pip install -y -e ~/$ENV_NAME
|
| 7 |
+
sh ~/$ENV_NAME/checkpoints/download_ckpts.sh
|
| 8 |
+
mv sam2.1_hiera_base_plus.pt ~/$ENV_NAME/checkpoints/
|
| 9 |
+
mv sam2.1_hiera_large.pt ~/$ENV_NAME/checkpoints/
|
| 10 |
+
mv sam2.1_hiera_small.pt ~/$ENV_NAME/checkpoints/
|
| 11 |
+
mv sam2.1_hiera_tiny.pt ~/$ENV_NAME/checkpoints/
|
setup.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
from setuptools import setup, find_packages
|
| 4 |
+
|
| 5 |
+
this_directory = Path(__file__).parent
|
| 6 |
+
long_description = (this_directory / "README.md").read_text()
|
| 7 |
+
|
| 8 |
+
setup(
|
| 9 |
+
name="tryondiffusion",
|
| 10 |
+
version="0.1.0",
|
| 11 |
+
license='Creative Commons BY-NC 4.0',
|
| 12 |
+
packages=find_packages(),
|
| 13 |
+
long_description=long_description,
|
| 14 |
+
long_description_content_type='text/markdown',
|
| 15 |
+
url='https://github.com/kailashahirwar/tryondiffusion',
|
| 16 |
+
keywords='Unofficial implementation of TryOnDiffusion: A Tale Of Two UNets',
|
| 17 |
+
install_requires=[
|
| 18 |
+
"torch",
|
| 19 |
+
"numpy",
|
| 20 |
+
"opencv-python",
|
| 21 |
+
"pillow",
|
| 22 |
+
"matplotlib",
|
| 23 |
+
"tqdm",
|
| 24 |
+
"torchvision",
|
| 25 |
+
"einops",
|
| 26 |
+
"scipy",
|
| 27 |
+
"scikit-image",
|
| 28 |
+
"gradio==4.44.1",
|
| 29 |
+
"gradio_modal==0.0.3"
|
| 30 |
+
]
|
| 31 |
+
)
|
tryon/README.md
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Try-On Preprocessing
|
| 2 |
+
|
| 3 |
+
Before you start, make a .env file in your project's main folder. Put these environment variables inside it.
|
| 4 |
+
```
|
| 5 |
+
U2NET_CLOTH_SEG_CHECKPOINT_PATH=cloth_segm.pth
|
| 6 |
+
```
|
| 7 |
+
|
| 8 |
+
#### Remember to load environment variables before you start running scripts.
|
| 9 |
+
|
| 10 |
+
```
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
|
| 13 |
+
load_dotenv()
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
### segment garment
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
from tryon.preprocessing import segment_garment
|
| 20 |
+
|
| 21 |
+
segment_garment(inputs_dir=<inputs_dir>,
|
| 22 |
+
outputs_dir=<outputs_dir>, cls=<cls>)
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
possible values for cls: lower, upper, all
|
| 26 |
+
|
| 27 |
+
### extract garment
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
from tryon.preprocessing import extract_garment
|
| 31 |
+
|
| 32 |
+
extract_garment(inputs_dir=<inputs_dir>,
|
| 33 |
+
outputs_dir=<outputs_dir>, cls=<cls>)
|
| 34 |
+
```
|
tryon/__init__.py
ADDED
|
File without changes
|
tryon/models/__init__.py
ADDED
|
File without changes
|
tryon/models/ootdiffusion/setup.sh
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ENV_NAME="ootdiffusion"
|
| 2 |
+
PROJECT_DIR="/home/ubuntu/ootdiffusion"
|
| 3 |
+
|
| 4 |
+
if [ ! -d ~/miniconda3/envs/$ENV_NAME ]; then
|
| 5 |
+
echo "creating conda environment"
|
| 6 |
+
conda create -y -n $ENV_NAME python==3.10
|
| 7 |
+
fi
|
| 8 |
+
|
| 9 |
+
# clone repository
|
| 10 |
+
if [ ! -d $PROJECT_DIR ]; then
|
| 11 |
+
echo "cloning OOTDiffusion repository"
|
| 12 |
+
git clone https://github.com/tryonlabs/OOTDiffusion.git $PROJECT_DIR
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
~/miniconda3/envs/$ENV_NAME/bin/pip install -r $PROJECT_DIR/requirements.txt
|
| 16 |
+
|
| 17 |
+
if [ ! -d $PROJECT_DIR/checkpoints/ootd ]; then
|
| 18 |
+
echo "downloading checkpoints"
|
| 19 |
+
|
| 20 |
+
# download checkpoints
|
| 21 |
+
git clone https://huggingface.co/levihsu/OOTDiffusion ~/ootd-checkpoints
|
| 22 |
+
git clone https://huggingface.co/openai/clip-vit-large-patch14 ~/clip-vit-large-patch14
|
| 23 |
+
|
| 24 |
+
mv ~/ootd-checkpoints/checkpoints/* $PROJECT_DIR/checkpoints/
|
| 25 |
+
rm -rf ~/ootd-checkpoints
|
| 26 |
+
|
| 27 |
+
mv ~/clip-vit-large-patch14 $PROJECT_DIR/checkpoints/
|
| 28 |
+
rm -rf ~/clip-vit-large-patch14
|
| 29 |
+
|
| 30 |
+
fi
|
tryon/preprocessing/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .preprocess_garment import segment_garment, extract_garment
|
| 2 |
+
from .utils import convert_to_jpg
|
| 3 |
+
from .preprocess_human import segment_human
|
tryon/preprocessing/captioning/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .generate_caption import (caption_image, create_llava_next_pipeline,
|
| 2 |
+
create_phi35mini_pipeline, convert_outfit_json_to_caption)
|
tryon/preprocessing/captioning/generate_caption.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 5 |
+
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def caption_image(image, question, model=None, processor=None, json_only=False):
|
| 9 |
+
"""
|
| 10 |
+
Extract outfit details using an image-to-text model
|
| 11 |
+
:param image: input image
|
| 12 |
+
:param question: question
|
| 13 |
+
:param model: model pipeline
|
| 14 |
+
:param processor: processor
|
| 15 |
+
:param json_only: True or False - if json only
|
| 16 |
+
:return: json data
|
| 17 |
+
"""
|
| 18 |
+
if model is None and processor is None:
|
| 19 |
+
model, processor = create_llava_next_pipeline()
|
| 20 |
+
|
| 21 |
+
conversation = [
|
| 22 |
+
{
|
| 23 |
+
"role": "user",
|
| 24 |
+
"content": [
|
| 25 |
+
{"type": "image"},
|
| 26 |
+
{"type": "text", "text": question},
|
| 27 |
+
],
|
| 28 |
+
},
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
| 32 |
+
inputs = processor(image, prompt, return_tensors="pt").to("cuda:0")
|
| 33 |
+
|
| 34 |
+
output = model.generate(**inputs, max_new_tokens=300)
|
| 35 |
+
output = processor.decode(output[0], skip_special_tokens=True).split("[/INST]")[-1]
|
| 36 |
+
json_data = json.loads(output.replace("```json", "").replace("```", "").strip())
|
| 37 |
+
|
| 38 |
+
if not json_only:
|
| 39 |
+
generated_caption = convert_outfit_json_to_caption(json_data)
|
| 40 |
+
else:
|
| 41 |
+
generated_caption = None
|
| 42 |
+
|
| 43 |
+
return json_data, generated_caption
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def create_phi35mini_pipeline():
|
| 47 |
+
"""
|
| 48 |
+
Create Phi-3.5-mini-instruct pipeline
|
| 49 |
+
:return: model pipeline
|
| 50 |
+
"""
|
| 51 |
+
torch.random.manual_seed(0)
|
| 52 |
+
|
| 53 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 54 |
+
"microsoft/Phi-3.5-mini-instruct",
|
| 55 |
+
device_map="cuda",
|
| 56 |
+
torch_dtype="auto",
|
| 57 |
+
trust_remote_code=True,
|
| 58 |
+
attn_implementation="flash_attention_2"
|
| 59 |
+
)
|
| 60 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
|
| 61 |
+
|
| 62 |
+
pipe = pipeline(
|
| 63 |
+
"text-generation",
|
| 64 |
+
model=model,
|
| 65 |
+
tokenizer=tokenizer,
|
| 66 |
+
)
|
| 67 |
+
return pipe
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def create_llava_next_pipeline():
|
| 71 |
+
"""
|
| 72 |
+
Create LlaVA-NeXT pipeline
|
| 73 |
+
:return: model pipeline
|
| 74 |
+
"""
|
| 75 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 76 |
+
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
| 77 |
+
model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf",
|
| 78 |
+
torch_dtype=torch.float16,
|
| 79 |
+
low_cpu_mem_usage=True)
|
| 80 |
+
model.to(device)
|
| 81 |
+
|
| 82 |
+
return model, processor
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def convert_outfit_json_to_caption(json_data, pipe=None):
|
| 86 |
+
"""
|
| 87 |
+
Convert JSON data of an outfit into a natural language caption
|
| 88 |
+
:param json_data: json data
|
| 89 |
+
:param pipe: model pipeline
|
| 90 |
+
:return: generated caption
|
| 91 |
+
"""
|
| 92 |
+
if pipe is None:
|
| 93 |
+
pipe = create_phi35mini_pipeline()
|
| 94 |
+
|
| 95 |
+
generation_args = {
|
| 96 |
+
"max_new_tokens": 300,
|
| 97 |
+
"return_full_text": False,
|
| 98 |
+
"temperature": 0.0,
|
| 99 |
+
"do_sample": False,
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
messages = [{"role": "user",
|
| 103 |
+
"content": f'Convert the {json.dumps(json_data)} JSON data into a natural '
|
| 104 |
+
f'language paragraph beginning with "An outfit with"'}]
|
| 105 |
+
|
| 106 |
+
output = pipe(messages, **generation_args)[0]['generated_text'].strip()
|
| 107 |
+
print(f"Output: {output}")
|
| 108 |
+
return output
|
tryon/preprocessing/extract_garment_new.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
|
| 9 |
+
from .u2net import load_cloth_segm_model
|
| 10 |
+
from .utils import NormalizeImage, naive_cutout, resize_by_bigger_index, image_resize
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def extract_garment(image, cls="all", resize_to_width=None, net=None, device=None):
|
| 14 |
+
"""
|
| 15 |
+
extracts garments from the given image
|
| 16 |
+
:param image: input image
|
| 17 |
+
:param cls: garment classes to extract
|
| 18 |
+
:param resize_to_width: if required
|
| 19 |
+
:return: extracted garments
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
if net is None:
|
| 23 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 24 |
+
net = load_cloth_segm_model(device, os.environ.get("U2NET_CLOTH_SEGM_CHECKPOINT_PATH"), in_ch=3, out_ch=4)
|
| 25 |
+
|
| 26 |
+
transform_fn = transforms.Compose(
|
| 27 |
+
[transforms.ToTensor(),
|
| 28 |
+
NormalizeImage(0.5, 0.5)]
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
img_size = image.size
|
| 32 |
+
img = image.resize((768, 768), Image.BICUBIC)
|
| 33 |
+
image_tensor = transform_fn(img)
|
| 34 |
+
image_tensor = torch.unsqueeze(image_tensor, 0)
|
| 35 |
+
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
output_tensor = net(image_tensor.to(device))
|
| 38 |
+
output_tensor = F.log_softmax(output_tensor[0], dim=1)
|
| 39 |
+
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
|
| 40 |
+
output_tensor = torch.squeeze(output_tensor, dim=0)
|
| 41 |
+
output_arr = output_tensor.cpu().numpy()
|
| 42 |
+
|
| 43 |
+
classes = {1: "upper", 2: "lower", 3: "dress"}
|
| 44 |
+
|
| 45 |
+
if cls == "all":
|
| 46 |
+
classes_to_save = []
|
| 47 |
+
|
| 48 |
+
# Check which classes are present in the image
|
| 49 |
+
for cls in range(1, 4): # Exclude background class (0)
|
| 50 |
+
if np.any(output_arr == cls):
|
| 51 |
+
classes_to_save.append(cls)
|
| 52 |
+
elif cls == "upper":
|
| 53 |
+
classes_to_save = [1]
|
| 54 |
+
elif cls == "lower":
|
| 55 |
+
classes_to_save = [2]
|
| 56 |
+
elif cls == "dress":
|
| 57 |
+
classes_to_save = [3]
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Unknown cls: {cls}")
|
| 60 |
+
|
| 61 |
+
garments = dict()
|
| 62 |
+
|
| 63 |
+
for cls1 in classes_to_save:
|
| 64 |
+
alpha_mask = (output_arr == cls1).astype(np.uint8) * 255
|
| 65 |
+
alpha_mask = alpha_mask[0] # Selecting the first channel to make it 2D
|
| 66 |
+
alpha_mask_img = Image.fromarray(alpha_mask, mode='L')
|
| 67 |
+
alpha_mask_img = alpha_mask_img.resize(img_size, Image.BICUBIC)
|
| 68 |
+
|
| 69 |
+
cutout = np.array(naive_cutout(image, alpha_mask_img))
|
| 70 |
+
cutout = resize_by_bigger_index(cutout)
|
| 71 |
+
|
| 72 |
+
canvas = np.ones((1024, 768, 3), np.uint8) * 255
|
| 73 |
+
y1, y2 = (canvas.shape[0] - cutout.shape[0]) // 2, (canvas.shape[0] + cutout.shape[0]) // 2
|
| 74 |
+
x1, x2 = (canvas.shape[1] - cutout.shape[1]) // 2, (canvas.shape[1] + cutout.shape[1]) // 2
|
| 75 |
+
|
| 76 |
+
alpha_s = cutout[:, :, 3] / 255.0
|
| 77 |
+
alpha_l = 1.0 - alpha_s
|
| 78 |
+
|
| 79 |
+
for c in range(0, 3):
|
| 80 |
+
canvas[y1:y2, x1:x2, c] = (alpha_s * cutout[:, :, c] +
|
| 81 |
+
alpha_l * canvas[y1:y2, x1:x2, c])
|
| 82 |
+
|
| 83 |
+
# resize image before saving
|
| 84 |
+
if resize_to_width:
|
| 85 |
+
canvas = image_resize(canvas, width=resize_to_width)
|
| 86 |
+
|
| 87 |
+
canvas = Image.fromarray(canvas)
|
| 88 |
+
|
| 89 |
+
garments[classes[cls1]] = canvas
|
| 90 |
+
|
| 91 |
+
return garments
|
tryon/preprocessing/preprocess_garment.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from .u2net import load_cloth_segm_model
|
| 13 |
+
from .utils import NormalizeImage, naive_cutout, resize_by_bigger_index, image_resize
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def segment_garment(inputs_dir, outputs_dir, cls="all"):
|
| 17 |
+
os.makedirs(outputs_dir, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 20 |
+
|
| 21 |
+
transform_fn = transforms.Compose(
|
| 22 |
+
[transforms.ToTensor(),
|
| 23 |
+
NormalizeImage(0.5, 0.5)]
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
net = load_cloth_segm_model(device, os.environ.get("U2NET_CLOTH_SEGM_CHECKPOINT_PATH"), in_ch=3, out_ch=4)
|
| 27 |
+
|
| 28 |
+
images_list = sorted(os.listdir(inputs_dir))
|
| 29 |
+
pbar = tqdm(total=len(images_list))
|
| 30 |
+
|
| 31 |
+
for image_name in images_list:
|
| 32 |
+
img = Image.open(os.path.join(inputs_dir, image_name)).convert('RGB')
|
| 33 |
+
img_size = img.size
|
| 34 |
+
img = img.resize((768, 768), Image.BICUBIC)
|
| 35 |
+
image_tensor = transform_fn(img)
|
| 36 |
+
image_tensor = torch.unsqueeze(image_tensor, 0)
|
| 37 |
+
|
| 38 |
+
with torch.no_grad():
|
| 39 |
+
output_tensor = net(image_tensor.to(device))
|
| 40 |
+
output_tensor = F.log_softmax(output_tensor[0], dim=1)
|
| 41 |
+
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
|
| 42 |
+
output_tensor = torch.squeeze(output_tensor, dim=0)
|
| 43 |
+
output_arr = output_tensor.cpu().numpy()
|
| 44 |
+
|
| 45 |
+
if cls == "all":
|
| 46 |
+
classes_to_save = []
|
| 47 |
+
|
| 48 |
+
# Check which classes are present in the image
|
| 49 |
+
for cls in range(1, 4): # Exclude background class (0)
|
| 50 |
+
if np.any(output_arr == cls):
|
| 51 |
+
classes_to_save.append(cls)
|
| 52 |
+
elif cls == "upper":
|
| 53 |
+
classes_to_save = [1]
|
| 54 |
+
elif cls == "lower":
|
| 55 |
+
classes_to_save = [2]
|
| 56 |
+
elif cls == "dress":
|
| 57 |
+
classes_to_save = [3]
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Unknown cls: {cls}")
|
| 60 |
+
|
| 61 |
+
for cls1 in classes_to_save:
|
| 62 |
+
alpha_mask = (output_arr == cls1).astype(np.uint8) * 255
|
| 63 |
+
alpha_mask = alpha_mask[0] # Selecting the first channel to make it 2D
|
| 64 |
+
alpha_mask_img = Image.fromarray(alpha_mask, mode='L')
|
| 65 |
+
alpha_mask_img = alpha_mask_img.resize(img_size, Image.BICUBIC)
|
| 66 |
+
alpha_mask_img.save(os.path.join(outputs_dir, f'{image_name.split(".")[0]}_{cls1}.jpg'))
|
| 67 |
+
|
| 68 |
+
pbar.update(1)
|
| 69 |
+
|
| 70 |
+
pbar.close()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def extract_garment(inputs_dir, outputs_dir, cls="all", resize_to_width=None):
|
| 74 |
+
os.makedirs(outputs_dir, exist_ok=True)
|
| 75 |
+
cloth_mask_dir = os.path.join(Path(outputs_dir).parent.absolute(), "cloth-mask")
|
| 76 |
+
os.makedirs(cloth_mask_dir, exist_ok=True)
|
| 77 |
+
|
| 78 |
+
segment_garment(inputs_dir, os.path.join(Path(outputs_dir).parent.absolute(), "cloth-mask"), cls=cls)
|
| 79 |
+
|
| 80 |
+
images_path = sorted(glob.glob(os.path.join(inputs_dir, "*")))
|
| 81 |
+
masks_path = sorted(glob.glob(os.path.join(cloth_mask_dir, "*")))
|
| 82 |
+
img = Image.open(images_path[0])
|
| 83 |
+
|
| 84 |
+
for mask_path in masks_path:
|
| 85 |
+
mask = Image.open(mask_path)
|
| 86 |
+
|
| 87 |
+
cutout = np.array(naive_cutout(img, mask))
|
| 88 |
+
cutout = resize_by_bigger_index(cutout)
|
| 89 |
+
|
| 90 |
+
canvas = np.ones((1024, 768, 3), np.uint8) * 255
|
| 91 |
+
y1, y2 = (canvas.shape[0] - cutout.shape[0]) // 2, (canvas.shape[0] + cutout.shape[0]) // 2
|
| 92 |
+
x1, x2 = (canvas.shape[1] - cutout.shape[1]) // 2, (canvas.shape[1] + cutout.shape[1]) // 2
|
| 93 |
+
|
| 94 |
+
alpha_s = cutout[:, :, 3] / 255.0
|
| 95 |
+
alpha_l = 1.0 - alpha_s
|
| 96 |
+
|
| 97 |
+
for c in range(0, 3):
|
| 98 |
+
canvas[y1:y2, x1:x2, c] = (alpha_s * cutout[:, :, c] +
|
| 99 |
+
alpha_l * canvas[y1:y2, x1:x2, c])
|
| 100 |
+
|
| 101 |
+
# resize image before saving
|
| 102 |
+
if resize_to_width:
|
| 103 |
+
canvas = image_resize(canvas, width=resize_to_width)
|
| 104 |
+
|
| 105 |
+
canvas = Image.fromarray(canvas)
|
| 106 |
+
|
| 107 |
+
canvas.save(os.path.join(outputs_dir, f"{os.path.basename(mask_path).split('.')[0]}.jpg"), format='JPEG')
|
tryon/preprocessing/preprocess_human.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from skimage import io
|
| 8 |
+
from torch.autograd import Variable
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
|
| 12 |
+
from .u2net import RescaleT, ToTensorLab, SalObjDataset, normPRED, load_human_segm_model
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def pred_to_image(predictions, image_path):
|
| 16 |
+
im = Image.fromarray(predictions.squeeze().cpu().data.numpy() * 255).convert('RGB')
|
| 17 |
+
image = io.imread(image_path)
|
| 18 |
+
imo = im.resize((image.shape[1], image.shape[0]), resample=Image.BILINEAR)
|
| 19 |
+
return imo
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def segment_human(image_path, output_dir):
|
| 23 |
+
"""
|
| 24 |
+
Segment human using U-2-Net
|
| 25 |
+
:param image_path: image path
|
| 26 |
+
:param output_dir: output directory
|
| 27 |
+
"""
|
| 28 |
+
model_name = "u2net"
|
| 29 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 30 |
+
images = [image_path]
|
| 31 |
+
|
| 32 |
+
# 1. dataloader
|
| 33 |
+
test_salobj_dataset = SalObjDataset(img_name_list=images,
|
| 34 |
+
lbl_name_list=[],
|
| 35 |
+
transform=transforms.Compose([RescaleT(320),
|
| 36 |
+
ToTensorLab(flag=0)])
|
| 37 |
+
)
|
| 38 |
+
test_salobj_dataloader = DataLoader(test_salobj_dataset,
|
| 39 |
+
batch_size=1,
|
| 40 |
+
shuffle=False,
|
| 41 |
+
num_workers=1)
|
| 42 |
+
|
| 43 |
+
net = load_human_segm_model(device, model_name)
|
| 44 |
+
|
| 45 |
+
# 2. inference
|
| 46 |
+
for i_test, data_test in enumerate(test_salobj_dataloader):
|
| 47 |
+
print("inferencing:", images[i_test].split(os.sep)[-1])
|
| 48 |
+
|
| 49 |
+
inputs_test = data_test['image']
|
| 50 |
+
inputs_test = inputs_test.type(torch.FloatTensor)
|
| 51 |
+
|
| 52 |
+
if torch.cuda.is_available():
|
| 53 |
+
inputs_test = Variable(inputs_test.cuda())
|
| 54 |
+
else:
|
| 55 |
+
inputs_test = Variable(inputs_test)
|
| 56 |
+
|
| 57 |
+
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
|
| 58 |
+
|
| 59 |
+
# normalization
|
| 60 |
+
pred = d1[:, 0, :, :]
|
| 61 |
+
pred = normPRED(pred)
|
| 62 |
+
|
| 63 |
+
mask = pred_to_image(pred, images[i_test])
|
| 64 |
+
mask_cv2 = cv2.cvtColor(np.array(mask), cv2.COLOR_RGB2BGR)
|
| 65 |
+
|
| 66 |
+
subimage = cv2.subtract(mask_cv2, cv2.imread(images[i_test]))
|
| 67 |
+
original = Image.open(images[i_test])
|
| 68 |
+
subimage = Image.fromarray(cv2.cvtColor(subimage, cv2.COLOR_BGR2RGB))
|
| 69 |
+
|
| 70 |
+
subimage = subimage.convert("RGBA")
|
| 71 |
+
original = original.convert("RGBA")
|
| 72 |
+
|
| 73 |
+
subdata = subimage.getdata()
|
| 74 |
+
ogdata = original.getdata()
|
| 75 |
+
|
| 76 |
+
newdata = []
|
| 77 |
+
for i in range(subdata.size[0] * subdata.size[1]):
|
| 78 |
+
if subdata[i][0] == 0 and subdata[i][1] == 0 and subdata[i][2] == 0:
|
| 79 |
+
newdata.append((231, 231, 231, 231))
|
| 80 |
+
else:
|
| 81 |
+
newdata.append(ogdata[i])
|
| 82 |
+
subimage.putdata(newdata)
|
| 83 |
+
|
| 84 |
+
subimage.save(os.path.join(output_dir, f"{images[i_test].split(os.sep)[-1].split('.')[0]}.png"))
|
| 85 |
+
|
| 86 |
+
del d1, d2, d3, d4, d5, d6, d7
|
tryon/preprocessing/sam2/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from sam2.build_sam import build_sam2
|
| 7 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 8 |
+
|
| 9 |
+
SAM2_DIR = os.path.join(str(Path.home()), 'sam2')
|
| 10 |
+
|
| 11 |
+
checkpoint = os.path.join(SAM2_DIR, "checkpoints/sam2.1_hiera_large.pt")
|
| 12 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
| 13 |
+
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
|
| 14 |
+
|
| 15 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 16 |
+
predictor.set_image(Image.open("img000.webp"))
|
| 17 |
+
input_point = np.array([[500, 375]])
|
| 18 |
+
input_label = np.array([1])
|
| 19 |
+
masks, _, _ = predictor.predict(
|
| 20 |
+
point_coords=input_point,
|
| 21 |
+
point_labels=input_label,
|
| 22 |
+
multimask_output=True)
|
| 23 |
+
print(masks)
|
tryon/preprocessing/u2net/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .load_u2net import load_cloth_segm_model, load_human_segm_model
|
| 2 |
+
from .data_loader import SalObjDataset, RescaleT, ToTensorLab, ToTensor
|
| 3 |
+
from .utils import normPRED
|
tryon/preprocessing/u2net/data_loader.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function, division
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from skimage import io, transform, color
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RescaleT(object):
|
| 12 |
+
|
| 13 |
+
def __init__(self, output_size):
|
| 14 |
+
assert isinstance(output_size, (int, tuple))
|
| 15 |
+
self.output_size = output_size
|
| 16 |
+
|
| 17 |
+
def __call__(self, sample):
|
| 18 |
+
imidx, image, label = sample['imidx'], sample['image'], sample['label']
|
| 19 |
+
|
| 20 |
+
h, w = image.shape[:2]
|
| 21 |
+
|
| 22 |
+
if isinstance(self.output_size, int):
|
| 23 |
+
if h > w:
|
| 24 |
+
new_h, new_w = self.output_size * h / w, self.output_size
|
| 25 |
+
else:
|
| 26 |
+
new_h, new_w = self.output_size, self.output_size * w / h
|
| 27 |
+
else:
|
| 28 |
+
new_h, new_w = self.output_size
|
| 29 |
+
|
| 30 |
+
new_h, new_w = int(new_h), int(new_w)
|
| 31 |
+
|
| 32 |
+
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
|
| 33 |
+
# img = transform.resize(image,(new_h,new_w),mode='constant')
|
| 34 |
+
# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
|
| 35 |
+
|
| 36 |
+
img = transform.resize(image, (self.output_size, self.output_size), mode='constant')
|
| 37 |
+
lbl = transform.resize(label, (self.output_size, self.output_size), mode='constant', order=0,
|
| 38 |
+
preserve_range=True)
|
| 39 |
+
|
| 40 |
+
return {'imidx': imidx, 'image': img, 'label': lbl}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Rescale(object):
|
| 44 |
+
|
| 45 |
+
def __init__(self, output_size):
|
| 46 |
+
assert isinstance(output_size, (int, tuple))
|
| 47 |
+
self.output_size = output_size
|
| 48 |
+
|
| 49 |
+
def __call__(self, sample):
|
| 50 |
+
imidx, image, label = sample['imidx'], sample['image'], sample['label']
|
| 51 |
+
|
| 52 |
+
if random.random() >= 0.5:
|
| 53 |
+
image = image[::-1]
|
| 54 |
+
label = label[::-1]
|
| 55 |
+
|
| 56 |
+
h, w = image.shape[:2]
|
| 57 |
+
|
| 58 |
+
if isinstance(self.output_size, int):
|
| 59 |
+
if h > w:
|
| 60 |
+
new_h, new_w = self.output_size * h / w, self.output_size
|
| 61 |
+
else:
|
| 62 |
+
new_h, new_w = self.output_size, self.output_size * w / h
|
| 63 |
+
else:
|
| 64 |
+
new_h, new_w = self.output_size
|
| 65 |
+
|
| 66 |
+
new_h, new_w = int(new_h), int(new_w)
|
| 67 |
+
|
| 68 |
+
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
|
| 69 |
+
img = transform.resize(image, (new_h, new_w), mode='constant')
|
| 70 |
+
lbl = transform.resize(label, (new_h, new_w), mode='constant', order=0, preserve_range=True)
|
| 71 |
+
|
| 72 |
+
return {'imidx': imidx, 'image': img, 'label': lbl}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class RandomCrop(object):
|
| 76 |
+
|
| 77 |
+
def __init__(self, output_size):
|
| 78 |
+
assert isinstance(output_size, (int, tuple))
|
| 79 |
+
if isinstance(output_size, int):
|
| 80 |
+
self.output_size = (output_size, output_size)
|
| 81 |
+
else:
|
| 82 |
+
assert len(output_size) == 2
|
| 83 |
+
self.output_size = output_size
|
| 84 |
+
|
| 85 |
+
def __call__(self, sample):
|
| 86 |
+
imidx, image, label = sample['imidx'], sample['image'], sample['label']
|
| 87 |
+
|
| 88 |
+
if random.random() >= 0.5:
|
| 89 |
+
image = image[::-1]
|
| 90 |
+
label = label[::-1]
|
| 91 |
+
|
| 92 |
+
h, w = image.shape[:2]
|
| 93 |
+
new_h, new_w = self.output_size
|
| 94 |
+
|
| 95 |
+
top = np.random.randint(0, h - new_h)
|
| 96 |
+
left = np.random.randint(0, w - new_w)
|
| 97 |
+
|
| 98 |
+
image = image[top: top + new_h, left: left + new_w]
|
| 99 |
+
label = label[top: top + new_h, left: left + new_w]
|
| 100 |
+
|
| 101 |
+
return {'imidx': imidx, 'image': image, 'label': label}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class ToTensor(object):
|
| 105 |
+
"""Convert ndarrays in sample to Tensors."""
|
| 106 |
+
|
| 107 |
+
def __call__(self, sample):
|
| 108 |
+
|
| 109 |
+
imidx, image, label = sample['imidx'], sample['image'], sample['label']
|
| 110 |
+
|
| 111 |
+
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
|
| 112 |
+
tmpLbl = np.zeros(label.shape)
|
| 113 |
+
|
| 114 |
+
image = image / np.max(image)
|
| 115 |
+
if (np.max(label) < 1e-6):
|
| 116 |
+
label = label
|
| 117 |
+
else:
|
| 118 |
+
label = label / np.max(label)
|
| 119 |
+
|
| 120 |
+
if image.shape[2] == 1:
|
| 121 |
+
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
| 122 |
+
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
|
| 123 |
+
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
|
| 124 |
+
else:
|
| 125 |
+
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
| 126 |
+
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
|
| 127 |
+
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
|
| 128 |
+
|
| 129 |
+
tmpLbl[:, :, 0] = label[:, :, 0]
|
| 130 |
+
|
| 131 |
+
tmpImg = tmpImg.transpose((2, 0, 1))
|
| 132 |
+
tmpLbl = label.transpose((2, 0, 1))
|
| 133 |
+
|
| 134 |
+
return {'imidx': torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class ToTensorLab(object):
|
| 138 |
+
"""Convert ndarrays in sample to Tensors."""
|
| 139 |
+
|
| 140 |
+
def __init__(self, flag=0):
|
| 141 |
+
self.flag = flag
|
| 142 |
+
|
| 143 |
+
def __call__(self, sample):
|
| 144 |
+
|
| 145 |
+
imidx, image, label = sample['imidx'], sample['image'], sample['label']
|
| 146 |
+
|
| 147 |
+
tmpLbl = np.zeros(label.shape)
|
| 148 |
+
|
| 149 |
+
if (np.max(label) < 1e-6):
|
| 150 |
+
label = label
|
| 151 |
+
else:
|
| 152 |
+
label = label / np.max(label)
|
| 153 |
+
|
| 154 |
+
# change the color space
|
| 155 |
+
if self.flag == 2: # with rgb and Lab colors
|
| 156 |
+
tmpImg = np.zeros((image.shape[0], image.shape[1], 6))
|
| 157 |
+
tmpImgt = np.zeros((image.shape[0], image.shape[1], 3))
|
| 158 |
+
if image.shape[2] == 1:
|
| 159 |
+
tmpImgt[:, :, 0] = image[:, :, 0]
|
| 160 |
+
tmpImgt[:, :, 1] = image[:, :, 0]
|
| 161 |
+
tmpImgt[:, :, 2] = image[:, :, 0]
|
| 162 |
+
else:
|
| 163 |
+
tmpImgt = image
|
| 164 |
+
tmpImgtl = color.rgb2lab(tmpImgt)
|
| 165 |
+
|
| 166 |
+
# nomalize image to range [0,1]
|
| 167 |
+
tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (
|
| 168 |
+
np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0]))
|
| 169 |
+
tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (
|
| 170 |
+
np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1]))
|
| 171 |
+
tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (
|
| 172 |
+
np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2]))
|
| 173 |
+
tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (
|
| 174 |
+
np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0]))
|
| 175 |
+
tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (
|
| 176 |
+
np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1]))
|
| 177 |
+
tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (
|
| 178 |
+
np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2]))
|
| 179 |
+
|
| 180 |
+
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
|
| 181 |
+
|
| 182 |
+
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(tmpImg[:, :, 0])
|
| 183 |
+
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1])
|
| 184 |
+
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2])
|
| 185 |
+
tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(tmpImg[:, :, 3])
|
| 186 |
+
tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(tmpImg[:, :, 4])
|
| 187 |
+
tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(tmpImg[:, :, 5])
|
| 188 |
+
|
| 189 |
+
elif self.flag == 1: # with Lab color
|
| 190 |
+
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
|
| 191 |
+
|
| 192 |
+
if image.shape[2] == 1:
|
| 193 |
+
tmpImg[:, :, 0] = image[:, :, 0]
|
| 194 |
+
tmpImg[:, :, 1] = image[:, :, 0]
|
| 195 |
+
tmpImg[:, :, 2] = image[:, :, 0]
|
| 196 |
+
else:
|
| 197 |
+
tmpImg = image
|
| 198 |
+
|
| 199 |
+
tmpImg = color.rgb2lab(tmpImg)
|
| 200 |
+
|
| 201 |
+
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
|
| 202 |
+
|
| 203 |
+
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (
|
| 204 |
+
np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0]))
|
| 205 |
+
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (
|
| 206 |
+
np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1]))
|
| 207 |
+
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (
|
| 208 |
+
np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2]))
|
| 209 |
+
|
| 210 |
+
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(tmpImg[:, :, 0])
|
| 211 |
+
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1])
|
| 212 |
+
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2])
|
| 213 |
+
|
| 214 |
+
else: # with rgb color
|
| 215 |
+
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
|
| 216 |
+
image = image / np.max(image)
|
| 217 |
+
if image.shape[2] == 1:
|
| 218 |
+
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
| 219 |
+
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
|
| 220 |
+
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
|
| 221 |
+
else:
|
| 222 |
+
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
| 223 |
+
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
|
| 224 |
+
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
|
| 225 |
+
|
| 226 |
+
tmpLbl[:, :, 0] = label[:, :, 0]
|
| 227 |
+
|
| 228 |
+
tmpImg = tmpImg.transpose((2, 0, 1))
|
| 229 |
+
tmpLbl = label.transpose((2, 0, 1))
|
| 230 |
+
|
| 231 |
+
return {'imidx': torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class SalObjDataset(Dataset):
|
| 235 |
+
def __init__(self, img_name_list, lbl_name_list, transform=None):
|
| 236 |
+
# self.root_dir = root_dir
|
| 237 |
+
# self.image_name_list = glob.glob(image_dir+'*.png')
|
| 238 |
+
# self.label_name_list = glob.glob(label_dir+'*.png')
|
| 239 |
+
self.image_name_list = img_name_list
|
| 240 |
+
self.label_name_list = lbl_name_list
|
| 241 |
+
self.transform = transform
|
| 242 |
+
|
| 243 |
+
def __len__(self):
|
| 244 |
+
return len(self.image_name_list)
|
| 245 |
+
|
| 246 |
+
def __getitem__(self, idx):
|
| 247 |
+
|
| 248 |
+
# image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
|
| 249 |
+
# label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
|
| 250 |
+
|
| 251 |
+
image = io.imread(self.image_name_list[idx])
|
| 252 |
+
imname = self.image_name_list[idx]
|
| 253 |
+
imidx = np.array([idx])
|
| 254 |
+
|
| 255 |
+
if (0 == len(self.label_name_list)):
|
| 256 |
+
label_3 = np.zeros(image.shape)
|
| 257 |
+
else:
|
| 258 |
+
label_3 = io.imread(self.label_name_list[idx])
|
| 259 |
+
|
| 260 |
+
label = np.zeros(label_3.shape[0:2])
|
| 261 |
+
if (3 == len(label_3.shape)):
|
| 262 |
+
label = label_3[:, :, 0]
|
| 263 |
+
elif (2 == len(label_3.shape)):
|
| 264 |
+
label = label_3
|
| 265 |
+
|
| 266 |
+
if (3 == len(image.shape) and 2 == len(label.shape)):
|
| 267 |
+
label = label[:, :, np.newaxis]
|
| 268 |
+
elif (2 == len(image.shape) and 2 == len(label.shape)):
|
| 269 |
+
image = image[:, :, np.newaxis]
|
| 270 |
+
label = label[:, :, np.newaxis]
|
| 271 |
+
|
| 272 |
+
sample = {'imidx': imidx, 'image': image, 'label': label}
|
| 273 |
+
|
| 274 |
+
if self.transform:
|
| 275 |
+
sample = self.transform(sample)
|
| 276 |
+
|
| 277 |
+
return sample
|
tryon/preprocessing/u2net/load_u2net.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from tryon.preprocessing.u2net import u2net_cloth_segm, u2net_human_segm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_cloth_segm_model(device, checkpoint_path, in_ch=3, out_ch=1):
|
| 10 |
+
if not os.path.exists(checkpoint_path):
|
| 11 |
+
print("Invalid path")
|
| 12 |
+
return
|
| 13 |
+
|
| 14 |
+
model = u2net_cloth_segm.U2NET(in_ch=in_ch, out_ch=out_ch)
|
| 15 |
+
|
| 16 |
+
model_state_dict = torch.load(checkpoint_path, map_location=device)
|
| 17 |
+
new_state_dict = OrderedDict()
|
| 18 |
+
for k, v in model_state_dict.items():
|
| 19 |
+
name = k[7:] # remove `module.`
|
| 20 |
+
new_state_dict[name] = v
|
| 21 |
+
|
| 22 |
+
model.load_state_dict(new_state_dict)
|
| 23 |
+
model = model.to(device=device)
|
| 24 |
+
|
| 25 |
+
print("Checkpoints loaded from path: {}".format(checkpoint_path))
|
| 26 |
+
|
| 27 |
+
return model
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_human_segm_model(device, model_name):
|
| 31 |
+
if model_name == 'u2net':
|
| 32 |
+
print("loading U2NET(173.6 MB)...")
|
| 33 |
+
net = u2net_human_segm.U2NET(3, 1)
|
| 34 |
+
elif model_name == 'u2netp':
|
| 35 |
+
print("loading U2NEP(4.7 MB)...")
|
| 36 |
+
net = u2net_human_segm.U2NETP(3, 1)
|
| 37 |
+
else:
|
| 38 |
+
net = None
|
| 39 |
+
|
| 40 |
+
if torch.cuda.is_available():
|
| 41 |
+
net.load_state_dict(torch.load(os.environ.get("U2NET_SEGM_CHECKPOINT_PATH")))
|
| 42 |
+
net.cuda()
|
| 43 |
+
else:
|
| 44 |
+
net.load_state_dict(torch.load(os.environ.get("U2NET_SEGM_CHECKPOINT_PATH"), map_location=device))
|
| 45 |
+
net.eval()
|
| 46 |
+
|
| 47 |
+
return net
|
tryon/preprocessing/u2net/u2net_cloth_segm.py
ADDED
|
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class REBNCONV(nn.Module):
|
| 7 |
+
def __init__(self, in_ch=3, out_ch=3, dirate=1):
|
| 8 |
+
super(REBNCONV, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.conv_s1 = nn.Conv2d(
|
| 11 |
+
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
|
| 12 |
+
)
|
| 13 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
| 14 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
hx = x
|
| 18 |
+
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
| 19 |
+
|
| 20 |
+
return xout
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
| 24 |
+
def _upsample_like(src, tar):
|
| 25 |
+
src = F.upsample(src, size=tar.shape[2:], mode="bilinear")
|
| 26 |
+
|
| 27 |
+
return src
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
### RSU-7 ###
|
| 31 |
+
class RSU7(nn.Module): # UNet07DRES(nn.Module):
|
| 32 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 33 |
+
super(RSU7, self).__init__()
|
| 34 |
+
|
| 35 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 36 |
+
|
| 37 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 38 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 39 |
+
|
| 40 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 41 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 42 |
+
|
| 43 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 44 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 45 |
+
|
| 46 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 47 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 48 |
+
|
| 49 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 50 |
+
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 51 |
+
|
| 52 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 53 |
+
|
| 54 |
+
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 55 |
+
|
| 56 |
+
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 57 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 58 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 59 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 60 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 61 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
hx = x
|
| 65 |
+
hxin = self.rebnconvin(hx)
|
| 66 |
+
|
| 67 |
+
hx1 = self.rebnconv1(hxin)
|
| 68 |
+
hx = self.pool1(hx1)
|
| 69 |
+
|
| 70 |
+
hx2 = self.rebnconv2(hx)
|
| 71 |
+
hx = self.pool2(hx2)
|
| 72 |
+
|
| 73 |
+
hx3 = self.rebnconv3(hx)
|
| 74 |
+
hx = self.pool3(hx3)
|
| 75 |
+
|
| 76 |
+
hx4 = self.rebnconv4(hx)
|
| 77 |
+
hx = self.pool4(hx4)
|
| 78 |
+
|
| 79 |
+
hx5 = self.rebnconv5(hx)
|
| 80 |
+
hx = self.pool5(hx5)
|
| 81 |
+
|
| 82 |
+
hx6 = self.rebnconv6(hx)
|
| 83 |
+
|
| 84 |
+
hx7 = self.rebnconv7(hx6)
|
| 85 |
+
|
| 86 |
+
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
| 87 |
+
hx6dup = _upsample_like(hx6d, hx5)
|
| 88 |
+
|
| 89 |
+
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
| 90 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 91 |
+
|
| 92 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
| 93 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 94 |
+
|
| 95 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
| 96 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 97 |
+
|
| 98 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 99 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 100 |
+
|
| 101 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 102 |
+
|
| 103 |
+
"""
|
| 104 |
+
del hx1, hx2, hx3, hx4, hx5, hx6, hx7
|
| 105 |
+
del hx6d, hx5d, hx3d, hx2d
|
| 106 |
+
del hx2dup, hx3dup, hx4dup, hx5dup, hx6dup
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
return hx1d + hxin
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
### RSU-6 ###
|
| 113 |
+
class RSU6(nn.Module): # UNet06DRES(nn.Module):
|
| 114 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 115 |
+
super(RSU6, self).__init__()
|
| 116 |
+
|
| 117 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 118 |
+
|
| 119 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 120 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 121 |
+
|
| 122 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 123 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 124 |
+
|
| 125 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 126 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 127 |
+
|
| 128 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 129 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 130 |
+
|
| 131 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 132 |
+
|
| 133 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 134 |
+
|
| 135 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 136 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 137 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 138 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 139 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
hx = x
|
| 143 |
+
|
| 144 |
+
hxin = self.rebnconvin(hx)
|
| 145 |
+
|
| 146 |
+
hx1 = self.rebnconv1(hxin)
|
| 147 |
+
hx = self.pool1(hx1)
|
| 148 |
+
|
| 149 |
+
hx2 = self.rebnconv2(hx)
|
| 150 |
+
hx = self.pool2(hx2)
|
| 151 |
+
|
| 152 |
+
hx3 = self.rebnconv3(hx)
|
| 153 |
+
hx = self.pool3(hx3)
|
| 154 |
+
|
| 155 |
+
hx4 = self.rebnconv4(hx)
|
| 156 |
+
hx = self.pool4(hx4)
|
| 157 |
+
|
| 158 |
+
hx5 = self.rebnconv5(hx)
|
| 159 |
+
|
| 160 |
+
hx6 = self.rebnconv6(hx5)
|
| 161 |
+
|
| 162 |
+
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
| 163 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 164 |
+
|
| 165 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
| 166 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 167 |
+
|
| 168 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
| 169 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 170 |
+
|
| 171 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 172 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 173 |
+
|
| 174 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 175 |
+
|
| 176 |
+
"""
|
| 177 |
+
del hx1, hx2, hx3, hx4, hx5, hx6
|
| 178 |
+
del hx5d, hx4d, hx3d, hx2d
|
| 179 |
+
del hx2dup, hx3dup, hx4dup, hx5dup
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
return hx1d + hxin
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
### RSU-5 ###
|
| 186 |
+
class RSU5(nn.Module): # UNet05DRES(nn.Module):
|
| 187 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 188 |
+
super(RSU5, self).__init__()
|
| 189 |
+
|
| 190 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 191 |
+
|
| 192 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 193 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 194 |
+
|
| 195 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 196 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 197 |
+
|
| 198 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 199 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 200 |
+
|
| 201 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 202 |
+
|
| 203 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 204 |
+
|
| 205 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 206 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 207 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 208 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 209 |
+
|
| 210 |
+
def forward(self, x):
|
| 211 |
+
hx = x
|
| 212 |
+
|
| 213 |
+
hxin = self.rebnconvin(hx)
|
| 214 |
+
|
| 215 |
+
hx1 = self.rebnconv1(hxin)
|
| 216 |
+
hx = self.pool1(hx1)
|
| 217 |
+
|
| 218 |
+
hx2 = self.rebnconv2(hx)
|
| 219 |
+
hx = self.pool2(hx2)
|
| 220 |
+
|
| 221 |
+
hx3 = self.rebnconv3(hx)
|
| 222 |
+
hx = self.pool3(hx3)
|
| 223 |
+
|
| 224 |
+
hx4 = self.rebnconv4(hx)
|
| 225 |
+
|
| 226 |
+
hx5 = self.rebnconv5(hx4)
|
| 227 |
+
|
| 228 |
+
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
| 229 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 230 |
+
|
| 231 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
| 232 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 233 |
+
|
| 234 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 235 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 236 |
+
|
| 237 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 238 |
+
|
| 239 |
+
"""
|
| 240 |
+
del hx1, hx2, hx3, hx4, hx5
|
| 241 |
+
del hx4d, hx3d, hx2d
|
| 242 |
+
del hx2dup, hx3dup, hx4dup
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
return hx1d + hxin
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
### RSU-4 ###
|
| 249 |
+
class RSU4(nn.Module): # UNet04DRES(nn.Module):
|
| 250 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 251 |
+
super(RSU4, self).__init__()
|
| 252 |
+
|
| 253 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 254 |
+
|
| 255 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 256 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 257 |
+
|
| 258 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 259 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 260 |
+
|
| 261 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 262 |
+
|
| 263 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 264 |
+
|
| 265 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 266 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 267 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 268 |
+
|
| 269 |
+
def forward(self, x):
|
| 270 |
+
hx = x
|
| 271 |
+
|
| 272 |
+
hxin = self.rebnconvin(hx)
|
| 273 |
+
|
| 274 |
+
hx1 = self.rebnconv1(hxin)
|
| 275 |
+
hx = self.pool1(hx1)
|
| 276 |
+
|
| 277 |
+
hx2 = self.rebnconv2(hx)
|
| 278 |
+
hx = self.pool2(hx2)
|
| 279 |
+
|
| 280 |
+
hx3 = self.rebnconv3(hx)
|
| 281 |
+
|
| 282 |
+
hx4 = self.rebnconv4(hx3)
|
| 283 |
+
|
| 284 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
| 285 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 286 |
+
|
| 287 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 288 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 289 |
+
|
| 290 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 291 |
+
|
| 292 |
+
"""
|
| 293 |
+
del hx1, hx2, hx3, hx4
|
| 294 |
+
del hx3d, hx2d
|
| 295 |
+
del hx2dup, hx3dup
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
return hx1d + hxin
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
### RSU-4F ###
|
| 302 |
+
class RSU4F(nn.Module): # UNet04FRES(nn.Module):
|
| 303 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 304 |
+
super(RSU4F, self).__init__()
|
| 305 |
+
|
| 306 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 307 |
+
|
| 308 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 309 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 310 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
| 311 |
+
|
| 312 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
| 313 |
+
|
| 314 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
| 315 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
| 316 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 317 |
+
|
| 318 |
+
def forward(self, x):
|
| 319 |
+
hx = x
|
| 320 |
+
|
| 321 |
+
hxin = self.rebnconvin(hx)
|
| 322 |
+
|
| 323 |
+
hx1 = self.rebnconv1(hxin)
|
| 324 |
+
hx2 = self.rebnconv2(hx1)
|
| 325 |
+
hx3 = self.rebnconv3(hx2)
|
| 326 |
+
|
| 327 |
+
hx4 = self.rebnconv4(hx3)
|
| 328 |
+
|
| 329 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
| 330 |
+
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
| 331 |
+
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
| 332 |
+
|
| 333 |
+
"""
|
| 334 |
+
del hx1, hx2, hx3, hx4
|
| 335 |
+
del hx3d, hx2d
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
return hx1d + hxin
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
##### U^2-Net ####
|
| 342 |
+
class U2NET(nn.Module):
|
| 343 |
+
def __init__(self, in_ch=3, out_ch=1):
|
| 344 |
+
super(U2NET, self).__init__()
|
| 345 |
+
|
| 346 |
+
self.stage1 = RSU7(in_ch, 32, 64)
|
| 347 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 348 |
+
|
| 349 |
+
self.stage2 = RSU6(64, 32, 128)
|
| 350 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 351 |
+
|
| 352 |
+
self.stage3 = RSU5(128, 64, 256)
|
| 353 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 354 |
+
|
| 355 |
+
self.stage4 = RSU4(256, 128, 512)
|
| 356 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 357 |
+
|
| 358 |
+
self.stage5 = RSU4F(512, 256, 512)
|
| 359 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 360 |
+
|
| 361 |
+
self.stage6 = RSU4F(512, 256, 512)
|
| 362 |
+
|
| 363 |
+
# decoder
|
| 364 |
+
self.stage5d = RSU4F(1024, 256, 512)
|
| 365 |
+
self.stage4d = RSU4(1024, 128, 256)
|
| 366 |
+
self.stage3d = RSU5(512, 64, 128)
|
| 367 |
+
self.stage2d = RSU6(256, 32, 64)
|
| 368 |
+
self.stage1d = RSU7(128, 16, 64)
|
| 369 |
+
|
| 370 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 371 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 372 |
+
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
| 373 |
+
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
| 374 |
+
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
| 375 |
+
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
| 376 |
+
|
| 377 |
+
self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
|
| 378 |
+
|
| 379 |
+
def forward(self, x):
|
| 380 |
+
hx = x
|
| 381 |
+
|
| 382 |
+
# stage 1
|
| 383 |
+
hx1 = self.stage1(hx)
|
| 384 |
+
hx = self.pool12(hx1)
|
| 385 |
+
|
| 386 |
+
# stage 2
|
| 387 |
+
hx2 = self.stage2(hx)
|
| 388 |
+
hx = self.pool23(hx2)
|
| 389 |
+
|
| 390 |
+
# stage 3
|
| 391 |
+
hx3 = self.stage3(hx)
|
| 392 |
+
hx = self.pool34(hx3)
|
| 393 |
+
|
| 394 |
+
# stage 4
|
| 395 |
+
hx4 = self.stage4(hx)
|
| 396 |
+
hx = self.pool45(hx4)
|
| 397 |
+
|
| 398 |
+
# stage 5
|
| 399 |
+
hx5 = self.stage5(hx)
|
| 400 |
+
hx = self.pool56(hx5)
|
| 401 |
+
|
| 402 |
+
# stage 6
|
| 403 |
+
hx6 = self.stage6(hx)
|
| 404 |
+
hx6up = _upsample_like(hx6, hx5)
|
| 405 |
+
|
| 406 |
+
# -------------------- decoder --------------------
|
| 407 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
| 408 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 409 |
+
|
| 410 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
| 411 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 412 |
+
|
| 413 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
| 414 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 415 |
+
|
| 416 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
| 417 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 418 |
+
|
| 419 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
| 420 |
+
|
| 421 |
+
# side output
|
| 422 |
+
d1 = self.side1(hx1d)
|
| 423 |
+
|
| 424 |
+
d2 = self.side2(hx2d)
|
| 425 |
+
d2 = _upsample_like(d2, d1)
|
| 426 |
+
|
| 427 |
+
d3 = self.side3(hx3d)
|
| 428 |
+
d3 = _upsample_like(d3, d1)
|
| 429 |
+
|
| 430 |
+
d4 = self.side4(hx4d)
|
| 431 |
+
d4 = _upsample_like(d4, d1)
|
| 432 |
+
|
| 433 |
+
d5 = self.side5(hx5d)
|
| 434 |
+
d5 = _upsample_like(d5, d1)
|
| 435 |
+
|
| 436 |
+
d6 = self.side6(hx6)
|
| 437 |
+
d6 = _upsample_like(d6, d1)
|
| 438 |
+
|
| 439 |
+
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
| 440 |
+
|
| 441 |
+
"""
|
| 442 |
+
del hx1, hx2, hx3, hx4, hx5, hx6
|
| 443 |
+
del hx5d, hx4d, hx3d, hx2d, hx1d
|
| 444 |
+
del hx6up, hx5dup, hx4dup, hx3dup, hx2dup
|
| 445 |
+
"""
|
| 446 |
+
|
| 447 |
+
return d0, d1, d2, d3, d4, d5, d6
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
### U^2-Net small ###
|
| 451 |
+
class U2NETP(nn.Module):
|
| 452 |
+
def __init__(self, in_ch=3, out_ch=1):
|
| 453 |
+
super(U2NETP, self).__init__()
|
| 454 |
+
|
| 455 |
+
self.stage1 = RSU7(in_ch, 16, 64)
|
| 456 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 457 |
+
|
| 458 |
+
self.stage2 = RSU6(64, 16, 64)
|
| 459 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 460 |
+
|
| 461 |
+
self.stage3 = RSU5(64, 16, 64)
|
| 462 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 463 |
+
|
| 464 |
+
self.stage4 = RSU4(64, 16, 64)
|
| 465 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 466 |
+
|
| 467 |
+
self.stage5 = RSU4F(64, 16, 64)
|
| 468 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 469 |
+
|
| 470 |
+
self.stage6 = RSU4F(64, 16, 64)
|
| 471 |
+
|
| 472 |
+
# decoder
|
| 473 |
+
self.stage5d = RSU4F(128, 16, 64)
|
| 474 |
+
self.stage4d = RSU4(128, 16, 64)
|
| 475 |
+
self.stage3d = RSU5(128, 16, 64)
|
| 476 |
+
self.stage2d = RSU6(128, 16, 64)
|
| 477 |
+
self.stage1d = RSU7(128, 16, 64)
|
| 478 |
+
|
| 479 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 480 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 481 |
+
self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 482 |
+
self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 483 |
+
self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 484 |
+
self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 485 |
+
|
| 486 |
+
self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
|
| 487 |
+
|
| 488 |
+
def forward(self, x):
|
| 489 |
+
hx = x
|
| 490 |
+
|
| 491 |
+
# stage 1
|
| 492 |
+
hx1 = self.stage1(hx)
|
| 493 |
+
hx = self.pool12(hx1)
|
| 494 |
+
|
| 495 |
+
# stage 2
|
| 496 |
+
hx2 = self.stage2(hx)
|
| 497 |
+
hx = self.pool23(hx2)
|
| 498 |
+
|
| 499 |
+
# stage 3
|
| 500 |
+
hx3 = self.stage3(hx)
|
| 501 |
+
hx = self.pool34(hx3)
|
| 502 |
+
|
| 503 |
+
# stage 4
|
| 504 |
+
hx4 = self.stage4(hx)
|
| 505 |
+
hx = self.pool45(hx4)
|
| 506 |
+
|
| 507 |
+
# stage 5
|
| 508 |
+
hx5 = self.stage5(hx)
|
| 509 |
+
hx = self.pool56(hx5)
|
| 510 |
+
|
| 511 |
+
# stage 6
|
| 512 |
+
hx6 = self.stage6(hx)
|
| 513 |
+
hx6up = _upsample_like(hx6, hx5)
|
| 514 |
+
|
| 515 |
+
# decoder
|
| 516 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
| 517 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 518 |
+
|
| 519 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
| 520 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 521 |
+
|
| 522 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
| 523 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 524 |
+
|
| 525 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
| 526 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 527 |
+
|
| 528 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
| 529 |
+
|
| 530 |
+
# side output
|
| 531 |
+
d1 = self.side1(hx1d)
|
| 532 |
+
|
| 533 |
+
d2 = self.side2(hx2d)
|
| 534 |
+
d2 = _upsample_like(d2, d1)
|
| 535 |
+
|
| 536 |
+
d3 = self.side3(hx3d)
|
| 537 |
+
d3 = _upsample_like(d3, d1)
|
| 538 |
+
|
| 539 |
+
d4 = self.side4(hx4d)
|
| 540 |
+
d4 = _upsample_like(d4, d1)
|
| 541 |
+
|
| 542 |
+
d5 = self.side5(hx5d)
|
| 543 |
+
d5 = _upsample_like(d5, d1)
|
| 544 |
+
|
| 545 |
+
d6 = self.side6(hx6)
|
| 546 |
+
d6 = _upsample_like(d6, d1)
|
| 547 |
+
|
| 548 |
+
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
| 549 |
+
|
| 550 |
+
return d0, d1, d2, d3, d4, d5, d6
|
tryon/preprocessing/u2net/u2net_human_segm.py
ADDED
|
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class REBNCONV(nn.Module):
|
| 7 |
+
def __init__(self, in_ch=3, out_ch=3, dirate=1):
|
| 8 |
+
super(REBNCONV, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
|
| 11 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
| 12 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
hx = x
|
| 16 |
+
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
| 17 |
+
|
| 18 |
+
return xout
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
| 22 |
+
def _upsample_like(src, tar):
|
| 23 |
+
src = F.upsample(src, size=tar.shape[2:], mode='bilinear')
|
| 24 |
+
|
| 25 |
+
return src
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
### RSU-7 ###
|
| 29 |
+
class RSU7(nn.Module): # UNet07DRES(nn.Module):
|
| 30 |
+
|
| 31 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 32 |
+
super(RSU7, self).__init__()
|
| 33 |
+
|
| 34 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 35 |
+
|
| 36 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 37 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 38 |
+
|
| 39 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 40 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 41 |
+
|
| 42 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 43 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 44 |
+
|
| 45 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 46 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 47 |
+
|
| 48 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 49 |
+
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 50 |
+
|
| 51 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 52 |
+
|
| 53 |
+
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 54 |
+
|
| 55 |
+
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 56 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 57 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 58 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 59 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 60 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
hx = x
|
| 64 |
+
hxin = self.rebnconvin(hx)
|
| 65 |
+
|
| 66 |
+
hx1 = self.rebnconv1(hxin)
|
| 67 |
+
hx = self.pool1(hx1)
|
| 68 |
+
|
| 69 |
+
hx2 = self.rebnconv2(hx)
|
| 70 |
+
hx = self.pool2(hx2)
|
| 71 |
+
|
| 72 |
+
hx3 = self.rebnconv3(hx)
|
| 73 |
+
hx = self.pool3(hx3)
|
| 74 |
+
|
| 75 |
+
hx4 = self.rebnconv4(hx)
|
| 76 |
+
hx = self.pool4(hx4)
|
| 77 |
+
|
| 78 |
+
hx5 = self.rebnconv5(hx)
|
| 79 |
+
hx = self.pool5(hx5)
|
| 80 |
+
|
| 81 |
+
hx6 = self.rebnconv6(hx)
|
| 82 |
+
|
| 83 |
+
hx7 = self.rebnconv7(hx6)
|
| 84 |
+
|
| 85 |
+
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
| 86 |
+
hx6dup = _upsample_like(hx6d, hx5)
|
| 87 |
+
|
| 88 |
+
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
| 89 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 90 |
+
|
| 91 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
| 92 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 93 |
+
|
| 94 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
| 95 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 96 |
+
|
| 97 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 98 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 99 |
+
|
| 100 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 101 |
+
|
| 102 |
+
return hx1d + hxin
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
### RSU-6 ###
|
| 106 |
+
class RSU6(nn.Module): # UNet06DRES(nn.Module):
|
| 107 |
+
|
| 108 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 109 |
+
super(RSU6, self).__init__()
|
| 110 |
+
|
| 111 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 112 |
+
|
| 113 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 114 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 115 |
+
|
| 116 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 117 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 118 |
+
|
| 119 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 120 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 121 |
+
|
| 122 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 123 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 124 |
+
|
| 125 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 126 |
+
|
| 127 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 128 |
+
|
| 129 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 130 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 131 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 132 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 133 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
hx = x
|
| 137 |
+
|
| 138 |
+
hxin = self.rebnconvin(hx)
|
| 139 |
+
|
| 140 |
+
hx1 = self.rebnconv1(hxin)
|
| 141 |
+
hx = self.pool1(hx1)
|
| 142 |
+
|
| 143 |
+
hx2 = self.rebnconv2(hx)
|
| 144 |
+
hx = self.pool2(hx2)
|
| 145 |
+
|
| 146 |
+
hx3 = self.rebnconv3(hx)
|
| 147 |
+
hx = self.pool3(hx3)
|
| 148 |
+
|
| 149 |
+
hx4 = self.rebnconv4(hx)
|
| 150 |
+
hx = self.pool4(hx4)
|
| 151 |
+
|
| 152 |
+
hx5 = self.rebnconv5(hx)
|
| 153 |
+
|
| 154 |
+
hx6 = self.rebnconv6(hx5)
|
| 155 |
+
|
| 156 |
+
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
| 157 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 158 |
+
|
| 159 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
| 160 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 161 |
+
|
| 162 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
| 163 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 164 |
+
|
| 165 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 166 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 167 |
+
|
| 168 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 169 |
+
|
| 170 |
+
return hx1d + hxin
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
### RSU-5 ###
|
| 174 |
+
class RSU5(nn.Module): # UNet05DRES(nn.Module):
|
| 175 |
+
|
| 176 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 177 |
+
super(RSU5, self).__init__()
|
| 178 |
+
|
| 179 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 180 |
+
|
| 181 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 182 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 183 |
+
|
| 184 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 185 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 186 |
+
|
| 187 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 188 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 189 |
+
|
| 190 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 191 |
+
|
| 192 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 193 |
+
|
| 194 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 195 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 196 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 197 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 198 |
+
|
| 199 |
+
def forward(self, x):
|
| 200 |
+
hx = x
|
| 201 |
+
|
| 202 |
+
hxin = self.rebnconvin(hx)
|
| 203 |
+
|
| 204 |
+
hx1 = self.rebnconv1(hxin)
|
| 205 |
+
hx = self.pool1(hx1)
|
| 206 |
+
|
| 207 |
+
hx2 = self.rebnconv2(hx)
|
| 208 |
+
hx = self.pool2(hx2)
|
| 209 |
+
|
| 210 |
+
hx3 = self.rebnconv3(hx)
|
| 211 |
+
hx = self.pool3(hx3)
|
| 212 |
+
|
| 213 |
+
hx4 = self.rebnconv4(hx)
|
| 214 |
+
|
| 215 |
+
hx5 = self.rebnconv5(hx4)
|
| 216 |
+
|
| 217 |
+
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
| 218 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 219 |
+
|
| 220 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
| 221 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 222 |
+
|
| 223 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 224 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 225 |
+
|
| 226 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 227 |
+
|
| 228 |
+
return hx1d + hxin
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
### RSU-4 ###
|
| 232 |
+
class RSU4(nn.Module): # UNet04DRES(nn.Module):
|
| 233 |
+
|
| 234 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 235 |
+
super(RSU4, self).__init__()
|
| 236 |
+
|
| 237 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 238 |
+
|
| 239 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 240 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 241 |
+
|
| 242 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 243 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 244 |
+
|
| 245 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 246 |
+
|
| 247 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 248 |
+
|
| 249 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 250 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 251 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 252 |
+
|
| 253 |
+
def forward(self, x):
|
| 254 |
+
hx = x
|
| 255 |
+
|
| 256 |
+
hxin = self.rebnconvin(hx)
|
| 257 |
+
|
| 258 |
+
hx1 = self.rebnconv1(hxin)
|
| 259 |
+
hx = self.pool1(hx1)
|
| 260 |
+
|
| 261 |
+
hx2 = self.rebnconv2(hx)
|
| 262 |
+
hx = self.pool2(hx2)
|
| 263 |
+
|
| 264 |
+
hx3 = self.rebnconv3(hx)
|
| 265 |
+
|
| 266 |
+
hx4 = self.rebnconv4(hx3)
|
| 267 |
+
|
| 268 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
| 269 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 270 |
+
|
| 271 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 272 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 273 |
+
|
| 274 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 275 |
+
|
| 276 |
+
return hx1d + hxin
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
### RSU-4F ###
|
| 280 |
+
class RSU4F(nn.Module): # UNet04FRES(nn.Module):
|
| 281 |
+
|
| 282 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 283 |
+
super(RSU4F, self).__init__()
|
| 284 |
+
|
| 285 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 286 |
+
|
| 287 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 288 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 289 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
| 290 |
+
|
| 291 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
| 292 |
+
|
| 293 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
| 294 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
| 295 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 296 |
+
|
| 297 |
+
def forward(self, x):
|
| 298 |
+
hx = x
|
| 299 |
+
|
| 300 |
+
hxin = self.rebnconvin(hx)
|
| 301 |
+
|
| 302 |
+
hx1 = self.rebnconv1(hxin)
|
| 303 |
+
hx2 = self.rebnconv2(hx1)
|
| 304 |
+
hx3 = self.rebnconv3(hx2)
|
| 305 |
+
|
| 306 |
+
hx4 = self.rebnconv4(hx3)
|
| 307 |
+
|
| 308 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
| 309 |
+
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
| 310 |
+
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
| 311 |
+
|
| 312 |
+
return hx1d + hxin
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
##### U^2-Net ####
|
| 316 |
+
class U2NET(nn.Module):
|
| 317 |
+
|
| 318 |
+
def __init__(self, in_ch=3, out_ch=1):
|
| 319 |
+
super(U2NET, self).__init__()
|
| 320 |
+
|
| 321 |
+
self.stage1 = RSU7(in_ch, 32, 64)
|
| 322 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 323 |
+
|
| 324 |
+
self.stage2 = RSU6(64, 32, 128)
|
| 325 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 326 |
+
|
| 327 |
+
self.stage3 = RSU5(128, 64, 256)
|
| 328 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 329 |
+
|
| 330 |
+
self.stage4 = RSU4(256, 128, 512)
|
| 331 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 332 |
+
|
| 333 |
+
self.stage5 = RSU4F(512, 256, 512)
|
| 334 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 335 |
+
|
| 336 |
+
self.stage6 = RSU4F(512, 256, 512)
|
| 337 |
+
|
| 338 |
+
# decoder
|
| 339 |
+
self.stage5d = RSU4F(1024, 256, 512)
|
| 340 |
+
self.stage4d = RSU4(1024, 128, 256)
|
| 341 |
+
self.stage3d = RSU5(512, 64, 128)
|
| 342 |
+
self.stage2d = RSU6(256, 32, 64)
|
| 343 |
+
self.stage1d = RSU7(128, 16, 64)
|
| 344 |
+
|
| 345 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 346 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 347 |
+
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
| 348 |
+
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
| 349 |
+
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
| 350 |
+
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
| 351 |
+
|
| 352 |
+
self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
|
| 353 |
+
|
| 354 |
+
def forward(self, x):
|
| 355 |
+
hx = x
|
| 356 |
+
|
| 357 |
+
# stage 1
|
| 358 |
+
hx1 = self.stage1(hx)
|
| 359 |
+
hx = self.pool12(hx1)
|
| 360 |
+
|
| 361 |
+
# stage 2
|
| 362 |
+
hx2 = self.stage2(hx)
|
| 363 |
+
hx = self.pool23(hx2)
|
| 364 |
+
|
| 365 |
+
# stage 3
|
| 366 |
+
hx3 = self.stage3(hx)
|
| 367 |
+
hx = self.pool34(hx3)
|
| 368 |
+
|
| 369 |
+
# stage 4
|
| 370 |
+
hx4 = self.stage4(hx)
|
| 371 |
+
hx = self.pool45(hx4)
|
| 372 |
+
|
| 373 |
+
# stage 5
|
| 374 |
+
hx5 = self.stage5(hx)
|
| 375 |
+
hx = self.pool56(hx5)
|
| 376 |
+
|
| 377 |
+
# stage 6
|
| 378 |
+
hx6 = self.stage6(hx)
|
| 379 |
+
hx6up = _upsample_like(hx6, hx5)
|
| 380 |
+
|
| 381 |
+
# -------------------- decoder --------------------
|
| 382 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
| 383 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 384 |
+
|
| 385 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
| 386 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 387 |
+
|
| 388 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
| 389 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 390 |
+
|
| 391 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
| 392 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 393 |
+
|
| 394 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
| 395 |
+
|
| 396 |
+
# side output
|
| 397 |
+
d1 = self.side1(hx1d)
|
| 398 |
+
|
| 399 |
+
d2 = self.side2(hx2d)
|
| 400 |
+
d2 = _upsample_like(d2, d1)
|
| 401 |
+
|
| 402 |
+
d3 = self.side3(hx3d)
|
| 403 |
+
d3 = _upsample_like(d3, d1)
|
| 404 |
+
|
| 405 |
+
d4 = self.side4(hx4d)
|
| 406 |
+
d4 = _upsample_like(d4, d1)
|
| 407 |
+
|
| 408 |
+
d5 = self.side5(hx5d)
|
| 409 |
+
d5 = _upsample_like(d5, d1)
|
| 410 |
+
|
| 411 |
+
d6 = self.side6(hx6)
|
| 412 |
+
d6 = _upsample_like(d6, d1)
|
| 413 |
+
|
| 414 |
+
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
| 415 |
+
|
| 416 |
+
return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
### U^2-Net small ###
|
| 420 |
+
class U2NETP(nn.Module):
|
| 421 |
+
|
| 422 |
+
def __init__(self, in_ch=3, out_ch=1):
|
| 423 |
+
super(U2NETP, self).__init__()
|
| 424 |
+
|
| 425 |
+
self.stage1 = RSU7(in_ch, 16, 64)
|
| 426 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 427 |
+
|
| 428 |
+
self.stage2 = RSU6(64, 16, 64)
|
| 429 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 430 |
+
|
| 431 |
+
self.stage3 = RSU5(64, 16, 64)
|
| 432 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 433 |
+
|
| 434 |
+
self.stage4 = RSU4(64, 16, 64)
|
| 435 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 436 |
+
|
| 437 |
+
self.stage5 = RSU4F(64, 16, 64)
|
| 438 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 439 |
+
|
| 440 |
+
self.stage6 = RSU4F(64, 16, 64)
|
| 441 |
+
|
| 442 |
+
# decoder
|
| 443 |
+
self.stage5d = RSU4F(128, 16, 64)
|
| 444 |
+
self.stage4d = RSU4(128, 16, 64)
|
| 445 |
+
self.stage3d = RSU5(128, 16, 64)
|
| 446 |
+
self.stage2d = RSU6(128, 16, 64)
|
| 447 |
+
self.stage1d = RSU7(128, 16, 64)
|
| 448 |
+
|
| 449 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 450 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 451 |
+
self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 452 |
+
self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 453 |
+
self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 454 |
+
self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 455 |
+
|
| 456 |
+
self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
|
| 457 |
+
|
| 458 |
+
def forward(self, x):
|
| 459 |
+
hx = x
|
| 460 |
+
|
| 461 |
+
# stage 1
|
| 462 |
+
hx1 = self.stage1(hx)
|
| 463 |
+
hx = self.pool12(hx1)
|
| 464 |
+
|
| 465 |
+
# stage 2
|
| 466 |
+
hx2 = self.stage2(hx)
|
| 467 |
+
hx = self.pool23(hx2)
|
| 468 |
+
|
| 469 |
+
# stage 3
|
| 470 |
+
hx3 = self.stage3(hx)
|
| 471 |
+
hx = self.pool34(hx3)
|
| 472 |
+
|
| 473 |
+
# stage 4
|
| 474 |
+
hx4 = self.stage4(hx)
|
| 475 |
+
hx = self.pool45(hx4)
|
| 476 |
+
|
| 477 |
+
# stage 5
|
| 478 |
+
hx5 = self.stage5(hx)
|
| 479 |
+
hx = self.pool56(hx5)
|
| 480 |
+
|
| 481 |
+
# stage 6
|
| 482 |
+
hx6 = self.stage6(hx)
|
| 483 |
+
hx6up = _upsample_like(hx6, hx5)
|
| 484 |
+
|
| 485 |
+
# decoder
|
| 486 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
| 487 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 488 |
+
|
| 489 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
| 490 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 491 |
+
|
| 492 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
| 493 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 494 |
+
|
| 495 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
| 496 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 497 |
+
|
| 498 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
| 499 |
+
|
| 500 |
+
# side output
|
| 501 |
+
d1 = self.side1(hx1d)
|
| 502 |
+
|
| 503 |
+
d2 = self.side2(hx2d)
|
| 504 |
+
d2 = _upsample_like(d2, d1)
|
| 505 |
+
|
| 506 |
+
d3 = self.side3(hx3d)
|
| 507 |
+
d3 = _upsample_like(d3, d1)
|
| 508 |
+
|
| 509 |
+
d4 = self.side4(hx4d)
|
| 510 |
+
d4 = _upsample_like(d4, d1)
|
| 511 |
+
|
| 512 |
+
d5 = self.side5(hx5d)
|
| 513 |
+
d5 = _upsample_like(d5, d1)
|
| 514 |
+
|
| 515 |
+
d6 = self.side6(hx6)
|
| 516 |
+
d6 = _upsample_like(d6, d1)
|
| 517 |
+
|
| 518 |
+
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
| 519 |
+
|
| 520 |
+
return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
|
tryon/preprocessing/u2net/utils.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def normPRED(d):
|
| 5 |
+
ma = torch.max(d)
|
| 6 |
+
mi = torch.min(d)
|
| 7 |
+
|
| 8 |
+
dn = (d - mi) / (ma - mi)
|
| 9 |
+
|
| 10 |
+
return dn
|
tryon/preprocessing/utils.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class NormalizeImage(object):
|
| 10 |
+
"""Normalize given tensor into given mean and standard dev
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
mean (float): Desired mean to substract from tensors
|
| 14 |
+
std (float): Desired std to divide from tensors
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, mean, std):
|
| 18 |
+
assert isinstance(mean, (float))
|
| 19 |
+
if isinstance(mean, float):
|
| 20 |
+
self.mean = mean
|
| 21 |
+
|
| 22 |
+
if isinstance(std, float):
|
| 23 |
+
self.std = std
|
| 24 |
+
|
| 25 |
+
self.normalize_1 = transforms.Normalize(self.mean, self.std)
|
| 26 |
+
self.normalize_3 = transforms.Normalize([self.mean] * 3, [self.std] * 3)
|
| 27 |
+
self.normalize_18 = transforms.Normalize([self.mean] * 18, [self.std] * 18)
|
| 28 |
+
|
| 29 |
+
def __call__(self, image_tensor):
|
| 30 |
+
if image_tensor.shape[0] == 1:
|
| 31 |
+
return self.normalize_1(image_tensor)
|
| 32 |
+
|
| 33 |
+
elif image_tensor.shape[0] == 3:
|
| 34 |
+
return self.normalize_3(image_tensor)
|
| 35 |
+
|
| 36 |
+
elif image_tensor.shape[0] == 18:
|
| 37 |
+
return self.normalize_18(image_tensor)
|
| 38 |
+
|
| 39 |
+
else:
|
| 40 |
+
assert "Please set proper channels! Normalization implemented only for 1, 3 and 18"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def naive_cutout(img, mask):
|
| 44 |
+
empty = Image.new("RGBA", (img.size), 0)
|
| 45 |
+
cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
|
| 46 |
+
return cutout
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def resize_by_bigger_index(crop):
|
| 50 |
+
# function resizes and keeps the aspect ratio same
|
| 51 |
+
crop_shape = crop.shape # hxwxc
|
| 52 |
+
if crop_shape[0] / crop_shape[1] <= 1.33:
|
| 53 |
+
resized_crop = image_resize(crop, width=768)
|
| 54 |
+
else:
|
| 55 |
+
resized_crop = image_resize(crop, height=1024)
|
| 56 |
+
return resized_crop
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def image_resize(image, width=None, height=None):
|
| 60 |
+
dim = None
|
| 61 |
+
(h, w) = image.shape[:2]
|
| 62 |
+
|
| 63 |
+
if width is None and height is None:
|
| 64 |
+
return image
|
| 65 |
+
|
| 66 |
+
if width is None:
|
| 67 |
+
r = height / float(h)
|
| 68 |
+
dim = (int(w * r), height)
|
| 69 |
+
|
| 70 |
+
else:
|
| 71 |
+
r = width / float(w)
|
| 72 |
+
dim = (width, int(h * r))
|
| 73 |
+
|
| 74 |
+
resized = cv2.resize(image, dim)
|
| 75 |
+
|
| 76 |
+
return resized
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def convert_to_jpg(image_path, output_dir, size=None):
|
| 80 |
+
"""
|
| 81 |
+
Convert image to jpg format
|
| 82 |
+
:param image_path: image path
|
| 83 |
+
:param output_dir: output directory
|
| 84 |
+
:param size: desired size of the image (w, h)
|
| 85 |
+
"""
|
| 86 |
+
img = cv2.imread(image_path)
|
| 87 |
+
if size is not None:
|
| 88 |
+
img = image_resize(img, width=size[0], height=size[1])
|
| 89 |
+
|
| 90 |
+
filename = Path(image_path).name
|
| 91 |
+
cv2.imwrite(os.path.join(output_dir, filename.split(".")[0] + ".jpg"), img)
|
tryondiffusion/__init__.py
ADDED
|
File without changes
|
tryondiffusion/diffusion.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from torch import optim
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
import cv2
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
from .network import UNet64, UNet128
|
| 14 |
+
from .utils import mk_folders, GaussianSmoothing, UNetDataset
|
| 15 |
+
from .ema import EMA
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def smoothen_image(img, sigma):
|
| 19 |
+
# As suggested in:
|
| 20 |
+
# https://jmlr.csail.mit.edu/papers/volume23/21-0635/21-0635.pdf Section 4.4
|
| 21 |
+
|
| 22 |
+
smoothing2d = GaussianSmoothing(channels=3,
|
| 23 |
+
kernel_size=3,
|
| 24 |
+
sigma=sigma,
|
| 25 |
+
conv_dim=2)
|
| 26 |
+
|
| 27 |
+
img = F.pad(img, (1, 1, 1, 1), mode='reflect')
|
| 28 |
+
img = smoothing2d(img)
|
| 29 |
+
|
| 30 |
+
return img
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def schedule_lr(total_steps, start_lr=0.0, stop_lr=0.0001, pct_increasing_lr=0.02):
|
| 34 |
+
n = total_steps * pct_increasing_lr
|
| 35 |
+
n = round(n)
|
| 36 |
+
lambdas = list(np.linspace(start_lr, stop_lr, n))
|
| 37 |
+
constant_lr_list = [stop_lr] * (total_steps - n)
|
| 38 |
+
lambdas.extend(constant_lr_list)
|
| 39 |
+
return lambdas
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Diffusion:
|
| 43 |
+
|
| 44 |
+
def __init__(self,
|
| 45 |
+
device,
|
| 46 |
+
pose_embed_dim,
|
| 47 |
+
time_steps=256,
|
| 48 |
+
beta_start=1e-4,
|
| 49 |
+
beta_end=0.02,
|
| 50 |
+
unet_dim=64,
|
| 51 |
+
noise_input_channel=3,
|
| 52 |
+
beta_ema=0.995):
|
| 53 |
+
self.time_steps = time_steps
|
| 54 |
+
self.beta_start = beta_start
|
| 55 |
+
self.beta_end = beta_end
|
| 56 |
+
|
| 57 |
+
self.beta = self.linear_beta_scheduler().to(device)
|
| 58 |
+
self.alpha = 1 - self.beta
|
| 59 |
+
self.alpha_cumprod = torch.cumprod(self.alpha, dim=0)
|
| 60 |
+
|
| 61 |
+
self.noise_input_channel = noise_input_channel
|
| 62 |
+
self.unet_dim = unet_dim
|
| 63 |
+
if unet_dim == 128:
|
| 64 |
+
self.net = UNet128(pose_embed_dim, device, time_steps).to(device)
|
| 65 |
+
elif unet_dim == 64:
|
| 66 |
+
self.net = UNet64(pose_embed_dim, device, time_steps).to(device)
|
| 67 |
+
|
| 68 |
+
self.ema_net = copy.deepcopy(self.net).eval().requires_grad_(False)
|
| 69 |
+
self.beta_ema = beta_ema
|
| 70 |
+
|
| 71 |
+
self.device = device
|
| 72 |
+
|
| 73 |
+
def linear_beta_scheduler(self):
|
| 74 |
+
return torch.linspace(self.beta_start, self.beta_end, self.time_steps)
|
| 75 |
+
|
| 76 |
+
def sample_time_steps(self, batch_size):
|
| 77 |
+
return torch.randint(low=1, high=self.time_steps, size=(batch_size,))
|
| 78 |
+
|
| 79 |
+
def add_noise_to_img(self, img, t):
|
| 80 |
+
sqrt_alpha_timestep = torch.sqrt(self.alpha_cumprod[t])[:, None, None, None]
|
| 81 |
+
sqrt_one_minus_alpha_timestep = torch.sqrt(1 - self.alpha_cumprod[t])[:, None, None, None]
|
| 82 |
+
epsilon = torch.randn_like(img)
|
| 83 |
+
return (sqrt_alpha_timestep * epsilon) + (sqrt_one_minus_alpha_timestep * epsilon), epsilon
|
| 84 |
+
|
| 85 |
+
@torch.inference_mode()
|
| 86 |
+
def sample(self, use_ema, conditional_inputs):
|
| 87 |
+
model = self.ema_net if use_ema else self.net
|
| 88 |
+
ic, jp, jg, ia = conditional_inputs
|
| 89 |
+
ic = ic.to(self.device)
|
| 90 |
+
jp = jp.to(self.device)
|
| 91 |
+
jg = jg.to(self.device)
|
| 92 |
+
ia = ia.to(self.device)
|
| 93 |
+
batch_size = len(ic)
|
| 94 |
+
logging.info(f"Running inference for {batch_size} images")
|
| 95 |
+
|
| 96 |
+
model.eval()
|
| 97 |
+
with torch.inference_mode():
|
| 98 |
+
|
| 99 |
+
# noise augmentation during testing as suggested in paper
|
| 100 |
+
sigma = float(torch.FloatTensor(1).uniform_(0.4, 0.6))
|
| 101 |
+
ia = smoothen_image(ia, sigma)
|
| 102 |
+
ic = smoothen_image(ic, sigma)
|
| 103 |
+
|
| 104 |
+
inp_network_noise = torch.randn(batch_size, self.noise_input_channel, self.unet_dim, self.unet_dim).to(self.device)
|
| 105 |
+
|
| 106 |
+
# paper says to add noise augmentation to input noise during inference
|
| 107 |
+
inp_network_noise = smoothen_image(inp_network_noise, sigma)
|
| 108 |
+
|
| 109 |
+
# concatenating noise with rgb agnostic image across channels
|
| 110 |
+
# corrupt -> concatenate -> predict
|
| 111 |
+
x = torch.cat((inp_network_noise, ia), dim=1)
|
| 112 |
+
|
| 113 |
+
for i in reversed(range(1, self.time_steps)):
|
| 114 |
+
t = (torch.ones(batch_size) * i).long().to(self.device)
|
| 115 |
+
predicted_noise = model(x, ic, jp, jg, t, sigma)
|
| 116 |
+
# ToDo: Add Classifier-Free Guidance with guidance weight 2
|
| 117 |
+
alpha = self.alpha[t][:, None, None, None]
|
| 118 |
+
alpha_cumprod = self.alpha_cumprod[t][:, None, None, None]
|
| 119 |
+
beta = self.beta[t][:, None, None, None]
|
| 120 |
+
if i > 1:
|
| 121 |
+
noise = torch.randn_like(inp_network_noise)
|
| 122 |
+
else:
|
| 123 |
+
noise = torch.zeros_like(inp_network_noise)
|
| 124 |
+
|
| 125 |
+
inp_network_noise = 1 / torch.sqrt(alpha) * (inp_network_noise - ((1 - alpha) / (torch.sqrt(1 - alpha_cumprod))) * predicted_noise) + torch.sqrt(beta) * noise
|
| 126 |
+
inp_network_noise = (inp_network_noise.clamp(-1, 1) + 1) / 2
|
| 127 |
+
inp_network_noise = (inp_network_noise * 255).type(torch.uint8)
|
| 128 |
+
return inp_network_noise
|
| 129 |
+
|
| 130 |
+
def prepare(self, args):
|
| 131 |
+
mk_folders(args.run_name)
|
| 132 |
+
train_dataset = UNetDataset(ip_dir=args.train_ip_folder,
|
| 133 |
+
jp_dir=args.train_jp_folder,
|
| 134 |
+
jg_dir=args.train_jg_folder,
|
| 135 |
+
ia_dir=args.train_ia_folder,
|
| 136 |
+
ic_dir=args.train_ic_folder,
|
| 137 |
+
unet_size=self.unet_dim)
|
| 138 |
+
|
| 139 |
+
validation_dataset = UNetDataset(ip_dir=args.validation_ip_folder,
|
| 140 |
+
jp_dir=args.validation_jp_folder,
|
| 141 |
+
jg_dir=args.validation_jg_folder,
|
| 142 |
+
ia_dir=args.validation_ia_folder,
|
| 143 |
+
ic_dir=args.validation_ic_folder,
|
| 144 |
+
unet_size=self.unet_dim)
|
| 145 |
+
|
| 146 |
+
self.train_dataloader = DataLoader(train_dataset, args.batch_size_train, shuffle=True)
|
| 147 |
+
# give args.batch_size_validation 1 while training
|
| 148 |
+
self.val_dataloader = DataLoader(validation_dataset, args.batch_size_validation, shuffle=True)
|
| 149 |
+
|
| 150 |
+
self.optimizer = optim.AdamW(self.net.parameters(), lr=args.lr, eps=1e-4)
|
| 151 |
+
self.scheduler = schedule_lr(total_steps=args.total_steps, start_lr=args.start_lr,
|
| 152 |
+
stop_lr=args.stop_lr, pct_increasing_lr=args.pct_increasing_lr)
|
| 153 |
+
self.mse = nn.MSELoss()
|
| 154 |
+
self.ema = EMA(self.beta_ema)
|
| 155 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
| 156 |
+
|
| 157 |
+
def train_step(self, loss, running_step):
|
| 158 |
+
self.optimizer.zero_grad()
|
| 159 |
+
self.scaler.scale(loss).backward()
|
| 160 |
+
self.scaler.step(self.optimizer)
|
| 161 |
+
self.scaler.update()
|
| 162 |
+
self.ema.step_ema(self.ema_net, self.net)
|
| 163 |
+
for g in self.optimizer.param_groups:
|
| 164 |
+
g['lr'] = self.scheduler[running_step]
|
| 165 |
+
|
| 166 |
+
def single_epoch(self, train=True):
|
| 167 |
+
avg_loss = 0.
|
| 168 |
+
if train:
|
| 169 |
+
self.net.train()
|
| 170 |
+
else:
|
| 171 |
+
self.net.eval()
|
| 172 |
+
|
| 173 |
+
for ip, jp, jg, ia, ic in self.train_dataloader:
|
| 174 |
+
|
| 175 |
+
# noise augmentation
|
| 176 |
+
sigma = float(torch.FloatTensor(1).uniform_(0.4, 0.6))
|
| 177 |
+
ia = smoothen_image(ia, sigma)
|
| 178 |
+
ic = smoothen_image(ic, sigma)
|
| 179 |
+
|
| 180 |
+
with torch.autocast(self.device) and (torch.inference_mode() if not train else torch.enable_grad()):
|
| 181 |
+
ip = ip.to(self.device)
|
| 182 |
+
jp = jp.to(self.device)
|
| 183 |
+
jg = jg.to(self.device)
|
| 184 |
+
ia = ia.to(self.device)
|
| 185 |
+
ic = ic.to(self.device)
|
| 186 |
+
t = self.sample_time_steps(ip.shape[0]).to(self.device)
|
| 187 |
+
|
| 188 |
+
# corrupt -> concatenate -> predict
|
| 189 |
+
zt, noise_epsilon = self.add_noise_to_img(ip, t)
|
| 190 |
+
|
| 191 |
+
zt = torch.cat((zt, ia), dim=1)
|
| 192 |
+
|
| 193 |
+
# ToDO: Make conditional inputs null, at 10% of the training time,
|
| 194 |
+
# ToDo: for classifier-free guidance(GitHub Issue #21), with guidance weight 2.
|
| 195 |
+
|
| 196 |
+
predicted_noise = self.net(zt, ic, jp, jg, t, sigma)
|
| 197 |
+
loss = self.mse(noise_epsilon, predicted_noise)
|
| 198 |
+
avg_loss += loss
|
| 199 |
+
|
| 200 |
+
if train:
|
| 201 |
+
self.train_step(loss, self.running_train_steps)
|
| 202 |
+
# ToDo: Add logs to tensorboard as well
|
| 203 |
+
logging.info(
|
| 204 |
+
f"train_mse_loss: {loss.item():2.3f}, learning_rate: {self.scheduler[self.running_train_steps]}")
|
| 205 |
+
self.running_train_steps += 1
|
| 206 |
+
|
| 207 |
+
return avg_loss.mean().item()
|
| 208 |
+
|
| 209 |
+
def logging_images(self, epoch, run_name):
|
| 210 |
+
|
| 211 |
+
for idx, (ip, jp, jg, ia, ic) in enumerate(self.val_dataloader):
|
| 212 |
+
# sampled image
|
| 213 |
+
sampled_image = self.sample(use_ema=False, conditional_inputs=(ic, jp, jg, ia))
|
| 214 |
+
sampled_image = sampled_image[0].permute(1, 2, 0).squeeze().cpu().numpy()
|
| 215 |
+
|
| 216 |
+
# ema sampled image
|
| 217 |
+
ema_sampled_image = self.sample(use_ema=True, conditional_inputs=(ic, jp, jg, ia))
|
| 218 |
+
ema_sampled_image = ema_sampled_image[0].permute(1, 2, 0).squeeze().cpu().numpy()
|
| 219 |
+
|
| 220 |
+
# base images
|
| 221 |
+
ip_np = ip[0].permute(1, 2, 0).squeeze().cpu().numpy()
|
| 222 |
+
ic_np = ic[0].permute(1, 2, 0).squeeze().cpu().numpy()
|
| 223 |
+
ia_np = ia[0].permute(1, 2, 0).squeeze().cpu().numpy()
|
| 224 |
+
|
| 225 |
+
# make to folders
|
| 226 |
+
os.makedirs(os.path.join("results", run_name, "images", f"{idx}_E{epoch}"), exist_ok=True)
|
| 227 |
+
|
| 228 |
+
# define folder paths
|
| 229 |
+
images_folder = os.path.join("results", run_name, "images", f"{idx}_E{epoch}")
|
| 230 |
+
|
| 231 |
+
# save base images
|
| 232 |
+
cv2.imwrite(os.path.join(images_folder, "ground_truth.png"), ip_np)
|
| 233 |
+
cv2.imwrite(os.path.join(images_folder, "segmented_garment.png"), ic_np)
|
| 234 |
+
cv2.imwrite(os.path.join(images_folder, "cloth_agnostic_rgb.png"), ia_np)
|
| 235 |
+
|
| 236 |
+
# save sampled image
|
| 237 |
+
cv2.imwrite(os.path.join(images_folder, "sampled_image.png"), sampled_image)
|
| 238 |
+
|
| 239 |
+
# save ema sampled image
|
| 240 |
+
cv2.imwrite(os.path.join(images_folder, "ema_sampled_image.png"), ema_sampled_image)
|
| 241 |
+
|
| 242 |
+
def save_models(self, run_name, epoch=-1):
|
| 243 |
+
|
| 244 |
+
torch.save(self.net.state_dict(), os.path.join("models", run_name, f"ckpt_{epoch}.pt"))
|
| 245 |
+
torch.save(self.ema_net.state_dict(), os.path.join("models", run_name, f"ema_ckpt_{epoch}.pt"))
|
| 246 |
+
torch.save(self.optimizer.state_dict(), os.path.join("models", run_name, f"optim_{epoch}.pt"))
|
| 247 |
+
|
| 248 |
+
def fit(self, args):
|
| 249 |
+
|
| 250 |
+
logging.info(f"Starting training")
|
| 251 |
+
|
| 252 |
+
data_len = len(self.train_dataloader)
|
| 253 |
+
|
| 254 |
+
epochs = round((args.total_steps * args.batch_size_train) / data_len)
|
| 255 |
+
|
| 256 |
+
if epochs < 0:
|
| 257 |
+
epochs = 1
|
| 258 |
+
|
| 259 |
+
self.running_train_steps = 0
|
| 260 |
+
|
| 261 |
+
for epoch in range(epochs):
|
| 262 |
+
logging.info(f"Starting Epoch: {epoch + 1}")
|
| 263 |
+
_ = self.single_epoch(train=True)
|
| 264 |
+
|
| 265 |
+
if (epoch + 1) % args.calculate_loss_frequency == 0:
|
| 266 |
+
avg_loss = self.single_epoch(train=False)
|
| 267 |
+
logging.info(f"Average Loss: {avg_loss}")
|
| 268 |
+
|
| 269 |
+
if (epoch + 1) % args.image_logging_frequency == 0:
|
| 270 |
+
self.logging_images(epoch, args.run_name)
|
| 271 |
+
|
| 272 |
+
if (epoch + 1) % args.model_saving_frequency == 0:
|
| 273 |
+
self.save_models(args.run_name, epoch)
|
| 274 |
+
|
| 275 |
+
logging.info(f"Training Done Successfully! Yayyy! Now let's hope for good results")
|