Spaces:
Runtime error
Runtime error
Initial commit
Browse files- .gitattributes +1 -0
- .gitignore +18 -0
- LICENSE.md +25 -0
- README.md +22 -2
- app.py +25 -0
- example/color_matching/cutting_fruit.jpg +3 -0
- example/color_matching/field.jpg +3 -0
- example/color_matching/fruit_stand.jpg +3 -0
- example/color_matching/portrait.jpg +3 -0
- example/guidance/arch.jpg +3 -0
- example/guidance/blue-cast.jpg +3 -0
- example/guidance/boat-gray.jpg +3 -0
- example/guidance/building.jpg +3 -0
- example/guidance/canyon.jpg +3 -0
- example/guidance/colorful-buildings.jpg +3 -0
- example/guidance/dark-sky.jpg +3 -0
- example/guidance/fisher.jpg +3 -0
- example/guidance/food_stand.jpg +3 -0
- example/guidance/gray-power.jpg +3 -0
- example/guidance/greenhouse_2.jpg +3 -0
- example/guidance/industrial.jpg +3 -0
- example/guidance/lake.jpg +3 -0
- example/guidance/lake_green.jpg +3 -0
- example/guidance/lake_sunset.jpg +3 -0
- example/guidance/mountain.jpg +3 -0
- example/guidance/ornament.jpg +3 -0
- example/guidance/path.jpg +3 -0
- example/guidance/sky_pier.jpg +3 -0
- example/guidance/snow.jpg +3 -0
- example/guidance/sunken_boat.jpg +3 -0
- example/guidance/waterfall.jpg +3 -0
- requirements.txt +17 -0
- src/color_matcher.py +205 -0
- src/gradio_demo/color_matching.py +164 -0
- src/gradio_demo/sw_guidance.py +336 -0
- src/loss/__init__.py +4 -0
- src/loss/abstract_loss.py +20 -0
- src/loss/vector_swd.py +431 -0
- src/sw_sdthree_guidance.py +654 -0
- src/utils/asc_cdl.py +279 -0
- src/utils/color_space.py +80 -0
- src/utils/image.py +33 -0
- src/utils/math.py +39 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
| 11 |
+
|
| 12 |
+
# Output
|
| 13 |
+
output*/
|
| 14 |
+
out*/
|
| 15 |
+
out/
|
| 16 |
+
.gradio/
|
| 17 |
+
/*.png
|
| 18 |
+
/*.csv
|
LICENSE.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
STABILITY AI COMMUNITY LICENSE AGREEMENT Last Updated: July 5, 2024
|
| 2 |
+
|
| 3 |
+
I. INTRODUCTION
|
| 4 |
+
|
| 5 |
+
This Agreement applies to any individual person or entity ("You", "Your" or "Licensee") that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
|
| 6 |
+
|
| 7 |
+
This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
|
| 8 |
+
|
| 9 |
+
By clicking "I Accept" or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then "You" includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity's behalf.
|
| 10 |
+
|
| 11 |
+
II. RESEARCH & NON-COMMERCIAL USE LICENSE
|
| 12 |
+
|
| 13 |
+
Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. "Research Purpose" means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. "Non-Commercial Purpose" means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
|
| 14 |
+
|
| 15 |
+
III. COMMERCIAL USE LICENSE
|
| 16 |
+
|
| 17 |
+
Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. "Commercial Purpose" means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business's or organization's internal operations. If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
|
| 18 |
+
|
| 19 |
+
IV. GENERAL TERMS
|
| 20 |
+
|
| 21 |
+
Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms. a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright Β© Stability AI Ltd. All Rights Reserved", and (iii) prominently display "Powered by Stability AI" on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the "Notice" text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the "Notice" text file that You changed the Stability AI Materials and how it was modified. b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI's AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works). c. Intellectual Property. (i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein. (ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI's ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI. (iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law. (iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement. (v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI's existing or prospective technology, products or services (collectively, "Feedback"). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided "AS IS" and You make no warranties whatsoever about any Feedback. d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS. e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement. g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
|
| 22 |
+
|
| 23 |
+
V. DEFINITIONS
|
| 24 |
+
|
| 25 |
+
"Affiliate(s)" means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, "control" means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity. "Agreement" means this Stability AI Community License Agreement. "AUP" means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time. "Derivative Work(s)" means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model's output, including"fine tune" and "low-rank adaptation" models derived from a Model or a Model's output, but do not include the output of any Model. "Documentation" means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models. "Model(s)" means, collectively, Stability AI's proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability's Core Models Webpage available at, https://stability.ai/core-models, as may be updated from time to time. "Stability AI" or "we" means Stability AI Ltd. and its Affiliates. "Software" means Stability AI's proprietary software made available under this Agreement now or in the future. "Stability AI Materials" means, collectively, Stability's proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement. "Trade Control Laws" means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
|
README.md
CHANGED
|
@@ -1,12 +1,32 @@
|
|
| 1 |
---
|
| 2 |
title: ReSWD
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.47.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: ReSWD
|
| 3 |
+
emoji: π
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.47.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
models:
|
| 11 |
+
- stabilityai/stable-diffusion-3.5-large-turbo
|
| 12 |
+
license: other
|
| 13 |
+
license_name: stabilityai-ai-community
|
| 14 |
+
license_link: LICENSE.md
|
| 15 |
---
|
| 16 |
|
| 17 |
+
# ReSWD: ReSTIRβd, not shaken. Combining Reservoir Sampling and Sliced Wasserstein Distance for Variance Reduction.
|
| 18 |
+
|
| 19 |
+
<a href="https://reservoirswd.github.io/"><img src="https://img.shields.io/badge/Project%20Page-5CE1BC.svg"></a> <a href="https://reservoirswd.github.io/static/paper.pdf"><img src="https://img.shields.io/badge/Arxiv-2408.00653-B31B1B.svg"></a> <a href="https://huggingface.co/spaces/stabilityai/reswd"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Gradio%20Demo-Huggingface-orange"></a>
|
| 20 |
+
|
| 21 |
+
This is the official codebase for **ReSWD**, a state-of-the-art algorithm for distribution matching with reduced variance. It has several applications (such as diffusion guidance or color matching).
|
| 22 |
+
|
| 23 |
+
## Citation
|
| 24 |
+
|
| 25 |
+
```BibTeX
|
| 26 |
+
@article{boss2025reswd,
|
| 27 |
+
title={ReSWD: ReSTIRβd, not shaken. Combining Reservoir Sampling and Sliced Wasserstein Distance for Variance Reduction.},
|
| 28 |
+
author={Boss, Mark and Engelhardt, Andreas and DonnΓ©, Simon and Jampani, Varun},
|
| 29 |
+
journal={arXiv preprint},
|
| 30 |
+
year={2025}
|
| 31 |
+
}
|
| 32 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import lightning as L
|
| 3 |
+
|
| 4 |
+
from src.gradio_demo.color_matching import create_color_matching
|
| 5 |
+
from src.gradio_demo.sw_guidance import create_sw_guidance
|
| 6 |
+
|
| 7 |
+
with gr.Blocks() as demo:
|
| 8 |
+
gr.Markdown(
|
| 9 |
+
"""
|
| 10 |
+
# ReSWD
|
| 11 |
+
|
| 12 |
+
ReSTIRβd, not shaken. Combining Reservoir Sampling and Sliced Wasserstein
|
| 13 |
+
Distance for Variance Reduction.
|
| 14 |
+
|
| 15 |
+
ReSWD is a method for distribution matching with reduced variance.
|
| 16 |
+
"""
|
| 17 |
+
)
|
| 18 |
+
fabric = L.Fabric(devices=1, accelerator="auto", precision="16-mixed")
|
| 19 |
+
|
| 20 |
+
with gr.Tab("SW Guidance (SD 3.5 Large Turbo)"):
|
| 21 |
+
create_sw_guidance(fabric, "stabilityai/stable-diffusion-3.5-large-turbo")
|
| 22 |
+
with gr.Tab("Color Matching"):
|
| 23 |
+
create_color_matching(fabric)
|
| 24 |
+
|
| 25 |
+
demo.launch()
|
example/color_matching/cutting_fruit.jpg
ADDED
|
Git LFS Details
|
example/color_matching/field.jpg
ADDED
|
Git LFS Details
|
example/color_matching/fruit_stand.jpg
ADDED
|
Git LFS Details
|
example/color_matching/portrait.jpg
ADDED
|
Git LFS Details
|
example/guidance/arch.jpg
ADDED
|
Git LFS Details
|
example/guidance/blue-cast.jpg
ADDED
|
Git LFS Details
|
example/guidance/boat-gray.jpg
ADDED
|
Git LFS Details
|
example/guidance/building.jpg
ADDED
|
Git LFS Details
|
example/guidance/canyon.jpg
ADDED
|
Git LFS Details
|
example/guidance/colorful-buildings.jpg
ADDED
|
Git LFS Details
|
example/guidance/dark-sky.jpg
ADDED
|
Git LFS Details
|
example/guidance/fisher.jpg
ADDED
|
Git LFS Details
|
example/guidance/food_stand.jpg
ADDED
|
Git LFS Details
|
example/guidance/gray-power.jpg
ADDED
|
Git LFS Details
|
example/guidance/greenhouse_2.jpg
ADDED
|
Git LFS Details
|
example/guidance/industrial.jpg
ADDED
|
Git LFS Details
|
example/guidance/lake.jpg
ADDED
|
Git LFS Details
|
example/guidance/lake_green.jpg
ADDED
|
Git LFS Details
|
example/guidance/lake_sunset.jpg
ADDED
|
Git LFS Details
|
example/guidance/mountain.jpg
ADDED
|
Git LFS Details
|
example/guidance/ornament.jpg
ADDED
|
Git LFS Details
|
example/guidance/path.jpg
ADDED
|
Git LFS Details
|
example/guidance/sky_pier.jpg
ADDED
|
Git LFS Details
|
example/guidance/snow.jpg
ADDED
|
Git LFS Details
|
example/guidance/sunken_boat.jpg
ADDED
|
Git LFS Details
|
example/guidance/waterfall.jpg
ADDED
|
Git LFS Details
|
requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate>=1.7.1
|
| 2 |
+
diffusers>=0.33.1
|
| 3 |
+
imageio>=2.37.0
|
| 4 |
+
jaxtyping>=0.3.2
|
| 5 |
+
lightning>=2.5.1.post0
|
| 6 |
+
lxml>=5.4.0
|
| 7 |
+
matplotlib>=3.10.3
|
| 8 |
+
numpy>=2.2.5
|
| 9 |
+
opencv-python>=4.11.0.86
|
| 10 |
+
pillow>=11.2.1
|
| 11 |
+
protobuf>=6.32.1
|
| 12 |
+
scipy>=1.15.3
|
| 13 |
+
sentencepiece>=0.2.1
|
| 14 |
+
torch>=2.7.0
|
| 15 |
+
torchvision>=0.22.0
|
| 16 |
+
tqdm>=4.67.1
|
| 17 |
+
transformers>=4.52.4
|
src/color_matcher.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Literal, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import imageio
|
| 6 |
+
import lightning as L
|
| 7 |
+
import torch
|
| 8 |
+
from jaxtyping import Float
|
| 9 |
+
from torchvision.transforms import Resize
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from src.loss import AbstractLoss
|
| 13 |
+
from src.loss.vector_swd import VectorSWDLoss
|
| 14 |
+
from src.utils.asc_cdl import asc_cdl_forward, save_asc_cdl
|
| 15 |
+
from src.utils.color_space import rgb_to_lab
|
| 16 |
+
from src.utils.image import from_torch, read_img, to_torch, write_img
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CDL(torch.nn.Module):
|
| 20 |
+
def __init__(self, batch_size: int):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.cdl_slope = torch.nn.Parameter(torch.ones(batch_size, 3))
|
| 23 |
+
self.cdl_offset = torch.nn.Parameter(torch.zeros(batch_size, 3))
|
| 24 |
+
self.cdl_power = torch.nn.Parameter(torch.ones(batch_size, 3))
|
| 25 |
+
self.cdl_saturation = torch.nn.Parameter(torch.ones(batch_size))
|
| 26 |
+
|
| 27 |
+
def forward(
|
| 28 |
+
self, x: Float[torch.Tensor, "*B C H W"]
|
| 29 |
+
) -> Float[torch.Tensor, "*B C H W"]:
|
| 30 |
+
return asc_cdl_forward(
|
| 31 |
+
x, self.cdl_slope, self.cdl_offset, self.cdl_power, self.cdl_saturation
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def to_cdl_xml(self) -> str:
|
| 35 |
+
ret = []
|
| 36 |
+
for b in range(self.cdl_slope.shape[0]):
|
| 37 |
+
ret.append(
|
| 38 |
+
save_asc_cdl(
|
| 39 |
+
{
|
| 40 |
+
"slope": self.cdl_slope[b],
|
| 41 |
+
"offset": self.cdl_offset[b],
|
| 42 |
+
"power": self.cdl_power[b],
|
| 43 |
+
"saturation": self.cdl_saturation[b],
|
| 44 |
+
},
|
| 45 |
+
None,
|
| 46 |
+
)
|
| 47 |
+
)
|
| 48 |
+
return ret
|
| 49 |
+
|
| 50 |
+
def save(self, path: str):
|
| 51 |
+
for b in range(self.cdl_slope.shape[0]):
|
| 52 |
+
save_asc_cdl(
|
| 53 |
+
{
|
| 54 |
+
"slope": self.cdl_slope[b],
|
| 55 |
+
"offset": self.cdl_offset[b],
|
| 56 |
+
"power": self.cdl_power[b],
|
| 57 |
+
"saturation": self.cdl_saturation[b],
|
| 58 |
+
},
|
| 59 |
+
os.path.join(path, f"cdl_{b}.xml"),
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def train(
|
| 64 |
+
fabric: L.Fabric,
|
| 65 |
+
criteria: AbstractLoss,
|
| 66 |
+
source_img: Float[torch.Tensor, "B C H W"],
|
| 67 |
+
target_img: Float[torch.Tensor, "B C H W"],
|
| 68 |
+
num_steps: int,
|
| 69 |
+
lr: float,
|
| 70 |
+
match_resolution: int,
|
| 71 |
+
silent: bool = False,
|
| 72 |
+
write_video_animation_path: Optional[str] = None,
|
| 73 |
+
) -> Tuple[Float[torch.Tensor, "*B C H W"], CDL, List[float]]:
|
| 74 |
+
criteria = fabric.setup(criteria)
|
| 75 |
+
|
| 76 |
+
source_max_res = Resize(match_resolution, antialias=True)(source_img)
|
| 77 |
+
target_max_res = Resize(match_resolution, antialias=True)(target_img)
|
| 78 |
+
|
| 79 |
+
target_cielab = (
|
| 80 |
+
fabric.to_device(rgb_to_lab(target_max_res).permute(0, 3, 1, 2))
|
| 81 |
+
.permute(0, 2, 3, 1)
|
| 82 |
+
.contiguous()
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
source_max_res = fabric.to_device(source_max_res)
|
| 86 |
+
source_img = fabric.to_device(source_img)
|
| 87 |
+
|
| 88 |
+
batch_size = source_img.shape[0]
|
| 89 |
+
cdl = CDL(batch_size)
|
| 90 |
+
|
| 91 |
+
optim = torch.optim.Adam(cdl.parameters(), lr=lr)
|
| 92 |
+
cdl, optim = fabric.setup(cdl, optim)
|
| 93 |
+
|
| 94 |
+
lossses = []
|
| 95 |
+
for i in tqdm(range(num_steps), disable=silent):
|
| 96 |
+
optim.zero_grad(set_to_none=True)
|
| 97 |
+
|
| 98 |
+
cdl_source = cdl(source_max_res)
|
| 99 |
+
source_cielab = (
|
| 100 |
+
rgb_to_lab(cdl_source.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous()
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
loss = criteria(
|
| 104 |
+
source_cielab.view(source_cielab.shape[0], source_cielab.shape[1], -1),
|
| 105 |
+
target_cielab.view(target_cielab.shape[0], target_cielab.shape[1], -1),
|
| 106 |
+
i,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
fabric.backward(loss)
|
| 110 |
+
optim.step()
|
| 111 |
+
|
| 112 |
+
lossses.append(loss.item())
|
| 113 |
+
|
| 114 |
+
if write_video_animation_path is not None:
|
| 115 |
+
write_img(
|
| 116 |
+
os.path.join(write_video_animation_path, f"{i:05d}.jpg"),
|
| 117 |
+
from_torch(cdl(source_img).squeeze(0) * 2 - 1),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
source_full_res_cdl = cdl(source_img)
|
| 121 |
+
|
| 122 |
+
gc.collect()
|
| 123 |
+
torch.cuda.empty_cache()
|
| 124 |
+
|
| 125 |
+
return source_full_res_cdl, cdl, lossses
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def run(
|
| 129 |
+
save_dir: str,
|
| 130 |
+
source_img: List[str],
|
| 131 |
+
target_img: List[str],
|
| 132 |
+
matching_resolution: int,
|
| 133 |
+
precision: Literal["32-true", "16-mixed"] = "16-mixed",
|
| 134 |
+
num_projections: int = 64,
|
| 135 |
+
lr: float = 0.01,
|
| 136 |
+
steps: int = 300,
|
| 137 |
+
use_ucv: bool = False,
|
| 138 |
+
use_lcv: bool = False,
|
| 139 |
+
distance: Literal["l1", "l2"] = "l1",
|
| 140 |
+
refresh_projections_every_n_steps: int = 1,
|
| 141 |
+
num_new_candidates: int = 32,
|
| 142 |
+
sampling_mode: Literal["gaussian", "qmc"] = "gaussian",
|
| 143 |
+
write_video: bool = False,
|
| 144 |
+
**kwargs,
|
| 145 |
+
):
|
| 146 |
+
fabric = L.Fabric(devices=1, accelerator="auto", precision=precision)
|
| 147 |
+
|
| 148 |
+
source_imgs = torch.stack(
|
| 149 |
+
[to_torch(read_img(s)) * 0.5 + 0.5 for s in source_img], dim=0
|
| 150 |
+
)
|
| 151 |
+
target_imgs = torch.stack(
|
| 152 |
+
[to_torch(read_img(t)) * 0.5 + 0.5 for t in target_img], dim=0
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
criteria = VectorSWDLoss(
|
| 156 |
+
num_proj=num_projections,
|
| 157 |
+
distance=distance,
|
| 158 |
+
use_ucv=use_ucv,
|
| 159 |
+
use_lcv=use_lcv,
|
| 160 |
+
refresh_projections_every_n_steps=refresh_projections_every_n_steps,
|
| 161 |
+
num_new_candidates=num_new_candidates,
|
| 162 |
+
sampling_mode=sampling_mode,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 166 |
+
animation_dir = os.path.join(save_dir, "animation")
|
| 167 |
+
|
| 168 |
+
if write_video:
|
| 169 |
+
os.makedirs(animation_dir, exist_ok=True)
|
| 170 |
+
|
| 171 |
+
source_full_res_cdl, cdl, lossses = train(
|
| 172 |
+
fabric,
|
| 173 |
+
criteria,
|
| 174 |
+
source_imgs,
|
| 175 |
+
target_imgs,
|
| 176 |
+
steps,
|
| 177 |
+
lr,
|
| 178 |
+
matching_resolution,
|
| 179 |
+
write_video_animation_path=animation_dir if write_video else None,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
cdl.save(save_dir)
|
| 183 |
+
|
| 184 |
+
for i, img in enumerate(source_full_res_cdl):
|
| 185 |
+
write_img(
|
| 186 |
+
os.path.join(save_dir, f"color_matched_{i}.png"),
|
| 187 |
+
from_torch(img * 2 - 1),
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
if write_video:
|
| 191 |
+
# Get the list of image files in the animation directory
|
| 192 |
+
image_files = [f for f in os.listdir(animation_dir) if f.endswith(".jpg")]
|
| 193 |
+
image_files.sort(
|
| 194 |
+
key=lambda x: int(x.split(".")[0])
|
| 195 |
+
) # Ensure they are in the correct order
|
| 196 |
+
|
| 197 |
+
# Create a video from the images
|
| 198 |
+
with imageio.get_writer(
|
| 199 |
+
os.path.join(save_dir, "animation.mp4"), fps=30, codec="libx264"
|
| 200 |
+
) as writer:
|
| 201 |
+
for image_file in image_files:
|
| 202 |
+
image = imageio.imread(os.path.join(animation_dir, image_file))
|
| 203 |
+
writer.append_data(image)
|
| 204 |
+
|
| 205 |
+
return source_full_res_cdl, cdl, lossses
|
src/gradio_demo/color_matching.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import lightning as L
|
| 5 |
+
import numpy as np
|
| 6 |
+
import spaces
|
| 7 |
+
|
| 8 |
+
from src.color_matcher import train
|
| 9 |
+
from src.loss import VectorSWDLoss
|
| 10 |
+
from src.utils.image import from_torch, to_torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def create_color_matching(fabric: L.Fabric):
|
| 14 |
+
"""
|
| 15 |
+
Creates the Gradio interface for color matching between source and target images.
|
| 16 |
+
"""
|
| 17 |
+
gr.Markdown(
|
| 18 |
+
"""
|
| 19 |
+
# Color Matching
|
| 20 |
+
Matches the color of source images to target images using ASC CDL parameters.
|
| 21 |
+
"""
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
gr.Markdown("## Source Images")
|
| 25 |
+
with gr.Row(variant="compact"):
|
| 26 |
+
source_image = gr.Image(label="Source Image", height=512)
|
| 27 |
+
target_image = gr.Image(label="Target Image", height=512)
|
| 28 |
+
with gr.Row(variant="compact"):
|
| 29 |
+
example_paths = os.path.join("example", "color_matching")
|
| 30 |
+
# Find all images in the example directory
|
| 31 |
+
example_images = [
|
| 32 |
+
os.path.join(example_paths, f)
|
| 33 |
+
for f in os.listdir(example_paths)
|
| 34 |
+
if os.path.isfile(os.path.join(example_paths, f))
|
| 35 |
+
and os.path.splitext(f)[-1] in [".png", ".jpg", ".jpeg"]
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
gr.Examples(
|
| 39 |
+
examples=example_images,
|
| 40 |
+
inputs=[source_image],
|
| 41 |
+
)
|
| 42 |
+
gr.Examples(
|
| 43 |
+
examples=example_images,
|
| 44 |
+
inputs=[target_image],
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
gr.Markdown(
|
| 48 |
+
"""
|
| 49 |
+
# Configuration
|
| 50 |
+
Adjust the parameters for the color matching process.
|
| 51 |
+
"""
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
with gr.Accordion("Advanced Config", open=False):
|
| 55 |
+
learning_rate_slider = gr.Slider(
|
| 56 |
+
1e-4, 0.1, value=1e-3, step=1e-4, label="Learning rate"
|
| 57 |
+
)
|
| 58 |
+
match_resolution_slider = gr.Slider(
|
| 59 |
+
64, 1024, value=128, step=64, label="Match resolution"
|
| 60 |
+
)
|
| 61 |
+
num_steps_slider = gr.Slider(
|
| 62 |
+
50, 500, value=150, step=25, label="Number of optimization steps"
|
| 63 |
+
)
|
| 64 |
+
control_variates_dropdown = gr.Dropdown(
|
| 65 |
+
choices=["None", "Lower", "Upper"],
|
| 66 |
+
value="None",
|
| 67 |
+
label="Control Variates",
|
| 68 |
+
info="Select which control variates to use for optimization",
|
| 69 |
+
)
|
| 70 |
+
candidates_per_pass_slider = gr.Slider(
|
| 71 |
+
0,
|
| 72 |
+
64,
|
| 73 |
+
value=16,
|
| 74 |
+
step=1,
|
| 75 |
+
label="Number of new candidates per pass.",
|
| 76 |
+
info=(
|
| 77 |
+
"The number of new candidates to generate per pass in the reservoir "
|
| 78 |
+
"sampling."
|
| 79 |
+
),
|
| 80 |
+
)
|
| 81 |
+
num_projections_slider = gr.Slider(
|
| 82 |
+
16, 1024, value=64, step=16, label="Number of projections"
|
| 83 |
+
)
|
| 84 |
+
sampling_mode_dropdown = gr.Dropdown(
|
| 85 |
+
choices=["gaussian", "qmc"],
|
| 86 |
+
value="gaussian",
|
| 87 |
+
label="Sampling Mode",
|
| 88 |
+
info="Select which sampling mode to use for projections",
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
run_button = gr.Button("Run Color Matching", variant="primary")
|
| 92 |
+
|
| 93 |
+
with gr.Column(variant="compact"):
|
| 94 |
+
output_image = gr.Image(label="Color Matched Image", height=512)
|
| 95 |
+
output_cdl = gr.Textbox(label="CDL Parameters")
|
| 96 |
+
|
| 97 |
+
@spaces.GPU
|
| 98 |
+
def run_color_matching(
|
| 99 |
+
match_resolution: int,
|
| 100 |
+
num_steps: int,
|
| 101 |
+
control_variates: str,
|
| 102 |
+
num_projections: int,
|
| 103 |
+
candidates_per_pass: int,
|
| 104 |
+
sampling_mode: str,
|
| 105 |
+
learning_rate: float,
|
| 106 |
+
*images,
|
| 107 |
+
):
|
| 108 |
+
"""
|
| 109 |
+
Runs the color matching process between source and target images.
|
| 110 |
+
"""
|
| 111 |
+
# Split images into source and target pair
|
| 112 |
+
source_img = images[0] / 255.0
|
| 113 |
+
target_img = images[1] / 255.0
|
| 114 |
+
|
| 115 |
+
# Convert images to tensors
|
| 116 |
+
source_tensors = to_torch(source_img).float().unsqueeze(0)
|
| 117 |
+
target_tensors = to_torch(target_img).float().unsqueeze(0)
|
| 118 |
+
|
| 119 |
+
# Configure loss function
|
| 120 |
+
criteria = VectorSWDLoss(
|
| 121 |
+
num_proj=num_projections,
|
| 122 |
+
use_ucv=control_variates == "Upper",
|
| 123 |
+
use_lcv=control_variates == "Lower",
|
| 124 |
+
num_new_candidates=candidates_per_pass,
|
| 125 |
+
sampling_mode=sampling_mode,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Run color matching
|
| 129 |
+
source_matched, cdl, losses = train(
|
| 130 |
+
fabric=fabric,
|
| 131 |
+
criteria=criteria,
|
| 132 |
+
source_img=source_tensors,
|
| 133 |
+
target_img=target_tensors,
|
| 134 |
+
num_steps=num_steps,
|
| 135 |
+
lr=learning_rate,
|
| 136 |
+
match_resolution=match_resolution,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
return [
|
| 140 |
+
(from_torch(source_matched.squeeze(0)) * 255).astype(np.uint8),
|
| 141 |
+
cdl.to_cdl_xml()[0],
|
| 142 |
+
]
|
| 143 |
+
|
| 144 |
+
run_button.click(
|
| 145 |
+
run_color_matching,
|
| 146 |
+
inputs=[
|
| 147 |
+
match_resolution_slider,
|
| 148 |
+
num_steps_slider,
|
| 149 |
+
control_variates_dropdown,
|
| 150 |
+
num_projections_slider,
|
| 151 |
+
candidates_per_pass_slider,
|
| 152 |
+
sampling_mode_dropdown,
|
| 153 |
+
learning_rate_slider,
|
| 154 |
+
source_image,
|
| 155 |
+
target_image,
|
| 156 |
+
],
|
| 157 |
+
outputs=[output_image, output_cdl],
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
clear_button = gr.ClearButton(
|
| 161 |
+
[source_image, target_image, output_image, output_cdl]
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
clear_button.click(lambda: None)
|
src/gradio_demo/sw_guidance.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import lightning as L
|
| 6 |
+
import numpy as np
|
| 7 |
+
import spaces
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from src.sw_sdthree_guidance import create_pipeline
|
| 11 |
+
from src.sw_sdthree_guidance import run as sw_guidance_run
|
| 12 |
+
|
| 13 |
+
preset_lrs = [1e-6, 1.0]
|
| 14 |
+
|
| 15 |
+
log_lrs = np.log10(preset_lrs)
|
| 16 |
+
|
| 17 |
+
models = {
|
| 18 |
+
"stabilityai/stable-diffusion-3.5-large": {
|
| 19 |
+
"num_inference_steps": 30,
|
| 20 |
+
"guidance_scale": 7.0,
|
| 21 |
+
"sw_u_lr": np.log10(3.2e-3),
|
| 22 |
+
"sw_steps": 6,
|
| 23 |
+
"cfg_rescale_phi": 0.7,
|
| 24 |
+
},
|
| 25 |
+
"stabilityai/stable-diffusion-3.5-medium": {
|
| 26 |
+
"num_inference_steps": 30,
|
| 27 |
+
"guidance_scale": 3.5,
|
| 28 |
+
"sw_u_lr": np.log10(3.2e-3),
|
| 29 |
+
"sw_steps": 6,
|
| 30 |
+
"cfg_rescale_phi": 0.7,
|
| 31 |
+
},
|
| 32 |
+
"stabilityai/stable-diffusion-3.5-large-turbo": {
|
| 33 |
+
"num_inference_steps": 4,
|
| 34 |
+
"guidance_scale": 1.0,
|
| 35 |
+
"sw_u_lr": np.log10(4e-3),
|
| 36 |
+
"sw_steps": 6,
|
| 37 |
+
"cfg_rescale_phi": 0.65,
|
| 38 |
+
},
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def log_slider_to_lr(log_lr):
|
| 43 |
+
return float(f"{10**log_lr:.1e}")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def create_sw_guidance(
|
| 47 |
+
fabric: L.Fabric, model_name: str = "stabilityai/stable-diffusion-3.5-large"
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Creates the Gradio interface for SW guidance with SD3.5.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
fabric: Lightning Fabric instance
|
| 54 |
+
model_name: The model to use for guidance
|
| 55 |
+
"""
|
| 56 |
+
gr.Markdown(
|
| 57 |
+
f"""
|
| 58 |
+
# SW Guidance with {model_name.split('/')[-1]}
|
| 59 |
+
Generates images using SW Guidance with a reference image and text prompt.
|
| 60 |
+
"""
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
pipe = create_pipeline(
|
| 64 |
+
model_name,
|
| 65 |
+
device="cuda",
|
| 66 |
+
compile=True,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
model_config = models[model_name]
|
| 70 |
+
|
| 71 |
+
@spaces.GPU
|
| 72 |
+
def run_sw_guidance(
|
| 73 |
+
num_inference_steps: int,
|
| 74 |
+
num_guided_steps_perc: float,
|
| 75 |
+
guidance_scale: float,
|
| 76 |
+
sw_u_lr: float,
|
| 77 |
+
sw_steps: int,
|
| 78 |
+
num_projections: int,
|
| 79 |
+
control_variates: str,
|
| 80 |
+
distance: str,
|
| 81 |
+
candidates_per_pass: int,
|
| 82 |
+
subsampling_factor: int,
|
| 83 |
+
sampling_mode: str,
|
| 84 |
+
cfg_rescale_phi: float,
|
| 85 |
+
prompt: str,
|
| 86 |
+
reference_image: np.ndarray,
|
| 87 |
+
seed: Optional[int] = None,
|
| 88 |
+
):
|
| 89 |
+
"""
|
| 90 |
+
Runs the SW guidance process with the given parameters.
|
| 91 |
+
"""
|
| 92 |
+
if reference_image is None:
|
| 93 |
+
raise gr.Error("Please provide a reference image")
|
| 94 |
+
if not prompt:
|
| 95 |
+
raise gr.Error("Please provide a prompt")
|
| 96 |
+
|
| 97 |
+
# Convert numpy array to PIL Image
|
| 98 |
+
ref_img = Image.fromarray(reference_image)
|
| 99 |
+
|
| 100 |
+
# Run SW guidance
|
| 101 |
+
image = sw_guidance_run(
|
| 102 |
+
prompt=prompt,
|
| 103 |
+
reference_image=ref_img,
|
| 104 |
+
model_path=model_name,
|
| 105 |
+
num_inference_steps=num_inference_steps,
|
| 106 |
+
num_guided_steps=int(num_guided_steps_perc * num_inference_steps),
|
| 107 |
+
guidance_scale=guidance_scale,
|
| 108 |
+
sw_u_lr=log_slider_to_lr(sw_u_lr),
|
| 109 |
+
sw_steps=sw_steps,
|
| 110 |
+
height=1024,
|
| 111 |
+
width=1024,
|
| 112 |
+
device="cuda",
|
| 113 |
+
num_projections=num_projections,
|
| 114 |
+
use_ucv=control_variates == "Upper",
|
| 115 |
+
use_lcv=control_variates == "Lower",
|
| 116 |
+
distance=distance,
|
| 117 |
+
num_new_candidates=candidates_per_pass,
|
| 118 |
+
subsampling_factor=subsampling_factor,
|
| 119 |
+
sampling_mode=sampling_mode,
|
| 120 |
+
pipe=pipe,
|
| 121 |
+
compile=True,
|
| 122 |
+
seed=seed,
|
| 123 |
+
cfg_rescale_phi=cfg_rescale_phi,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
return np.array(image)
|
| 127 |
+
|
| 128 |
+
gr.Markdown("## Input")
|
| 129 |
+
with gr.Row(equal_height=True):
|
| 130 |
+
with gr.Column(variant="panel"):
|
| 131 |
+
prompt = gr.Textbox(
|
| 132 |
+
label="Prompt",
|
| 133 |
+
placeholder="Enter your prompt here...",
|
| 134 |
+
lines=2,
|
| 135 |
+
)
|
| 136 |
+
reference_image = gr.Image(label="Reference Image", height=512)
|
| 137 |
+
|
| 138 |
+
with gr.Column(variant="panel"):
|
| 139 |
+
output_image = gr.Image(label="Generated Image", height=512)
|
| 140 |
+
|
| 141 |
+
example_pairs = [
|
| 142 |
+
("A diver discovering an underwater city", "sunken_boat.jpg"),
|
| 143 |
+
("A raccoon in a forest", "waterfall.jpg"),
|
| 144 |
+
("A cat detective solving a mystery", "building.jpg"),
|
| 145 |
+
("A raccoon reading a book by candlelight.", "food_stand.jpg"),
|
| 146 |
+
("A lion meditating on a mountain", "mountain.jpg"),
|
| 147 |
+
("A squirrel kayaking", "lake_green.jpg"),
|
| 148 |
+
("A young dragon roasting marshmallows", "canyon.jpg"),
|
| 149 |
+
("A family picnic beneath floating lanterns", "lake_sunset.jpg"),
|
| 150 |
+
("Elephants holding umbrellas in a rainstorm", "fisher.jpg"),
|
| 151 |
+
("A boy and his dog exploring a crystal cave", "lake.jpg"),
|
| 152 |
+
("A snowman sharing cocoa with woodland animals", "snow.jpg"),
|
| 153 |
+
("A kitten exploring an antique library", "ornament.jpg"),
|
| 154 |
+
("Children riding flying bicycles over mountains", "sky_pier.jpg"),
|
| 155 |
+
("An ancient tree whispering stories to deer", "greenhouse_2.jpg"),
|
| 156 |
+
("Owls with monocles in treetops", "path.jpg"),
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
default_config = {
|
| 160 |
+
"num_guided_steps_perc": 0.95,
|
| 161 |
+
"num_projections": 32,
|
| 162 |
+
"control_variates": "None",
|
| 163 |
+
"distance": "l1",
|
| 164 |
+
"candidates_per_pass": 8,
|
| 165 |
+
"subsampling_factor": 1,
|
| 166 |
+
"sampling_mode": "gaussian",
|
| 167 |
+
} | model_config
|
| 168 |
+
|
| 169 |
+
def run_example(prompt: str, reference_image: np.ndarray):
|
| 170 |
+
return run_sw_guidance(
|
| 171 |
+
**default_config,
|
| 172 |
+
prompt=prompt,
|
| 173 |
+
reference_image=reference_image,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
example_inputs = [
|
| 177 |
+
[prompt, os.path.join("example", "guidance", img_file)]
|
| 178 |
+
for prompt, img_file in example_pairs
|
| 179 |
+
]
|
| 180 |
+
|
| 181 |
+
run_button = gr.Button("Generate Image", variant="primary")
|
| 182 |
+
|
| 183 |
+
gr.Examples(
|
| 184 |
+
examples=example_inputs,
|
| 185 |
+
inputs=[prompt, reference_image],
|
| 186 |
+
outputs=[output_image],
|
| 187 |
+
fn=run_example,
|
| 188 |
+
label="Prompt + Reference Image Examples",
|
| 189 |
+
examples_per_page=5,
|
| 190 |
+
cache_examples=True,
|
| 191 |
+
cache_mode="lazy",
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
gr.Markdown(
|
| 195 |
+
"""
|
| 196 |
+
# Configuration
|
| 197 |
+
Adjust the parameters for the SW guidance process.
|
| 198 |
+
"""
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
with gr.Accordion("Basic Config", open=True):
|
| 202 |
+
num_inference_steps_slider = gr.Slider(
|
| 203 |
+
1,
|
| 204 |
+
100,
|
| 205 |
+
value=model_config["num_inference_steps"],
|
| 206 |
+
step=1,
|
| 207 |
+
label="Number of inference steps",
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
guidance_scale_slider = gr.Slider(
|
| 211 |
+
0.0,
|
| 212 |
+
20.0,
|
| 213 |
+
value=model_config["guidance_scale"],
|
| 214 |
+
step=0.1,
|
| 215 |
+
label="Guidance scale",
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
cfg_rescale_phi_slider = gr.Slider(
|
| 219 |
+
0.0,
|
| 220 |
+
1.0,
|
| 221 |
+
value=model_config["cfg_rescale_phi"],
|
| 222 |
+
step=0.05,
|
| 223 |
+
label="CFG Rescale Phi",
|
| 224 |
+
info="Controls the rescaling of classifier-free guidance",
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
seed_input = gr.Number(
|
| 228 |
+
value=lambda: None,
|
| 229 |
+
label="Seed (leave empty for random)",
|
| 230 |
+
precision=0,
|
| 231 |
+
interactive=True,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
with gr.Row(variant="panel", equal_height=True):
|
| 235 |
+
sw_u_lr_slider = gr.Slider(
|
| 236 |
+
minimum=log_lrs.min(),
|
| 237 |
+
maximum=log_lrs.max(),
|
| 238 |
+
value=model_config["sw_u_lr"],
|
| 239 |
+
step=0.05,
|
| 240 |
+
label="SW guidance learning rate (Log scale)",
|
| 241 |
+
interactive=True,
|
| 242 |
+
scale=4,
|
| 243 |
+
)
|
| 244 |
+
lr_display = gr.Textbox(
|
| 245 |
+
label="Learning Rate",
|
| 246 |
+
value=f"{log_slider_to_lr(model_config['sw_u_lr']):.1e}",
|
| 247 |
+
interactive=False,
|
| 248 |
+
scale=1,
|
| 249 |
+
)
|
| 250 |
+
sw_u_lr_slider.change(
|
| 251 |
+
lambda x: gr.update(value=f"{log_slider_to_lr(x):.1e}"),
|
| 252 |
+
inputs=sw_u_lr_slider,
|
| 253 |
+
outputs=lr_display,
|
| 254 |
+
show_progress=False,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
with gr.Accordion("Advanced Config", open=False):
|
| 258 |
+
num_guided_steps_perc_slider = gr.Slider(
|
| 259 |
+
0.0,
|
| 260 |
+
1.0,
|
| 261 |
+
value=default_config["num_guided_steps_perc"],
|
| 262 |
+
step=0.05,
|
| 263 |
+
label="Percentage of steps to apply SW guidance",
|
| 264 |
+
)
|
| 265 |
+
sw_steps_slider = gr.Slider(
|
| 266 |
+
0,
|
| 267 |
+
32,
|
| 268 |
+
value=model_config["sw_steps"],
|
| 269 |
+
step=1,
|
| 270 |
+
label="Number of SW guidance steps, 0 means no SW guidance",
|
| 271 |
+
)
|
| 272 |
+
num_projections_slider = gr.Slider(
|
| 273 |
+
16,
|
| 274 |
+
1024,
|
| 275 |
+
value=default_config["num_projections"],
|
| 276 |
+
step=16,
|
| 277 |
+
label="Number of projections",
|
| 278 |
+
)
|
| 279 |
+
control_variates_dropdown = gr.Dropdown(
|
| 280 |
+
choices=["None", "Lower", "Upper"],
|
| 281 |
+
value=default_config["control_variates"],
|
| 282 |
+
label="Control Variates",
|
| 283 |
+
info="Select which control variates to use for optimization",
|
| 284 |
+
)
|
| 285 |
+
distance_dropdown = gr.Dropdown(
|
| 286 |
+
choices=["l1", "l2"],
|
| 287 |
+
value=default_config["distance"],
|
| 288 |
+
label="Distance metric",
|
| 289 |
+
info="Select which distance metric to use",
|
| 290 |
+
)
|
| 291 |
+
candidates_per_pass_slider = gr.Slider(
|
| 292 |
+
0,
|
| 293 |
+
64,
|
| 294 |
+
value=default_config["candidates_per_pass"],
|
| 295 |
+
step=1,
|
| 296 |
+
label="Number of new candidates per pass. 0 means no reservoir sampling",
|
| 297 |
+
)
|
| 298 |
+
subsampling_factor_slider = gr.Slider(
|
| 299 |
+
1,
|
| 300 |
+
16,
|
| 301 |
+
value=default_config["subsampling_factor"],
|
| 302 |
+
step=1,
|
| 303 |
+
label="Subsampling factor",
|
| 304 |
+
)
|
| 305 |
+
sampling_mode_dropdown = gr.Dropdown(
|
| 306 |
+
choices=["gaussian", "qmc"],
|
| 307 |
+
value="qmc",
|
| 308 |
+
label="Sampling Mode",
|
| 309 |
+
info="Select which sampling mode to use for projections",
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
run_button.click(
|
| 313 |
+
run_sw_guidance,
|
| 314 |
+
inputs=[
|
| 315 |
+
num_inference_steps_slider,
|
| 316 |
+
num_guided_steps_perc_slider,
|
| 317 |
+
guidance_scale_slider,
|
| 318 |
+
sw_u_lr_slider,
|
| 319 |
+
sw_steps_slider,
|
| 320 |
+
num_projections_slider,
|
| 321 |
+
control_variates_dropdown,
|
| 322 |
+
distance_dropdown,
|
| 323 |
+
candidates_per_pass_slider,
|
| 324 |
+
subsampling_factor_slider,
|
| 325 |
+
sampling_mode_dropdown,
|
| 326 |
+
cfg_rescale_phi_slider,
|
| 327 |
+
prompt,
|
| 328 |
+
reference_image,
|
| 329 |
+
seed_input,
|
| 330 |
+
],
|
| 331 |
+
outputs=[output_image],
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
clear_button = gr.ClearButton([prompt, reference_image, output_image, seed_input])
|
| 335 |
+
|
| 336 |
+
clear_button.click(lambda: None)
|
src/loss/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .abstract_loss import AbstractLoss
|
| 2 |
+
from .vector_swd import VectorSWDLoss
|
| 3 |
+
|
| 4 |
+
__all__ = ["AbstractLoss", "VectorSWDLoss"]
|
src/loss/abstract_loss.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from jaxtyping import Float
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AbstractLoss(nn.Module, abc.ABC):
|
| 9 |
+
@abc.abstractmethod
|
| 10 |
+
def forward(
|
| 11 |
+
self,
|
| 12 |
+
pred: Float[torch.Tensor, "B C H W"],
|
| 13 |
+
gt: Float[torch.Tensor, "B C H W"],
|
| 14 |
+
step: int,
|
| 15 |
+
**kwargs,
|
| 16 |
+
) -> Float[torch.Tensor, ""]:
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
def reset(self):
|
| 20 |
+
pass
|
src/loss/vector_swd.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import typing
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from jaxtyping import Float
|
| 6 |
+
|
| 7 |
+
from src.loss.abstract_loss import AbstractLoss
|
| 8 |
+
from src.utils.math import sobol_sphere
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def process_vector(
|
| 12 |
+
x: Float[torch.Tensor, "B D N"],
|
| 13 |
+
dirs: Float[torch.Tensor, "K D"],
|
| 14 |
+
) -> Float[torch.Tensor, "K B*N_valid"]:
|
| 15 |
+
"""
|
| 16 |
+
Project a 1-D sequence with a bank of linear directions.
|
| 17 |
+
|
| 18 |
+
Args
|
| 19 |
+
----
|
| 20 |
+
x : (B, D, N) tensor β predictions or ground truth
|
| 21 |
+
dirs : (K, D) tensor β unit-length projection directions
|
| 22 |
+
|
| 23 |
+
Returns
|
| 24 |
+
-------
|
| 25 |
+
proj : (K, B*N_valid) tensor of flattened projections
|
| 26 |
+
"""
|
| 27 |
+
B, D, N = x.shape
|
| 28 |
+
K, _ = dirs.shape
|
| 29 |
+
|
| 30 |
+
# linear projection: x (B,D,N) -> (B,N,K) -> (K,B*N)
|
| 31 |
+
proj = F.linear(x.transpose(1, 2).to(torch.float32), dirs.to(torch.float32))
|
| 32 |
+
proj = proj.permute(2, 0, 1).reshape(K, -1).to(x.dtype)
|
| 33 |
+
|
| 34 |
+
return proj
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class VectorSWDLoss(AbstractLoss):
|
| 38 |
+
"""
|
| 39 |
+
1-D Sliced-Wasserstein Distance on sequences.
|
| 40 |
+
|
| 41 |
+
This loss computes the sliced Wasserstein distance between predicted and ground
|
| 42 |
+
truth sequences by projecting them onto random directions and computing the
|
| 43 |
+
Wasserstein distance in 1D. It supports reservoir sampling for adaptive direction
|
| 44 |
+
selection and various variance reduction techniques.
|
| 45 |
+
|
| 46 |
+
Parameters
|
| 47 |
+
----------
|
| 48 |
+
num_proj : int, default=64
|
| 49 |
+
Number of random projections to use per step (K).
|
| 50 |
+
|
| 51 |
+
distance : {"l1", "l2"}, default="l1"
|
| 52 |
+
Distance metric to use for computing the Wasserstein distance.
|
| 53 |
+
|
| 54 |
+
use_ucv : bool, default=False
|
| 55 |
+
Whether to use upper bounds control variates for variance reduction.
|
| 56 |
+
Mutually exclusive with use_lcv.
|
| 57 |
+
|
| 58 |
+
use_lcv : bool, default=False
|
| 59 |
+
Whether to use lower bounds control variates for variance reduction.
|
| 60 |
+
Mutually exclusive with use_ucv.
|
| 61 |
+
|
| 62 |
+
refresh_projections_every_n_steps : int, default=1
|
| 63 |
+
How often to refresh the projection directions. A value of 1 means
|
| 64 |
+
refresh every step, higher values reuse directions for multiple steps.
|
| 65 |
+
|
| 66 |
+
num_new_candidates : int, default=16
|
| 67 |
+
Number of new candidate directions to generate per step (M).
|
| 68 |
+
If 0, reservoir sampling is disabled. Must not exceed num_proj.
|
| 69 |
+
|
| 70 |
+
ess_alpha : float, default=0.5
|
| 71 |
+
Effective sample size threshold for resetting the reservoir.
|
| 72 |
+
When ESS drops below ess_alpha * reservoir_size, the reservoir is reset.
|
| 73 |
+
|
| 74 |
+
time_decay_tau : float or None, default=30.0
|
| 75 |
+
Time decay parameter for reservoir weights. If None, no time decay is applied.
|
| 76 |
+
Weights decay exponentially with age: exp(-age / time_decay_tau).
|
| 77 |
+
|
| 78 |
+
missing_value_method : {"random_replicate", "interpolate"},
|
| 79 |
+
default="random_replicate"
|
| 80 |
+
Method for handling sequences of different lengths:
|
| 81 |
+
- "random_replicate": Randomly replicate shorter sequences
|
| 82 |
+
- "interpolate": Use linear interpolation to match lengths
|
| 83 |
+
|
| 84 |
+
sampling_mode : {"gaussian", "qmc"}, default="qmc"
|
| 85 |
+
Method for generating random projection directions:
|
| 86 |
+
- "gaussian": Standard Gaussian sampling
|
| 87 |
+
- "qmc": Quasi-Monte Carlo sampling using Sobol sequences
|
| 88 |
+
|
| 89 |
+
Notes
|
| 90 |
+
-----
|
| 91 |
+
- Reservoir sampling is enabled when num_new_candidates > 0
|
| 92 |
+
- Reservoir size = num_proj - num_new_candidates
|
| 93 |
+
- When use_ucv or use_lcv is True, variance reduction is applied using
|
| 94 |
+
control variates based on the difference between sample and population means
|
| 95 |
+
- The loss automatically handles sequences of different lengths using the
|
| 96 |
+
specified missing_value_method
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
num_proj: int = 64,
|
| 102 |
+
distance: typing.Literal["l1", "l2"] = "l1",
|
| 103 |
+
use_ucv: bool = False,
|
| 104 |
+
use_lcv: bool = False,
|
| 105 |
+
refresh_projections_every_n_steps: int = 1,
|
| 106 |
+
num_new_candidates: int = 16,
|
| 107 |
+
ess_alpha: float = 0.5,
|
| 108 |
+
time_decay_tau: float | None = 30.0,
|
| 109 |
+
missing_value_method: typing.Literal[
|
| 110 |
+
"random_replicate", "interpolate"
|
| 111 |
+
] = "random_replicate",
|
| 112 |
+
sampling_mode: typing.Literal[
|
| 113 |
+
"gaussian",
|
| 114 |
+
"qmc",
|
| 115 |
+
] = "qmc",
|
| 116 |
+
):
|
| 117 |
+
super().__init__()
|
| 118 |
+
|
| 119 |
+
assert not (use_ucv and use_lcv), "use_ucv and use_lcv cannot both be True"
|
| 120 |
+
|
| 121 |
+
self.num_proj = num_proj
|
| 122 |
+
self.distance = distance
|
| 123 |
+
self.use_ucv = use_ucv
|
| 124 |
+
self.use_lcv = use_lcv
|
| 125 |
+
|
| 126 |
+
self.refresh_projections_every_n_steps = refresh_projections_every_n_steps
|
| 127 |
+
self.num_new_candidates = num_new_candidates # M
|
| 128 |
+
self.ess_alpha = ess_alpha
|
| 129 |
+
self.time_decay_tau = time_decay_tau
|
| 130 |
+
self.missing_value_method = missing_value_method
|
| 131 |
+
|
| 132 |
+
if num_new_candidates > 0 and self.refresh_projections_every_n_steps != 1:
|
| 133 |
+
# Print a warning that this is not recommended
|
| 134 |
+
print(
|
| 135 |
+
"WARNING: num_new_candidates > 0 (enabling reservoir sampling) and "
|
| 136 |
+
"refresh_projections_every_n_steps != 1 is not recommended"
|
| 137 |
+
)
|
| 138 |
+
assert (
|
| 139 |
+
num_new_candidates <= num_proj
|
| 140 |
+
), "`num_new_candidates` must not exceed `num_proj`"
|
| 141 |
+
|
| 142 |
+
# internal state for reservoir sampling
|
| 143 |
+
self.restir_enabled = self.num_new_candidates > 0
|
| 144 |
+
self.reservoir_size = self.num_proj - self.num_new_candidates
|
| 145 |
+
self.register_buffer("_reservoir_filters", torch.empty(0))
|
| 146 |
+
self.register_buffer("_reservoir_weights", torch.empty(0))
|
| 147 |
+
self.register_buffer("_reservoir_steps", torch.empty(0, dtype=torch.long))
|
| 148 |
+
self.register_buffer("_reservoir_keys", torch.empty(0))
|
| 149 |
+
self.register_buffer("_cumulative_weights", torch.tensor(0.0))
|
| 150 |
+
self.register_buffer("_has_reservoir", torch.tensor(False, dtype=torch.bool))
|
| 151 |
+
|
| 152 |
+
self._cached_dirs: typing.Optional[torch.Tensor] = None
|
| 153 |
+
self.sampling_mode = sampling_mode
|
| 154 |
+
self.sobol_engine = None
|
| 155 |
+
|
| 156 |
+
def _gaussian_proposals(self, k: int, d: int, device: torch.device) -> torch.Tensor:
|
| 157 |
+
"""Generate Gaussian random projection directions."""
|
| 158 |
+
w = torch.randn(k, d, device=device)
|
| 159 |
+
return w / (w.norm(dim=1, keepdim=True) + 1e-8) # unit length
|
| 160 |
+
|
| 161 |
+
def _qmc_proposals(self, k: int, d: int, device: torch.device) -> torch.Tensor:
|
| 162 |
+
"""Generate quasi-Monte Carlo projection directions using Sobol sequences."""
|
| 163 |
+
vecs, self.sobol_engine = sobol_sphere(k, d, device, self.sobol_engine)
|
| 164 |
+
return vecs.view(k, d)
|
| 165 |
+
|
| 166 |
+
def _draw_dirs(self, k: int, d: int, device: torch.device) -> torch.Tensor:
|
| 167 |
+
"""Draw projection directions using the specified sampling mode."""
|
| 168 |
+
if self.sampling_mode == "gaussian":
|
| 169 |
+
return self._gaussian_proposals(k, d, device)
|
| 170 |
+
if self.sampling_mode == "qmc":
|
| 171 |
+
return self._qmc_proposals(k, d, device)
|
| 172 |
+
raise ValueError("bad sampling_mode")
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
def _duplicate_to_match(a: torch.Tensor, b: torch.Tensor, method: str):
|
| 176 |
+
"""
|
| 177 |
+
Make two tensors have the same length by duplicating the shorter one.
|
| 178 |
+
|
| 179 |
+
Args
|
| 180 |
+
----
|
| 181 |
+
a, b : (K, Nβ) and (K, Nβ) tensors
|
| 182 |
+
method : "random_replicate" or "interpolate"
|
| 183 |
+
|
| 184 |
+
Returns
|
| 185 |
+
-------
|
| 186 |
+
a, b : Tensors with matching second dimension
|
| 187 |
+
"""
|
| 188 |
+
if a.shape[1] == b.shape[1]:
|
| 189 |
+
return a, b
|
| 190 |
+
if a.shape[1] < b.shape[1]:
|
| 191 |
+
a, b = b, a # swap so that `a` is the larger
|
| 192 |
+
|
| 193 |
+
K, NA = a.shape
|
| 194 |
+
NB = b.shape[1]
|
| 195 |
+
|
| 196 |
+
# repeat / interpolate B until it matches A
|
| 197 |
+
if method == "random_replicate":
|
| 198 |
+
repeats = NA // NB
|
| 199 |
+
b = torch.cat([b] * repeats, dim=1)
|
| 200 |
+
if b.shape[1] < NA:
|
| 201 |
+
idx = torch.randint(0, NB, (NA - b.shape[1],), device=b.device)
|
| 202 |
+
b = torch.cat([b, b[:, idx]], dim=1)
|
| 203 |
+
else: # interpolate
|
| 204 |
+
b = F.interpolate(
|
| 205 |
+
b.unsqueeze(0), size=(NA,), mode="linear", align_corners=False
|
| 206 |
+
).squeeze(0)
|
| 207 |
+
return a, b
|
| 208 |
+
|
| 209 |
+
def reset(self):
|
| 210 |
+
"""Reset the reservoir sampling state."""
|
| 211 |
+
if self.restir_enabled:
|
| 212 |
+
self._reservoir_filters = torch.empty(0)
|
| 213 |
+
self._reservoir_weights = torch.empty(0)
|
| 214 |
+
self._cumulative_weights.data.fill_(0)
|
| 215 |
+
self._has_reservoir.fill_(False)
|
| 216 |
+
self._reservoir_steps = torch.empty(0, dtype=torch.long)
|
| 217 |
+
self._reservoir_keys = torch.empty(0)
|
| 218 |
+
|
| 219 |
+
def _wrs_multi(
|
| 220 |
+
self, filters: torch.Tensor, weights: torch.Tensor, step: int
|
| 221 |
+
) -> torch.Tensor:
|
| 222 |
+
"""
|
| 223 |
+
Weighted reservoir sampling that keeps exactly self.reservoir_size samples and
|
| 224 |
+
returns their indices inside the concatenated candidate set.
|
| 225 |
+
|
| 226 |
+
Args
|
| 227 |
+
----
|
| 228 |
+
filters : (K+M, D) tensor of candidate directions
|
| 229 |
+
weights : (K+M,) tensor of importance weights
|
| 230 |
+
step : Current training step
|
| 231 |
+
|
| 232 |
+
Returns
|
| 233 |
+
-------
|
| 234 |
+
keep_idx : Indices of kept samples
|
| 235 |
+
keep_w : Normalized weights of kept samples
|
| 236 |
+
"""
|
| 237 |
+
R = self.reservoir_size
|
| 238 |
+
device = weights.device
|
| 239 |
+
|
| 240 |
+
u = torch.rand_like(weights)
|
| 241 |
+
keys = u.pow(1.0 / weights.clamp_min(1e-9))
|
| 242 |
+
|
| 243 |
+
if not self._has_reservoir.item():
|
| 244 |
+
self._reservoir_filters = filters[:R]
|
| 245 |
+
self._reservoir_weights = weights[:R]
|
| 246 |
+
self._reservoir_keys = keys[:R]
|
| 247 |
+
self._reservoir_steps = torch.full(
|
| 248 |
+
(R,), step, dtype=torch.long, device=device
|
| 249 |
+
)
|
| 250 |
+
self._has_reservoir.fill_(True)
|
| 251 |
+
|
| 252 |
+
new_filters = filters[R:]
|
| 253 |
+
new_keys = keys[R:]
|
| 254 |
+
new_weights = weights[R:]
|
| 255 |
+
new_steps = torch.full(
|
| 256 |
+
(new_filters.size(0),), step, dtype=torch.long, device=device
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
all_filters = torch.cat([self._reservoir_filters, new_filters], 0)
|
| 260 |
+
all_keys = torch.cat([self._reservoir_keys, new_keys], 0)
|
| 261 |
+
all_weights = torch.cat([self._reservoir_weights, new_weights], 0)
|
| 262 |
+
all_steps = torch.cat([self._reservoir_steps, new_steps], 0)
|
| 263 |
+
|
| 264 |
+
topk_keys, topk_idx = torch.topk(all_keys, R, largest=True)
|
| 265 |
+
|
| 266 |
+
self._reservoir_filters = all_filters[topk_idx]
|
| 267 |
+
self._reservoir_weights = all_weights[topk_idx]
|
| 268 |
+
self._reservoir_keys = topk_keys
|
| 269 |
+
self._reservoir_steps = all_steps[topk_idx]
|
| 270 |
+
|
| 271 |
+
# indices w.r.t. current cand_dirs (old R first, then new M)
|
| 272 |
+
keep_idx = torch.cat(
|
| 273 |
+
[
|
| 274 |
+
torch.arange(R, device=device),
|
| 275 |
+
torch.arange(R, R + new_filters.size(0), device=device),
|
| 276 |
+
]
|
| 277 |
+
)[topk_idx]
|
| 278 |
+
keep_w = self._reservoir_weights / self._reservoir_weights.sum().clamp_min(
|
| 279 |
+
1e-12
|
| 280 |
+
)
|
| 281 |
+
return keep_idx, keep_w
|
| 282 |
+
|
| 283 |
+
def _apply_time_decay(self, step: int):
|
| 284 |
+
"""
|
| 285 |
+
Apply exponential time decay to stored reservoir weights.
|
| 286 |
+
|
| 287 |
+
Args
|
| 288 |
+
----
|
| 289 |
+
step : Current training step
|
| 290 |
+
"""
|
| 291 |
+
if self.time_decay_tau is None or not self._has_reservoir.item():
|
| 292 |
+
return
|
| 293 |
+
age = (step - self._reservoir_steps).to(torch.float32)
|
| 294 |
+
decay = torch.exp(-age / self.time_decay_tau).to(self._reservoir_weights.dtype)
|
| 295 |
+
self._reservoir_weights.mul_(decay)
|
| 296 |
+
self._reservoir_keys.mul_(decay) # preserve ordering consistency
|
| 297 |
+
|
| 298 |
+
def forward(
|
| 299 |
+
self,
|
| 300 |
+
pred: Float[torch.Tensor, "B D N"],
|
| 301 |
+
gt: Float[torch.Tensor, "B D N"],
|
| 302 |
+
step: int,
|
| 303 |
+
):
|
| 304 |
+
"""
|
| 305 |
+
Compute the sliced Wasserstein distance between predicted and ground truth
|
| 306 |
+
sequences.
|
| 307 |
+
|
| 308 |
+
Args
|
| 309 |
+
----
|
| 310 |
+
pred : (B, D, N) tensor of predicted sequences
|
| 311 |
+
gt : (B, D, N) tensor of ground truth sequences
|
| 312 |
+
step : Current training step for reservoir sampling
|
| 313 |
+
|
| 314 |
+
Returns
|
| 315 |
+
-------
|
| 316 |
+
loss : Scalar tensor containing the computed loss
|
| 317 |
+
"""
|
| 318 |
+
B, D, N = pred.shape
|
| 319 |
+
K = self.num_proj
|
| 320 |
+
M = self.num_new_candidates
|
| 321 |
+
R = self.reservoir_size
|
| 322 |
+
device = pred.device
|
| 323 |
+
gt = gt.detach()
|
| 324 |
+
|
| 325 |
+
self._apply_time_decay(step)
|
| 326 |
+
|
| 327 |
+
# Get candidate directions
|
| 328 |
+
if step % self.refresh_projections_every_n_steps == 0:
|
| 329 |
+
new_dirs = self._draw_dirs(
|
| 330 |
+
M if self.restir_enabled and self._has_reservoir.item() else K,
|
| 331 |
+
D,
|
| 332 |
+
device,
|
| 333 |
+
)
|
| 334 |
+
self._cached_dirs = new_dirs
|
| 335 |
+
else:
|
| 336 |
+
new_dirs = self._cached_dirs
|
| 337 |
+
|
| 338 |
+
if self.restir_enabled and self._has_reservoir.item():
|
| 339 |
+
cand_dirs = torch.cat(
|
| 340 |
+
[self._reservoir_filters, new_dirs], dim=0
|
| 341 |
+
) # [K+M, C,P,P]
|
| 342 |
+
else:
|
| 343 |
+
cand_dirs = new_dirs
|
| 344 |
+
|
| 345 |
+
# Project sequences
|
| 346 |
+
cand_pred = process_vector(pred, cand_dirs)
|
| 347 |
+
cand_gt = process_vector(gt, cand_dirs)
|
| 348 |
+
|
| 349 |
+
cand_pred, cand_gt = self._duplicate_to_match(
|
| 350 |
+
cand_pred, cand_gt, self.missing_value_method
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
cand_pred = cand_pred.sort(dim=1).values
|
| 354 |
+
cand_gt = cand_gt.sort(dim=1).values
|
| 355 |
+
|
| 356 |
+
# Select K directions (reservoir) & importance weights
|
| 357 |
+
if self.restir_enabled:
|
| 358 |
+
with torch.no_grad():
|
| 359 |
+
base = cand_pred - cand_gt
|
| 360 |
+
base = base.abs() if self.distance == "l1" else base.square()
|
| 361 |
+
ris_weights = base.mean(1) # (K+M)
|
| 362 |
+
keep_idx, keep_w = self._wrs_multi(cand_dirs, ris_weights, step)
|
| 363 |
+
|
| 364 |
+
w = keep_w
|
| 365 |
+
w_hat = keep_w
|
| 366 |
+
|
| 367 |
+
dirs = cand_dirs[keep_idx]
|
| 368 |
+
proj_pred = cand_pred[keep_idx]
|
| 369 |
+
proj_gt = cand_gt[keep_idx]
|
| 370 |
+
else:
|
| 371 |
+
dirs = cand_dirs
|
| 372 |
+
proj_pred = cand_pred
|
| 373 |
+
proj_gt = cand_gt
|
| 374 |
+
w = torch.full((dirs.shape[0],), 1.0 / K, device=device)
|
| 375 |
+
|
| 376 |
+
# Compute SWD
|
| 377 |
+
diff = proj_pred - proj_gt
|
| 378 |
+
diff = diff.abs() if self.distance == "l1" else diff.square()
|
| 379 |
+
per_slice = diff.mean(1) # (L,)
|
| 380 |
+
|
| 381 |
+
if self.use_ucv or self.use_lcv:
|
| 382 |
+
X_vecs = pred.permute(0, 2, 1).reshape(-1, D) # (BΒ·N, D)
|
| 383 |
+
Y_vecs = gt.permute(0, 2, 1).reshape(-1, D) # (BΒ·N, D)
|
| 384 |
+
|
| 385 |
+
m1 = X_vecs.mean(0) # (D,)
|
| 386 |
+
m2 = Y_vecs.mean(0)
|
| 387 |
+
diff_m = m1 - m2 # (D,)
|
| 388 |
+
|
| 389 |
+
theta = dirs # (L, D) already unit-norm
|
| 390 |
+
|
| 391 |
+
if self.use_ucv:
|
| 392 |
+
diff_X = X_vecs - m1
|
| 393 |
+
diff_Y = Y_vecs - m2
|
| 394 |
+
|
| 395 |
+
d = D
|
| 396 |
+
trSigX = diff_X.pow(2).mean()
|
| 397 |
+
trSigY = diff_Y.pow(2).mean()
|
| 398 |
+
G_bar = (diff_m @ diff_m) / d + (trSigX + trSigY)
|
| 399 |
+
|
| 400 |
+
delta2 = (theta @ diff_m) ** 2 # (L,)
|
| 401 |
+
|
| 402 |
+
proj_X = diff_X @ theta.t() # (BΒ·N, L)
|
| 403 |
+
proj_Y = diff_Y @ theta.t()
|
| 404 |
+
varX = proj_X.pow(2).mean(0) # (L,)
|
| 405 |
+
varY = proj_Y.pow(2).mean(0)
|
| 406 |
+
G_hat = delta2 + varX + varY
|
| 407 |
+
else: # LCV
|
| 408 |
+
d = D
|
| 409 |
+
G_bar = (diff_m @ diff_m) / d
|
| 410 |
+
G_hat = (theta @ diff_m) ** 2
|
| 411 |
+
|
| 412 |
+
diff_hat_G_mean_G = G_hat - G_bar
|
| 413 |
+
|
| 414 |
+
hat_A = (w * per_slice).sum()
|
| 415 |
+
var_G = (w * diff_hat_G_mean_G.pow(2)).sum()
|
| 416 |
+
cov_AG = (w * (per_slice - hat_A) * diff_hat_G_mean_G).sum()
|
| 417 |
+
hat_alpha = cov_AG / (var_G + 1e-12)
|
| 418 |
+
loss = hat_A - hat_alpha * (w * diff_hat_G_mean_G).sum()
|
| 419 |
+
else:
|
| 420 |
+
loss = (w * per_slice).sum()
|
| 421 |
+
|
| 422 |
+
# Reservoir update
|
| 423 |
+
if self.restir_enabled and self.ess_alpha > 0:
|
| 424 |
+
with torch.no_grad():
|
| 425 |
+
ess = (w_hat.sum().square()) / (w_hat.square().sum() + 1e-12)
|
| 426 |
+
ess = torch.nan_to_num(ess, nan=0.0, posinf=R, neginf=0.0).item()
|
| 427 |
+
if ess < self.ess_alpha * R:
|
| 428 |
+
print(f"ESS: {ess} is less than {self.ess_alpha * R}, resetting")
|
| 429 |
+
self.reset()
|
| 430 |
+
|
| 431 |
+
return loss
|
src/sw_sdthree_guidance.py
ADDED
|
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Implementation of the SW Guidance method with our enhanced SWD implementation
|
| 2 |
+
# See: https://github.com/alobashev/sw-guidance/ for the original implementation
|
| 3 |
+
#
|
| 4 |
+
# Alexander Lobashev, Maria Larchenko, Dmitry Guskov
|
| 5 |
+
# Color Conditional Generation with Sliced Wasserstein Guidance
|
| 6 |
+
# https://arxiv.org/abs/2503.19034
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
import gc
|
| 10 |
+
import os
|
| 11 |
+
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import PIL
|
| 15 |
+
import torch
|
| 16 |
+
from diffusers import (
|
| 17 |
+
FlowMatchEulerDiscreteScheduler,
|
| 18 |
+
StableDiffusion3Pipeline,
|
| 19 |
+
)
|
| 20 |
+
from diffusers.image_processor import PipelineImageInput
|
| 21 |
+
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import (
|
| 22 |
+
XLA_AVAILABLE,
|
| 23 |
+
StableDiffusion3PipelineOutput,
|
| 24 |
+
calculate_shift,
|
| 25 |
+
retrieve_timesteps,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
from src.loss.vector_swd import VectorSWDLoss
|
| 29 |
+
from src.utils.color_space import rgb_to_lab
|
| 30 |
+
from src.utils.image import from_torch, write_img
|
| 31 |
+
|
| 32 |
+
if XLA_AVAILABLE:
|
| 33 |
+
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import xm
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _no_grad_noise(model, *args, **kw):
|
| 37 |
+
"""Forward pass with grad disabled; result is returned detached."""
|
| 38 |
+
with torch.no_grad():
|
| 39 |
+
return model(*args, **kw, return_dict=False)[0].detach()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ---------------- explicit pipeline forward call
|
| 43 |
+
class SWStableDiffusion3Pipeline(StableDiffusion3Pipeline):
|
| 44 |
+
swd: VectorSWDLoss = None
|
| 45 |
+
|
| 46 |
+
def setup_swd(
|
| 47 |
+
self,
|
| 48 |
+
num_projections: int = 64,
|
| 49 |
+
use_ucv: bool = False,
|
| 50 |
+
use_lcv: bool = False,
|
| 51 |
+
distance: Literal["l1", "l2"] = "l1",
|
| 52 |
+
num_new_candidates: int = 32,
|
| 53 |
+
subsampling_factor: int = 1,
|
| 54 |
+
sampling_mode: Literal["gaussian", "qmc"] = "qmc",
|
| 55 |
+
):
|
| 56 |
+
self.swd = VectorSWDLoss(
|
| 57 |
+
num_proj=num_projections,
|
| 58 |
+
distance=distance,
|
| 59 |
+
use_ucv=use_ucv,
|
| 60 |
+
use_lcv=use_lcv,
|
| 61 |
+
num_new_candidates=num_new_candidates,
|
| 62 |
+
missing_value_method="interpolate",
|
| 63 |
+
ess_alpha=-1,
|
| 64 |
+
sampling_mode=sampling_mode,
|
| 65 |
+
).to(self.device)
|
| 66 |
+
self.subsampling_factor = subsampling_factor
|
| 67 |
+
|
| 68 |
+
def do_sw_guidance(
|
| 69 |
+
self,
|
| 70 |
+
sw_steps,
|
| 71 |
+
sw_u_lr,
|
| 72 |
+
latents,
|
| 73 |
+
t,
|
| 74 |
+
prompt_embeds,
|
| 75 |
+
pooled_prompt_embeds,
|
| 76 |
+
pixels_ref,
|
| 77 |
+
cur_iter_step,
|
| 78 |
+
write_video_animation_path: Optional[str] = None,
|
| 79 |
+
):
|
| 80 |
+
if sw_steps == 0:
|
| 81 |
+
return latents
|
| 82 |
+
|
| 83 |
+
if latents.shape[0] != prompt_embeds.shape[0]:
|
| 84 |
+
prompt_embeds = prompt_embeds[1].unsqueeze(0)
|
| 85 |
+
pooled_prompt_embeds = pooled_prompt_embeds[1].unsqueeze(0)
|
| 86 |
+
|
| 87 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 88 |
+
timestep = t.expand(latents.shape[0])
|
| 89 |
+
|
| 90 |
+
pixels_ref = (
|
| 91 |
+
rgb_to_lab(pixels_ref.unsqueeze(0).clamp(0, 1).permute(0, 3, 1, 2))
|
| 92 |
+
.permute(0, 2, 3, 1)
|
| 93 |
+
.contiguous()
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
csc_scaler = torch.tensor(
|
| 97 |
+
[100, 2 * 128, 2 * 128], dtype=torch.bfloat16, device=latents.device
|
| 98 |
+
).view(1, 3, 1)
|
| 99 |
+
csc_bias = torch.tensor(
|
| 100 |
+
[0, 0.5, 0.5], dtype=torch.bfloat16, device=latents.device
|
| 101 |
+
).view(1, 3, 1)
|
| 102 |
+
|
| 103 |
+
u = torch.nn.Parameter(
|
| 104 |
+
torch.zeros_like(latents, dtype=torch.bfloat16, device=latents.device)
|
| 105 |
+
)
|
| 106 |
+
optimizer = torch.optim.Adam([u], lr=sw_u_lr)
|
| 107 |
+
|
| 108 |
+
for tt in range(sw_steps):
|
| 109 |
+
optimizer.zero_grad()
|
| 110 |
+
|
| 111 |
+
x_hat_t = latents.detach() + u
|
| 112 |
+
noise_pred = _no_grad_noise(
|
| 113 |
+
self.transformer,
|
| 114 |
+
hidden_states=x_hat_t,
|
| 115 |
+
timestep=timestep,
|
| 116 |
+
encoder_hidden_states=prompt_embeds,
|
| 117 |
+
pooled_projections=pooled_prompt_embeds,
|
| 118 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# ------------ Compute x_0
|
| 122 |
+
sigma_t = self.scheduler.sigmas[
|
| 123 |
+
self.scheduler.index_for_timestep(t)
|
| 124 |
+
] # scalar
|
| 125 |
+
while sigma_t.ndim < x_hat_t.ndim:
|
| 126 |
+
sigma_t = sigma_t.unsqueeze(-1)
|
| 127 |
+
sigma_t = sigma_t.to(x_hat_t.dtype).to(latents.device)
|
| 128 |
+
|
| 129 |
+
x_0 = x_hat_t - sigma_t * noise_pred
|
| 130 |
+
|
| 131 |
+
# ------------ Compute loss
|
| 132 |
+
img_unscaled = self.vae.decode(
|
| 133 |
+
(x_0 / self.vae.config.scaling_factor) + self.vae.config.shift_factor,
|
| 134 |
+
return_dict=False,
|
| 135 |
+
)[0]
|
| 136 |
+
image = (img_unscaled * 0.5 + 0.5).clamp(0, 1)
|
| 137 |
+
image_matched = (
|
| 138 |
+
rgb_to_lab(image.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous()
|
| 139 |
+
)
|
| 140 |
+
# reshape to (B, D, N) where D=3, N = H*W
|
| 141 |
+
pred_seq = image_matched.view(1, 3, -1) / csc_scaler + csc_bias
|
| 142 |
+
ref_seq = pixels_ref.view(1, 3, -1) / csc_scaler + csc_bias
|
| 143 |
+
|
| 144 |
+
# Apply subsampling if enabled
|
| 145 |
+
if self.subsampling_factor > 1:
|
| 146 |
+
pred_seq = pred_seq[..., :: self.subsampling_factor]
|
| 147 |
+
ref_seq = ref_seq[..., :: self.subsampling_factor]
|
| 148 |
+
|
| 149 |
+
loss = self.swd(pred=pred_seq, gt=ref_seq, step=tt)
|
| 150 |
+
|
| 151 |
+
loss.backward()
|
| 152 |
+
optimizer.step()
|
| 153 |
+
|
| 154 |
+
if write_video_animation_path is not None:
|
| 155 |
+
frame_idx = cur_iter_step * sw_steps + tt
|
| 156 |
+
write_img(
|
| 157 |
+
os.path.join(write_video_animation_path, f"{frame_idx:05d}.jpg"),
|
| 158 |
+
from_torch(img_unscaled.squeeze(0)),
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
latents = latents.detach() + u.detach()
|
| 162 |
+
|
| 163 |
+
gc.collect()
|
| 164 |
+
torch.cuda.empty_cache()
|
| 165 |
+
return latents
|
| 166 |
+
|
| 167 |
+
def __call__(
|
| 168 |
+
self,
|
| 169 |
+
sw_reference: PIL.Image = None,
|
| 170 |
+
sw_steps: int = 8,
|
| 171 |
+
sw_u_lr: float = 0.05 * 10**3,
|
| 172 |
+
num_guided_steps: int = None,
|
| 173 |
+
# -----------------------------------
|
| 174 |
+
prompt: Union[str, List[str]] = None,
|
| 175 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 176 |
+
prompt_3: Optional[Union[str, List[str]]] = None,
|
| 177 |
+
height: Optional[int] = None,
|
| 178 |
+
width: Optional[int] = None,
|
| 179 |
+
num_inference_steps: int = 28,
|
| 180 |
+
sigmas: Optional[List[float]] = None,
|
| 181 |
+
guidance_scale: float = 7.0,
|
| 182 |
+
cfg_rescale_phi: float = 0.7,
|
| 183 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 184 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 185 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
| 186 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 187 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 188 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 189 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 190 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 191 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 192 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 193 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 194 |
+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
|
| 195 |
+
output_type: Optional[str] = "pil",
|
| 196 |
+
return_dict: bool = True,
|
| 197 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 198 |
+
clip_skip: Optional[int] = None,
|
| 199 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 200 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 201 |
+
max_sequence_length: int = 256,
|
| 202 |
+
skip_guidance_layers: List[int] = None,
|
| 203 |
+
skip_layer_guidance_scale: float = 2.8,
|
| 204 |
+
skip_layer_guidance_stop: float = 0.2,
|
| 205 |
+
skip_layer_guidance_start: float = 0.01,
|
| 206 |
+
mu: Optional[float] = None,
|
| 207 |
+
write_video_animation_path: Optional[str] = None,
|
| 208 |
+
):
|
| 209 |
+
assert self.swd is not None, "SWD not initialized"
|
| 210 |
+
|
| 211 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 212 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 213 |
+
|
| 214 |
+
# 1. Check inputs. Raise error if not correct
|
| 215 |
+
self.check_inputs(
|
| 216 |
+
prompt,
|
| 217 |
+
prompt_2,
|
| 218 |
+
prompt_3,
|
| 219 |
+
height,
|
| 220 |
+
width,
|
| 221 |
+
negative_prompt=negative_prompt,
|
| 222 |
+
negative_prompt_2=negative_prompt_2,
|
| 223 |
+
negative_prompt_3=negative_prompt_3,
|
| 224 |
+
prompt_embeds=prompt_embeds,
|
| 225 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 226 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 227 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 228 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 229 |
+
max_sequence_length=max_sequence_length,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
self._guidance_scale = guidance_scale
|
| 233 |
+
self._skip_layer_guidance_scale = skip_layer_guidance_scale
|
| 234 |
+
self._clip_skip = clip_skip
|
| 235 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 236 |
+
self._interrupt = False
|
| 237 |
+
|
| 238 |
+
# 2. Define call parameters
|
| 239 |
+
if prompt is not None and isinstance(prompt, str):
|
| 240 |
+
batch_size = 1
|
| 241 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 242 |
+
batch_size = len(prompt)
|
| 243 |
+
else:
|
| 244 |
+
batch_size = prompt_embeds.shape[0]
|
| 245 |
+
|
| 246 |
+
device = self._execution_device
|
| 247 |
+
|
| 248 |
+
lora_scale = (
|
| 249 |
+
self.joint_attention_kwargs.get("scale", None)
|
| 250 |
+
if self.joint_attention_kwargs is not None
|
| 251 |
+
else None
|
| 252 |
+
)
|
| 253 |
+
(
|
| 254 |
+
prompt_embeds,
|
| 255 |
+
negative_prompt_embeds,
|
| 256 |
+
pooled_prompt_embeds,
|
| 257 |
+
negative_pooled_prompt_embeds,
|
| 258 |
+
) = self.encode_prompt(
|
| 259 |
+
prompt=prompt,
|
| 260 |
+
prompt_2=prompt_2,
|
| 261 |
+
prompt_3=prompt_3,
|
| 262 |
+
negative_prompt=negative_prompt,
|
| 263 |
+
negative_prompt_2=negative_prompt_2,
|
| 264 |
+
negative_prompt_3=negative_prompt_3,
|
| 265 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 266 |
+
prompt_embeds=prompt_embeds,
|
| 267 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 268 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 269 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 270 |
+
device=device,
|
| 271 |
+
clip_skip=self.clip_skip,
|
| 272 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 273 |
+
max_sequence_length=max_sequence_length,
|
| 274 |
+
lora_scale=lora_scale,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if self.do_classifier_free_guidance:
|
| 278 |
+
if skip_guidance_layers is not None:
|
| 279 |
+
original_prompt_embeds = prompt_embeds
|
| 280 |
+
original_pooled_prompt_embeds = pooled_prompt_embeds
|
| 281 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 282 |
+
pooled_prompt_embeds = torch.cat(
|
| 283 |
+
[negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# 4. Prepare latent variables
|
| 287 |
+
num_channels_latents = self.transformer.config.in_channels
|
| 288 |
+
latents = self.prepare_latents(
|
| 289 |
+
batch_size * num_images_per_prompt,
|
| 290 |
+
num_channels_latents,
|
| 291 |
+
height,
|
| 292 |
+
width,
|
| 293 |
+
prompt_embeds.dtype,
|
| 294 |
+
device,
|
| 295 |
+
generator,
|
| 296 |
+
latents,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# 5. Prepare timesteps
|
| 300 |
+
scheduler_kwargs = {}
|
| 301 |
+
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
|
| 302 |
+
_, _, height, width = latents.shape
|
| 303 |
+
image_seq_len = (height // self.transformer.config.patch_size) * (
|
| 304 |
+
width // self.transformer.config.patch_size
|
| 305 |
+
)
|
| 306 |
+
mu = calculate_shift(
|
| 307 |
+
image_seq_len,
|
| 308 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 309 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 310 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 311 |
+
self.scheduler.config.get("max_shift", 1.16),
|
| 312 |
+
)
|
| 313 |
+
scheduler_kwargs["mu"] = mu
|
| 314 |
+
elif mu is not None:
|
| 315 |
+
scheduler_kwargs["mu"] = mu
|
| 316 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 317 |
+
self.scheduler,
|
| 318 |
+
num_inference_steps,
|
| 319 |
+
device,
|
| 320 |
+
sigmas=sigmas,
|
| 321 |
+
**scheduler_kwargs,
|
| 322 |
+
)
|
| 323 |
+
num_warmup_steps = max(
|
| 324 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
| 325 |
+
)
|
| 326 |
+
self._num_timesteps = len(timesteps)
|
| 327 |
+
|
| 328 |
+
# 6. Prepare image embeddings
|
| 329 |
+
if (
|
| 330 |
+
ip_adapter_image is not None and self.is_ip_adapter_active
|
| 331 |
+
) or ip_adapter_image_embeds is not None:
|
| 332 |
+
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 333 |
+
ip_adapter_image,
|
| 334 |
+
ip_adapter_image_embeds,
|
| 335 |
+
device,
|
| 336 |
+
batch_size * num_images_per_prompt,
|
| 337 |
+
self.do_classifier_free_guidance,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
if self.joint_attention_kwargs is None:
|
| 341 |
+
self._joint_attention_kwargs = {
|
| 342 |
+
"ip_adapter_image_embeds": ip_adapter_image_embeds
|
| 343 |
+
}
|
| 344 |
+
else:
|
| 345 |
+
self._joint_attention_kwargs.update(
|
| 346 |
+
ip_adapter_image_embeds=ip_adapter_image_embeds
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
if sw_reference is not None:
|
| 350 |
+
# Resize so the reference is maximal width or height of the output image
|
| 351 |
+
|
| 352 |
+
target_max_size = max(height, width)
|
| 353 |
+
reference_max_size = max(sw_reference.width, sw_reference.height)
|
| 354 |
+
scale_factor = target_max_size / reference_max_size
|
| 355 |
+
|
| 356 |
+
sw_reference = sw_reference.resize(
|
| 357 |
+
(
|
| 358 |
+
int(sw_reference.width * scale_factor),
|
| 359 |
+
int(sw_reference.height * scale_factor),
|
| 360 |
+
)
|
| 361 |
+
)
|
| 362 |
+
pixels_ref = (
|
| 363 |
+
torch.Tensor(np.array(sw_reference).astype(np.float32) / 255)
|
| 364 |
+
.permute(2, 0, 1)
|
| 365 |
+
.to(device)
|
| 366 |
+
.to(torch.bfloat16)
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# 7. Denoising loop
|
| 370 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 371 |
+
for i, t in enumerate(timesteps):
|
| 372 |
+
if self.interrupt:
|
| 373 |
+
continue
|
| 374 |
+
|
| 375 |
+
# broadcast to batch dimension in a way that's compatible
|
| 376 |
+
# with ONNX/Core ML
|
| 377 |
+
timestep = t.expand(latents.shape[0])
|
| 378 |
+
|
| 379 |
+
# SW Guidance
|
| 380 |
+
if sw_reference is not None:
|
| 381 |
+
if num_guided_steps is None or i < num_guided_steps:
|
| 382 |
+
latents = self.do_sw_guidance(
|
| 383 |
+
sw_steps,
|
| 384 |
+
sw_u_lr,
|
| 385 |
+
latents,
|
| 386 |
+
t,
|
| 387 |
+
prompt_embeds,
|
| 388 |
+
pooled_prompt_embeds,
|
| 389 |
+
pixels_ref,
|
| 390 |
+
cur_iter_step=i,
|
| 391 |
+
write_video_animation_path=write_video_animation_path,
|
| 392 |
+
)
|
| 393 |
+
if i == num_guided_steps // 2:
|
| 394 |
+
self.swd.reset()
|
| 395 |
+
|
| 396 |
+
# expand the latents if we are doing classifier free guidance
|
| 397 |
+
latent_model_input = (
|
| 398 |
+
torch.cat([latents] * 2)
|
| 399 |
+
if self.do_classifier_free_guidance
|
| 400 |
+
else latents
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
with torch.no_grad():
|
| 404 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 405 |
+
|
| 406 |
+
noise_pred = self.transformer(
|
| 407 |
+
hidden_states=latent_model_input,
|
| 408 |
+
timestep=timestep,
|
| 409 |
+
encoder_hidden_states=prompt_embeds,
|
| 410 |
+
pooled_projections=pooled_prompt_embeds,
|
| 411 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 412 |
+
return_dict=False,
|
| 413 |
+
)[0]
|
| 414 |
+
|
| 415 |
+
# perform guidance
|
| 416 |
+
if self.do_classifier_free_guidance:
|
| 417 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 418 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
| 419 |
+
noise_pred_text - noise_pred_uncond
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
should_skip_layers = (
|
| 423 |
+
True
|
| 424 |
+
if i > num_inference_steps * skip_layer_guidance_start
|
| 425 |
+
and i < num_inference_steps * skip_layer_guidance_stop
|
| 426 |
+
else False
|
| 427 |
+
)
|
| 428 |
+
if skip_guidance_layers is not None and should_skip_layers:
|
| 429 |
+
timestep = t.expand(latents.shape[0])
|
| 430 |
+
latent_model_input = latents
|
| 431 |
+
noise_pred_skip_layers = self.transformer(
|
| 432 |
+
hidden_states=latent_model_input,
|
| 433 |
+
timestep=timestep,
|
| 434 |
+
encoder_hidden_states=original_prompt_embeds,
|
| 435 |
+
pooled_projections=original_pooled_prompt_embeds,
|
| 436 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 437 |
+
return_dict=False,
|
| 438 |
+
skip_layers=skip_guidance_layers,
|
| 439 |
+
)[0]
|
| 440 |
+
noise_pred = (
|
| 441 |
+
noise_pred
|
| 442 |
+
+ (noise_pred_text - noise_pred_skip_layers)
|
| 443 |
+
* self._skip_layer_guidance_scale
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Based on Sec. 3.4 of Lin, Liu, Li, Yang -
|
| 447 |
+
# Common Diffusion Noise Schedules and Sample Steps are Flawed
|
| 448 |
+
# https://arxiv.org/abs/2305.08891
|
| 449 |
+
# While Flow matching is free of most issues, a high CFG scale
|
| 450 |
+
# can still cause over-exposure issues as discussed in the work.
|
| 451 |
+
if cfg_rescale_phi is not None and cfg_rescale_phi > 0:
|
| 452 |
+
# Ο_pos and Ο_cfg are per-sample (BΓ1Γ1Γ1) stdevs
|
| 453 |
+
sigma_pos = noise_pred_text.std(dim=(1, 2, 3), keepdim=True)
|
| 454 |
+
sigma_cfg = noise_pred.std(dim=(1, 2, 3), keepdim=True)
|
| 455 |
+
|
| 456 |
+
# Linear blend between the raw ratio and 1,
|
| 457 |
+
# cf. Eq. (15β16) in the paper
|
| 458 |
+
factor = torch.lerp(
|
| 459 |
+
sigma_pos / (sigma_cfg + 1e-8), # avoid div-by-zero
|
| 460 |
+
torch.ones_like(sigma_cfg),
|
| 461 |
+
1.0 - cfg_rescale_phi,
|
| 462 |
+
)
|
| 463 |
+
noise_pred = noise_pred * factor
|
| 464 |
+
else:
|
| 465 |
+
noise_pred = noise_pred
|
| 466 |
+
|
| 467 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 468 |
+
latents_dtype = latents.dtype
|
| 469 |
+
latents = self.scheduler.step(
|
| 470 |
+
noise_pred, t, latents, return_dict=False
|
| 471 |
+
)[0]
|
| 472 |
+
|
| 473 |
+
if latents.dtype != latents_dtype:
|
| 474 |
+
if torch.backends.mps.is_available():
|
| 475 |
+
# some platforms (eg. apple mps) misbehave due to a
|
| 476 |
+
# pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 477 |
+
latents = latents.to(latents_dtype)
|
| 478 |
+
|
| 479 |
+
if callback_on_step_end is not None:
|
| 480 |
+
callback_kwargs = {}
|
| 481 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 482 |
+
callback_kwargs[k] = locals()[k]
|
| 483 |
+
callback_outputs = callback_on_step_end(
|
| 484 |
+
self, i, t, callback_kwargs
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
latents = callback_outputs.pop("latents", latents)
|
| 488 |
+
prompt_embeds = callback_outputs.pop(
|
| 489 |
+
"prompt_embeds", prompt_embeds
|
| 490 |
+
)
|
| 491 |
+
negative_prompt_embeds = callback_outputs.pop(
|
| 492 |
+
"negative_prompt_embeds", negative_prompt_embeds
|
| 493 |
+
)
|
| 494 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
| 495 |
+
"negative_pooled_prompt_embeds",
|
| 496 |
+
negative_pooled_prompt_embeds,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
if write_video_animation_path is not None and i >= num_guided_steps:
|
| 500 |
+
with torch.no_grad():
|
| 501 |
+
image = self.vae.decode(
|
| 502 |
+
(latents / self.vae.config.scaling_factor)
|
| 503 |
+
+ self.vae.config.shift_factor,
|
| 504 |
+
return_dict=False,
|
| 505 |
+
)[0]
|
| 506 |
+
cur_frame_idx = i * sw_steps
|
| 507 |
+
write_img(
|
| 508 |
+
os.path.join(
|
| 509 |
+
write_video_animation_path,
|
| 510 |
+
f"{cur_frame_idx:05d}.jpg",
|
| 511 |
+
),
|
| 512 |
+
from_torch(image.squeeze(0)),
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
# call the callback, if provided
|
| 516 |
+
if i == len(timesteps) - 1 or (
|
| 517 |
+
(i + 1) > num_warmup_steps
|
| 518 |
+
and (i + 1) % self.scheduler.order == 0
|
| 519 |
+
):
|
| 520 |
+
progress_bar.update()
|
| 521 |
+
|
| 522 |
+
if XLA_AVAILABLE:
|
| 523 |
+
xm.mark_step()
|
| 524 |
+
|
| 525 |
+
if output_type == "latent":
|
| 526 |
+
image = latents
|
| 527 |
+
|
| 528 |
+
else:
|
| 529 |
+
latents = (
|
| 530 |
+
latents / self.vae.config.scaling_factor
|
| 531 |
+
) + self.vae.config.shift_factor
|
| 532 |
+
|
| 533 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 534 |
+
image = self.image_processor.postprocess(
|
| 535 |
+
image.detach(), output_type=output_type
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
# Offload all models
|
| 539 |
+
self.maybe_free_model_hooks()
|
| 540 |
+
|
| 541 |
+
if not return_dict:
|
| 542 |
+
return (image,)
|
| 543 |
+
|
| 544 |
+
return StableDiffusion3PipelineOutput(images=image)
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def run(
|
| 548 |
+
prompt: str,
|
| 549 |
+
reference_image: PIL.Image.Image,
|
| 550 |
+
model_path: str,
|
| 551 |
+
num_inference_steps: int = 30,
|
| 552 |
+
num_guided_steps: int = 28,
|
| 553 |
+
guidance_scale: float = 5.0,
|
| 554 |
+
cfg_rescale_phi: float = 0.7,
|
| 555 |
+
sw_u_lr: float = 3e-3,
|
| 556 |
+
sw_steps: int = 8,
|
| 557 |
+
height: int = 768,
|
| 558 |
+
width: int = 768,
|
| 559 |
+
device: str = "cuda",
|
| 560 |
+
seed: Optional[int] = None,
|
| 561 |
+
# Add new SW-related parameters
|
| 562 |
+
num_projections: int = 64,
|
| 563 |
+
use_ucv: bool = False,
|
| 564 |
+
use_lcv: bool = False,
|
| 565 |
+
distance: Literal["l1", "l2"] = "l1",
|
| 566 |
+
num_new_candidates: int = 32,
|
| 567 |
+
subsampling_factor: int = 1,
|
| 568 |
+
sampling_mode: Literal["gaussian", "qmc"] = "gaussian",
|
| 569 |
+
pipe: Optional[SWStableDiffusion3Pipeline] = None,
|
| 570 |
+
compile: bool = False,
|
| 571 |
+
video_animation_path: Optional[str] = None,
|
| 572 |
+
) -> PIL.Image.Image:
|
| 573 |
+
"""
|
| 574 |
+
Generate an image using SW Guidance with a given prompt and reference image.
|
| 575 |
+
|
| 576 |
+
Args:
|
| 577 |
+
prompt (str): Text prompt to guide the generation
|
| 578 |
+
reference_image (PIL.Image.Image): Reference image to guide the generation
|
| 579 |
+
model_path (str): Path to the model weights
|
| 580 |
+
num_inference_steps (int): Number of denoising steps
|
| 581 |
+
num_guided_steps (int): Number of steps to apply SW guidance
|
| 582 |
+
guidance_scale (float): Scale for classifier-free guidance
|
| 583 |
+
cfg_rescale_phi (float): Rescale factor for classifier-free guidance
|
| 584 |
+
sw_u_lr (float): Learning rate for SW guidance
|
| 585 |
+
sw_steps (int): Number of steps to apply SW guidance
|
| 586 |
+
height (int): Output image height
|
| 587 |
+
width (int): Output image width
|
| 588 |
+
device (str): Device to run the model on
|
| 589 |
+
num_projections (int): Number of random projections for VectorSWDLoss
|
| 590 |
+
use_ucv (bool): Use UCV variant of VectorSWDLoss
|
| 591 |
+
use_lcv (bool): Use LCV variant of VectorSWDLoss
|
| 592 |
+
distance (str): Distance metric for VectorSWDLoss ("l1" or "l2")
|
| 593 |
+
refresh_projections_every_n_steps (int): How often to refresh projections
|
| 594 |
+
num_new_candidates (int): Number of new candidates for the reservoir
|
| 595 |
+
subsampling_factor (int): Factor to subsample points for SW computation.
|
| 596 |
+
Higher values reduce memory usage but may affect quality.
|
| 597 |
+
sampling_mode (str): Sampling mode for VectorSWDLoss.
|
| 598 |
+
pipe (SWStableDiffusion3Pipeline): Pipeline to use for generation.
|
| 599 |
+
If None, a new pipeline is created.
|
| 600 |
+
compile (bool): Whether to compile the pipeline.
|
| 601 |
+
|
| 602 |
+
Returns:
|
| 603 |
+
PIL.Image.Image: Generated image
|
| 604 |
+
"""
|
| 605 |
+
# Normalize device to torch.device for robustness
|
| 606 |
+
device = torch.device(device) if not isinstance(device, torch.device) else device
|
| 607 |
+
if pipe is None:
|
| 608 |
+
pipe = create_pipeline(model_path, device, compile=compile)
|
| 609 |
+
|
| 610 |
+
pipe.setup_swd(
|
| 611 |
+
num_projections=num_projections,
|
| 612 |
+
use_ucv=use_ucv,
|
| 613 |
+
use_lcv=use_lcv,
|
| 614 |
+
distance=distance,
|
| 615 |
+
num_new_candidates=num_new_candidates,
|
| 616 |
+
subsampling_factor=subsampling_factor,
|
| 617 |
+
sampling_mode=sampling_mode,
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
if seed is not None:
|
| 621 |
+
print(f"Using seed: {seed}")
|
| 622 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 623 |
+
else:
|
| 624 |
+
generator = None
|
| 625 |
+
|
| 626 |
+
image = pipe(
|
| 627 |
+
prompt=prompt,
|
| 628 |
+
num_inference_steps=num_inference_steps,
|
| 629 |
+
num_guided_steps=num_guided_steps,
|
| 630 |
+
guidance_scale=guidance_scale,
|
| 631 |
+
cfg_rescale_phi=cfg_rescale_phi,
|
| 632 |
+
sw_u_lr=sw_u_lr,
|
| 633 |
+
sw_steps=sw_steps,
|
| 634 |
+
height=height,
|
| 635 |
+
width=width,
|
| 636 |
+
sw_reference=reference_image,
|
| 637 |
+
generator=generator,
|
| 638 |
+
write_video_animation_path=video_animation_path,
|
| 639 |
+
).images[0]
|
| 640 |
+
|
| 641 |
+
return image
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def create_pipeline(model_path, device: str = "cuda", compile: bool = False):
|
| 645 |
+
pipe = SWStableDiffusion3Pipeline.from_pretrained(
|
| 646 |
+
model_path,
|
| 647 |
+
torch_dtype=torch.bfloat16,
|
| 648 |
+
)
|
| 649 |
+
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
| 650 |
+
pipe.to(device)
|
| 651 |
+
if compile:
|
| 652 |
+
pipe.transformer = torch.compile(pipe.transformer)
|
| 653 |
+
pipe.vae.decoder = torch.compile(pipe.vae.decoder)
|
| 654 |
+
return pipe
|
src/utils/asc_cdl.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from jaxtyping import Float
|
| 6 |
+
from lxml import etree
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_asc_cdl(cdl_path: str, device: torch.device = torch.device("cpu")) -> dict:
|
| 10 |
+
"""
|
| 11 |
+
Loads ASC CDL parameters from an XML file.
|
| 12 |
+
|
| 13 |
+
Parameters:
|
| 14 |
+
cdl_path (str): Path to the ASC CDL XML file
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
Dict:
|
| 18 |
+
slope, offset, power, and saturation values as torch tensors
|
| 19 |
+
"""
|
| 20 |
+
try:
|
| 21 |
+
tree = etree.parse(cdl_path)
|
| 22 |
+
root = tree.getroot()
|
| 23 |
+
except Exception as e:
|
| 24 |
+
raise ValueError(f"Error loading ASC CDL from {cdl_path}: {e}")
|
| 25 |
+
|
| 26 |
+
# Extract SOP values
|
| 27 |
+
sop_node = root.find(".//SOPNode")
|
| 28 |
+
slope = torch.tensor(
|
| 29 |
+
[float(x) for x in sop_node.find("Slope").text.split()], device=device
|
| 30 |
+
)
|
| 31 |
+
offset = torch.tensor(
|
| 32 |
+
[float(x) for x in sop_node.find("Offset").text.split()], device=device
|
| 33 |
+
)
|
| 34 |
+
power = torch.tensor(
|
| 35 |
+
[float(x) for x in sop_node.find("Power").text.split()], device=device
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Extract Saturation value
|
| 39 |
+
sat_node = root.find(".//SatNode")
|
| 40 |
+
saturation = torch.tensor(float(sat_node.find("Saturation").text), device=device)
|
| 41 |
+
|
| 42 |
+
return {"slope": slope, "offset": offset, "power": power, "saturation": saturation}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def save_asc_cdl(cdl_dict: dict, cdl_path: Optional[str]):
|
| 46 |
+
"""
|
| 47 |
+
Saves ASC CDL parameters to an XML file.
|
| 48 |
+
|
| 49 |
+
Parameters:
|
| 50 |
+
cdl_dict (dict): Dictionary containing slope, offset, power, and
|
| 51 |
+
saturation values
|
| 52 |
+
"""
|
| 53 |
+
root = etree.Element("ASC_CDL")
|
| 54 |
+
sop_node = etree.SubElement(root, "SOPNode")
|
| 55 |
+
etree.SubElement(sop_node, "Slope").text = " ".join(
|
| 56 |
+
str(x) for x in cdl_dict["slope"].detach().cpu().numpy()
|
| 57 |
+
)
|
| 58 |
+
etree.SubElement(sop_node, "Offset").text = " ".join(
|
| 59 |
+
str(x) for x in cdl_dict["offset"].detach().cpu().numpy()
|
| 60 |
+
)
|
| 61 |
+
etree.SubElement(sop_node, "Power").text = " ".join(
|
| 62 |
+
str(x) for x in cdl_dict["power"].detach().cpu().numpy()
|
| 63 |
+
)
|
| 64 |
+
sat_node = etree.SubElement(root, "SatNode")
|
| 65 |
+
etree.SubElement(sat_node, "Saturation").text = str(
|
| 66 |
+
cdl_dict["saturation"].detach().cpu().numpy()
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
tree = etree.ElementTree(root)
|
| 70 |
+
if cdl_path is not None:
|
| 71 |
+
try:
|
| 72 |
+
tree.write(
|
| 73 |
+
cdl_path, pretty_print=True, xml_declaration=True, encoding="utf-8"
|
| 74 |
+
)
|
| 75 |
+
except Exception as e:
|
| 76 |
+
raise ValueError(f"Error saving ASC CDL to {cdl_path}: {e}")
|
| 77 |
+
else:
|
| 78 |
+
return etree.tostring(
|
| 79 |
+
root, pretty_print=True, xml_declaration=True, encoding="utf-8"
|
| 80 |
+
).decode("utf-8")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def apply_sop(
|
| 84 |
+
img: Float[torch.Tensor, "*B C H W"],
|
| 85 |
+
slope: Float[torch.Tensor, "*B C"],
|
| 86 |
+
offset: Float[torch.Tensor, "*B C"],
|
| 87 |
+
power: Float[torch.Tensor, "*B C"],
|
| 88 |
+
clamp: bool = True,
|
| 89 |
+
) -> Float[torch.Tensor, "*B C H W"]:
|
| 90 |
+
"""
|
| 91 |
+
Applies Slope, Offset, and Power adjustments.
|
| 92 |
+
|
| 93 |
+
Parameters:
|
| 94 |
+
img (torch.Tensor): Input image tensor (*B, C, H, W)
|
| 95 |
+
slope (torch.Tensor): Slope per channel (*B, C)
|
| 96 |
+
offset (torch.Tensor): Offset per channel (*B, C)
|
| 97 |
+
power (torch.Tensor): Power per channel (*B, C)
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
torch.Tensor: Image after SOP adjustments.
|
| 101 |
+
"""
|
| 102 |
+
so = img * slope.unsqueeze(-1).unsqueeze(-1) + offset.unsqueeze(-1).unsqueeze(-1)
|
| 103 |
+
if clamp:
|
| 104 |
+
so = torch.clamp(so, min=0.0, max=1.0)
|
| 105 |
+
return torch.where(
|
| 106 |
+
so > 1e-7, torch.pow(so.clamp(min=1e-7), power.unsqueeze(-1).unsqueeze(-1)), so
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def apply_saturation(
|
| 111 |
+
img: Float[torch.Tensor, "*B C H W"],
|
| 112 |
+
saturation: Float[torch.Tensor, "*B"],
|
| 113 |
+
) -> Float[torch.Tensor, "*B C H W"]:
|
| 114 |
+
"""
|
| 115 |
+
Applies saturation adjustment.
|
| 116 |
+
|
| 117 |
+
Parameters:
|
| 118 |
+
img (torch.Tensor): Image tensor (*B, C, H, W)
|
| 119 |
+
saturation (torch.Tensor): Saturation factor (*B)
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
torch.Tensor: Image after saturation adjustment.
|
| 123 |
+
"""
|
| 124 |
+
# Calculate luminance using Rec. 709 coefficients
|
| 125 |
+
lum = (
|
| 126 |
+
0.2126 * img[..., 0, :, :]
|
| 127 |
+
+ 0.7152 * img[..., 1, :, :]
|
| 128 |
+
+ 0.0722 * img[..., 2, :, :]
|
| 129 |
+
)
|
| 130 |
+
lum = lum.unsqueeze(-3) # Add channel dimension
|
| 131 |
+
return lum + (img - lum) * saturation.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def asc_cdl_forward(
|
| 135 |
+
img: Float[torch.Tensor, "*B C H W"],
|
| 136 |
+
slope: Float[torch.Tensor, "*B C"],
|
| 137 |
+
offset: Float[torch.Tensor, "*B C"],
|
| 138 |
+
power: Float[torch.Tensor, "*B C"],
|
| 139 |
+
saturation: Float[torch.Tensor, "*B"],
|
| 140 |
+
clamp: bool = True,
|
| 141 |
+
) -> Float[torch.Tensor, "*B C H W"]:
|
| 142 |
+
"""
|
| 143 |
+
Applies ASC CDL transformation in Fwd or FwdNoClamp mode.
|
| 144 |
+
|
| 145 |
+
Parameters:
|
| 146 |
+
img (torch.Tensor): Input image tensor (*B, C, H, W)
|
| 147 |
+
slope (torch.Tensor): Slope per channel (*B, C)
|
| 148 |
+
offset (torch.Tensor): Offset per channel (*B, C)
|
| 149 |
+
power (torch.Tensor): Power per channel (*B, C)
|
| 150 |
+
saturation (torch.Tensor): Saturation factor (*B)
|
| 151 |
+
clamp (bool): If True, clamps output to [0, 1] (Fwd mode).
|
| 152 |
+
If False, no clamping (FwdNoClamp mode).
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
torch.Tensor: Transformed image tensor.
|
| 156 |
+
"""
|
| 157 |
+
# Add warning if saturation, slope, power are below 0
|
| 158 |
+
if (saturation < 0).any():
|
| 159 |
+
warnings.warn("Saturation is below 0, this will result in a color shift.")
|
| 160 |
+
if (slope < 0).any():
|
| 161 |
+
warnings.warn("Slope is below 0, this will result in a color shift.")
|
| 162 |
+
if (power < 0).any():
|
| 163 |
+
warnings.warn("Power is below 0, this will result in a color shift.")
|
| 164 |
+
|
| 165 |
+
img_batch_dim = img.shape[:-3]
|
| 166 |
+
# Check if slope, offset, power, saturation have the same batch dimension
|
| 167 |
+
# If they do not have any batch dimensions, add a single batch dimensions
|
| 168 |
+
if slope.ndim == 1:
|
| 169 |
+
slope = slope.view(*[1] * len(img_batch_dim), *slope.shape)
|
| 170 |
+
if offset.ndim == 1:
|
| 171 |
+
offset = offset.view(*[1] * len(img_batch_dim), *offset.shape)
|
| 172 |
+
if power.ndim == 1:
|
| 173 |
+
power = power.view(*[1] * len(img_batch_dim), *power.shape)
|
| 174 |
+
if saturation.ndim == 0:
|
| 175 |
+
saturation = saturation.view(*[1] * len(img_batch_dim), *saturation.shape)
|
| 176 |
+
|
| 177 |
+
# Now check that the lengths are matching
|
| 178 |
+
assert slope.ndim == len(img_batch_dim) + 1
|
| 179 |
+
assert offset.ndim == len(img_batch_dim) + 1
|
| 180 |
+
assert power.ndim == len(img_batch_dim) + 1
|
| 181 |
+
assert saturation.ndim == len(img_batch_dim)
|
| 182 |
+
|
| 183 |
+
# Apply Slope, Offset, and Power adjustments
|
| 184 |
+
img = apply_sop(img, slope, offset, power, clamp=clamp)
|
| 185 |
+
# print("img after sop", img.min(), img.max())
|
| 186 |
+
# Apply Saturation adjustment
|
| 187 |
+
img = apply_saturation(img, saturation)
|
| 188 |
+
# print("img after saturation", img.min(), img.max())
|
| 189 |
+
# Clamp if in Fwd mode
|
| 190 |
+
if clamp:
|
| 191 |
+
img = torch.clamp(img, 0.0, 1.0)
|
| 192 |
+
return img
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def inverse_saturation(
|
| 196 |
+
img: Float[torch.Tensor, "*B C H W"],
|
| 197 |
+
saturation: Float[torch.Tensor, "*B"],
|
| 198 |
+
) -> Float[torch.Tensor, "*B C H W"]:
|
| 199 |
+
"""
|
| 200 |
+
Reverts saturation adjustment.
|
| 201 |
+
|
| 202 |
+
Parameters:
|
| 203 |
+
img (torch.Tensor): Image tensor (*B, C, H, W)
|
| 204 |
+
saturation (torch.Tensor): Saturation factor (*B)
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
torch.Tensor: Image after reversing saturation adjustment.
|
| 208 |
+
"""
|
| 209 |
+
# Calculate luminance using Rec. 709 coefficients
|
| 210 |
+
lum = (
|
| 211 |
+
0.2126 * img[..., 0, :, :]
|
| 212 |
+
+ 0.7152 * img[..., 1, :, :]
|
| 213 |
+
+ 0.0722 * img[..., 2, :, :]
|
| 214 |
+
)
|
| 215 |
+
lum = lum.unsqueeze(-3) # Add channel dimension
|
| 216 |
+
return lum + (img - lum) / saturation.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def asc_cdl_reverse(
|
| 220 |
+
img: Float[torch.Tensor, "*B C H W"],
|
| 221 |
+
slope: Float[torch.Tensor, "*B C"],
|
| 222 |
+
offset: Float[torch.Tensor, "*B C"],
|
| 223 |
+
power: Float[torch.Tensor, "*B C"],
|
| 224 |
+
saturation: Float[torch.Tensor, "*B"],
|
| 225 |
+
clamp: bool = True,
|
| 226 |
+
) -> Float[torch.Tensor, "*B C H W"]:
|
| 227 |
+
"""
|
| 228 |
+
Applies reverse ASC CDL transformation.
|
| 229 |
+
|
| 230 |
+
Parameters:
|
| 231 |
+
img (torch.Tensor): Transformed image tensor (*B, C, H, W)
|
| 232 |
+
slope (torch.Tensor): Slope per channel (*B, C)
|
| 233 |
+
offset (torch.Tensor): Offset per channel (*B, C)
|
| 234 |
+
power (torch.Tensor): Power per channel (*B, C)
|
| 235 |
+
saturation (torch.Tensor): Saturation factor (*B)
|
| 236 |
+
clamp (bool): If True, clamps output to [0, 1].
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
torch.Tensor: Recovered input image tensor.
|
| 240 |
+
"""
|
| 241 |
+
# Add warning if saturation, slope, power are below 0
|
| 242 |
+
if (saturation < 0).any():
|
| 243 |
+
warnings.warn("Saturation is below 0, this will result in a color shift.")
|
| 244 |
+
if (slope < 0).any():
|
| 245 |
+
warnings.warn("Slope is below 0, this will result in a color shift.")
|
| 246 |
+
if (power < 0).any():
|
| 247 |
+
warnings.warn("Power is below 0, this will result in a color shift.")
|
| 248 |
+
|
| 249 |
+
img_batch_dim = img.shape[:-3]
|
| 250 |
+
# Check if slope, offset, power, saturation have the same batch dimension
|
| 251 |
+
# If they do not have any batch dimensions, add a single batch dimensions
|
| 252 |
+
if slope.ndim == 1:
|
| 253 |
+
slope = slope.view(*[1] * len(img_batch_dim), *slope.shape)
|
| 254 |
+
if offset.ndim == 1:
|
| 255 |
+
offset = offset.view(*[1] * len(img_batch_dim), *offset.shape)
|
| 256 |
+
if power.ndim == 1:
|
| 257 |
+
power = power.view(*[1] * len(img_batch_dim), *power.shape)
|
| 258 |
+
if saturation.ndim == 0:
|
| 259 |
+
saturation = saturation.view(*[1] * len(img_batch_dim), *saturation.shape)
|
| 260 |
+
|
| 261 |
+
# Now check that the lengths are matching
|
| 262 |
+
assert slope.ndim == len(img_batch_dim) + 1
|
| 263 |
+
assert offset.ndim == len(img_batch_dim) + 1
|
| 264 |
+
assert power.ndim == len(img_batch_dim) + 1
|
| 265 |
+
assert saturation.ndim == len(img_batch_dim)
|
| 266 |
+
|
| 267 |
+
# Inverse Saturation adjustment
|
| 268 |
+
img = inverse_saturation(img, saturation)
|
| 269 |
+
# Inverse SOP adjustments
|
| 270 |
+
if clamp:
|
| 271 |
+
img = torch.clamp(img, 0.0, 1.0)
|
| 272 |
+
img = torch.where(
|
| 273 |
+
img > 1e-7, torch.pow(img, 1 / power.unsqueeze(-1).unsqueeze(-1)), img
|
| 274 |
+
)
|
| 275 |
+
img = (img - offset.unsqueeze(-1).unsqueeze(-1)) / slope.unsqueeze(-1).unsqueeze(-1)
|
| 276 |
+
# Clamp if specified
|
| 277 |
+
if clamp:
|
| 278 |
+
img = torch.clamp(img, 0.0, 1.0)
|
| 279 |
+
return img
|
src/utils/color_space.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from jaxtyping import Float
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def srgb_to_linear(x: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]:
|
| 6 |
+
switch_val = 0.04045
|
| 7 |
+
return torch.where(
|
| 8 |
+
torch.greater(x, switch_val),
|
| 9 |
+
((x.clip(min=switch_val) + 0.055) / 1.055).pow(2.4),
|
| 10 |
+
x / 12.92,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def linear_to_srgb(x: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]:
|
| 15 |
+
switch_val = 0.0031308
|
| 16 |
+
return torch.where(
|
| 17 |
+
torch.greater(x, switch_val),
|
| 18 |
+
1.055 * x.clip(min=switch_val).pow(1.0 / 2.4) - 0.055,
|
| 19 |
+
x * 12.92,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def rgb_to_lab(srgb: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]:
|
| 24 |
+
srgb_pixels = torch.reshape(srgb, [-1, 3])
|
| 25 |
+
|
| 26 |
+
linear_mask = srgb_pixels <= 0.04045
|
| 27 |
+
exponential_mask = srgb_pixels > 0.04045
|
| 28 |
+
rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (
|
| 29 |
+
((srgb_pixels + 0.055) / 1.055) ** 2.4
|
| 30 |
+
) * exponential_mask
|
| 31 |
+
|
| 32 |
+
rgb_to_xyz = (
|
| 33 |
+
torch.tensor(
|
| 34 |
+
[
|
| 35 |
+
# X Y Z
|
| 36 |
+
[0.412453, 0.212671, 0.019334], # R
|
| 37 |
+
[0.357580, 0.715160, 0.119193], # G
|
| 38 |
+
[0.180423, 0.072169, 0.950227], # B
|
| 39 |
+
]
|
| 40 |
+
)
|
| 41 |
+
.to(srgb.dtype)
|
| 42 |
+
.to(srgb.device)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
xyz_pixels = torch.mm(rgb_pixels, rgb_to_xyz)
|
| 46 |
+
|
| 47 |
+
xyz_normalized_pixels = torch.mul(
|
| 48 |
+
xyz_pixels,
|
| 49 |
+
torch.tensor([1 / 0.950456, 1.0, 1 / 1.088754]).to(srgb.dtype).to(srgb.device),
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
epsilon = 6.0 / 29.0
|
| 53 |
+
linear_mask = (xyz_normalized_pixels <= (epsilon**3)).to(srgb.dtype).to(srgb.device)
|
| 54 |
+
|
| 55 |
+
exponential_mask = (
|
| 56 |
+
(xyz_normalized_pixels > (epsilon**3)).to(srgb.dtype).to(srgb.device)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
fxfyfz_pixels = (
|
| 60 |
+
xyz_normalized_pixels / (3 * epsilon**2) + 4.0 / 29.0
|
| 61 |
+
) * linear_mask + (
|
| 62 |
+
(xyz_normalized_pixels + 0.000001) ** (1.0 / 3.0)
|
| 63 |
+
) * exponential_mask
|
| 64 |
+
|
| 65 |
+
fxfyfz_to_lab = (
|
| 66 |
+
torch.tensor(
|
| 67 |
+
[
|
| 68 |
+
# l a b
|
| 69 |
+
[0.0, 500.0, 0.0], # fx
|
| 70 |
+
[116.0, -500.0, 200.0], # fy
|
| 71 |
+
[0.0, 0.0, -200.0], # fz
|
| 72 |
+
]
|
| 73 |
+
)
|
| 74 |
+
.to(srgb.dtype)
|
| 75 |
+
.to(srgb.device)
|
| 76 |
+
)
|
| 77 |
+
lab_pixels = torch.mm(fxfyfz_pixels, fxfyfz_to_lab) + torch.tensor(
|
| 78 |
+
[-16.0, 0.0, 0.0]
|
| 79 |
+
).to(srgb.dtype).to(srgb.device)
|
| 80 |
+
return torch.reshape(lab_pixels, srgb.shape)
|
src/utils/image.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from jaxtyping import Float
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def read_img(path):
|
| 8 |
+
img = cv2.imread(str(path), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
| 9 |
+
if img.ndim == 3:
|
| 10 |
+
img = cv2.cvtColor(img[..., :3], cv2.COLOR_BGR2RGB)
|
| 11 |
+
elif img.ndim == 2:
|
| 12 |
+
img = img[..., np.newaxis]
|
| 13 |
+
dinfo = np.iinfo(img.dtype)
|
| 14 |
+
return (img.astype(np.float32) / dinfo.max) * 2 - 1
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def write_img(path: str, data: np.ndarray):
|
| 18 |
+
data = np.clip(data * 0.5 + 0.5, 0, 1)
|
| 19 |
+
if data.ndim == 3 and data.shape[-1] == 3:
|
| 20 |
+
data = cv2.cvtColor(data, cv2.COLOR_RGB2BGR)
|
| 21 |
+
elif data.ndim == 2:
|
| 22 |
+
data = data[..., np.newaxis]
|
| 23 |
+
|
| 24 |
+
data = (data * 255).astype(np.uint8)
|
| 25 |
+
cv2.imwrite(path, data)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def to_torch(img: Float[np.ndarray, "H W C"]) -> Float[torch.Tensor, "C H W"]:
|
| 29 |
+
return torch.from_numpy(img).permute(2, 0, 1)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def from_torch(img: Float[torch.Tensor, "C H W"]) -> Float[np.ndarray, "H W C"]:
|
| 33 |
+
return img.permute(1, 2, 0).detach().cpu().float().numpy()
|
src/utils/math.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _random_orthonormal_matrix(d: int, device: torch.device) -> torch.Tensor:
|
| 8 |
+
"""Draw a random rotation matrix Q β SO(d) (Haar) via QR-factorisation."""
|
| 9 |
+
a = torch.randn(d, d, device=device)
|
| 10 |
+
# QR gives orthonormal columns; ensure right-handed
|
| 11 |
+
q, r = torch.linalg.qr(a, mode="reduced")
|
| 12 |
+
# make determinant +1 (special orthogonal) β flip first column if needed
|
| 13 |
+
if torch.det(q) < 0:
|
| 14 |
+
q[:, 0] = -q[:, 0]
|
| 15 |
+
return q # (d,d)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def sobol_sphere(
|
| 19 |
+
n: int,
|
| 20 |
+
d: int,
|
| 21 |
+
device: torch.device,
|
| 22 |
+
sobol_engine: Optional[torch.quasirandom.SobolEngine] = None,
|
| 23 |
+
) -> Union[torch.Tensor, torch.quasirandom.SobolEngine]:
|
| 24 |
+
"""n unit vectors on S^{d-1} via scrambled Sobol + Gaussian + random rotation."""
|
| 25 |
+
if sobol_engine is None:
|
| 26 |
+
sob = torch.quasirandom.SobolEngine(dimension=d, scramble=True)
|
| 27 |
+
else:
|
| 28 |
+
sob = sobol_engine
|
| 29 |
+
# Draw in [0,1)^d then map β π©(0,1)
|
| 30 |
+
u01 = sob.draw(n).to(device)
|
| 31 |
+
|
| 32 |
+
eps = 1e-7
|
| 33 |
+
u01 = u01.clamp(min=eps, max=1.0 - eps) # avoid 0 and 1 exactly
|
| 34 |
+
|
| 35 |
+
z = torch.erfinv(2.0 * u01 - 1.0) * math.sqrt(2.0) # inverse-CDF of Normal
|
| 36 |
+
z = z / (z.norm(dim=1, keepdim=True) + 1e-8) # project to sphere
|
| 37 |
+
# Random global rotation (RQMC) to make estimator unbiased
|
| 38 |
+
Q = _random_orthonormal_matrix(d, device)
|
| 39 |
+
return z @ Q.T, sob # (n,d)
|