SwimmingLiu commited on
Commit
d91c189
·
1 Parent(s): a48a9ba
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +163 -0
  2. CONTRIBUTING.md +20 -0
  3. README.md +38 -14
  4. caption_images.py +52 -0
  5. demo/__init__.py +3 -0
  6. demo/extract_garment/README.md +14 -0
  7. demo/extract_garment/__init__.py +1 -0
  8. demo/extract_garment/app.py +76 -0
  9. demo/extract_garment/requirements.txt +3 -0
  10. demo/model_swap/.gitignore +1 -0
  11. demo/model_swap/README.md +14 -0
  12. demo/model_swap/__init__.py +1 -0
  13. demo/model_swap/app.py +321 -0
  14. demo/model_swap/requirements.txt +2 -0
  15. demo/outfit_generator/README.md +86 -0
  16. demo/outfit_generator/__init__.py +1 -0
  17. demo/outfit_generator/app.py +164 -0
  18. demo/outfit_generator/images/sample1.jpeg +0 -0
  19. demo/outfit_generator/images/sample2.jpeg +0 -0
  20. demo/outfit_generator/images/sample3.jpeg +0 -0
  21. demo/outfit_generator/images/sample4.jpeg +0 -0
  22. demo/outfit_generator/requirements.txt +10 -0
  23. environment.yml +179 -0
  24. main.py +44 -0
  25. requirements.txt +15 -0
  26. run_demo.py +18 -0
  27. run_ootd.py +37 -0
  28. scripts/install_conda.sh +10 -0
  29. scripts/install_sam2.sh +11 -0
  30. setup.py +31 -0
  31. tryon/README.md +34 -0
  32. tryon/__init__.py +0 -0
  33. tryon/models/__init__.py +0 -0
  34. tryon/models/ootdiffusion/setup.sh +30 -0
  35. tryon/preprocessing/__init__.py +3 -0
  36. tryon/preprocessing/captioning/__init__.py +2 -0
  37. tryon/preprocessing/captioning/generate_caption.py +108 -0
  38. tryon/preprocessing/extract_garment_new.py +91 -0
  39. tryon/preprocessing/preprocess_garment.py +107 -0
  40. tryon/preprocessing/preprocess_human.py +86 -0
  41. tryon/preprocessing/sam2/__init__.py +23 -0
  42. tryon/preprocessing/u2net/__init__.py +3 -0
  43. tryon/preprocessing/u2net/data_loader.py +277 -0
  44. tryon/preprocessing/u2net/load_u2net.py +47 -0
  45. tryon/preprocessing/u2net/u2net_cloth_segm.py +550 -0
  46. tryon/preprocessing/u2net/u2net_human_segm.py +520 -0
  47. tryon/preprocessing/u2net/utils.py +10 -0
  48. tryon/preprocessing/utils.py +91 -0
  49. tryondiffusion/__init__.py +0 -0
  50. tryondiffusion/diffusion.py +275 -0
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ u2net_cloth_segm.pth
163
+ u2net_segm.pth
CONTRIBUTING.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## How to contribute to tryondiffusion
2
+
3
+ ### 1. Open an issue
4
+ We recommend opening an issue (if one doesn't already exist) and discussing your intended changes before making any changes.
5
+ We'll be able to provide you feedback and confirm the planned modifications this way.
6
+
7
+ ### 2. Make changes in the code
8
+ Start with forking the repository, set up the environment, install the dependencies, and make changes in the code appropriately.
9
+
10
+ ### 3. Create pull request
11
+ Create a pull request to the main branch from your fork's branch.
12
+
13
+ ### 4. Merging pull request
14
+ Once the pull request is created, we will review the code changes and merge the pull request as soon as possible.
15
+
16
+
17
+ ### Writing documentation
18
+
19
+ If you are interested in writing the documentation, you can add it to README.md and create a pull request.
20
+ For now, we are maintaining our documentation in a single file and we will add more files as it grows.
README.md CHANGED
@@ -1,14 +1,38 @@
1
- ---
2
- title: Tryondiffusion
3
- emoji: 📈
4
- colorFrom: green
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.7.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: 'TryOnDiffusion: A Tale of Two UNets Implementation'
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Try On Diffusion: A Tale of Two UNets Implementation
2
+ ### [Paper Link](https://arxiv.org/abs/2306.08276)
3
+
4
+ ### [Click here](https://discord.gg/T5mPpZHxkY) to join our discord channel
5
+
6
+ ## Roadmap
7
+
8
+ 1. ~~Prepare initial implementation~~
9
+ 1. Test initial implementation with small dataset (VITON-HD)
10
+ 1. Gather sufficient data and compute resources
11
+ 1. Prepare and train final implementation
12
+ 1. Publicly release parameters
13
+
14
+ ## How to contribute to tryondiffusion
15
+
16
+ ### 1. Open an issue
17
+ We recommend opening an issue (if one doesn't already exist) and discussing your intended changes before making any changes.
18
+ We'll be able to provide you feedback and confirm the planned modifications this way.
19
+
20
+ ### 2. Make changes in the code
21
+ Start with forking the repository, set up the environment, install the dependencies, and make changes in the code appropriately.
22
+
23
+ ### 3. Create pull request
24
+ Create a pull request to the main branch from your fork's branch.
25
+
26
+ ### 4. Merging pull request
27
+ Once the pull request is created, we will review the code changes and merge the pull request as soon as possible.
28
+
29
+
30
+ ### Writing documentation
31
+
32
+ If you are interested in writing the documentation, you can add it to README.md and create a pull request.
33
+ For now, we are maintaining our documentation in a single file and we will add more files as it grows.
34
+
35
+
36
+ ## License
37
+
38
+ All material is made available under [Creative Commons BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/). You can **use** the material for **non-commercial purposes**, as long as you give appropriate credit by **citing our original [github repo](https://github.com/kailashahirwar/tryondiffusion)** and **indicate any changes** that you've made to the code.
caption_images.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import os
4
+
5
+ from PIL import Image
6
+
7
+ from tryon.preprocessing.captioning import caption_image, create_llava_next_pipeline
8
+
9
+ INPUT_IMAGES_DIR = os.path.join("fashion_datatset", "*")
10
+ OUTPUT_CAPTIONS_DIR = "fashion_datatset_captions"
11
+
12
+ os.makedirs(OUTPUT_CAPTIONS_DIR, exist_ok=True)
13
+
14
+
15
+ def change_extension(filename, new_extension):
16
+ base_name, _ = os.path.splitext(filename)
17
+ return f"{base_name}.{new_extension}"
18
+
19
+
20
+ if __name__ == '__main__':
21
+ model, processor = create_llava_next_pipeline()
22
+
23
+ images_path = sorted(glob.glob(INPUT_IMAGES_DIR))
24
+
25
+ for index, image_path in enumerate(images_path):
26
+ print(f"index: {index}, total images: {len(images_path)}, {image_path}")
27
+ image = Image.open(image_path)
28
+
29
+ prompt = """
30
+ You're a fashion expert. The list of clothing properties includes [color, pattern, style, fit, type, hemline,
31
+ material, sleeve-length, fabric-elasticity, neckline, waistline]. Please provide the following information in
32
+ JSON format for the outfit shown in the image. Question: What are the color, pattern, style, fit, type,
33
+ hemline, material, sleeve length, fabric elasticity, neckline, and waistline of the outfit in the image?
34
+ Answer:
35
+ """
36
+
37
+ json_file_path = os.path.join(OUTPUT_CAPTIONS_DIR,
38
+ change_extension(os.path.basename(image_path), "json"))
39
+ caption_file_path = os.path.join(OUTPUT_CAPTIONS_DIR,
40
+ change_extension(os.path.basename(image_path), "txt"))
41
+
42
+ if os.path.exists(caption_file_path) and os.path.exists(json_file_path):
43
+ print(f"caption already exists for {image_path}")
44
+ continue
45
+
46
+ json_data, generated_caption = caption_image(image, prompt, model, processor, json_only=False)
47
+
48
+ with open(json_file_path, "w") as f:
49
+ json.dump(json_data, f)
50
+
51
+ with open(caption_file_path, "w") as f:
52
+ f.write(generated_caption)
demo/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .extract_garment import demo as extract_garment_demo
2
+ from .model_swap import demo as model_swap_demo
3
+ from .outfit_generator import demo as outfit_generator_demo
demo/extract_garment/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Extract Garment AI
3
+ emoji: 📊
4
+ colorFrom: indigo
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 4.44.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Gradio Demo of Extract Garment AI by TryOn Labs.
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
demo/extract_garment/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .app import demo
demo/extract_garment/app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from PIL import Image
4
+
5
+ import gradio as gr
6
+ from tryon import preprocessing
7
+
8
+
9
+ def extract_garment(input_img, cls):
10
+ print(input_img, type(input_img), cls)
11
+
12
+ input_dir = "input_image"
13
+ output_dir = "output_image"
14
+
15
+ os.makedirs(input_dir, exist_ok=True)
16
+ os.makedirs(output_dir, exist_ok=True)
17
+
18
+ for f in glob.glob(input_dir + "/*.*"):
19
+ os.remove(f)
20
+
21
+ for f in glob.glob(output_dir + "/*.*"):
22
+ os.remove(f)
23
+
24
+ for f in glob.glob("cloth-mask/*.*"):
25
+ os.remove(f)
26
+
27
+ input_img.save(os.path.join(input_dir, "img.jpg"))
28
+
29
+ preprocessing.extract_garment(inputs_dir=input_dir, outputs_dir=output_dir, cls=cls)
30
+
31
+ return Image.open(glob.glob(output_dir + "/*.*")[0])
32
+
33
+
34
+ css = """
35
+ #col-container {
36
+ margin: 0 auto;
37
+ max-width: 720px;
38
+ }
39
+ """
40
+
41
+ with gr.Blocks(css=css) as demo:
42
+ with gr.Column(elem_id="col-container"):
43
+ gr.Markdown(f"""
44
+ # Clothes Extraction using U2Net
45
+ Pull out clothes like tops, bottoms, and dresses from a photo. This implementation is based on the [U2Net](https://github.com/xuebinqin/U-2-Net) model.
46
+ """)
47
+
48
+ with gr.Row():
49
+ with gr.Column():
50
+ input_image = gr.Image(label="Input Image", type='pil', height="400px", show_label=True)
51
+ dropdown = gr.Dropdown(["upper", "lower", "dress"], value="upper", label="Extract garment",
52
+ info="Select the garment type you wish to extract!")
53
+
54
+ output_image = gr.Image(label="Extracted garment", type='pil', height="400px", show_label=True,
55
+ show_download_button=True)
56
+
57
+ with gr.Row():
58
+ submit_button = gr.Button("Submit", variant='primary', scale=1)
59
+ reset_button = gr.ClearButton(value="Reset", scale=1)
60
+
61
+ gr.on(
62
+ triggers=[submit_button.click],
63
+ fn=extract_garment,
64
+ inputs=[input_image, dropdown],
65
+ outputs=[output_image]
66
+ )
67
+
68
+ reset_button.click(
69
+ fn=lambda: (None, "upper", None),
70
+ inputs=[],
71
+ outputs=[input_image, dropdown, output_image],
72
+ concurrency_limit=1,
73
+ )
74
+
75
+ if __name__ == '__main__':
76
+ demo.launch()
demo/extract_garment/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio==4.44.1
2
+ pillow
3
+ tryondiffusion
demo/model_swap/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .token
demo/model_swap/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Model Swap AI
3
+ emoji: 📊
4
+ colorFrom: indigo
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 4.44.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Gradio Demo of Model Swap AI by TryOn Labs.
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
demo/model_swap/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .app import demo
demo/model_swap/app.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+ import gradio as gr
4
+ import json
5
+ import requests
6
+ import time
7
+ from gradio_modal import Modal
8
+ from io import BytesIO
9
+
10
+ TRYON_SERVER_HOST = "https://prod.server.tryonlabs.ai"
11
+ TRYON_SERVER_PORT = "80"
12
+ if TRYON_SERVER_PORT == "80":
13
+ TRYON_SERVER_URL = f"{TRYON_SERVER_HOST}"
14
+ else:
15
+ TRYON_SERVER_URL = f"{TRYON_SERVER_HOST}:{TRYON_SERVER_PORT}"
16
+
17
+ TRYON_SERVER_API_URL = f"{TRYON_SERVER_URL}/api/v1/"
18
+
19
+
20
+ def start_model_swap(input_image, prompt, cls, seed, guidance_scale, num_results, strength, inference_steps):
21
+ # make a request to TryOn Server
22
+ # 1. create an experiment image
23
+ print("inputs:", input_image, prompt, cls, seed, guidance_scale, num_results, strength, inference_steps)
24
+
25
+ if input_image is None:
26
+ raise gr.Error("Select an image!")
27
+
28
+ if prompt is None or prompt == "":
29
+ raise gr.Error("Enter a prompt!")
30
+
31
+ token = load_token()
32
+ if token is None or token == "":
33
+ raise gr.Error("You need to login first!")
34
+ else:
35
+ login(token)
36
+
37
+ byte_io = BytesIO()
38
+ input_image.save(byte_io, 'png')
39
+ byte_io.seek(0)
40
+
41
+ r = requests.post(f"{TRYON_SERVER_API_URL}experiment_image/",
42
+ files={"image": (
43
+ 'ei_image.png',
44
+ byte_io,
45
+ 'image/png'
46
+ )},
47
+ data={
48
+ "type": "model",
49
+ "preprocess": "false"},
50
+ headers={
51
+ "Authorization": f"Bearer {token}"
52
+ })
53
+ # print(r.json())
54
+ if r.status_code == 200 or r.status_code == 201:
55
+ print("Experiment image created successfully", r.json())
56
+ res = r.json()
57
+ # 2 create an experiment
58
+ r2 = requests.post(f"{TRYON_SERVER_API_URL}experiment/",
59
+ data={
60
+ "model_id": res['id'],
61
+ "action": "model_swap",
62
+ "params": json.dumps({"prompt": prompt,
63
+ "guidance_scale": guidance_scale,
64
+ "strength": strength,
65
+ "num_inference_steps": inference_steps,
66
+ "seed": seed,
67
+ "garment_class": f"{cls} garment",
68
+ "negative_prompt": "(hands:1.15), disfigured, ugly, bad, immature"
69
+ ", cartoon, anime, 3d, painting, b&w, (ugly),"
70
+ " (pixelated), watermark, glossy, smooth, "
71
+ "earrings, necklace",
72
+ "num_results": num_results})
73
+ },
74
+ headers={
75
+ "Authorization": f"Bearer {token}"
76
+ })
77
+ if r2.status_code == 200 or r2.status_code == 201:
78
+ # 3. keep checking the status of the experiment
79
+ res2 = r2.json()
80
+ print("Experiment created successfully", res2)
81
+ time.sleep(10)
82
+
83
+ experiment = res2['experiment']
84
+ status = fetch_experiment_status(experiment_id=experiment['id'], token=token)
85
+ status_status = status['status']
86
+ while status_status == "running":
87
+ time.sleep(10)
88
+ status = fetch_experiment_status(experiment_id=experiment['id'], token=token)
89
+ status_status = status['status']
90
+ print(f"Current status: {status_status}")
91
+
92
+ if status['status'] == "success":
93
+ print("Experiment successful")
94
+ print(f"Results:{status['result_images']}")
95
+ return status['result_images']
96
+ elif status['status'] == "failed":
97
+ print("Experiment failed")
98
+ raise gr.Error("Experiment failed")
99
+ else:
100
+ print(f"Error: {r2.text}")
101
+ raise gr.Error(f"Failure: {r2.text}")
102
+ else:
103
+ print(f"Error: {r.text}")
104
+ raise gr.Error(f"Failure: {r.text}")
105
+
106
+
107
+ def fetch_experiment_status(experiment_id, token):
108
+ print(f"experiment id:{experiment_id}")
109
+
110
+ r3 = requests.get(f"{TRYON_SERVER_API_URL}experiment/{experiment_id}/",
111
+ headers={
112
+ "Authorization": f"Bearer {token}"
113
+ })
114
+ if r3.status_code == 200:
115
+ res = r3.json()
116
+ if res['status'] == "running":
117
+ return {"status": "running"}
118
+ elif res['status'] == "success":
119
+ experiment = r3.json()['experiment']
120
+ result_images = [f"{TRYON_SERVER_URL}/{experiment['result']['image_url']}"]
121
+ if len(experiment['results']) > 0:
122
+ for result in experiment['results']:
123
+ result_images.append(f"{TRYON_SERVER_URL}/{result['image_url']}")
124
+ return {"status": "success", "result_images": result_images}
125
+ elif res['status'] == "failed":
126
+ return {"status": "failed"}
127
+ else:
128
+ print(f"Error: {r3.text}")
129
+ return {"status": "failed"}
130
+
131
+
132
+ def get_user_credits(token):
133
+ if token == "":
134
+ return None
135
+
136
+ r = requests.get(f"{TRYON_SERVER_API_URL}user/get/", headers={
137
+ "Authorization": f"Bearer {token}"
138
+ })
139
+ if r.status_code == 200:
140
+ res = r.json()
141
+ return res['credits']
142
+ else:
143
+ print(f"Error: {r.text}")
144
+ return None
145
+
146
+
147
+ def load_token():
148
+ if os.path.exists(".token"):
149
+ with open(".token", "r") as f:
150
+ return json.load(f)['token']
151
+ else:
152
+ return None
153
+
154
+
155
+ def save_token(access_token):
156
+ if access_token != "":
157
+ with open(".token", "w") as f:
158
+ json.dump({"token": access_token}, f)
159
+ else:
160
+ raise gr.Error("No token provided!")
161
+
162
+
163
+ def is_logged_in():
164
+ loaded_token = load_token()
165
+ if loaded_token is None or loaded_token == "":
166
+ return False
167
+ else:
168
+ return True
169
+
170
+
171
+ def login(token):
172
+ print("logging in...")
173
+ # validate token
174
+ r = requests.post(f"{TRYON_SERVER_URL}/api/token/verify/", data={"token": token})
175
+ if r.status_code == 200:
176
+ save_token(token)
177
+ return True
178
+ else:
179
+ raise gr.Error("Login failed")
180
+
181
+
182
+ def logout():
183
+ print("logged out")
184
+ with open(".token", "w") as f:
185
+ json.dump({"token": ""}, f)
186
+ return [False, ""]
187
+
188
+
189
+ css = """
190
+ #col-container {
191
+ margin: 0 auto;
192
+ max-width: 1024px;
193
+ }
194
+ #credits-col-container{
195
+ display:flex;
196
+ justify-content: right;
197
+ align-items: center;
198
+ font-size: 24px;
199
+ margin-right: 1rem;
200
+ }
201
+ #login-modal{
202
+ max-width: 728px;
203
+ margin: 0 auto;
204
+ margin-top: 1rem;
205
+ margin-bottom: 1rem;
206
+ }
207
+ #login-logout-btn{
208
+ display:inline;
209
+ max-width: 124px;
210
+ }
211
+ """
212
+
213
+ with gr.Blocks(css=css, theme=gr.themes.Default()) as demo:
214
+ print("is logged in:", is_logged_in())
215
+ logged_in = gr.State(is_logged_in())
216
+ if os.path.exists(".token"):
217
+ with open(".token", "r") as f:
218
+ user_token = gr.State(json.load(f)["token"])
219
+ else:
220
+ user_token = gr.State("")
221
+
222
+ with Modal(visible=False) as modal:
223
+ @gr.render(inputs=user_token)
224
+ def rerender1(user_token1):
225
+ with gr.Column(elem_id="login-modal"):
226
+ access_token = gr.Textbox(
227
+ label="Token",
228
+ lines=1,
229
+ value=user_token1,
230
+ type="password",
231
+ placeholder="Enter your access token here!",
232
+ info="Visit https://playground.tryonlabs.ai to retrieve your access token."
233
+ )
234
+
235
+ login_submit_btn = gr.Button("Login", scale=1, variant='primary')
236
+ login_submit_btn.click(
237
+ fn=lambda access_token: (login(access_token), Modal(visible=False), access_token),
238
+ inputs=[access_token], outputs=[logged_in, modal, user_token],
239
+ concurrency_limit=1)
240
+
241
+ with gr.Row(elem_id="col-container"):
242
+ with gr.Column():
243
+ gr.Markdown(f"""
244
+ # Model Swap AI
245
+ ## by TryOn Labs (https://www.tryonlabs.ai)
246
+ Swap a human model with a artificial model generated by Artificial Model while keeping the garment intact.
247
+ """)
248
+
249
+
250
+ @gr.render(inputs=logged_in)
251
+ def rerender(is_logged_in):
252
+ with gr.Column():
253
+ if not is_logged_in:
254
+ with gr.Row(elem_id="credits-col-container"):
255
+ login_btn = gr.Button(value="Login", variant='primary', elem_id="login-logout-btn", size="sm")
256
+ login_btn.click(lambda: Modal(visible=True), None, modal)
257
+ else:
258
+ user_credits = get_user_credits(load_token())
259
+ print("user_credits", user_credits)
260
+ gr.HTML(f"""<div><p id="credits-col-container">Your Credits:
261
+ {user_credits if user_credits is not None else "0"}</p>
262
+ <p style="text-align: right;">Visit <a href="https://playground.tryonlabs.ai">
263
+ TryOn AI Playground</a> to acquire more credits</p></div>""")
264
+ with gr.Row(elem_id="credits-col-container"):
265
+ logout_btn = gr.Button(value="Logout", scale=1, variant='primary', size="sm",
266
+ elem_id="login-logout-btn")
267
+ logout_btn.click(fn=logout, inputs=None, outputs=[logged_in, user_token], concurrency_limit=1)
268
+
269
+ with gr.Column(elem_id="col-container"):
270
+ with gr.Row():
271
+ with gr.Column():
272
+ input_image = gr.Image(label="Original image", type='pil', height="400px", show_label=True)
273
+ prompt = gr.Textbox(
274
+ label="Prompt",
275
+ lines=3,
276
+ placeholder="Enter your prompt here!",
277
+ )
278
+ dropdown = gr.Dropdown(["upper", "lower", "dress"], value="upper", label="Retain garment",
279
+ info="Select the garment type you want to retain in the generated image!")
280
+
281
+ gallery = gr.Gallery(
282
+ label="Generated images", show_label=True, elem_id="gallery"
283
+ , columns=[3], rows=[1], object_fit="contain", height="auto")
284
+
285
+ # output_image = gr.Image(label="Swapped model", type='pil', height="400px", show_label=True,
286
+ # show_download_button=True)
287
+
288
+ with gr.Accordion("Advanced Settings", open=False):
289
+ with gr.Row():
290
+ seed = gr.Number(label="Seed", value=-1, interactive=True, minimum=-1)
291
+ guidance_scale = gr.Number(label="Guidance Scale", value=7.5, interactive=True, minimum=0.0,
292
+ maximum=10.0,
293
+ step=0.1)
294
+ num_results = gr.Number(label="Number of results", value=2, minimum=1, maximum=5)
295
+
296
+ with gr.Row():
297
+ strength = gr.Slider(0.00, 1.00, value=0.99, label="Strength",
298
+ info="Choose between 0.00 and 1.00", step=0.01, interactive=True)
299
+ inference_steps = gr.Number(label="Inference Steps", value=20, interactive=True, minimum=1, step=1)
300
+
301
+ with gr.Row():
302
+ submit_button = gr.Button("Submit", variant='primary', scale=1)
303
+ reset_button = gr.ClearButton(value="Reset", scale=1)
304
+
305
+ gr.on(
306
+ triggers=[submit_button.click],
307
+ fn=start_model_swap,
308
+ inputs=[input_image, prompt, dropdown, seed, guidance_scale, num_results, strength, inference_steps],
309
+ outputs=[gallery]
310
+ )
311
+
312
+ reset_button.click(
313
+ fn=lambda: (None, None, "upper", None, -1, 7.5, 2, 0.99, 20),
314
+ inputs=[],
315
+ outputs=[input_image, prompt, dropdown, gallery, seed, guidance_scale,
316
+ num_results, strength, inference_steps],
317
+ concurrency_limit=1,
318
+ )
319
+
320
+ if __name__ == '__main__':
321
+ demo.launch()
demo/model_swap/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio==4.44.1
2
+ gradio_modal==0.0.3
demo/outfit_generator/README.md ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FLUX.1-dev LoRA Outfit Generator Gradio Demo
2
+ ## by TryOn Labs (https://www.tryonlabs.ai)
3
+ Generate an outfit by describing the color, pattern, fit, style, material, type, etc.
4
+
5
+ ## Model description
6
+
7
+ FLUX.1-dev LoRA Outfit Generator can create an outfit by detailing the color, pattern, fit, style, material, and type.
8
+
9
+ ## Inference
10
+
11
+ ```
12
+ import random
13
+
14
+ from diffusers import FluxPipeline
15
+ import torch
16
+
17
+ seed=42
18
+ prompt = "denim dark blue 5-pocket ankle-length jeans in washed stretch denim slightly looser fit with a wide waist panel for best fit over the tummy and tapered legs with raw-edge frayed hems"
19
+ PRE_TRAINED_MODEL = "black-forest-labs/FLUX.1-dev"
20
+ FINE_TUNED_MODEL = "tryonlabs/FLUX.1-dev-LoRA-Outfit-Generator"
21
+
22
+ # Load Flux
23
+ pipe = FluxPipeline.from_pretrained(PRE_TRAINED_MODEL, torch_dtype=torch.float16).to("cuda")
24
+
25
+ # Load fine-tuned model
26
+ pipe.load_lora_weights(FINE_TUNED_MODEL, adapter_name="default", weight_name="outfit-generator.safetensors")
27
+
28
+ seed = random.randint(0, MAX_SEED)
29
+
30
+ generator = torch.Generator().manual_seed(seed)
31
+
32
+ image = pipe(prompt, height=1024, width=1024, num_images_per_prompt=1, generator=generator,
33
+ guidance_scale=4.5, num_inference_steps=40).images[0]
34
+
35
+ image.save("gen_image.jpg")
36
+ ```
37
+
38
+ ## Dataset used
39
+
40
+ H&M Fashion Captions Dataset - 20.5k samples
41
+ https://huggingface.co/datasets/tomytjandra/h-and-m-fashion-caption
42
+
43
+ ## Repository used
44
+
45
+ AI Toolkit by Ostris
46
+ https://github.com/ostris/ai-toolkit
47
+
48
+ ## Download model
49
+
50
+ Weights for this model are available in Safetensors format.
51
+
52
+ [Download](https://huggingface.co/tryonlabs/FLUX.1-dev-LoRA-Outfit-Generator/tree/main) them in the Files & versions tab.
53
+
54
+ ## Install dependencies
55
+
56
+ ```
57
+ git clone https://github.com/tryonlabs/FLUX.1-dev-LoRA-Outfit-Generator.git
58
+ cd FLUX.1-dev-LoRA-Outfit-Generator
59
+ conda create -n demo python=3.12
60
+ pip install -r requirements.txt
61
+ conda install pytorch pytorch-cuda=12.4 -c pytorch -c nvidia
62
+ ```
63
+
64
+ ## Run demo
65
+
66
+ ```
67
+ gradio app.py
68
+ ```
69
+
70
+ ## Generated images
71
+
72
+ ![alt](images/sample1.jpeg "sample1")
73
+ #### A dress with Color: Black, Department: Dresses, Detail: High Low,Fabric-Elasticity: No Sretch, Fit: Fitted, Hemline: Slit, Material: Gabardine, Neckline: Collared, Pattern: Solid, Sleeve-Length: Sleeveless, Style: Casual, Type: Tunic, Waistline: Regular
74
+ ***
75
+ ![alt](images/sample2.jpeg "sample2")
76
+ #### A dress with Color: Red, Department: Dresses, Detail: Belted, Fabric-Elasticity: High Stretch, Fit: Fitted, Hemline: Flared, Material: Gabardine, Neckline: Off The Shoulder, Pattern: Floral, Sleeve-Length: Sleeveless, Style: Elegant, Type: Fit and Flare, Waistline: High
77
+ ***
78
+ ![alt](images/sample3.jpeg "sample3")
79
+ #### A dress with Color: Multicolored, Department: Dresses, Detail: Split, Fabric-Elasticity: High Stretch, Fit: Fitted, Hemline: Slit, Material: Gabardine, Neckline: V Neck, Pattern: Leopard, Sleeve-Length: Sleeveless, Style: Casual, Type: T Shirt, Waistline: Regular
80
+ ***
81
+ ![alt](images/sample4.jpeg "sample4")
82
+ #### A dress with Color: Brown, Department: Dresses, Detail: Zipper, Fabric-Elasticity: No Sretch, Fit: Fitted, Hemline: Asymmetrical, Material: Satin, Neckline: Spaghetti Straps, Pattern: Floral, Sleeve-Length: Sleeveless, Style: Boho, Type: Cami Top, Waistline: High
83
+ ***
84
+
85
+ ## License
86
+ MIT [License](LICENSE)
demo/outfit_generator/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .app import demo
demo/outfit_generator/app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os.path
3
+ import random
4
+ import time
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import spaces
9
+ import torch
10
+ from diffusers import FluxPipeline
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ PRE_TRAINED_MODEL = "black-forest-labs/FLUX.1-dev"
14
+ FINE_TUNED_MODEL = "tryonlabs/FLUX.1-dev-LoRA-Outfit-Generator"
15
+ RESULTS_DIR = "~/results"
16
+ os.makedirs(RESULTS_DIR, exist_ok=True)
17
+
18
+ if torch.cuda.is_available():
19
+ torch_dtype = torch.bfloat16
20
+ else:
21
+ torch_dtype = torch.float32
22
+
23
+ # Load Flux
24
+ pipe = FluxPipeline.from_pretrained(PRE_TRAINED_MODEL, torch_dtype=torch.float16).to("cuda")
25
+
26
+ # Load your fine-tuned model
27
+ pipe.load_lora_weights(FINE_TUNED_MODEL, adapter_name="default", weight_name="outfit-generator.safetensors")
28
+
29
+ MAX_SEED = np.iinfo(np.int32).max
30
+ MAX_IMAGE_SIZE = 1024
31
+
32
+
33
+ @spaces.GPU(duration=65)
34
+ def infer(
35
+ prompt,
36
+ seed=42,
37
+ randomize_seed=False,
38
+ width=1024,
39
+ height=1024,
40
+ guidance_scale=4.5,
41
+ num_inference_steps=40,
42
+ progress=gr.Progress(track_tqdm=True),
43
+ ):
44
+ if randomize_seed:
45
+ seed = random.randint(0, MAX_SEED)
46
+
47
+ generator = torch.Generator().manual_seed(seed)
48
+
49
+ image = pipe(prompt, height=width, width=height, num_images_per_prompt=1, generator=generator,
50
+ guidance_scale=guidance_scale,
51
+ num_inference_steps=num_inference_steps).images[0]
52
+
53
+ try:
54
+ # save image
55
+ current_time = int(time.time() * 1000)
56
+ image.save(os.path.join(RESULTS_DIR, f"gen_img_{current_time}.png"))
57
+ with open(os.path.join(RESULTS_DIR, f"gen_img_{current_time}.json"), "w") as f:
58
+ json.dump({"prompt": prompt, "height": height, "width": width, "guidance_scale": guidance_scale,
59
+ "num_inference_steps": num_inference_steps, "seed": seed}, f)
60
+ except Exception as e:
61
+ print(str(e))
62
+
63
+ return image, seed
64
+
65
+
66
+ examples = [
67
+ "stripe red striped jersey top in a soft cotton and modal blend with short sleeves a chest pocket and rounded hem",
68
+ "A dress with Color: Orange, Department: Dresses, Detail: Split Thigh, Fabric-Elasticity: No Sretch, Fit: Fitted, Hemline: Slit, Material: Gabardine, Neckline: Gathered, Pattern: Tropical, Sleeve-Length: Sleeveless, Style: Boho, Type: A Line Skirt, Waistline: High",
69
+ "treatment dark pink knee-length skirt in crocodile-patterned imitation leather high waist with belt loops and press-studs a zip fly diagonal side pockets and a slit at the front the polyester content of the skirt is partly recycled",
70
+ "A dress with Color: Maroon, Department: Dresses, Detail: Ruched Bust, Fabric-Elasticity: Slight Stretch, Fit: Fitted, Hemline: Slit, Material: Gabardine, Neckline: Spaghetti Straps, Pattern: Floral, Sleeve-Length: Sleeveless, Style: Boho, Type: Cami Top, Waistline: Regular",
71
+ "denim dark blue 5-pocket ankle-length jeans in washed stretch denim slightly looser fit with a wide waist panel for best fit over the tummy and tapered legs with raw-edge frayed hems"
72
+ ]
73
+
74
+ css = """
75
+ #col-container {
76
+ margin: 0 auto;
77
+ max-width: 768px;
78
+ }
79
+ """
80
+
81
+ with gr.Blocks(css=css) as demo:
82
+ with gr.Column(elem_id="col-container"):
83
+ gr.Markdown(f"""
84
+ # FLUX.1-dev LoRA Outfit Generator
85
+ ## by TryOn Labs (https://www.tryonlabs.ai)
86
+ Generate an outfit by describing the color, pattern, fit, style, material, type, etc.
87
+ """)
88
+ with gr.Row():
89
+ prompt = gr.Text(
90
+ label="Prompt",
91
+ show_label=False,
92
+ max_lines=1,
93
+ placeholder="Enter your prompt",
94
+ container=False,
95
+ )
96
+
97
+ run_button = gr.Button("Run", scale=0, variant="primary")
98
+
99
+ result = gr.Image(label="Result", show_label=False)
100
+
101
+ with gr.Accordion("Advanced Settings", open=False):
102
+ seed = gr.Slider(
103
+ label="Seed",
104
+ minimum=0,
105
+ maximum=MAX_SEED,
106
+ step=1,
107
+ value=0,
108
+ )
109
+
110
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
111
+
112
+ with gr.Row():
113
+ width = gr.Slider(
114
+ label="Width",
115
+ minimum=512,
116
+ maximum=MAX_IMAGE_SIZE,
117
+ step=32,
118
+ value=1024,
119
+ )
120
+
121
+ height = gr.Slider(
122
+ label="Height",
123
+ minimum=512,
124
+ maximum=MAX_IMAGE_SIZE,
125
+ step=32,
126
+ value=1024,
127
+ )
128
+
129
+ with gr.Row():
130
+ guidance_scale = gr.Slider(
131
+ label="Guidance scale",
132
+ minimum=0.0,
133
+ maximum=7.5,
134
+ step=0.1,
135
+ value=4.5,
136
+ )
137
+
138
+ num_inference_steps = gr.Slider(
139
+ label="Number of inference steps",
140
+ minimum=1,
141
+ maximum=50,
142
+ step=1,
143
+ value=40,
144
+ )
145
+
146
+ gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True,
147
+ cache_mode="lazy")
148
+ gr.on(
149
+ triggers=[run_button.click, prompt.submit],
150
+ fn=infer,
151
+ inputs=[
152
+ prompt,
153
+ seed,
154
+ randomize_seed,
155
+ width,
156
+ height,
157
+ guidance_scale,
158
+ num_inference_steps,
159
+ ],
160
+ outputs=[result, seed],
161
+ )
162
+
163
+ if __name__ == "__main__":
164
+ demo.launch(share=True)
demo/outfit_generator/images/sample1.jpeg ADDED
demo/outfit_generator/images/sample2.jpeg ADDED
demo/outfit_generator/images/sample3.jpeg ADDED
demo/outfit_generator/images/sample4.jpeg ADDED
demo/outfit_generator/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ spaces
2
+ gradio
3
+ diffusers
4
+ torch
5
+ numpy
6
+ transformers
7
+ accelerate
8
+ protobuf
9
+ sentencepiece
10
+ peft==0.13.2
environment.yml ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tryondiffusion
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - blas=1.0=mkl
6
+ - bottleneck=1.3.5=py310h4e76f89_0
7
+ - bzip2=1.0.8=h1de35cc_0
8
+ - ca-certificates=2023.08.22=hecd8cb5_0
9
+ - cffi=1.15.1=py310h6c40b1e_3
10
+ - gmp=6.2.1=he9d5cce_3
11
+ - gmpy2=2.1.2=py310hd5de756_0
12
+ - intel-openmp=2023.1.0=ha357a0b_43547
13
+ - jinja2=3.1.2=py310hecd8cb5_0
14
+ - libcxx=14.0.6=h9765a3e_0
15
+ - libffi=3.4.4=hecd8cb5_0
16
+ - libprotobuf=3.20.3=hfff2838_0
17
+ - libuv=1.44.2=h6c40b1e_0
18
+ - mkl=2023.1.0=h8e150cf_43559
19
+ - mkl-service=2.4.0=py310h6c40b1e_1
20
+ - mkl_fft=1.3.8=py310h6c40b1e_0
21
+ - mkl_random=1.2.4=py310ha357a0b_0
22
+ - mpc=1.1.0=h6ef4df4_1
23
+ - mpfr=4.0.2=h9066e36_1
24
+ - mpmath=1.3.0=py310hecd8cb5_0
25
+ - ncurses=6.4=hcec6c5f_0
26
+ - networkx=3.1=py310hecd8cb5_0
27
+ - ninja=1.10.2=hecd8cb5_5
28
+ - ninja-base=1.10.2=haf03e11_5
29
+ - numexpr=2.8.7=py310h827a554_0
30
+ - openssl=3.0.11=hca72f7f_2
31
+ - pandas=2.1.1=py310h3ea8b11_0
32
+ - pip=23.2.1=py310hecd8cb5_0
33
+ - pycparser=2.21=pyhd3eb1b0_0
34
+ - python=3.10.13=h5ee71fb_0
35
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
36
+ - python-tzdata=2023.3=pyhd3eb1b0_0
37
+ - pytz=2023.3.post1=py310hecd8cb5_0
38
+ - readline=8.2=hca72f7f_0
39
+ - six=1.16.0=pyhd3eb1b0_1
40
+ - sqlite=3.41.2=h6c40b1e_0
41
+ - tbb=2021.8.0=ha357a0b_0
42
+ - tk=8.6.12=h5d9f67b_0
43
+ - tzdata=2023c=h04d1e81_0
44
+ - wheel=0.38.4=py310hecd8cb5_0
45
+ - xz=5.4.2=h6c40b1e_0
46
+ - zlib=1.2.13=h4dc903c_0
47
+ - pip:
48
+ - absl-py==2.0.0
49
+ - aiofiles==23.2.1
50
+ - annotated-types==0.6.0
51
+ - anyio==4.3.0
52
+ - appnope==0.1.3
53
+ - asttokens==2.4.0
54
+ - astunparse==1.6.3
55
+ - backcall==0.2.0
56
+ - cachetools==5.3.1
57
+ - carvekit==4.1.1
58
+ - certifi==2023.7.22
59
+ - charset-normalizer==3.2.0
60
+ - click==8.1.7
61
+ - comm==0.1.4
62
+ - contourpy==1.1.1
63
+ - cycler==0.11.0
64
+ - debugpy==1.8.0
65
+ - decorator==5.1.1
66
+ - diffusers==0.29.2
67
+ - einops==0.7.0
68
+ - exceptiongroup==1.1.3
69
+ - executing==1.2.0
70
+ - fastapi==0.108.0
71
+ - ffmpy==0.3.3
72
+ - filelock==3.12.4
73
+ - flatbuffers==23.5.26
74
+ - fonttools==4.42.1
75
+ - fsspec==2024.3.1
76
+ - gast==0.5.4
77
+ - google-auth==2.23.3
78
+ - google-auth-oauthlib==1.0.0
79
+ - google-pasta==0.2.0
80
+ - gradio==4.39.0
81
+ - gradio-client==1.1.1
82
+ - gradio-modal==0.0.3
83
+ - grpcio==1.59.0
84
+ - h11==0.14.0
85
+ - h5py==3.10.0
86
+ - httpcore==1.0.5
87
+ - httpx==0.27.0
88
+ - huggingface-hub==0.23.4
89
+ - idna==3.4
90
+ - imageio==2.34.0
91
+ - importlib-metadata==8.0.0
92
+ - importlib-resources==6.4.0
93
+ - ipykernel==6.25.2
94
+ - ipython==8.15.0
95
+ - jedi==0.19.0
96
+ - jupyter-client==8.3.1
97
+ - jupyter-core==5.3.1
98
+ - keras==2.14.0
99
+ - kiwisolver==1.4.5
100
+ - lazy-loader==0.3
101
+ - libclang==16.0.6
102
+ - loguru==0.7.2
103
+ - markdown==3.5
104
+ - markdown-it-py==3.0.0
105
+ - markupsafe==2.1.3
106
+ - matplotlib==3.8.0
107
+ - matplotlib-inline==0.1.6
108
+ - mdurl==0.1.2
109
+ - ml-dtypes==0.2.0
110
+ - nest-asyncio==1.5.8
111
+ - numpy==1.26.4
112
+ - oauthlib==3.2.2
113
+ - opencv-python==4.8.1.78
114
+ - opt-einsum==3.3.0
115
+ - orjson==3.10.6
116
+ - packaging==23.1
117
+ - parso==0.8.3
118
+ - pexpect==4.8.0
119
+ - pickleshare==0.7.5
120
+ - pillow==10.1.0
121
+ - platformdirs==3.10.0
122
+ - prompt-toolkit==3.0.39
123
+ - protobuf==4.24.4
124
+ - psutil==5.9.5
125
+ - ptyprocess==0.7.0
126
+ - pure-eval==0.2.2
127
+ - pyasn1==0.5.0
128
+ - pyasn1-modules==0.3.0
129
+ - pydantic==2.5.3
130
+ - pydantic-core==2.14.6
131
+ - pydub==0.25.1
132
+ - pygments==2.16.1
133
+ - pyparsing==3.1.1
134
+ - python-dotenv==1.0.1
135
+ - python-multipart==0.0.9
136
+ - pyyaml==6.0.1
137
+ - pyzmq==25.1.1
138
+ - regex==2024.5.15
139
+ - requests==2.31.0
140
+ - requests-oauthlib==1.3.1
141
+ - rich==13.7.1
142
+ - rsa==4.9
143
+ - ruff==0.5.5
144
+ - safetensors==0.4.3
145
+ - scikit-image==0.22.0
146
+ - scipy==1.11.4
147
+ - semantic-version==2.10.0
148
+ - setuptools==69.0.3
149
+ - shellingham==1.5.4
150
+ - sniffio==1.3.1
151
+ - stack-data==0.6.2
152
+ - starlette==0.32.0.post1
153
+ - sympy==1.12
154
+ - tensorboard==2.14.1
155
+ - tensorboard-data-server==0.7.1
156
+ - tensorflow==2.14.0
157
+ - tensorflow-estimator==2.14.0
158
+ - tensorflow-io-gcs-filesystem==0.34.0
159
+ - termcolor==2.3.0
160
+ - tifffile==2024.2.12
161
+ - tokenizers==0.19.1
162
+ - tomlkit==0.12.0
163
+ - torch==2.1.2
164
+ - torchvision==0.16.2
165
+ - tornado==6.3.3
166
+ - tqdm==4.66.1
167
+ - traitlets==5.10.0
168
+ - transformers==4.42.4
169
+ - typer==0.12.3
170
+ - typing==3.7.4.3
171
+ - typing-extensions==4.8.0
172
+ - urllib3==2.0.5
173
+ - uvicorn==0.25.0
174
+ - wcwidth==0.2.6
175
+ - websockets==11.0.3
176
+ - werkzeug==3.0.0
177
+ - wrapt==1.14.1
178
+ - zipp==3.19.2
179
+ prefix: /Users/apple/miniconda3/envs/tryondiffusion
main.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+
3
+ load_dotenv()
4
+
5
+ import time
6
+ import os
7
+ import argparse
8
+
9
+ from tryon.preprocessing import segment_human, segment_garment, extract_garment
10
+
11
+ if __name__ == '__main__':
12
+ argp = argparse.ArgumentParser(description="Tryon preprocessing")
13
+ argp.add_argument('-d',
14
+ '--dataset',
15
+ type=str, default="data", help='Path of the dataset dir')
16
+ argp.add_argument('-a',
17
+ '--action',
18
+ type=str, default="", help='Ex. segment_garment, extract_garment, segment_human')
19
+ argp.add_argument('-c',
20
+ '--cls',
21
+ type=str, default="upper", help='Ex. upper, lower, all')
22
+ args = argp.parse_args()
23
+
24
+ if args.action == "segment_garment":
25
+ # 1. segment garment
26
+ print('Start time:', int(time.time()))
27
+ segment_garment(inputs_dir=os.path.join(args.dataset, "original_cloth"),
28
+ outputs_dir=os.path.join(args.dataset, "garment_segmented"), cls=args.cls)
29
+ print("End time:", int(time.time()))
30
+
31
+ elif args.action == "extract_garment":
32
+ # 2. extract garment
33
+ print('Start time:', int(time.time()))
34
+ extract_garment(inputs_dir=os.path.join(args.dataset, "original_cloth"),
35
+ outputs_dir=os.path.join(args.dataset, "cloth"), cls=args.cls, resize_to_width=400)
36
+ print("End time:", int(time.time()))
37
+
38
+ elif args.action == "segment_human":
39
+ # 2. segment human
40
+ print('Start time:', int(time.time()))
41
+ image_path = os.path.join(args.dataset, "original_human", "model.jpg")
42
+ output_dir = os.path.join(args.dataset, "human_segmented")
43
+ segment_human(image_path=image_path, output_dir=output_dir)
44
+ print("End time:", int(time.time()))
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ opencv-python
4
+ pillow
5
+ matplotlib
6
+ tqdm
7
+ torchvision
8
+ einops
9
+ python-dotenv
10
+ scikit-image
11
+ diffusers
12
+ transformers
13
+ gradio==4.44.1
14
+ gradio_modal==0.0.3
15
+ python-dotenv
run_demo.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ if __name__ == '__main__':
4
+ argp = argparse.ArgumentParser(description="Gradio demo")
5
+ argp.add_argument('-n',
6
+ '--name',
7
+ type=str, default="data", help='Name of the gradio demo to launch')
8
+ args = argp.parse_args()
9
+
10
+ if args.name == "extract_garment":
11
+ from demo import extract_garment_demo as demo
12
+ demo.launch()
13
+ elif args.name == "model_swap":
14
+ from demo import model_swap_demo as demo
15
+ demo.launch()
16
+ elif args.name == "outfit_generator":
17
+ from demo import outfit_generator_demo as demo
18
+ demo.launch()
run_ootd.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import subprocess
3
+ import os
4
+ import pathlib
5
+
6
+ parser = argparse.ArgumentParser(description='run ootd')
7
+ parser.add_argument('--gpu_id', '-g', type=int, default=0, required=False)
8
+ parser.add_argument('--model_path', type=str, default="", required=True)
9
+ parser.add_argument('--cloth_path', type=str, default="", required=True)
10
+ parser.add_argument('--output_path', type=str, default="", required=True)
11
+ parser.add_argument('--model_type', type=str, default="hd", required=False)
12
+ parser.add_argument('--category', '-c', type=int, default=0, required=False)
13
+ parser.add_argument('--scale', type=float, default=2.0, required=False)
14
+ parser.add_argument('--step', type=int, default=20, required=False)
15
+ parser.add_argument('--sample', type=int, default=4, required=False)
16
+ parser.add_argument('--seed', type=int, default=-1, required=False)
17
+ args = parser.parse_args()
18
+
19
+ print(args)
20
+
21
+ if __name__ == '__main__':
22
+ ootdiffusion_dir = "/home/ubuntu/ootdiffusion"
23
+
24
+ command = (f"{os.path.join(str(pathlib.Path.home()), 'miniconda3/envs/ootdiffusion/bin/python')} "
25
+ f"run.py --model_path {args.model_path} --cloth_path {args.cloth_path} "
26
+ f"--output_path {args.output_path} --model_type {args.model_type} --category {args.category} "
27
+ f"--image_scale {args.scale} --gpu_id {args.gpu_id} --n_samples {args.sample} --seed {args.seed} "
28
+ f"--n_steps {args.step}")
29
+
30
+ print("command:", command, command.split(" "))
31
+
32
+ p = subprocess.Popen(command.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE,
33
+ cwd=ootdiffusion_dir)
34
+ out, err = p.communicate()
35
+ print(out, err)
36
+
37
+
scripts/install_conda.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Setup Ubuntu
2
+ sudo apt update --yes
3
+ sudo apt upgrade --yes
4
+
5
+ # Get Miniconda and make it the main Python interpreter
6
+ wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
7
+ bash ~/miniconda.sh -b -p ~/miniconda
8
+ rm ~/miniconda.sh
9
+
10
+ export PATH=~/miniconda/bin:$PATH
scripts/install_sam2.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ENV_NAME="sam2"
2
+ sudo apt-get -y update && sudo apt-get install -y --no-install-recommends ffmpeg libavutil-dev libavcodec-dev libavformat-dev libswscale-dev pkg-config build-essential libffi-dev
3
+ git clone https://github.com/facebookresearch/sam2.git ~/$ENV_NAME
4
+ conda create -n $ENV_NAME python=3.10
5
+ conda install -y -n $ENV_NAME pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
6
+ ~/miniconda3/envs/$ENV_NAME/bin/pip install -y -e ~/$ENV_NAME
7
+ sh ~/$ENV_NAME/checkpoints/download_ckpts.sh
8
+ mv sam2.1_hiera_base_plus.pt ~/$ENV_NAME/checkpoints/
9
+ mv sam2.1_hiera_large.pt ~/$ENV_NAME/checkpoints/
10
+ mv sam2.1_hiera_small.pt ~/$ENV_NAME/checkpoints/
11
+ mv sam2.1_hiera_tiny.pt ~/$ENV_NAME/checkpoints/
setup.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from setuptools import setup, find_packages
4
+
5
+ this_directory = Path(__file__).parent
6
+ long_description = (this_directory / "README.md").read_text()
7
+
8
+ setup(
9
+ name="tryondiffusion",
10
+ version="0.1.0",
11
+ license='Creative Commons BY-NC 4.0',
12
+ packages=find_packages(),
13
+ long_description=long_description,
14
+ long_description_content_type='text/markdown',
15
+ url='https://github.com/kailashahirwar/tryondiffusion',
16
+ keywords='Unofficial implementation of TryOnDiffusion: A Tale Of Two UNets',
17
+ install_requires=[
18
+ "torch",
19
+ "numpy",
20
+ "opencv-python",
21
+ "pillow",
22
+ "matplotlib",
23
+ "tqdm",
24
+ "torchvision",
25
+ "einops",
26
+ "scipy",
27
+ "scikit-image",
28
+ "gradio==4.44.1",
29
+ "gradio_modal==0.0.3"
30
+ ]
31
+ )
tryon/README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Try-On Preprocessing
2
+
3
+ Before you start, make a .env file in your project's main folder. Put these environment variables inside it.
4
+ ```
5
+ U2NET_CLOTH_SEG_CHECKPOINT_PATH=cloth_segm.pth
6
+ ```
7
+
8
+ #### Remember to load environment variables before you start running scripts.
9
+
10
+ ```
11
+ from dotenv import load_dotenv
12
+
13
+ load_dotenv()
14
+ ```
15
+
16
+ ### segment garment
17
+
18
+ ```
19
+ from tryon.preprocessing import segment_garment
20
+
21
+ segment_garment(inputs_dir=<inputs_dir>,
22
+ outputs_dir=<outputs_dir>, cls=<cls>)
23
+ ```
24
+
25
+ possible values for cls: lower, upper, all
26
+
27
+ ### extract garment
28
+
29
+ ```
30
+ from tryon.preprocessing import extract_garment
31
+
32
+ extract_garment(inputs_dir=<inputs_dir>,
33
+ outputs_dir=<outputs_dir>, cls=<cls>)
34
+ ```
tryon/__init__.py ADDED
File without changes
tryon/models/__init__.py ADDED
File without changes
tryon/models/ootdiffusion/setup.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ENV_NAME="ootdiffusion"
2
+ PROJECT_DIR="/home/ubuntu/ootdiffusion"
3
+
4
+ if [ ! -d ~/miniconda3/envs/$ENV_NAME ]; then
5
+ echo "creating conda environment"
6
+ conda create -y -n $ENV_NAME python==3.10
7
+ fi
8
+
9
+ # clone repository
10
+ if [ ! -d $PROJECT_DIR ]; then
11
+ echo "cloning OOTDiffusion repository"
12
+ git clone https://github.com/tryonlabs/OOTDiffusion.git $PROJECT_DIR
13
+ fi
14
+
15
+ ~/miniconda3/envs/$ENV_NAME/bin/pip install -r $PROJECT_DIR/requirements.txt
16
+
17
+ if [ ! -d $PROJECT_DIR/checkpoints/ootd ]; then
18
+ echo "downloading checkpoints"
19
+
20
+ # download checkpoints
21
+ git clone https://huggingface.co/levihsu/OOTDiffusion ~/ootd-checkpoints
22
+ git clone https://huggingface.co/openai/clip-vit-large-patch14 ~/clip-vit-large-patch14
23
+
24
+ mv ~/ootd-checkpoints/checkpoints/* $PROJECT_DIR/checkpoints/
25
+ rm -rf ~/ootd-checkpoints
26
+
27
+ mv ~/clip-vit-large-patch14 $PROJECT_DIR/checkpoints/
28
+ rm -rf ~/clip-vit-large-patch14
29
+
30
+ fi
tryon/preprocessing/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .preprocess_garment import segment_garment, extract_garment
2
+ from .utils import convert_to_jpg
3
+ from .preprocess_human import segment_human
tryon/preprocessing/captioning/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .generate_caption import (caption_image, create_llava_next_pipeline,
2
+ create_phi35mini_pipeline, convert_outfit_json_to_caption)
tryon/preprocessing/captioning/generate_caption.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
5
+ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
6
+
7
+
8
+ def caption_image(image, question, model=None, processor=None, json_only=False):
9
+ """
10
+ Extract outfit details using an image-to-text model
11
+ :param image: input image
12
+ :param question: question
13
+ :param model: model pipeline
14
+ :param processor: processor
15
+ :param json_only: True or False - if json only
16
+ :return: json data
17
+ """
18
+ if model is None and processor is None:
19
+ model, processor = create_llava_next_pipeline()
20
+
21
+ conversation = [
22
+ {
23
+ "role": "user",
24
+ "content": [
25
+ {"type": "image"},
26
+ {"type": "text", "text": question},
27
+ ],
28
+ },
29
+ ]
30
+
31
+ prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
32
+ inputs = processor(image, prompt, return_tensors="pt").to("cuda:0")
33
+
34
+ output = model.generate(**inputs, max_new_tokens=300)
35
+ output = processor.decode(output[0], skip_special_tokens=True).split("[/INST]")[-1]
36
+ json_data = json.loads(output.replace("```json", "").replace("```", "").strip())
37
+
38
+ if not json_only:
39
+ generated_caption = convert_outfit_json_to_caption(json_data)
40
+ else:
41
+ generated_caption = None
42
+
43
+ return json_data, generated_caption
44
+
45
+
46
+ def create_phi35mini_pipeline():
47
+ """
48
+ Create Phi-3.5-mini-instruct pipeline
49
+ :return: model pipeline
50
+ """
51
+ torch.random.manual_seed(0)
52
+
53
+ model = AutoModelForCausalLM.from_pretrained(
54
+ "microsoft/Phi-3.5-mini-instruct",
55
+ device_map="cuda",
56
+ torch_dtype="auto",
57
+ trust_remote_code=True,
58
+ attn_implementation="flash_attention_2"
59
+ )
60
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
61
+
62
+ pipe = pipeline(
63
+ "text-generation",
64
+ model=model,
65
+ tokenizer=tokenizer,
66
+ )
67
+ return pipe
68
+
69
+
70
+ def create_llava_next_pipeline():
71
+ """
72
+ Create LlaVA-NeXT pipeline
73
+ :return: model pipeline
74
+ """
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
+ processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
77
+ model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf",
78
+ torch_dtype=torch.float16,
79
+ low_cpu_mem_usage=True)
80
+ model.to(device)
81
+
82
+ return model, processor
83
+
84
+
85
+ def convert_outfit_json_to_caption(json_data, pipe=None):
86
+ """
87
+ Convert JSON data of an outfit into a natural language caption
88
+ :param json_data: json data
89
+ :param pipe: model pipeline
90
+ :return: generated caption
91
+ """
92
+ if pipe is None:
93
+ pipe = create_phi35mini_pipeline()
94
+
95
+ generation_args = {
96
+ "max_new_tokens": 300,
97
+ "return_full_text": False,
98
+ "temperature": 0.0,
99
+ "do_sample": False,
100
+ }
101
+
102
+ messages = [{"role": "user",
103
+ "content": f'Convert the {json.dumps(json_data)} JSON data into a natural '
104
+ f'language paragraph beginning with "An outfit with"'}]
105
+
106
+ output = pipe(messages, **generation_args)[0]['generated_text'].strip()
107
+ print(f"Output: {output}")
108
+ return output
tryon/preprocessing/extract_garment_new.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+
9
+ from .u2net import load_cloth_segm_model
10
+ from .utils import NormalizeImage, naive_cutout, resize_by_bigger_index, image_resize
11
+
12
+
13
+ def extract_garment(image, cls="all", resize_to_width=None, net=None, device=None):
14
+ """
15
+ extracts garments from the given image
16
+ :param image: input image
17
+ :param cls: garment classes to extract
18
+ :param resize_to_width: if required
19
+ :return: extracted garments
20
+ """
21
+
22
+ if net is None:
23
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
+ net = load_cloth_segm_model(device, os.environ.get("U2NET_CLOTH_SEGM_CHECKPOINT_PATH"), in_ch=3, out_ch=4)
25
+
26
+ transform_fn = transforms.Compose(
27
+ [transforms.ToTensor(),
28
+ NormalizeImage(0.5, 0.5)]
29
+ )
30
+
31
+ img_size = image.size
32
+ img = image.resize((768, 768), Image.BICUBIC)
33
+ image_tensor = transform_fn(img)
34
+ image_tensor = torch.unsqueeze(image_tensor, 0)
35
+
36
+ with torch.no_grad():
37
+ output_tensor = net(image_tensor.to(device))
38
+ output_tensor = F.log_softmax(output_tensor[0], dim=1)
39
+ output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
40
+ output_tensor = torch.squeeze(output_tensor, dim=0)
41
+ output_arr = output_tensor.cpu().numpy()
42
+
43
+ classes = {1: "upper", 2: "lower", 3: "dress"}
44
+
45
+ if cls == "all":
46
+ classes_to_save = []
47
+
48
+ # Check which classes are present in the image
49
+ for cls in range(1, 4): # Exclude background class (0)
50
+ if np.any(output_arr == cls):
51
+ classes_to_save.append(cls)
52
+ elif cls == "upper":
53
+ classes_to_save = [1]
54
+ elif cls == "lower":
55
+ classes_to_save = [2]
56
+ elif cls == "dress":
57
+ classes_to_save = [3]
58
+ else:
59
+ raise ValueError(f"Unknown cls: {cls}")
60
+
61
+ garments = dict()
62
+
63
+ for cls1 in classes_to_save:
64
+ alpha_mask = (output_arr == cls1).astype(np.uint8) * 255
65
+ alpha_mask = alpha_mask[0] # Selecting the first channel to make it 2D
66
+ alpha_mask_img = Image.fromarray(alpha_mask, mode='L')
67
+ alpha_mask_img = alpha_mask_img.resize(img_size, Image.BICUBIC)
68
+
69
+ cutout = np.array(naive_cutout(image, alpha_mask_img))
70
+ cutout = resize_by_bigger_index(cutout)
71
+
72
+ canvas = np.ones((1024, 768, 3), np.uint8) * 255
73
+ y1, y2 = (canvas.shape[0] - cutout.shape[0]) // 2, (canvas.shape[0] + cutout.shape[0]) // 2
74
+ x1, x2 = (canvas.shape[1] - cutout.shape[1]) // 2, (canvas.shape[1] + cutout.shape[1]) // 2
75
+
76
+ alpha_s = cutout[:, :, 3] / 255.0
77
+ alpha_l = 1.0 - alpha_s
78
+
79
+ for c in range(0, 3):
80
+ canvas[y1:y2, x1:x2, c] = (alpha_s * cutout[:, :, c] +
81
+ alpha_l * canvas[y1:y2, x1:x2, c])
82
+
83
+ # resize image before saving
84
+ if resize_to_width:
85
+ canvas = image_resize(canvas, width=resize_to_width)
86
+
87
+ canvas = Image.fromarray(canvas)
88
+
89
+ garments[classes[cls1]] = canvas
90
+
91
+ return garments
tryon/preprocessing/preprocess_garment.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from tqdm import tqdm
11
+
12
+ from .u2net import load_cloth_segm_model
13
+ from .utils import NormalizeImage, naive_cutout, resize_by_bigger_index, image_resize
14
+
15
+
16
+ def segment_garment(inputs_dir, outputs_dir, cls="all"):
17
+ os.makedirs(outputs_dir, exist_ok=True)
18
+
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+
21
+ transform_fn = transforms.Compose(
22
+ [transforms.ToTensor(),
23
+ NormalizeImage(0.5, 0.5)]
24
+ )
25
+
26
+ net = load_cloth_segm_model(device, os.environ.get("U2NET_CLOTH_SEGM_CHECKPOINT_PATH"), in_ch=3, out_ch=4)
27
+
28
+ images_list = sorted(os.listdir(inputs_dir))
29
+ pbar = tqdm(total=len(images_list))
30
+
31
+ for image_name in images_list:
32
+ img = Image.open(os.path.join(inputs_dir, image_name)).convert('RGB')
33
+ img_size = img.size
34
+ img = img.resize((768, 768), Image.BICUBIC)
35
+ image_tensor = transform_fn(img)
36
+ image_tensor = torch.unsqueeze(image_tensor, 0)
37
+
38
+ with torch.no_grad():
39
+ output_tensor = net(image_tensor.to(device))
40
+ output_tensor = F.log_softmax(output_tensor[0], dim=1)
41
+ output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
42
+ output_tensor = torch.squeeze(output_tensor, dim=0)
43
+ output_arr = output_tensor.cpu().numpy()
44
+
45
+ if cls == "all":
46
+ classes_to_save = []
47
+
48
+ # Check which classes are present in the image
49
+ for cls in range(1, 4): # Exclude background class (0)
50
+ if np.any(output_arr == cls):
51
+ classes_to_save.append(cls)
52
+ elif cls == "upper":
53
+ classes_to_save = [1]
54
+ elif cls == "lower":
55
+ classes_to_save = [2]
56
+ elif cls == "dress":
57
+ classes_to_save = [3]
58
+ else:
59
+ raise ValueError(f"Unknown cls: {cls}")
60
+
61
+ for cls1 in classes_to_save:
62
+ alpha_mask = (output_arr == cls1).astype(np.uint8) * 255
63
+ alpha_mask = alpha_mask[0] # Selecting the first channel to make it 2D
64
+ alpha_mask_img = Image.fromarray(alpha_mask, mode='L')
65
+ alpha_mask_img = alpha_mask_img.resize(img_size, Image.BICUBIC)
66
+ alpha_mask_img.save(os.path.join(outputs_dir, f'{image_name.split(".")[0]}_{cls1}.jpg'))
67
+
68
+ pbar.update(1)
69
+
70
+ pbar.close()
71
+
72
+
73
+ def extract_garment(inputs_dir, outputs_dir, cls="all", resize_to_width=None):
74
+ os.makedirs(outputs_dir, exist_ok=True)
75
+ cloth_mask_dir = os.path.join(Path(outputs_dir).parent.absolute(), "cloth-mask")
76
+ os.makedirs(cloth_mask_dir, exist_ok=True)
77
+
78
+ segment_garment(inputs_dir, os.path.join(Path(outputs_dir).parent.absolute(), "cloth-mask"), cls=cls)
79
+
80
+ images_path = sorted(glob.glob(os.path.join(inputs_dir, "*")))
81
+ masks_path = sorted(glob.glob(os.path.join(cloth_mask_dir, "*")))
82
+ img = Image.open(images_path[0])
83
+
84
+ for mask_path in masks_path:
85
+ mask = Image.open(mask_path)
86
+
87
+ cutout = np.array(naive_cutout(img, mask))
88
+ cutout = resize_by_bigger_index(cutout)
89
+
90
+ canvas = np.ones((1024, 768, 3), np.uint8) * 255
91
+ y1, y2 = (canvas.shape[0] - cutout.shape[0]) // 2, (canvas.shape[0] + cutout.shape[0]) // 2
92
+ x1, x2 = (canvas.shape[1] - cutout.shape[1]) // 2, (canvas.shape[1] + cutout.shape[1]) // 2
93
+
94
+ alpha_s = cutout[:, :, 3] / 255.0
95
+ alpha_l = 1.0 - alpha_s
96
+
97
+ for c in range(0, 3):
98
+ canvas[y1:y2, x1:x2, c] = (alpha_s * cutout[:, :, c] +
99
+ alpha_l * canvas[y1:y2, x1:x2, c])
100
+
101
+ # resize image before saving
102
+ if resize_to_width:
103
+ canvas = image_resize(canvas, width=resize_to_width)
104
+
105
+ canvas = Image.fromarray(canvas)
106
+
107
+ canvas.save(os.path.join(outputs_dir, f"{os.path.basename(mask_path).split('.')[0]}.jpg"), format='JPEG')
tryon/preprocessing/preprocess_human.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from skimage import io
8
+ from torch.autograd import Variable
9
+ from torch.utils.data import DataLoader
10
+ from torchvision import transforms
11
+
12
+ from .u2net import RescaleT, ToTensorLab, SalObjDataset, normPRED, load_human_segm_model
13
+
14
+
15
+ def pred_to_image(predictions, image_path):
16
+ im = Image.fromarray(predictions.squeeze().cpu().data.numpy() * 255).convert('RGB')
17
+ image = io.imread(image_path)
18
+ imo = im.resize((image.shape[1], image.shape[0]), resample=Image.BILINEAR)
19
+ return imo
20
+
21
+
22
+ def segment_human(image_path, output_dir):
23
+ """
24
+ Segment human using U-2-Net
25
+ :param image_path: image path
26
+ :param output_dir: output directory
27
+ """
28
+ model_name = "u2net"
29
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
+ images = [image_path]
31
+
32
+ # 1. dataloader
33
+ test_salobj_dataset = SalObjDataset(img_name_list=images,
34
+ lbl_name_list=[],
35
+ transform=transforms.Compose([RescaleT(320),
36
+ ToTensorLab(flag=0)])
37
+ )
38
+ test_salobj_dataloader = DataLoader(test_salobj_dataset,
39
+ batch_size=1,
40
+ shuffle=False,
41
+ num_workers=1)
42
+
43
+ net = load_human_segm_model(device, model_name)
44
+
45
+ # 2. inference
46
+ for i_test, data_test in enumerate(test_salobj_dataloader):
47
+ print("inferencing:", images[i_test].split(os.sep)[-1])
48
+
49
+ inputs_test = data_test['image']
50
+ inputs_test = inputs_test.type(torch.FloatTensor)
51
+
52
+ if torch.cuda.is_available():
53
+ inputs_test = Variable(inputs_test.cuda())
54
+ else:
55
+ inputs_test = Variable(inputs_test)
56
+
57
+ d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
58
+
59
+ # normalization
60
+ pred = d1[:, 0, :, :]
61
+ pred = normPRED(pred)
62
+
63
+ mask = pred_to_image(pred, images[i_test])
64
+ mask_cv2 = cv2.cvtColor(np.array(mask), cv2.COLOR_RGB2BGR)
65
+
66
+ subimage = cv2.subtract(mask_cv2, cv2.imread(images[i_test]))
67
+ original = Image.open(images[i_test])
68
+ subimage = Image.fromarray(cv2.cvtColor(subimage, cv2.COLOR_BGR2RGB))
69
+
70
+ subimage = subimage.convert("RGBA")
71
+ original = original.convert("RGBA")
72
+
73
+ subdata = subimage.getdata()
74
+ ogdata = original.getdata()
75
+
76
+ newdata = []
77
+ for i in range(subdata.size[0] * subdata.size[1]):
78
+ if subdata[i][0] == 0 and subdata[i][1] == 0 and subdata[i][2] == 0:
79
+ newdata.append((231, 231, 231, 231))
80
+ else:
81
+ newdata.append(ogdata[i])
82
+ subimage.putdata(newdata)
83
+
84
+ subimage.save(os.path.join(output_dir, f"{images[i_test].split(os.sep)[-1].split('.')[0]}.png"))
85
+
86
+ del d1, d2, d3, d4, d5, d6, d7
tryon/preprocessing/sam2/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ from pathlib import Path
4
+ import numpy as np
5
+ import torch
6
+ from sam2.build_sam import build_sam2
7
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
8
+
9
+ SAM2_DIR = os.path.join(str(Path.home()), 'sam2')
10
+
11
+ checkpoint = os.path.join(SAM2_DIR, "checkpoints/sam2.1_hiera_large.pt")
12
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
13
+ predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
14
+
15
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
16
+ predictor.set_image(Image.open("img000.webp"))
17
+ input_point = np.array([[500, 375]])
18
+ input_label = np.array([1])
19
+ masks, _, _ = predictor.predict(
20
+ point_coords=input_point,
21
+ point_labels=input_label,
22
+ multimask_output=True)
23
+ print(masks)
tryon/preprocessing/u2net/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .load_u2net import load_cloth_segm_model, load_human_segm_model
2
+ from .data_loader import SalObjDataset, RescaleT, ToTensorLab, ToTensor
3
+ from .utils import normPRED
tryon/preprocessing/u2net/data_loader.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ from skimage import io, transform, color
8
+ from torch.utils.data import Dataset
9
+
10
+
11
+ class RescaleT(object):
12
+
13
+ def __init__(self, output_size):
14
+ assert isinstance(output_size, (int, tuple))
15
+ self.output_size = output_size
16
+
17
+ def __call__(self, sample):
18
+ imidx, image, label = sample['imidx'], sample['image'], sample['label']
19
+
20
+ h, w = image.shape[:2]
21
+
22
+ if isinstance(self.output_size, int):
23
+ if h > w:
24
+ new_h, new_w = self.output_size * h / w, self.output_size
25
+ else:
26
+ new_h, new_w = self.output_size, self.output_size * w / h
27
+ else:
28
+ new_h, new_w = self.output_size
29
+
30
+ new_h, new_w = int(new_h), int(new_w)
31
+
32
+ # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
33
+ # img = transform.resize(image,(new_h,new_w),mode='constant')
34
+ # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
35
+
36
+ img = transform.resize(image, (self.output_size, self.output_size), mode='constant')
37
+ lbl = transform.resize(label, (self.output_size, self.output_size), mode='constant', order=0,
38
+ preserve_range=True)
39
+
40
+ return {'imidx': imidx, 'image': img, 'label': lbl}
41
+
42
+
43
+ class Rescale(object):
44
+
45
+ def __init__(self, output_size):
46
+ assert isinstance(output_size, (int, tuple))
47
+ self.output_size = output_size
48
+
49
+ def __call__(self, sample):
50
+ imidx, image, label = sample['imidx'], sample['image'], sample['label']
51
+
52
+ if random.random() >= 0.5:
53
+ image = image[::-1]
54
+ label = label[::-1]
55
+
56
+ h, w = image.shape[:2]
57
+
58
+ if isinstance(self.output_size, int):
59
+ if h > w:
60
+ new_h, new_w = self.output_size * h / w, self.output_size
61
+ else:
62
+ new_h, new_w = self.output_size, self.output_size * w / h
63
+ else:
64
+ new_h, new_w = self.output_size
65
+
66
+ new_h, new_w = int(new_h), int(new_w)
67
+
68
+ # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
69
+ img = transform.resize(image, (new_h, new_w), mode='constant')
70
+ lbl = transform.resize(label, (new_h, new_w), mode='constant', order=0, preserve_range=True)
71
+
72
+ return {'imidx': imidx, 'image': img, 'label': lbl}
73
+
74
+
75
+ class RandomCrop(object):
76
+
77
+ def __init__(self, output_size):
78
+ assert isinstance(output_size, (int, tuple))
79
+ if isinstance(output_size, int):
80
+ self.output_size = (output_size, output_size)
81
+ else:
82
+ assert len(output_size) == 2
83
+ self.output_size = output_size
84
+
85
+ def __call__(self, sample):
86
+ imidx, image, label = sample['imidx'], sample['image'], sample['label']
87
+
88
+ if random.random() >= 0.5:
89
+ image = image[::-1]
90
+ label = label[::-1]
91
+
92
+ h, w = image.shape[:2]
93
+ new_h, new_w = self.output_size
94
+
95
+ top = np.random.randint(0, h - new_h)
96
+ left = np.random.randint(0, w - new_w)
97
+
98
+ image = image[top: top + new_h, left: left + new_w]
99
+ label = label[top: top + new_h, left: left + new_w]
100
+
101
+ return {'imidx': imidx, 'image': image, 'label': label}
102
+
103
+
104
+ class ToTensor(object):
105
+ """Convert ndarrays in sample to Tensors."""
106
+
107
+ def __call__(self, sample):
108
+
109
+ imidx, image, label = sample['imidx'], sample['image'], sample['label']
110
+
111
+ tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
112
+ tmpLbl = np.zeros(label.shape)
113
+
114
+ image = image / np.max(image)
115
+ if (np.max(label) < 1e-6):
116
+ label = label
117
+ else:
118
+ label = label / np.max(label)
119
+
120
+ if image.shape[2] == 1:
121
+ tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
122
+ tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
123
+ tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
124
+ else:
125
+ tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
126
+ tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
127
+ tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
128
+
129
+ tmpLbl[:, :, 0] = label[:, :, 0]
130
+
131
+ tmpImg = tmpImg.transpose((2, 0, 1))
132
+ tmpLbl = label.transpose((2, 0, 1))
133
+
134
+ return {'imidx': torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
135
+
136
+
137
+ class ToTensorLab(object):
138
+ """Convert ndarrays in sample to Tensors."""
139
+
140
+ def __init__(self, flag=0):
141
+ self.flag = flag
142
+
143
+ def __call__(self, sample):
144
+
145
+ imidx, image, label = sample['imidx'], sample['image'], sample['label']
146
+
147
+ tmpLbl = np.zeros(label.shape)
148
+
149
+ if (np.max(label) < 1e-6):
150
+ label = label
151
+ else:
152
+ label = label / np.max(label)
153
+
154
+ # change the color space
155
+ if self.flag == 2: # with rgb and Lab colors
156
+ tmpImg = np.zeros((image.shape[0], image.shape[1], 6))
157
+ tmpImgt = np.zeros((image.shape[0], image.shape[1], 3))
158
+ if image.shape[2] == 1:
159
+ tmpImgt[:, :, 0] = image[:, :, 0]
160
+ tmpImgt[:, :, 1] = image[:, :, 0]
161
+ tmpImgt[:, :, 2] = image[:, :, 0]
162
+ else:
163
+ tmpImgt = image
164
+ tmpImgtl = color.rgb2lab(tmpImgt)
165
+
166
+ # nomalize image to range [0,1]
167
+ tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (
168
+ np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0]))
169
+ tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (
170
+ np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1]))
171
+ tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (
172
+ np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2]))
173
+ tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (
174
+ np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0]))
175
+ tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (
176
+ np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1]))
177
+ tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (
178
+ np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2]))
179
+
180
+ # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
181
+
182
+ tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(tmpImg[:, :, 0])
183
+ tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1])
184
+ tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2])
185
+ tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(tmpImg[:, :, 3])
186
+ tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(tmpImg[:, :, 4])
187
+ tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(tmpImg[:, :, 5])
188
+
189
+ elif self.flag == 1: # with Lab color
190
+ tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
191
+
192
+ if image.shape[2] == 1:
193
+ tmpImg[:, :, 0] = image[:, :, 0]
194
+ tmpImg[:, :, 1] = image[:, :, 0]
195
+ tmpImg[:, :, 2] = image[:, :, 0]
196
+ else:
197
+ tmpImg = image
198
+
199
+ tmpImg = color.rgb2lab(tmpImg)
200
+
201
+ # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
202
+
203
+ tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (
204
+ np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0]))
205
+ tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (
206
+ np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1]))
207
+ tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (
208
+ np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2]))
209
+
210
+ tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(tmpImg[:, :, 0])
211
+ tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1])
212
+ tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2])
213
+
214
+ else: # with rgb color
215
+ tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
216
+ image = image / np.max(image)
217
+ if image.shape[2] == 1:
218
+ tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
219
+ tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
220
+ tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
221
+ else:
222
+ tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
223
+ tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
224
+ tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
225
+
226
+ tmpLbl[:, :, 0] = label[:, :, 0]
227
+
228
+ tmpImg = tmpImg.transpose((2, 0, 1))
229
+ tmpLbl = label.transpose((2, 0, 1))
230
+
231
+ return {'imidx': torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
232
+
233
+
234
+ class SalObjDataset(Dataset):
235
+ def __init__(self, img_name_list, lbl_name_list, transform=None):
236
+ # self.root_dir = root_dir
237
+ # self.image_name_list = glob.glob(image_dir+'*.png')
238
+ # self.label_name_list = glob.glob(label_dir+'*.png')
239
+ self.image_name_list = img_name_list
240
+ self.label_name_list = lbl_name_list
241
+ self.transform = transform
242
+
243
+ def __len__(self):
244
+ return len(self.image_name_list)
245
+
246
+ def __getitem__(self, idx):
247
+
248
+ # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
249
+ # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
250
+
251
+ image = io.imread(self.image_name_list[idx])
252
+ imname = self.image_name_list[idx]
253
+ imidx = np.array([idx])
254
+
255
+ if (0 == len(self.label_name_list)):
256
+ label_3 = np.zeros(image.shape)
257
+ else:
258
+ label_3 = io.imread(self.label_name_list[idx])
259
+
260
+ label = np.zeros(label_3.shape[0:2])
261
+ if (3 == len(label_3.shape)):
262
+ label = label_3[:, :, 0]
263
+ elif (2 == len(label_3.shape)):
264
+ label = label_3
265
+
266
+ if (3 == len(image.shape) and 2 == len(label.shape)):
267
+ label = label[:, :, np.newaxis]
268
+ elif (2 == len(image.shape) and 2 == len(label.shape)):
269
+ image = image[:, :, np.newaxis]
270
+ label = label[:, :, np.newaxis]
271
+
272
+ sample = {'imidx': imidx, 'image': image, 'label': label}
273
+
274
+ if self.transform:
275
+ sample = self.transform(sample)
276
+
277
+ return sample
tryon/preprocessing/u2net/load_u2net.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+
4
+ import torch
5
+
6
+ from tryon.preprocessing.u2net import u2net_cloth_segm, u2net_human_segm
7
+
8
+
9
+ def load_cloth_segm_model(device, checkpoint_path, in_ch=3, out_ch=1):
10
+ if not os.path.exists(checkpoint_path):
11
+ print("Invalid path")
12
+ return
13
+
14
+ model = u2net_cloth_segm.U2NET(in_ch=in_ch, out_ch=out_ch)
15
+
16
+ model_state_dict = torch.load(checkpoint_path, map_location=device)
17
+ new_state_dict = OrderedDict()
18
+ for k, v in model_state_dict.items():
19
+ name = k[7:] # remove `module.`
20
+ new_state_dict[name] = v
21
+
22
+ model.load_state_dict(new_state_dict)
23
+ model = model.to(device=device)
24
+
25
+ print("Checkpoints loaded from path: {}".format(checkpoint_path))
26
+
27
+ return model
28
+
29
+
30
+ def load_human_segm_model(device, model_name):
31
+ if model_name == 'u2net':
32
+ print("loading U2NET(173.6 MB)...")
33
+ net = u2net_human_segm.U2NET(3, 1)
34
+ elif model_name == 'u2netp':
35
+ print("loading U2NEP(4.7 MB)...")
36
+ net = u2net_human_segm.U2NETP(3, 1)
37
+ else:
38
+ net = None
39
+
40
+ if torch.cuda.is_available():
41
+ net.load_state_dict(torch.load(os.environ.get("U2NET_SEGM_CHECKPOINT_PATH")))
42
+ net.cuda()
43
+ else:
44
+ net.load_state_dict(torch.load(os.environ.get("U2NET_SEGM_CHECKPOINT_PATH"), map_location=device))
45
+ net.eval()
46
+
47
+ return net
tryon/preprocessing/u2net/u2net_cloth_segm.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class REBNCONV(nn.Module):
7
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
8
+ super(REBNCONV, self).__init__()
9
+
10
+ self.conv_s1 = nn.Conv2d(
11
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
12
+ )
13
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
14
+ self.relu_s1 = nn.ReLU(inplace=True)
15
+
16
+ def forward(self, x):
17
+ hx = x
18
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
19
+
20
+ return xout
21
+
22
+
23
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
24
+ def _upsample_like(src, tar):
25
+ src = F.upsample(src, size=tar.shape[2:], mode="bilinear")
26
+
27
+ return src
28
+
29
+
30
+ ### RSU-7 ###
31
+ class RSU7(nn.Module): # UNet07DRES(nn.Module):
32
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
33
+ super(RSU7, self).__init__()
34
+
35
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
36
+
37
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
38
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
39
+
40
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
41
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
42
+
43
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
44
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
45
+
46
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
47
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
48
+
49
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
50
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
51
+
52
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
53
+
54
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
55
+
56
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
57
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
58
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
59
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
60
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
61
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
62
+
63
+ def forward(self, x):
64
+ hx = x
65
+ hxin = self.rebnconvin(hx)
66
+
67
+ hx1 = self.rebnconv1(hxin)
68
+ hx = self.pool1(hx1)
69
+
70
+ hx2 = self.rebnconv2(hx)
71
+ hx = self.pool2(hx2)
72
+
73
+ hx3 = self.rebnconv3(hx)
74
+ hx = self.pool3(hx3)
75
+
76
+ hx4 = self.rebnconv4(hx)
77
+ hx = self.pool4(hx4)
78
+
79
+ hx5 = self.rebnconv5(hx)
80
+ hx = self.pool5(hx5)
81
+
82
+ hx6 = self.rebnconv6(hx)
83
+
84
+ hx7 = self.rebnconv7(hx6)
85
+
86
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
87
+ hx6dup = _upsample_like(hx6d, hx5)
88
+
89
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
90
+ hx5dup = _upsample_like(hx5d, hx4)
91
+
92
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
93
+ hx4dup = _upsample_like(hx4d, hx3)
94
+
95
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
96
+ hx3dup = _upsample_like(hx3d, hx2)
97
+
98
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
99
+ hx2dup = _upsample_like(hx2d, hx1)
100
+
101
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
102
+
103
+ """
104
+ del hx1, hx2, hx3, hx4, hx5, hx6, hx7
105
+ del hx6d, hx5d, hx3d, hx2d
106
+ del hx2dup, hx3dup, hx4dup, hx5dup, hx6dup
107
+ """
108
+
109
+ return hx1d + hxin
110
+
111
+
112
+ ### RSU-6 ###
113
+ class RSU6(nn.Module): # UNet06DRES(nn.Module):
114
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
115
+ super(RSU6, self).__init__()
116
+
117
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
118
+
119
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
120
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
121
+
122
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
123
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
124
+
125
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
126
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+
133
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
134
+
135
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
136
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
137
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
138
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
139
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
140
+
141
+ def forward(self, x):
142
+ hx = x
143
+
144
+ hxin = self.rebnconvin(hx)
145
+
146
+ hx1 = self.rebnconv1(hxin)
147
+ hx = self.pool1(hx1)
148
+
149
+ hx2 = self.rebnconv2(hx)
150
+ hx = self.pool2(hx2)
151
+
152
+ hx3 = self.rebnconv3(hx)
153
+ hx = self.pool3(hx3)
154
+
155
+ hx4 = self.rebnconv4(hx)
156
+ hx = self.pool4(hx4)
157
+
158
+ hx5 = self.rebnconv5(hx)
159
+
160
+ hx6 = self.rebnconv6(hx5)
161
+
162
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
163
+ hx5dup = _upsample_like(hx5d, hx4)
164
+
165
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
166
+ hx4dup = _upsample_like(hx4d, hx3)
167
+
168
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
169
+ hx3dup = _upsample_like(hx3d, hx2)
170
+
171
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
172
+ hx2dup = _upsample_like(hx2d, hx1)
173
+
174
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
175
+
176
+ """
177
+ del hx1, hx2, hx3, hx4, hx5, hx6
178
+ del hx5d, hx4d, hx3d, hx2d
179
+ del hx2dup, hx3dup, hx4dup, hx5dup
180
+ """
181
+
182
+ return hx1d + hxin
183
+
184
+
185
+ ### RSU-5 ###
186
+ class RSU5(nn.Module): # UNet05DRES(nn.Module):
187
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
188
+ super(RSU5, self).__init__()
189
+
190
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
191
+
192
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
193
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
194
+
195
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
196
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
197
+
198
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
199
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
200
+
201
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
202
+
203
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
204
+
205
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
206
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
207
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
208
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
209
+
210
+ def forward(self, x):
211
+ hx = x
212
+
213
+ hxin = self.rebnconvin(hx)
214
+
215
+ hx1 = self.rebnconv1(hxin)
216
+ hx = self.pool1(hx1)
217
+
218
+ hx2 = self.rebnconv2(hx)
219
+ hx = self.pool2(hx2)
220
+
221
+ hx3 = self.rebnconv3(hx)
222
+ hx = self.pool3(hx3)
223
+
224
+ hx4 = self.rebnconv4(hx)
225
+
226
+ hx5 = self.rebnconv5(hx4)
227
+
228
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
229
+ hx4dup = _upsample_like(hx4d, hx3)
230
+
231
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
232
+ hx3dup = _upsample_like(hx3d, hx2)
233
+
234
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
235
+ hx2dup = _upsample_like(hx2d, hx1)
236
+
237
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
238
+
239
+ """
240
+ del hx1, hx2, hx3, hx4, hx5
241
+ del hx4d, hx3d, hx2d
242
+ del hx2dup, hx3dup, hx4dup
243
+ """
244
+
245
+ return hx1d + hxin
246
+
247
+
248
+ ### RSU-4 ###
249
+ class RSU4(nn.Module): # UNet04DRES(nn.Module):
250
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
251
+ super(RSU4, self).__init__()
252
+
253
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
254
+
255
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
256
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
257
+
258
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
259
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
260
+
261
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
262
+
263
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
264
+
265
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
266
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
267
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
268
+
269
+ def forward(self, x):
270
+ hx = x
271
+
272
+ hxin = self.rebnconvin(hx)
273
+
274
+ hx1 = self.rebnconv1(hxin)
275
+ hx = self.pool1(hx1)
276
+
277
+ hx2 = self.rebnconv2(hx)
278
+ hx = self.pool2(hx2)
279
+
280
+ hx3 = self.rebnconv3(hx)
281
+
282
+ hx4 = self.rebnconv4(hx3)
283
+
284
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
285
+ hx3dup = _upsample_like(hx3d, hx2)
286
+
287
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
288
+ hx2dup = _upsample_like(hx2d, hx1)
289
+
290
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
291
+
292
+ """
293
+ del hx1, hx2, hx3, hx4
294
+ del hx3d, hx2d
295
+ del hx2dup, hx3dup
296
+ """
297
+
298
+ return hx1d + hxin
299
+
300
+
301
+ ### RSU-4F ###
302
+ class RSU4F(nn.Module): # UNet04FRES(nn.Module):
303
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
304
+ super(RSU4F, self).__init__()
305
+
306
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
307
+
308
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
309
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
310
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
311
+
312
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
313
+
314
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
315
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
316
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
317
+
318
+ def forward(self, x):
319
+ hx = x
320
+
321
+ hxin = self.rebnconvin(hx)
322
+
323
+ hx1 = self.rebnconv1(hxin)
324
+ hx2 = self.rebnconv2(hx1)
325
+ hx3 = self.rebnconv3(hx2)
326
+
327
+ hx4 = self.rebnconv4(hx3)
328
+
329
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
330
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
331
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
332
+
333
+ """
334
+ del hx1, hx2, hx3, hx4
335
+ del hx3d, hx2d
336
+ """
337
+
338
+ return hx1d + hxin
339
+
340
+
341
+ ##### U^2-Net ####
342
+ class U2NET(nn.Module):
343
+ def __init__(self, in_ch=3, out_ch=1):
344
+ super(U2NET, self).__init__()
345
+
346
+ self.stage1 = RSU7(in_ch, 32, 64)
347
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
348
+
349
+ self.stage2 = RSU6(64, 32, 128)
350
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
351
+
352
+ self.stage3 = RSU5(128, 64, 256)
353
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
354
+
355
+ self.stage4 = RSU4(256, 128, 512)
356
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
357
+
358
+ self.stage5 = RSU4F(512, 256, 512)
359
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
360
+
361
+ self.stage6 = RSU4F(512, 256, 512)
362
+
363
+ # decoder
364
+ self.stage5d = RSU4F(1024, 256, 512)
365
+ self.stage4d = RSU4(1024, 128, 256)
366
+ self.stage3d = RSU5(512, 64, 128)
367
+ self.stage2d = RSU6(256, 32, 64)
368
+ self.stage1d = RSU7(128, 16, 64)
369
+
370
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
371
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
372
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
373
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
374
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
375
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
376
+
377
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
378
+
379
+ def forward(self, x):
380
+ hx = x
381
+
382
+ # stage 1
383
+ hx1 = self.stage1(hx)
384
+ hx = self.pool12(hx1)
385
+
386
+ # stage 2
387
+ hx2 = self.stage2(hx)
388
+ hx = self.pool23(hx2)
389
+
390
+ # stage 3
391
+ hx3 = self.stage3(hx)
392
+ hx = self.pool34(hx3)
393
+
394
+ # stage 4
395
+ hx4 = self.stage4(hx)
396
+ hx = self.pool45(hx4)
397
+
398
+ # stage 5
399
+ hx5 = self.stage5(hx)
400
+ hx = self.pool56(hx5)
401
+
402
+ # stage 6
403
+ hx6 = self.stage6(hx)
404
+ hx6up = _upsample_like(hx6, hx5)
405
+
406
+ # -------------------- decoder --------------------
407
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
408
+ hx5dup = _upsample_like(hx5d, hx4)
409
+
410
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
411
+ hx4dup = _upsample_like(hx4d, hx3)
412
+
413
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
414
+ hx3dup = _upsample_like(hx3d, hx2)
415
+
416
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
417
+ hx2dup = _upsample_like(hx2d, hx1)
418
+
419
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
420
+
421
+ # side output
422
+ d1 = self.side1(hx1d)
423
+
424
+ d2 = self.side2(hx2d)
425
+ d2 = _upsample_like(d2, d1)
426
+
427
+ d3 = self.side3(hx3d)
428
+ d3 = _upsample_like(d3, d1)
429
+
430
+ d4 = self.side4(hx4d)
431
+ d4 = _upsample_like(d4, d1)
432
+
433
+ d5 = self.side5(hx5d)
434
+ d5 = _upsample_like(d5, d1)
435
+
436
+ d6 = self.side6(hx6)
437
+ d6 = _upsample_like(d6, d1)
438
+
439
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
440
+
441
+ """
442
+ del hx1, hx2, hx3, hx4, hx5, hx6
443
+ del hx5d, hx4d, hx3d, hx2d, hx1d
444
+ del hx6up, hx5dup, hx4dup, hx3dup, hx2dup
445
+ """
446
+
447
+ return d0, d1, d2, d3, d4, d5, d6
448
+
449
+
450
+ ### U^2-Net small ###
451
+ class U2NETP(nn.Module):
452
+ def __init__(self, in_ch=3, out_ch=1):
453
+ super(U2NETP, self).__init__()
454
+
455
+ self.stage1 = RSU7(in_ch, 16, 64)
456
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
457
+
458
+ self.stage2 = RSU6(64, 16, 64)
459
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
460
+
461
+ self.stage3 = RSU5(64, 16, 64)
462
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
463
+
464
+ self.stage4 = RSU4(64, 16, 64)
465
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
466
+
467
+ self.stage5 = RSU4F(64, 16, 64)
468
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
469
+
470
+ self.stage6 = RSU4F(64, 16, 64)
471
+
472
+ # decoder
473
+ self.stage5d = RSU4F(128, 16, 64)
474
+ self.stage4d = RSU4(128, 16, 64)
475
+ self.stage3d = RSU5(128, 16, 64)
476
+ self.stage2d = RSU6(128, 16, 64)
477
+ self.stage1d = RSU7(128, 16, 64)
478
+
479
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
480
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
481
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
482
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
483
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
484
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
485
+
486
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
487
+
488
+ def forward(self, x):
489
+ hx = x
490
+
491
+ # stage 1
492
+ hx1 = self.stage1(hx)
493
+ hx = self.pool12(hx1)
494
+
495
+ # stage 2
496
+ hx2 = self.stage2(hx)
497
+ hx = self.pool23(hx2)
498
+
499
+ # stage 3
500
+ hx3 = self.stage3(hx)
501
+ hx = self.pool34(hx3)
502
+
503
+ # stage 4
504
+ hx4 = self.stage4(hx)
505
+ hx = self.pool45(hx4)
506
+
507
+ # stage 5
508
+ hx5 = self.stage5(hx)
509
+ hx = self.pool56(hx5)
510
+
511
+ # stage 6
512
+ hx6 = self.stage6(hx)
513
+ hx6up = _upsample_like(hx6, hx5)
514
+
515
+ # decoder
516
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
517
+ hx5dup = _upsample_like(hx5d, hx4)
518
+
519
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
520
+ hx4dup = _upsample_like(hx4d, hx3)
521
+
522
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
523
+ hx3dup = _upsample_like(hx3d, hx2)
524
+
525
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
526
+ hx2dup = _upsample_like(hx2d, hx1)
527
+
528
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
529
+
530
+ # side output
531
+ d1 = self.side1(hx1d)
532
+
533
+ d2 = self.side2(hx2d)
534
+ d2 = _upsample_like(d2, d1)
535
+
536
+ d3 = self.side3(hx3d)
537
+ d3 = _upsample_like(d3, d1)
538
+
539
+ d4 = self.side4(hx4d)
540
+ d4 = _upsample_like(d4, d1)
541
+
542
+ d5 = self.side5(hx5d)
543
+ d5 = _upsample_like(d5, d1)
544
+
545
+ d6 = self.side6(hx6)
546
+ d6 = _upsample_like(d6, d1)
547
+
548
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
549
+
550
+ return d0, d1, d2, d3, d4, d5, d6
tryon/preprocessing/u2net/u2net_human_segm.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class REBNCONV(nn.Module):
7
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
8
+ super(REBNCONV, self).__init__()
9
+
10
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
11
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
12
+ self.relu_s1 = nn.ReLU(inplace=True)
13
+
14
+ def forward(self, x):
15
+ hx = x
16
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
17
+
18
+ return xout
19
+
20
+
21
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
22
+ def _upsample_like(src, tar):
23
+ src = F.upsample(src, size=tar.shape[2:], mode='bilinear')
24
+
25
+ return src
26
+
27
+
28
+ ### RSU-7 ###
29
+ class RSU7(nn.Module): # UNet07DRES(nn.Module):
30
+
31
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
32
+ super(RSU7, self).__init__()
33
+
34
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
35
+
36
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
37
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
38
+
39
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
40
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
41
+
42
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
43
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
44
+
45
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
46
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
47
+
48
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
49
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
50
+
51
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
52
+
53
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
54
+
55
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
56
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
57
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
58
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
59
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
60
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
61
+
62
+ def forward(self, x):
63
+ hx = x
64
+ hxin = self.rebnconvin(hx)
65
+
66
+ hx1 = self.rebnconv1(hxin)
67
+ hx = self.pool1(hx1)
68
+
69
+ hx2 = self.rebnconv2(hx)
70
+ hx = self.pool2(hx2)
71
+
72
+ hx3 = self.rebnconv3(hx)
73
+ hx = self.pool3(hx3)
74
+
75
+ hx4 = self.rebnconv4(hx)
76
+ hx = self.pool4(hx4)
77
+
78
+ hx5 = self.rebnconv5(hx)
79
+ hx = self.pool5(hx5)
80
+
81
+ hx6 = self.rebnconv6(hx)
82
+
83
+ hx7 = self.rebnconv7(hx6)
84
+
85
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
86
+ hx6dup = _upsample_like(hx6d, hx5)
87
+
88
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
89
+ hx5dup = _upsample_like(hx5d, hx4)
90
+
91
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
92
+ hx4dup = _upsample_like(hx4d, hx3)
93
+
94
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
95
+ hx3dup = _upsample_like(hx3d, hx2)
96
+
97
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
98
+ hx2dup = _upsample_like(hx2d, hx1)
99
+
100
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
101
+
102
+ return hx1d + hxin
103
+
104
+
105
+ ### RSU-6 ###
106
+ class RSU6(nn.Module): # UNet06DRES(nn.Module):
107
+
108
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
109
+ super(RSU6, self).__init__()
110
+
111
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
112
+
113
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
114
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
115
+
116
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
117
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
118
+
119
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
120
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
121
+
122
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
123
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
124
+
125
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
126
+
127
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
128
+
129
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
130
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
131
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
132
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
133
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
134
+
135
+ def forward(self, x):
136
+ hx = x
137
+
138
+ hxin = self.rebnconvin(hx)
139
+
140
+ hx1 = self.rebnconv1(hxin)
141
+ hx = self.pool1(hx1)
142
+
143
+ hx2 = self.rebnconv2(hx)
144
+ hx = self.pool2(hx2)
145
+
146
+ hx3 = self.rebnconv3(hx)
147
+ hx = self.pool3(hx3)
148
+
149
+ hx4 = self.rebnconv4(hx)
150
+ hx = self.pool4(hx4)
151
+
152
+ hx5 = self.rebnconv5(hx)
153
+
154
+ hx6 = self.rebnconv6(hx5)
155
+
156
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
157
+ hx5dup = _upsample_like(hx5d, hx4)
158
+
159
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
160
+ hx4dup = _upsample_like(hx4d, hx3)
161
+
162
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
163
+ hx3dup = _upsample_like(hx3d, hx2)
164
+
165
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
166
+ hx2dup = _upsample_like(hx2d, hx1)
167
+
168
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
169
+
170
+ return hx1d + hxin
171
+
172
+
173
+ ### RSU-5 ###
174
+ class RSU5(nn.Module): # UNet05DRES(nn.Module):
175
+
176
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
177
+ super(RSU5, self).__init__()
178
+
179
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
180
+
181
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
182
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
183
+
184
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
185
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
186
+
187
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
188
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
189
+
190
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
191
+
192
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
193
+
194
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
195
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
196
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
197
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
198
+
199
+ def forward(self, x):
200
+ hx = x
201
+
202
+ hxin = self.rebnconvin(hx)
203
+
204
+ hx1 = self.rebnconv1(hxin)
205
+ hx = self.pool1(hx1)
206
+
207
+ hx2 = self.rebnconv2(hx)
208
+ hx = self.pool2(hx2)
209
+
210
+ hx3 = self.rebnconv3(hx)
211
+ hx = self.pool3(hx3)
212
+
213
+ hx4 = self.rebnconv4(hx)
214
+
215
+ hx5 = self.rebnconv5(hx4)
216
+
217
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
218
+ hx4dup = _upsample_like(hx4d, hx3)
219
+
220
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
221
+ hx3dup = _upsample_like(hx3d, hx2)
222
+
223
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
224
+ hx2dup = _upsample_like(hx2d, hx1)
225
+
226
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
227
+
228
+ return hx1d + hxin
229
+
230
+
231
+ ### RSU-4 ###
232
+ class RSU4(nn.Module): # UNet04DRES(nn.Module):
233
+
234
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
235
+ super(RSU4, self).__init__()
236
+
237
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
238
+
239
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
240
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
241
+
242
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
243
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
244
+
245
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
246
+
247
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
248
+
249
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
250
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
251
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
252
+
253
+ def forward(self, x):
254
+ hx = x
255
+
256
+ hxin = self.rebnconvin(hx)
257
+
258
+ hx1 = self.rebnconv1(hxin)
259
+ hx = self.pool1(hx1)
260
+
261
+ hx2 = self.rebnconv2(hx)
262
+ hx = self.pool2(hx2)
263
+
264
+ hx3 = self.rebnconv3(hx)
265
+
266
+ hx4 = self.rebnconv4(hx3)
267
+
268
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
269
+ hx3dup = _upsample_like(hx3d, hx2)
270
+
271
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
272
+ hx2dup = _upsample_like(hx2d, hx1)
273
+
274
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
275
+
276
+ return hx1d + hxin
277
+
278
+
279
+ ### RSU-4F ###
280
+ class RSU4F(nn.Module): # UNet04FRES(nn.Module):
281
+
282
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
283
+ super(RSU4F, self).__init__()
284
+
285
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
286
+
287
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
288
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
289
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
290
+
291
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
292
+
293
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
294
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
295
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
296
+
297
+ def forward(self, x):
298
+ hx = x
299
+
300
+ hxin = self.rebnconvin(hx)
301
+
302
+ hx1 = self.rebnconv1(hxin)
303
+ hx2 = self.rebnconv2(hx1)
304
+ hx3 = self.rebnconv3(hx2)
305
+
306
+ hx4 = self.rebnconv4(hx3)
307
+
308
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
309
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
310
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
311
+
312
+ return hx1d + hxin
313
+
314
+
315
+ ##### U^2-Net ####
316
+ class U2NET(nn.Module):
317
+
318
+ def __init__(self, in_ch=3, out_ch=1):
319
+ super(U2NET, self).__init__()
320
+
321
+ self.stage1 = RSU7(in_ch, 32, 64)
322
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
323
+
324
+ self.stage2 = RSU6(64, 32, 128)
325
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
326
+
327
+ self.stage3 = RSU5(128, 64, 256)
328
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
329
+
330
+ self.stage4 = RSU4(256, 128, 512)
331
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
332
+
333
+ self.stage5 = RSU4F(512, 256, 512)
334
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
335
+
336
+ self.stage6 = RSU4F(512, 256, 512)
337
+
338
+ # decoder
339
+ self.stage5d = RSU4F(1024, 256, 512)
340
+ self.stage4d = RSU4(1024, 128, 256)
341
+ self.stage3d = RSU5(512, 64, 128)
342
+ self.stage2d = RSU6(256, 32, 64)
343
+ self.stage1d = RSU7(128, 16, 64)
344
+
345
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
346
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
347
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
348
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
349
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
350
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
351
+
352
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
353
+
354
+ def forward(self, x):
355
+ hx = x
356
+
357
+ # stage 1
358
+ hx1 = self.stage1(hx)
359
+ hx = self.pool12(hx1)
360
+
361
+ # stage 2
362
+ hx2 = self.stage2(hx)
363
+ hx = self.pool23(hx2)
364
+
365
+ # stage 3
366
+ hx3 = self.stage3(hx)
367
+ hx = self.pool34(hx3)
368
+
369
+ # stage 4
370
+ hx4 = self.stage4(hx)
371
+ hx = self.pool45(hx4)
372
+
373
+ # stage 5
374
+ hx5 = self.stage5(hx)
375
+ hx = self.pool56(hx5)
376
+
377
+ # stage 6
378
+ hx6 = self.stage6(hx)
379
+ hx6up = _upsample_like(hx6, hx5)
380
+
381
+ # -------------------- decoder --------------------
382
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
383
+ hx5dup = _upsample_like(hx5d, hx4)
384
+
385
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
386
+ hx4dup = _upsample_like(hx4d, hx3)
387
+
388
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
389
+ hx3dup = _upsample_like(hx3d, hx2)
390
+
391
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
392
+ hx2dup = _upsample_like(hx2d, hx1)
393
+
394
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
395
+
396
+ # side output
397
+ d1 = self.side1(hx1d)
398
+
399
+ d2 = self.side2(hx2d)
400
+ d2 = _upsample_like(d2, d1)
401
+
402
+ d3 = self.side3(hx3d)
403
+ d3 = _upsample_like(d3, d1)
404
+
405
+ d4 = self.side4(hx4d)
406
+ d4 = _upsample_like(d4, d1)
407
+
408
+ d5 = self.side5(hx5d)
409
+ d5 = _upsample_like(d5, d1)
410
+
411
+ d6 = self.side6(hx6)
412
+ d6 = _upsample_like(d6, d1)
413
+
414
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
415
+
416
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
417
+
418
+
419
+ ### U^2-Net small ###
420
+ class U2NETP(nn.Module):
421
+
422
+ def __init__(self, in_ch=3, out_ch=1):
423
+ super(U2NETP, self).__init__()
424
+
425
+ self.stage1 = RSU7(in_ch, 16, 64)
426
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
427
+
428
+ self.stage2 = RSU6(64, 16, 64)
429
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
430
+
431
+ self.stage3 = RSU5(64, 16, 64)
432
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
433
+
434
+ self.stage4 = RSU4(64, 16, 64)
435
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
436
+
437
+ self.stage5 = RSU4F(64, 16, 64)
438
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
439
+
440
+ self.stage6 = RSU4F(64, 16, 64)
441
+
442
+ # decoder
443
+ self.stage5d = RSU4F(128, 16, 64)
444
+ self.stage4d = RSU4(128, 16, 64)
445
+ self.stage3d = RSU5(128, 16, 64)
446
+ self.stage2d = RSU6(128, 16, 64)
447
+ self.stage1d = RSU7(128, 16, 64)
448
+
449
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
450
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
451
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
452
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
453
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
454
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
455
+
456
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
457
+
458
+ def forward(self, x):
459
+ hx = x
460
+
461
+ # stage 1
462
+ hx1 = self.stage1(hx)
463
+ hx = self.pool12(hx1)
464
+
465
+ # stage 2
466
+ hx2 = self.stage2(hx)
467
+ hx = self.pool23(hx2)
468
+
469
+ # stage 3
470
+ hx3 = self.stage3(hx)
471
+ hx = self.pool34(hx3)
472
+
473
+ # stage 4
474
+ hx4 = self.stage4(hx)
475
+ hx = self.pool45(hx4)
476
+
477
+ # stage 5
478
+ hx5 = self.stage5(hx)
479
+ hx = self.pool56(hx5)
480
+
481
+ # stage 6
482
+ hx6 = self.stage6(hx)
483
+ hx6up = _upsample_like(hx6, hx5)
484
+
485
+ # decoder
486
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
487
+ hx5dup = _upsample_like(hx5d, hx4)
488
+
489
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
490
+ hx4dup = _upsample_like(hx4d, hx3)
491
+
492
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
493
+ hx3dup = _upsample_like(hx3d, hx2)
494
+
495
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
496
+ hx2dup = _upsample_like(hx2d, hx1)
497
+
498
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
499
+
500
+ # side output
501
+ d1 = self.side1(hx1d)
502
+
503
+ d2 = self.side2(hx2d)
504
+ d2 = _upsample_like(d2, d1)
505
+
506
+ d3 = self.side3(hx3d)
507
+ d3 = _upsample_like(d3, d1)
508
+
509
+ d4 = self.side4(hx4d)
510
+ d4 = _upsample_like(d4, d1)
511
+
512
+ d5 = self.side5(hx5d)
513
+ d5 = _upsample_like(d5, d1)
514
+
515
+ d6 = self.side6(hx6)
516
+ d6 = _upsample_like(d6, d1)
517
+
518
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
519
+
520
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
tryon/preprocessing/u2net/utils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def normPRED(d):
5
+ ma = torch.max(d)
6
+ mi = torch.min(d)
7
+
8
+ dn = (d - mi) / (ma - mi)
9
+
10
+ return dn
tryon/preprocessing/utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+
8
+
9
+ class NormalizeImage(object):
10
+ """Normalize given tensor into given mean and standard dev
11
+
12
+ Args:
13
+ mean (float): Desired mean to substract from tensors
14
+ std (float): Desired std to divide from tensors
15
+ """
16
+
17
+ def __init__(self, mean, std):
18
+ assert isinstance(mean, (float))
19
+ if isinstance(mean, float):
20
+ self.mean = mean
21
+
22
+ if isinstance(std, float):
23
+ self.std = std
24
+
25
+ self.normalize_1 = transforms.Normalize(self.mean, self.std)
26
+ self.normalize_3 = transforms.Normalize([self.mean] * 3, [self.std] * 3)
27
+ self.normalize_18 = transforms.Normalize([self.mean] * 18, [self.std] * 18)
28
+
29
+ def __call__(self, image_tensor):
30
+ if image_tensor.shape[0] == 1:
31
+ return self.normalize_1(image_tensor)
32
+
33
+ elif image_tensor.shape[0] == 3:
34
+ return self.normalize_3(image_tensor)
35
+
36
+ elif image_tensor.shape[0] == 18:
37
+ return self.normalize_18(image_tensor)
38
+
39
+ else:
40
+ assert "Please set proper channels! Normalization implemented only for 1, 3 and 18"
41
+
42
+
43
+ def naive_cutout(img, mask):
44
+ empty = Image.new("RGBA", (img.size), 0)
45
+ cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
46
+ return cutout
47
+
48
+
49
+ def resize_by_bigger_index(crop):
50
+ # function resizes and keeps the aspect ratio same
51
+ crop_shape = crop.shape # hxwxc
52
+ if crop_shape[0] / crop_shape[1] <= 1.33:
53
+ resized_crop = image_resize(crop, width=768)
54
+ else:
55
+ resized_crop = image_resize(crop, height=1024)
56
+ return resized_crop
57
+
58
+
59
+ def image_resize(image, width=None, height=None):
60
+ dim = None
61
+ (h, w) = image.shape[:2]
62
+
63
+ if width is None and height is None:
64
+ return image
65
+
66
+ if width is None:
67
+ r = height / float(h)
68
+ dim = (int(w * r), height)
69
+
70
+ else:
71
+ r = width / float(w)
72
+ dim = (width, int(h * r))
73
+
74
+ resized = cv2.resize(image, dim)
75
+
76
+ return resized
77
+
78
+
79
+ def convert_to_jpg(image_path, output_dir, size=None):
80
+ """
81
+ Convert image to jpg format
82
+ :param image_path: image path
83
+ :param output_dir: output directory
84
+ :param size: desired size of the image (w, h)
85
+ """
86
+ img = cv2.imread(image_path)
87
+ if size is not None:
88
+ img = image_resize(img, width=size[0], height=size[1])
89
+
90
+ filename = Path(image_path).name
91
+ cv2.imwrite(os.path.join(output_dir, filename.split(".")[0] + ".jpg"), img)
tryondiffusion/__init__.py ADDED
File without changes
tryondiffusion/diffusion.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+ from torch import optim
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+ import cv2
11
+ import numpy as np
12
+
13
+ from .network import UNet64, UNet128
14
+ from .utils import mk_folders, GaussianSmoothing, UNetDataset
15
+ from .ema import EMA
16
+
17
+
18
+ def smoothen_image(img, sigma):
19
+ # As suggested in:
20
+ # https://jmlr.csail.mit.edu/papers/volume23/21-0635/21-0635.pdf Section 4.4
21
+
22
+ smoothing2d = GaussianSmoothing(channels=3,
23
+ kernel_size=3,
24
+ sigma=sigma,
25
+ conv_dim=2)
26
+
27
+ img = F.pad(img, (1, 1, 1, 1), mode='reflect')
28
+ img = smoothing2d(img)
29
+
30
+ return img
31
+
32
+
33
+ def schedule_lr(total_steps, start_lr=0.0, stop_lr=0.0001, pct_increasing_lr=0.02):
34
+ n = total_steps * pct_increasing_lr
35
+ n = round(n)
36
+ lambdas = list(np.linspace(start_lr, stop_lr, n))
37
+ constant_lr_list = [stop_lr] * (total_steps - n)
38
+ lambdas.extend(constant_lr_list)
39
+ return lambdas
40
+
41
+
42
+ class Diffusion:
43
+
44
+ def __init__(self,
45
+ device,
46
+ pose_embed_dim,
47
+ time_steps=256,
48
+ beta_start=1e-4,
49
+ beta_end=0.02,
50
+ unet_dim=64,
51
+ noise_input_channel=3,
52
+ beta_ema=0.995):
53
+ self.time_steps = time_steps
54
+ self.beta_start = beta_start
55
+ self.beta_end = beta_end
56
+
57
+ self.beta = self.linear_beta_scheduler().to(device)
58
+ self.alpha = 1 - self.beta
59
+ self.alpha_cumprod = torch.cumprod(self.alpha, dim=0)
60
+
61
+ self.noise_input_channel = noise_input_channel
62
+ self.unet_dim = unet_dim
63
+ if unet_dim == 128:
64
+ self.net = UNet128(pose_embed_dim, device, time_steps).to(device)
65
+ elif unet_dim == 64:
66
+ self.net = UNet64(pose_embed_dim, device, time_steps).to(device)
67
+
68
+ self.ema_net = copy.deepcopy(self.net).eval().requires_grad_(False)
69
+ self.beta_ema = beta_ema
70
+
71
+ self.device = device
72
+
73
+ def linear_beta_scheduler(self):
74
+ return torch.linspace(self.beta_start, self.beta_end, self.time_steps)
75
+
76
+ def sample_time_steps(self, batch_size):
77
+ return torch.randint(low=1, high=self.time_steps, size=(batch_size,))
78
+
79
+ def add_noise_to_img(self, img, t):
80
+ sqrt_alpha_timestep = torch.sqrt(self.alpha_cumprod[t])[:, None, None, None]
81
+ sqrt_one_minus_alpha_timestep = torch.sqrt(1 - self.alpha_cumprod[t])[:, None, None, None]
82
+ epsilon = torch.randn_like(img)
83
+ return (sqrt_alpha_timestep * epsilon) + (sqrt_one_minus_alpha_timestep * epsilon), epsilon
84
+
85
+ @torch.inference_mode()
86
+ def sample(self, use_ema, conditional_inputs):
87
+ model = self.ema_net if use_ema else self.net
88
+ ic, jp, jg, ia = conditional_inputs
89
+ ic = ic.to(self.device)
90
+ jp = jp.to(self.device)
91
+ jg = jg.to(self.device)
92
+ ia = ia.to(self.device)
93
+ batch_size = len(ic)
94
+ logging.info(f"Running inference for {batch_size} images")
95
+
96
+ model.eval()
97
+ with torch.inference_mode():
98
+
99
+ # noise augmentation during testing as suggested in paper
100
+ sigma = float(torch.FloatTensor(1).uniform_(0.4, 0.6))
101
+ ia = smoothen_image(ia, sigma)
102
+ ic = smoothen_image(ic, sigma)
103
+
104
+ inp_network_noise = torch.randn(batch_size, self.noise_input_channel, self.unet_dim, self.unet_dim).to(self.device)
105
+
106
+ # paper says to add noise augmentation to input noise during inference
107
+ inp_network_noise = smoothen_image(inp_network_noise, sigma)
108
+
109
+ # concatenating noise with rgb agnostic image across channels
110
+ # corrupt -> concatenate -> predict
111
+ x = torch.cat((inp_network_noise, ia), dim=1)
112
+
113
+ for i in reversed(range(1, self.time_steps)):
114
+ t = (torch.ones(batch_size) * i).long().to(self.device)
115
+ predicted_noise = model(x, ic, jp, jg, t, sigma)
116
+ # ToDo: Add Classifier-Free Guidance with guidance weight 2
117
+ alpha = self.alpha[t][:, None, None, None]
118
+ alpha_cumprod = self.alpha_cumprod[t][:, None, None, None]
119
+ beta = self.beta[t][:, None, None, None]
120
+ if i > 1:
121
+ noise = torch.randn_like(inp_network_noise)
122
+ else:
123
+ noise = torch.zeros_like(inp_network_noise)
124
+
125
+ inp_network_noise = 1 / torch.sqrt(alpha) * (inp_network_noise - ((1 - alpha) / (torch.sqrt(1 - alpha_cumprod))) * predicted_noise) + torch.sqrt(beta) * noise
126
+ inp_network_noise = (inp_network_noise.clamp(-1, 1) + 1) / 2
127
+ inp_network_noise = (inp_network_noise * 255).type(torch.uint8)
128
+ return inp_network_noise
129
+
130
+ def prepare(self, args):
131
+ mk_folders(args.run_name)
132
+ train_dataset = UNetDataset(ip_dir=args.train_ip_folder,
133
+ jp_dir=args.train_jp_folder,
134
+ jg_dir=args.train_jg_folder,
135
+ ia_dir=args.train_ia_folder,
136
+ ic_dir=args.train_ic_folder,
137
+ unet_size=self.unet_dim)
138
+
139
+ validation_dataset = UNetDataset(ip_dir=args.validation_ip_folder,
140
+ jp_dir=args.validation_jp_folder,
141
+ jg_dir=args.validation_jg_folder,
142
+ ia_dir=args.validation_ia_folder,
143
+ ic_dir=args.validation_ic_folder,
144
+ unet_size=self.unet_dim)
145
+
146
+ self.train_dataloader = DataLoader(train_dataset, args.batch_size_train, shuffle=True)
147
+ # give args.batch_size_validation 1 while training
148
+ self.val_dataloader = DataLoader(validation_dataset, args.batch_size_validation, shuffle=True)
149
+
150
+ self.optimizer = optim.AdamW(self.net.parameters(), lr=args.lr, eps=1e-4)
151
+ self.scheduler = schedule_lr(total_steps=args.total_steps, start_lr=args.start_lr,
152
+ stop_lr=args.stop_lr, pct_increasing_lr=args.pct_increasing_lr)
153
+ self.mse = nn.MSELoss()
154
+ self.ema = EMA(self.beta_ema)
155
+ self.scaler = torch.cuda.amp.GradScaler()
156
+
157
+ def train_step(self, loss, running_step):
158
+ self.optimizer.zero_grad()
159
+ self.scaler.scale(loss).backward()
160
+ self.scaler.step(self.optimizer)
161
+ self.scaler.update()
162
+ self.ema.step_ema(self.ema_net, self.net)
163
+ for g in self.optimizer.param_groups:
164
+ g['lr'] = self.scheduler[running_step]
165
+
166
+ def single_epoch(self, train=True):
167
+ avg_loss = 0.
168
+ if train:
169
+ self.net.train()
170
+ else:
171
+ self.net.eval()
172
+
173
+ for ip, jp, jg, ia, ic in self.train_dataloader:
174
+
175
+ # noise augmentation
176
+ sigma = float(torch.FloatTensor(1).uniform_(0.4, 0.6))
177
+ ia = smoothen_image(ia, sigma)
178
+ ic = smoothen_image(ic, sigma)
179
+
180
+ with torch.autocast(self.device) and (torch.inference_mode() if not train else torch.enable_grad()):
181
+ ip = ip.to(self.device)
182
+ jp = jp.to(self.device)
183
+ jg = jg.to(self.device)
184
+ ia = ia.to(self.device)
185
+ ic = ic.to(self.device)
186
+ t = self.sample_time_steps(ip.shape[0]).to(self.device)
187
+
188
+ # corrupt -> concatenate -> predict
189
+ zt, noise_epsilon = self.add_noise_to_img(ip, t)
190
+
191
+ zt = torch.cat((zt, ia), dim=1)
192
+
193
+ # ToDO: Make conditional inputs null, at 10% of the training time,
194
+ # ToDo: for classifier-free guidance(GitHub Issue #21), with guidance weight 2.
195
+
196
+ predicted_noise = self.net(zt, ic, jp, jg, t, sigma)
197
+ loss = self.mse(noise_epsilon, predicted_noise)
198
+ avg_loss += loss
199
+
200
+ if train:
201
+ self.train_step(loss, self.running_train_steps)
202
+ # ToDo: Add logs to tensorboard as well
203
+ logging.info(
204
+ f"train_mse_loss: {loss.item():2.3f}, learning_rate: {self.scheduler[self.running_train_steps]}")
205
+ self.running_train_steps += 1
206
+
207
+ return avg_loss.mean().item()
208
+
209
+ def logging_images(self, epoch, run_name):
210
+
211
+ for idx, (ip, jp, jg, ia, ic) in enumerate(self.val_dataloader):
212
+ # sampled image
213
+ sampled_image = self.sample(use_ema=False, conditional_inputs=(ic, jp, jg, ia))
214
+ sampled_image = sampled_image[0].permute(1, 2, 0).squeeze().cpu().numpy()
215
+
216
+ # ema sampled image
217
+ ema_sampled_image = self.sample(use_ema=True, conditional_inputs=(ic, jp, jg, ia))
218
+ ema_sampled_image = ema_sampled_image[0].permute(1, 2, 0).squeeze().cpu().numpy()
219
+
220
+ # base images
221
+ ip_np = ip[0].permute(1, 2, 0).squeeze().cpu().numpy()
222
+ ic_np = ic[0].permute(1, 2, 0).squeeze().cpu().numpy()
223
+ ia_np = ia[0].permute(1, 2, 0).squeeze().cpu().numpy()
224
+
225
+ # make to folders
226
+ os.makedirs(os.path.join("results", run_name, "images", f"{idx}_E{epoch}"), exist_ok=True)
227
+
228
+ # define folder paths
229
+ images_folder = os.path.join("results", run_name, "images", f"{idx}_E{epoch}")
230
+
231
+ # save base images
232
+ cv2.imwrite(os.path.join(images_folder, "ground_truth.png"), ip_np)
233
+ cv2.imwrite(os.path.join(images_folder, "segmented_garment.png"), ic_np)
234
+ cv2.imwrite(os.path.join(images_folder, "cloth_agnostic_rgb.png"), ia_np)
235
+
236
+ # save sampled image
237
+ cv2.imwrite(os.path.join(images_folder, "sampled_image.png"), sampled_image)
238
+
239
+ # save ema sampled image
240
+ cv2.imwrite(os.path.join(images_folder, "ema_sampled_image.png"), ema_sampled_image)
241
+
242
+ def save_models(self, run_name, epoch=-1):
243
+
244
+ torch.save(self.net.state_dict(), os.path.join("models", run_name, f"ckpt_{epoch}.pt"))
245
+ torch.save(self.ema_net.state_dict(), os.path.join("models", run_name, f"ema_ckpt_{epoch}.pt"))
246
+ torch.save(self.optimizer.state_dict(), os.path.join("models", run_name, f"optim_{epoch}.pt"))
247
+
248
+ def fit(self, args):
249
+
250
+ logging.info(f"Starting training")
251
+
252
+ data_len = len(self.train_dataloader)
253
+
254
+ epochs = round((args.total_steps * args.batch_size_train) / data_len)
255
+
256
+ if epochs < 0:
257
+ epochs = 1
258
+
259
+ self.running_train_steps = 0
260
+
261
+ for epoch in range(epochs):
262
+ logging.info(f"Starting Epoch: {epoch + 1}")
263
+ _ = self.single_epoch(train=True)
264
+
265
+ if (epoch + 1) % args.calculate_loss_frequency == 0:
266
+ avg_loss = self.single_epoch(train=False)
267
+ logging.info(f"Average Loss: {avg_loss}")
268
+
269
+ if (epoch + 1) % args.image_logging_frequency == 0:
270
+ self.logging_images(epoch, args.run_name)
271
+
272
+ if (epoch + 1) % args.model_saving_frequency == 0:
273
+ self.save_models(args.run_name, epoch)
274
+
275
+ logging.info(f"Training Done Successfully! Yayyy! Now let's hope for good results")