mboss commited on
Commit
7349148
Β·
1 Parent(s): 93bac7e

Initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ # Output
13
+ output*/
14
+ out*/
15
+ out/
16
+ .gradio/
17
+ /*.png
18
+ /*.csv
LICENSE.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STABILITY AI COMMUNITY LICENSE AGREEMENT Last Updated: July 5, 2024
2
+
3
+ I. INTRODUCTION
4
+
5
+ This Agreement applies to any individual person or entity ("You", "Your" or "Licensee") that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
6
+
7
+ This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
8
+
9
+ By clicking "I Accept" or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then "You" includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity's behalf.
10
+
11
+ II. RESEARCH & NON-COMMERCIAL USE LICENSE
12
+
13
+ Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. "Research Purpose" means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. "Non-Commercial Purpose" means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
14
+
15
+ III. COMMERCIAL USE LICENSE
16
+
17
+ Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. "Commercial Purpose" means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business's or organization's internal operations. If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
18
+
19
+ IV. GENERAL TERMS
20
+
21
+ Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms. a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright Β© Stability AI Ltd. All Rights Reserved", and (iii) prominently display "Powered by Stability AI" on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the "Notice" text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the "Notice" text file that You changed the Stability AI Materials and how it was modified. b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI's AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works). c. Intellectual Property. (i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein. (ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI's ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI. (iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law. (iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement. (v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI's existing or prospective technology, products or services (collectively, "Feedback"). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided "AS IS" and You make no warranties whatsoever about any Feedback. d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS. e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement. g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
22
+
23
+ V. DEFINITIONS
24
+
25
+ "Affiliate(s)" means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, "control" means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity. "Agreement" means this Stability AI Community License Agreement. "AUP" means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time. "Derivative Work(s)" means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model's output, including"fine tune" and "low-rank adaptation" models derived from a Model or a Model's output, but do not include the output of any Model. "Documentation" means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models. "Model(s)" means, collectively, Stability AI's proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability's Core Models Webpage available at, https://stability.ai/core-models, as may be updated from time to time. "Stability AI" or "we" means Stability AI Ltd. and its Affiliates. "Software" means Stability AI's proprietary software made available under this Agreement now or in the future. "Stability AI Materials" means, collectively, Stability's proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement. "Trade Control Laws" means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
README.md CHANGED
@@ -1,12 +1,32 @@
1
  ---
2
  title: ReSWD
3
- emoji: 🌍
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.47.1
8
  app_file: app.py
9
  pinned: false
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: ReSWD
3
+ emoji: πŸ“Š
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.47.1
8
  app_file: app.py
9
  pinned: false
10
+ models:
11
+ - stabilityai/stable-diffusion-3.5-large-turbo
12
+ license: other
13
+ license_name: stabilityai-ai-community
14
+ license_link: LICENSE.md
15
  ---
16
 
17
+ # ReSWD: ReSTIRβ€˜d, not shaken. Combining Reservoir Sampling and Sliced Wasserstein Distance for Variance Reduction.
18
+
19
+ <a href="https://reservoirswd.github.io/"><img src="https://img.shields.io/badge/Project%20Page-5CE1BC.svg"></a> <a href="https://reservoirswd.github.io/static/paper.pdf"><img src="https://img.shields.io/badge/Arxiv-2408.00653-B31B1B.svg"></a> <a href="https://huggingface.co/spaces/stabilityai/reswd"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Gradio%20Demo-Huggingface-orange"></a>
20
+
21
+ This is the official codebase for **ReSWD**, a state-of-the-art algorithm for distribution matching with reduced variance. It has several applications (such as diffusion guidance or color matching).
22
+
23
+ ## Citation
24
+
25
+ ```BibTeX
26
+ @article{boss2025reswd,
27
+ title={ReSWD: ReSTIRβ€˜d, not shaken. Combining Reservoir Sampling and Sliced Wasserstein Distance for Variance Reduction.},
28
+ author={Boss, Mark and Engelhardt, Andreas and DonnΓ©, Simon and Jampani, Varun},
29
+ journal={arXiv preprint},
30
+ year={2025}
31
+ }
32
+ ```
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import lightning as L
3
+
4
+ from src.gradio_demo.color_matching import create_color_matching
5
+ from src.gradio_demo.sw_guidance import create_sw_guidance
6
+
7
+ with gr.Blocks() as demo:
8
+ gr.Markdown(
9
+ """
10
+ # ReSWD
11
+
12
+ ReSTIRβ€˜d, not shaken. Combining Reservoir Sampling and Sliced Wasserstein
13
+ Distance for Variance Reduction.
14
+
15
+ ReSWD is a method for distribution matching with reduced variance.
16
+ """
17
+ )
18
+ fabric = L.Fabric(devices=1, accelerator="auto", precision="16-mixed")
19
+
20
+ with gr.Tab("SW Guidance (SD 3.5 Large Turbo)"):
21
+ create_sw_guidance(fabric, "stabilityai/stable-diffusion-3.5-large-turbo")
22
+ with gr.Tab("Color Matching"):
23
+ create_color_matching(fabric)
24
+
25
+ demo.launch()
example/color_matching/cutting_fruit.jpg ADDED

Git LFS Details

  • SHA256: 1ca550dfa59f2c5c6fd39d097a86edb5ec703cfce7ef5f604f308860fbc9f3db
  • Pointer size: 131 Bytes
  • Size of remote file: 975 kB
example/color_matching/field.jpg ADDED

Git LFS Details

  • SHA256: 95c87921c33e5ca2fe96874befaa74038a416f65f4eb503f64f51fd3bd6e088a
  • Pointer size: 131 Bytes
  • Size of remote file: 682 kB
example/color_matching/fruit_stand.jpg ADDED

Git LFS Details

  • SHA256: 088d65eee1da1b16e233f99bf3e64f67f0be058dd78dd53de0018ba138a41a89
  • Pointer size: 131 Bytes
  • Size of remote file: 841 kB
example/color_matching/portrait.jpg ADDED

Git LFS Details

  • SHA256: a9a1b84d0f5358522019e12ce64084a29ff9fee8ee6f01b60bbdcefa3b2fb70f
  • Pointer size: 131 Bytes
  • Size of remote file: 557 kB
example/guidance/arch.jpg ADDED

Git LFS Details

  • SHA256: 20983f0bd28bb783e2619be9e999be025b3f903e73f19c6f9e9e0d2c889475c5
  • Pointer size: 131 Bytes
  • Size of remote file: 109 kB
example/guidance/blue-cast.jpg ADDED

Git LFS Details

  • SHA256: 02b033a744a99f9c9d10ef06269d513f9caaf6395f537e78e508ef580c8503b8
  • Pointer size: 130 Bytes
  • Size of remote file: 26.6 kB
example/guidance/boat-gray.jpg ADDED

Git LFS Details

  • SHA256: 96c89af97e5e66da3e9bd6e8cbc6c5ac2fdb2e1908932bdd309f977a87b6357b
  • Pointer size: 130 Bytes
  • Size of remote file: 32.8 kB
example/guidance/building.jpg ADDED

Git LFS Details

  • SHA256: ed437eb31b43feb351b19f54d31fd87eebf467458cd0a207250bf15f442f8907
  • Pointer size: 131 Bytes
  • Size of remote file: 292 kB
example/guidance/canyon.jpg ADDED

Git LFS Details

  • SHA256: 0e742e327d3c138d56f920e60f96e1380cf4080626a5214e50515611b5a79465
  • Pointer size: 131 Bytes
  • Size of remote file: 408 kB
example/guidance/colorful-buildings.jpg ADDED

Git LFS Details

  • SHA256: e52d68f3efe1a7b0077cf52dddb587542a587e249b2b5adf3e12563de8187c64
  • Pointer size: 131 Bytes
  • Size of remote file: 157 kB
example/guidance/dark-sky.jpg ADDED

Git LFS Details

  • SHA256: 326fd77963438ed4b85ee9ca12e5ba9fea4f92ecbce78f5942c66e43458814e2
  • Pointer size: 130 Bytes
  • Size of remote file: 20.4 kB
example/guidance/fisher.jpg ADDED

Git LFS Details

  • SHA256: 632a87c65b6703a41c525ad41582b01e085e561966e9ef06f666412a5228299d
  • Pointer size: 131 Bytes
  • Size of remote file: 236 kB
example/guidance/food_stand.jpg ADDED

Git LFS Details

  • SHA256: 1a79fbd456ac21b068f1d486f71c489c3d1d194712417eacef3b8fd27fa91b4a
  • Pointer size: 130 Bytes
  • Size of remote file: 87.5 kB
example/guidance/gray-power.jpg ADDED

Git LFS Details

  • SHA256: 17cc8d1d5e4f95b47a430dd907875e902b412b74132fe5018b280b92deafa6f5
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB
example/guidance/greenhouse_2.jpg ADDED

Git LFS Details

  • SHA256: 34554564d7e395cff2aa6a2645bb34f9822916977582f7db0b0e7d3b6b9dec9e
  • Pointer size: 130 Bytes
  • Size of remote file: 99.3 kB
example/guidance/industrial.jpg ADDED

Git LFS Details

  • SHA256: 8cd9c78fdb31d378e1b54f4c211aeeaf8e5f95c2ed772f2fd03062a4de237fd7
  • Pointer size: 131 Bytes
  • Size of remote file: 169 kB
example/guidance/lake.jpg ADDED

Git LFS Details

  • SHA256: c8672eb27358786075c9ee6cbd9ced813fc3927d847a6e8a532b5447f6b9d331
  • Pointer size: 131 Bytes
  • Size of remote file: 153 kB
example/guidance/lake_green.jpg ADDED

Git LFS Details

  • SHA256: ff40e518b6ffe4dab79e76903424cab249d3a3ec0b5e6071d17763c6b12b2d32
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB
example/guidance/lake_sunset.jpg ADDED

Git LFS Details

  • SHA256: 3d908f24b2f4eda01583d09386ec48e2da854f9fe677a66d8abcab1cc1eead04
  • Pointer size: 130 Bytes
  • Size of remote file: 60.9 kB
example/guidance/mountain.jpg ADDED

Git LFS Details

  • SHA256: 031356e0edc7b2ebc339b01b398ff8258f987a73aa392d6374ecc1af880537a0
  • Pointer size: 130 Bytes
  • Size of remote file: 93.3 kB
example/guidance/ornament.jpg ADDED

Git LFS Details

  • SHA256: 71a2f6efbc7eab1745731e1d49bb29454068b08a3fe118dd48069391712277bf
  • Pointer size: 131 Bytes
  • Size of remote file: 215 kB
example/guidance/path.jpg ADDED

Git LFS Details

  • SHA256: 2ca87c111a3e6f2450ac5896abbccaaff6abec928286a60b45dcf5e2777264b2
  • Pointer size: 131 Bytes
  • Size of remote file: 182 kB
example/guidance/sky_pier.jpg ADDED

Git LFS Details

  • SHA256: 9773fb6b512f53121ea5ef57b8766ced0b254d058395c6d6de63a8c4220f8e62
  • Pointer size: 130 Bytes
  • Size of remote file: 58.9 kB
example/guidance/snow.jpg ADDED

Git LFS Details

  • SHA256: 8ace5984603adca5d2853bad57f75ec62d9c482d3136c36a8b6156ff5fbaad2e
  • Pointer size: 131 Bytes
  • Size of remote file: 167 kB
example/guidance/sunken_boat.jpg ADDED

Git LFS Details

  • SHA256: 4dbbfc0220becb16e364b51a9063a3c8707295a8726df4103b1355d49cefecd9
  • Pointer size: 131 Bytes
  • Size of remote file: 115 kB
example/guidance/waterfall.jpg ADDED

Git LFS Details

  • SHA256: 6b424a96166a2f333d237c403e676323e4574137763c9ba09a0370e5bf63c992
  • Pointer size: 131 Bytes
  • Size of remote file: 238 kB
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=1.7.1
2
+ diffusers>=0.33.1
3
+ imageio>=2.37.0
4
+ jaxtyping>=0.3.2
5
+ lightning>=2.5.1.post0
6
+ lxml>=5.4.0
7
+ matplotlib>=3.10.3
8
+ numpy>=2.2.5
9
+ opencv-python>=4.11.0.86
10
+ pillow>=11.2.1
11
+ protobuf>=6.32.1
12
+ scipy>=1.15.3
13
+ sentencepiece>=0.2.1
14
+ torch>=2.7.0
15
+ torchvision>=0.22.0
16
+ tqdm>=4.67.1
17
+ transformers>=4.52.4
src/color_matcher.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ from typing import List, Literal, Optional, Tuple
4
+
5
+ import imageio
6
+ import lightning as L
7
+ import torch
8
+ from jaxtyping import Float
9
+ from torchvision.transforms import Resize
10
+ from tqdm import tqdm
11
+
12
+ from src.loss import AbstractLoss
13
+ from src.loss.vector_swd import VectorSWDLoss
14
+ from src.utils.asc_cdl import asc_cdl_forward, save_asc_cdl
15
+ from src.utils.color_space import rgb_to_lab
16
+ from src.utils.image import from_torch, read_img, to_torch, write_img
17
+
18
+
19
+ class CDL(torch.nn.Module):
20
+ def __init__(self, batch_size: int):
21
+ super().__init__()
22
+ self.cdl_slope = torch.nn.Parameter(torch.ones(batch_size, 3))
23
+ self.cdl_offset = torch.nn.Parameter(torch.zeros(batch_size, 3))
24
+ self.cdl_power = torch.nn.Parameter(torch.ones(batch_size, 3))
25
+ self.cdl_saturation = torch.nn.Parameter(torch.ones(batch_size))
26
+
27
+ def forward(
28
+ self, x: Float[torch.Tensor, "*B C H W"]
29
+ ) -> Float[torch.Tensor, "*B C H W"]:
30
+ return asc_cdl_forward(
31
+ x, self.cdl_slope, self.cdl_offset, self.cdl_power, self.cdl_saturation
32
+ )
33
+
34
+ def to_cdl_xml(self) -> str:
35
+ ret = []
36
+ for b in range(self.cdl_slope.shape[0]):
37
+ ret.append(
38
+ save_asc_cdl(
39
+ {
40
+ "slope": self.cdl_slope[b],
41
+ "offset": self.cdl_offset[b],
42
+ "power": self.cdl_power[b],
43
+ "saturation": self.cdl_saturation[b],
44
+ },
45
+ None,
46
+ )
47
+ )
48
+ return ret
49
+
50
+ def save(self, path: str):
51
+ for b in range(self.cdl_slope.shape[0]):
52
+ save_asc_cdl(
53
+ {
54
+ "slope": self.cdl_slope[b],
55
+ "offset": self.cdl_offset[b],
56
+ "power": self.cdl_power[b],
57
+ "saturation": self.cdl_saturation[b],
58
+ },
59
+ os.path.join(path, f"cdl_{b}.xml"),
60
+ )
61
+
62
+
63
+ def train(
64
+ fabric: L.Fabric,
65
+ criteria: AbstractLoss,
66
+ source_img: Float[torch.Tensor, "B C H W"],
67
+ target_img: Float[torch.Tensor, "B C H W"],
68
+ num_steps: int,
69
+ lr: float,
70
+ match_resolution: int,
71
+ silent: bool = False,
72
+ write_video_animation_path: Optional[str] = None,
73
+ ) -> Tuple[Float[torch.Tensor, "*B C H W"], CDL, List[float]]:
74
+ criteria = fabric.setup(criteria)
75
+
76
+ source_max_res = Resize(match_resolution, antialias=True)(source_img)
77
+ target_max_res = Resize(match_resolution, antialias=True)(target_img)
78
+
79
+ target_cielab = (
80
+ fabric.to_device(rgb_to_lab(target_max_res).permute(0, 3, 1, 2))
81
+ .permute(0, 2, 3, 1)
82
+ .contiguous()
83
+ )
84
+
85
+ source_max_res = fabric.to_device(source_max_res)
86
+ source_img = fabric.to_device(source_img)
87
+
88
+ batch_size = source_img.shape[0]
89
+ cdl = CDL(batch_size)
90
+
91
+ optim = torch.optim.Adam(cdl.parameters(), lr=lr)
92
+ cdl, optim = fabric.setup(cdl, optim)
93
+
94
+ lossses = []
95
+ for i in tqdm(range(num_steps), disable=silent):
96
+ optim.zero_grad(set_to_none=True)
97
+
98
+ cdl_source = cdl(source_max_res)
99
+ source_cielab = (
100
+ rgb_to_lab(cdl_source.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous()
101
+ )
102
+
103
+ loss = criteria(
104
+ source_cielab.view(source_cielab.shape[0], source_cielab.shape[1], -1),
105
+ target_cielab.view(target_cielab.shape[0], target_cielab.shape[1], -1),
106
+ i,
107
+ )
108
+
109
+ fabric.backward(loss)
110
+ optim.step()
111
+
112
+ lossses.append(loss.item())
113
+
114
+ if write_video_animation_path is not None:
115
+ write_img(
116
+ os.path.join(write_video_animation_path, f"{i:05d}.jpg"),
117
+ from_torch(cdl(source_img).squeeze(0) * 2 - 1),
118
+ )
119
+
120
+ source_full_res_cdl = cdl(source_img)
121
+
122
+ gc.collect()
123
+ torch.cuda.empty_cache()
124
+
125
+ return source_full_res_cdl, cdl, lossses
126
+
127
+
128
+ def run(
129
+ save_dir: str,
130
+ source_img: List[str],
131
+ target_img: List[str],
132
+ matching_resolution: int,
133
+ precision: Literal["32-true", "16-mixed"] = "16-mixed",
134
+ num_projections: int = 64,
135
+ lr: float = 0.01,
136
+ steps: int = 300,
137
+ use_ucv: bool = False,
138
+ use_lcv: bool = False,
139
+ distance: Literal["l1", "l2"] = "l1",
140
+ refresh_projections_every_n_steps: int = 1,
141
+ num_new_candidates: int = 32,
142
+ sampling_mode: Literal["gaussian", "qmc"] = "gaussian",
143
+ write_video: bool = False,
144
+ **kwargs,
145
+ ):
146
+ fabric = L.Fabric(devices=1, accelerator="auto", precision=precision)
147
+
148
+ source_imgs = torch.stack(
149
+ [to_torch(read_img(s)) * 0.5 + 0.5 for s in source_img], dim=0
150
+ )
151
+ target_imgs = torch.stack(
152
+ [to_torch(read_img(t)) * 0.5 + 0.5 for t in target_img], dim=0
153
+ )
154
+
155
+ criteria = VectorSWDLoss(
156
+ num_proj=num_projections,
157
+ distance=distance,
158
+ use_ucv=use_ucv,
159
+ use_lcv=use_lcv,
160
+ refresh_projections_every_n_steps=refresh_projections_every_n_steps,
161
+ num_new_candidates=num_new_candidates,
162
+ sampling_mode=sampling_mode,
163
+ )
164
+
165
+ os.makedirs(save_dir, exist_ok=True)
166
+ animation_dir = os.path.join(save_dir, "animation")
167
+
168
+ if write_video:
169
+ os.makedirs(animation_dir, exist_ok=True)
170
+
171
+ source_full_res_cdl, cdl, lossses = train(
172
+ fabric,
173
+ criteria,
174
+ source_imgs,
175
+ target_imgs,
176
+ steps,
177
+ lr,
178
+ matching_resolution,
179
+ write_video_animation_path=animation_dir if write_video else None,
180
+ )
181
+
182
+ cdl.save(save_dir)
183
+
184
+ for i, img in enumerate(source_full_res_cdl):
185
+ write_img(
186
+ os.path.join(save_dir, f"color_matched_{i}.png"),
187
+ from_torch(img * 2 - 1),
188
+ )
189
+
190
+ if write_video:
191
+ # Get the list of image files in the animation directory
192
+ image_files = [f for f in os.listdir(animation_dir) if f.endswith(".jpg")]
193
+ image_files.sort(
194
+ key=lambda x: int(x.split(".")[0])
195
+ ) # Ensure they are in the correct order
196
+
197
+ # Create a video from the images
198
+ with imageio.get_writer(
199
+ os.path.join(save_dir, "animation.mp4"), fps=30, codec="libx264"
200
+ ) as writer:
201
+ for image_file in image_files:
202
+ image = imageio.imread(os.path.join(animation_dir, image_file))
203
+ writer.append_data(image)
204
+
205
+ return source_full_res_cdl, cdl, lossses
src/gradio_demo/color_matching.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import lightning as L
5
+ import numpy as np
6
+ import spaces
7
+
8
+ from src.color_matcher import train
9
+ from src.loss import VectorSWDLoss
10
+ from src.utils.image import from_torch, to_torch
11
+
12
+
13
+ def create_color_matching(fabric: L.Fabric):
14
+ """
15
+ Creates the Gradio interface for color matching between source and target images.
16
+ """
17
+ gr.Markdown(
18
+ """
19
+ # Color Matching
20
+ Matches the color of source images to target images using ASC CDL parameters.
21
+ """
22
+ )
23
+
24
+ gr.Markdown("## Source Images")
25
+ with gr.Row(variant="compact"):
26
+ source_image = gr.Image(label="Source Image", height=512)
27
+ target_image = gr.Image(label="Target Image", height=512)
28
+ with gr.Row(variant="compact"):
29
+ example_paths = os.path.join("example", "color_matching")
30
+ # Find all images in the example directory
31
+ example_images = [
32
+ os.path.join(example_paths, f)
33
+ for f in os.listdir(example_paths)
34
+ if os.path.isfile(os.path.join(example_paths, f))
35
+ and os.path.splitext(f)[-1] in [".png", ".jpg", ".jpeg"]
36
+ ]
37
+
38
+ gr.Examples(
39
+ examples=example_images,
40
+ inputs=[source_image],
41
+ )
42
+ gr.Examples(
43
+ examples=example_images,
44
+ inputs=[target_image],
45
+ )
46
+
47
+ gr.Markdown(
48
+ """
49
+ # Configuration
50
+ Adjust the parameters for the color matching process.
51
+ """
52
+ )
53
+
54
+ with gr.Accordion("Advanced Config", open=False):
55
+ learning_rate_slider = gr.Slider(
56
+ 1e-4, 0.1, value=1e-3, step=1e-4, label="Learning rate"
57
+ )
58
+ match_resolution_slider = gr.Slider(
59
+ 64, 1024, value=128, step=64, label="Match resolution"
60
+ )
61
+ num_steps_slider = gr.Slider(
62
+ 50, 500, value=150, step=25, label="Number of optimization steps"
63
+ )
64
+ control_variates_dropdown = gr.Dropdown(
65
+ choices=["None", "Lower", "Upper"],
66
+ value="None",
67
+ label="Control Variates",
68
+ info="Select which control variates to use for optimization",
69
+ )
70
+ candidates_per_pass_slider = gr.Slider(
71
+ 0,
72
+ 64,
73
+ value=16,
74
+ step=1,
75
+ label="Number of new candidates per pass.",
76
+ info=(
77
+ "The number of new candidates to generate per pass in the reservoir "
78
+ "sampling."
79
+ ),
80
+ )
81
+ num_projections_slider = gr.Slider(
82
+ 16, 1024, value=64, step=16, label="Number of projections"
83
+ )
84
+ sampling_mode_dropdown = gr.Dropdown(
85
+ choices=["gaussian", "qmc"],
86
+ value="gaussian",
87
+ label="Sampling Mode",
88
+ info="Select which sampling mode to use for projections",
89
+ )
90
+
91
+ run_button = gr.Button("Run Color Matching", variant="primary")
92
+
93
+ with gr.Column(variant="compact"):
94
+ output_image = gr.Image(label="Color Matched Image", height=512)
95
+ output_cdl = gr.Textbox(label="CDL Parameters")
96
+
97
+ @spaces.GPU
98
+ def run_color_matching(
99
+ match_resolution: int,
100
+ num_steps: int,
101
+ control_variates: str,
102
+ num_projections: int,
103
+ candidates_per_pass: int,
104
+ sampling_mode: str,
105
+ learning_rate: float,
106
+ *images,
107
+ ):
108
+ """
109
+ Runs the color matching process between source and target images.
110
+ """
111
+ # Split images into source and target pair
112
+ source_img = images[0] / 255.0
113
+ target_img = images[1] / 255.0
114
+
115
+ # Convert images to tensors
116
+ source_tensors = to_torch(source_img).float().unsqueeze(0)
117
+ target_tensors = to_torch(target_img).float().unsqueeze(0)
118
+
119
+ # Configure loss function
120
+ criteria = VectorSWDLoss(
121
+ num_proj=num_projections,
122
+ use_ucv=control_variates == "Upper",
123
+ use_lcv=control_variates == "Lower",
124
+ num_new_candidates=candidates_per_pass,
125
+ sampling_mode=sampling_mode,
126
+ )
127
+
128
+ # Run color matching
129
+ source_matched, cdl, losses = train(
130
+ fabric=fabric,
131
+ criteria=criteria,
132
+ source_img=source_tensors,
133
+ target_img=target_tensors,
134
+ num_steps=num_steps,
135
+ lr=learning_rate,
136
+ match_resolution=match_resolution,
137
+ )
138
+
139
+ return [
140
+ (from_torch(source_matched.squeeze(0)) * 255).astype(np.uint8),
141
+ cdl.to_cdl_xml()[0],
142
+ ]
143
+
144
+ run_button.click(
145
+ run_color_matching,
146
+ inputs=[
147
+ match_resolution_slider,
148
+ num_steps_slider,
149
+ control_variates_dropdown,
150
+ num_projections_slider,
151
+ candidates_per_pass_slider,
152
+ sampling_mode_dropdown,
153
+ learning_rate_slider,
154
+ source_image,
155
+ target_image,
156
+ ],
157
+ outputs=[output_image, output_cdl],
158
+ )
159
+
160
+ clear_button = gr.ClearButton(
161
+ [source_image, target_image, output_image, output_cdl]
162
+ )
163
+
164
+ clear_button.click(lambda: None)
src/gradio_demo/sw_guidance.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import gradio as gr
5
+ import lightning as L
6
+ import numpy as np
7
+ import spaces
8
+ from PIL import Image
9
+
10
+ from src.sw_sdthree_guidance import create_pipeline
11
+ from src.sw_sdthree_guidance import run as sw_guidance_run
12
+
13
+ preset_lrs = [1e-6, 1.0]
14
+
15
+ log_lrs = np.log10(preset_lrs)
16
+
17
+ models = {
18
+ "stabilityai/stable-diffusion-3.5-large": {
19
+ "num_inference_steps": 30,
20
+ "guidance_scale": 7.0,
21
+ "sw_u_lr": np.log10(3.2e-3),
22
+ "sw_steps": 6,
23
+ "cfg_rescale_phi": 0.7,
24
+ },
25
+ "stabilityai/stable-diffusion-3.5-medium": {
26
+ "num_inference_steps": 30,
27
+ "guidance_scale": 3.5,
28
+ "sw_u_lr": np.log10(3.2e-3),
29
+ "sw_steps": 6,
30
+ "cfg_rescale_phi": 0.7,
31
+ },
32
+ "stabilityai/stable-diffusion-3.5-large-turbo": {
33
+ "num_inference_steps": 4,
34
+ "guidance_scale": 1.0,
35
+ "sw_u_lr": np.log10(4e-3),
36
+ "sw_steps": 6,
37
+ "cfg_rescale_phi": 0.65,
38
+ },
39
+ }
40
+
41
+
42
+ def log_slider_to_lr(log_lr):
43
+ return float(f"{10**log_lr:.1e}")
44
+
45
+
46
+ def create_sw_guidance(
47
+ fabric: L.Fabric, model_name: str = "stabilityai/stable-diffusion-3.5-large"
48
+ ):
49
+ """
50
+ Creates the Gradio interface for SW guidance with SD3.5.
51
+
52
+ Args:
53
+ fabric: Lightning Fabric instance
54
+ model_name: The model to use for guidance
55
+ """
56
+ gr.Markdown(
57
+ f"""
58
+ # SW Guidance with {model_name.split('/')[-1]}
59
+ Generates images using SW Guidance with a reference image and text prompt.
60
+ """
61
+ )
62
+
63
+ pipe = create_pipeline(
64
+ model_name,
65
+ device="cuda",
66
+ compile=True,
67
+ )
68
+
69
+ model_config = models[model_name]
70
+
71
+ @spaces.GPU
72
+ def run_sw_guidance(
73
+ num_inference_steps: int,
74
+ num_guided_steps_perc: float,
75
+ guidance_scale: float,
76
+ sw_u_lr: float,
77
+ sw_steps: int,
78
+ num_projections: int,
79
+ control_variates: str,
80
+ distance: str,
81
+ candidates_per_pass: int,
82
+ subsampling_factor: int,
83
+ sampling_mode: str,
84
+ cfg_rescale_phi: float,
85
+ prompt: str,
86
+ reference_image: np.ndarray,
87
+ seed: Optional[int] = None,
88
+ ):
89
+ """
90
+ Runs the SW guidance process with the given parameters.
91
+ """
92
+ if reference_image is None:
93
+ raise gr.Error("Please provide a reference image")
94
+ if not prompt:
95
+ raise gr.Error("Please provide a prompt")
96
+
97
+ # Convert numpy array to PIL Image
98
+ ref_img = Image.fromarray(reference_image)
99
+
100
+ # Run SW guidance
101
+ image = sw_guidance_run(
102
+ prompt=prompt,
103
+ reference_image=ref_img,
104
+ model_path=model_name,
105
+ num_inference_steps=num_inference_steps,
106
+ num_guided_steps=int(num_guided_steps_perc * num_inference_steps),
107
+ guidance_scale=guidance_scale,
108
+ sw_u_lr=log_slider_to_lr(sw_u_lr),
109
+ sw_steps=sw_steps,
110
+ height=1024,
111
+ width=1024,
112
+ device="cuda",
113
+ num_projections=num_projections,
114
+ use_ucv=control_variates == "Upper",
115
+ use_lcv=control_variates == "Lower",
116
+ distance=distance,
117
+ num_new_candidates=candidates_per_pass,
118
+ subsampling_factor=subsampling_factor,
119
+ sampling_mode=sampling_mode,
120
+ pipe=pipe,
121
+ compile=True,
122
+ seed=seed,
123
+ cfg_rescale_phi=cfg_rescale_phi,
124
+ )
125
+
126
+ return np.array(image)
127
+
128
+ gr.Markdown("## Input")
129
+ with gr.Row(equal_height=True):
130
+ with gr.Column(variant="panel"):
131
+ prompt = gr.Textbox(
132
+ label="Prompt",
133
+ placeholder="Enter your prompt here...",
134
+ lines=2,
135
+ )
136
+ reference_image = gr.Image(label="Reference Image", height=512)
137
+
138
+ with gr.Column(variant="panel"):
139
+ output_image = gr.Image(label="Generated Image", height=512)
140
+
141
+ example_pairs = [
142
+ ("A diver discovering an underwater city", "sunken_boat.jpg"),
143
+ ("A raccoon in a forest", "waterfall.jpg"),
144
+ ("A cat detective solving a mystery", "building.jpg"),
145
+ ("A raccoon reading a book by candlelight.", "food_stand.jpg"),
146
+ ("A lion meditating on a mountain", "mountain.jpg"),
147
+ ("A squirrel kayaking", "lake_green.jpg"),
148
+ ("A young dragon roasting marshmallows", "canyon.jpg"),
149
+ ("A family picnic beneath floating lanterns", "lake_sunset.jpg"),
150
+ ("Elephants holding umbrellas in a rainstorm", "fisher.jpg"),
151
+ ("A boy and his dog exploring a crystal cave", "lake.jpg"),
152
+ ("A snowman sharing cocoa with woodland animals", "snow.jpg"),
153
+ ("A kitten exploring an antique library", "ornament.jpg"),
154
+ ("Children riding flying bicycles over mountains", "sky_pier.jpg"),
155
+ ("An ancient tree whispering stories to deer", "greenhouse_2.jpg"),
156
+ ("Owls with monocles in treetops", "path.jpg"),
157
+ ]
158
+
159
+ default_config = {
160
+ "num_guided_steps_perc": 0.95,
161
+ "num_projections": 32,
162
+ "control_variates": "None",
163
+ "distance": "l1",
164
+ "candidates_per_pass": 8,
165
+ "subsampling_factor": 1,
166
+ "sampling_mode": "gaussian",
167
+ } | model_config
168
+
169
+ def run_example(prompt: str, reference_image: np.ndarray):
170
+ return run_sw_guidance(
171
+ **default_config,
172
+ prompt=prompt,
173
+ reference_image=reference_image,
174
+ )
175
+
176
+ example_inputs = [
177
+ [prompt, os.path.join("example", "guidance", img_file)]
178
+ for prompt, img_file in example_pairs
179
+ ]
180
+
181
+ run_button = gr.Button("Generate Image", variant="primary")
182
+
183
+ gr.Examples(
184
+ examples=example_inputs,
185
+ inputs=[prompt, reference_image],
186
+ outputs=[output_image],
187
+ fn=run_example,
188
+ label="Prompt + Reference Image Examples",
189
+ examples_per_page=5,
190
+ cache_examples=True,
191
+ cache_mode="lazy",
192
+ )
193
+
194
+ gr.Markdown(
195
+ """
196
+ # Configuration
197
+ Adjust the parameters for the SW guidance process.
198
+ """
199
+ )
200
+
201
+ with gr.Accordion("Basic Config", open=True):
202
+ num_inference_steps_slider = gr.Slider(
203
+ 1,
204
+ 100,
205
+ value=model_config["num_inference_steps"],
206
+ step=1,
207
+ label="Number of inference steps",
208
+ )
209
+
210
+ guidance_scale_slider = gr.Slider(
211
+ 0.0,
212
+ 20.0,
213
+ value=model_config["guidance_scale"],
214
+ step=0.1,
215
+ label="Guidance scale",
216
+ )
217
+
218
+ cfg_rescale_phi_slider = gr.Slider(
219
+ 0.0,
220
+ 1.0,
221
+ value=model_config["cfg_rescale_phi"],
222
+ step=0.05,
223
+ label="CFG Rescale Phi",
224
+ info="Controls the rescaling of classifier-free guidance",
225
+ )
226
+
227
+ seed_input = gr.Number(
228
+ value=lambda: None,
229
+ label="Seed (leave empty for random)",
230
+ precision=0,
231
+ interactive=True,
232
+ )
233
+
234
+ with gr.Row(variant="panel", equal_height=True):
235
+ sw_u_lr_slider = gr.Slider(
236
+ minimum=log_lrs.min(),
237
+ maximum=log_lrs.max(),
238
+ value=model_config["sw_u_lr"],
239
+ step=0.05,
240
+ label="SW guidance learning rate (Log scale)",
241
+ interactive=True,
242
+ scale=4,
243
+ )
244
+ lr_display = gr.Textbox(
245
+ label="Learning Rate",
246
+ value=f"{log_slider_to_lr(model_config['sw_u_lr']):.1e}",
247
+ interactive=False,
248
+ scale=1,
249
+ )
250
+ sw_u_lr_slider.change(
251
+ lambda x: gr.update(value=f"{log_slider_to_lr(x):.1e}"),
252
+ inputs=sw_u_lr_slider,
253
+ outputs=lr_display,
254
+ show_progress=False,
255
+ )
256
+
257
+ with gr.Accordion("Advanced Config", open=False):
258
+ num_guided_steps_perc_slider = gr.Slider(
259
+ 0.0,
260
+ 1.0,
261
+ value=default_config["num_guided_steps_perc"],
262
+ step=0.05,
263
+ label="Percentage of steps to apply SW guidance",
264
+ )
265
+ sw_steps_slider = gr.Slider(
266
+ 0,
267
+ 32,
268
+ value=model_config["sw_steps"],
269
+ step=1,
270
+ label="Number of SW guidance steps, 0 means no SW guidance",
271
+ )
272
+ num_projections_slider = gr.Slider(
273
+ 16,
274
+ 1024,
275
+ value=default_config["num_projections"],
276
+ step=16,
277
+ label="Number of projections",
278
+ )
279
+ control_variates_dropdown = gr.Dropdown(
280
+ choices=["None", "Lower", "Upper"],
281
+ value=default_config["control_variates"],
282
+ label="Control Variates",
283
+ info="Select which control variates to use for optimization",
284
+ )
285
+ distance_dropdown = gr.Dropdown(
286
+ choices=["l1", "l2"],
287
+ value=default_config["distance"],
288
+ label="Distance metric",
289
+ info="Select which distance metric to use",
290
+ )
291
+ candidates_per_pass_slider = gr.Slider(
292
+ 0,
293
+ 64,
294
+ value=default_config["candidates_per_pass"],
295
+ step=1,
296
+ label="Number of new candidates per pass. 0 means no reservoir sampling",
297
+ )
298
+ subsampling_factor_slider = gr.Slider(
299
+ 1,
300
+ 16,
301
+ value=default_config["subsampling_factor"],
302
+ step=1,
303
+ label="Subsampling factor",
304
+ )
305
+ sampling_mode_dropdown = gr.Dropdown(
306
+ choices=["gaussian", "qmc"],
307
+ value="qmc",
308
+ label="Sampling Mode",
309
+ info="Select which sampling mode to use for projections",
310
+ )
311
+
312
+ run_button.click(
313
+ run_sw_guidance,
314
+ inputs=[
315
+ num_inference_steps_slider,
316
+ num_guided_steps_perc_slider,
317
+ guidance_scale_slider,
318
+ sw_u_lr_slider,
319
+ sw_steps_slider,
320
+ num_projections_slider,
321
+ control_variates_dropdown,
322
+ distance_dropdown,
323
+ candidates_per_pass_slider,
324
+ subsampling_factor_slider,
325
+ sampling_mode_dropdown,
326
+ cfg_rescale_phi_slider,
327
+ prompt,
328
+ reference_image,
329
+ seed_input,
330
+ ],
331
+ outputs=[output_image],
332
+ )
333
+
334
+ clear_button = gr.ClearButton([prompt, reference_image, output_image, seed_input])
335
+
336
+ clear_button.click(lambda: None)
src/loss/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .abstract_loss import AbstractLoss
2
+ from .vector_swd import VectorSWDLoss
3
+
4
+ __all__ = ["AbstractLoss", "VectorSWDLoss"]
src/loss/abstract_loss.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from jaxtyping import Float
6
+
7
+
8
+ class AbstractLoss(nn.Module, abc.ABC):
9
+ @abc.abstractmethod
10
+ def forward(
11
+ self,
12
+ pred: Float[torch.Tensor, "B C H W"],
13
+ gt: Float[torch.Tensor, "B C H W"],
14
+ step: int,
15
+ **kwargs,
16
+ ) -> Float[torch.Tensor, ""]:
17
+ pass
18
+
19
+ def reset(self):
20
+ pass
src/loss/vector_swd.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from jaxtyping import Float
6
+
7
+ from src.loss.abstract_loss import AbstractLoss
8
+ from src.utils.math import sobol_sphere
9
+
10
+
11
+ def process_vector(
12
+ x: Float[torch.Tensor, "B D N"],
13
+ dirs: Float[torch.Tensor, "K D"],
14
+ ) -> Float[torch.Tensor, "K B*N_valid"]:
15
+ """
16
+ Project a 1-D sequence with a bank of linear directions.
17
+
18
+ Args
19
+ ----
20
+ x : (B, D, N) tensor – predictions or ground truth
21
+ dirs : (K, D) tensor – unit-length projection directions
22
+
23
+ Returns
24
+ -------
25
+ proj : (K, B*N_valid) tensor of flattened projections
26
+ """
27
+ B, D, N = x.shape
28
+ K, _ = dirs.shape
29
+
30
+ # linear projection: x (B,D,N) -> (B,N,K) -> (K,B*N)
31
+ proj = F.linear(x.transpose(1, 2).to(torch.float32), dirs.to(torch.float32))
32
+ proj = proj.permute(2, 0, 1).reshape(K, -1).to(x.dtype)
33
+
34
+ return proj
35
+
36
+
37
+ class VectorSWDLoss(AbstractLoss):
38
+ """
39
+ 1-D Sliced-Wasserstein Distance on sequences.
40
+
41
+ This loss computes the sliced Wasserstein distance between predicted and ground
42
+ truth sequences by projecting them onto random directions and computing the
43
+ Wasserstein distance in 1D. It supports reservoir sampling for adaptive direction
44
+ selection and various variance reduction techniques.
45
+
46
+ Parameters
47
+ ----------
48
+ num_proj : int, default=64
49
+ Number of random projections to use per step (K).
50
+
51
+ distance : {"l1", "l2"}, default="l1"
52
+ Distance metric to use for computing the Wasserstein distance.
53
+
54
+ use_ucv : bool, default=False
55
+ Whether to use upper bounds control variates for variance reduction.
56
+ Mutually exclusive with use_lcv.
57
+
58
+ use_lcv : bool, default=False
59
+ Whether to use lower bounds control variates for variance reduction.
60
+ Mutually exclusive with use_ucv.
61
+
62
+ refresh_projections_every_n_steps : int, default=1
63
+ How often to refresh the projection directions. A value of 1 means
64
+ refresh every step, higher values reuse directions for multiple steps.
65
+
66
+ num_new_candidates : int, default=16
67
+ Number of new candidate directions to generate per step (M).
68
+ If 0, reservoir sampling is disabled. Must not exceed num_proj.
69
+
70
+ ess_alpha : float, default=0.5
71
+ Effective sample size threshold for resetting the reservoir.
72
+ When ESS drops below ess_alpha * reservoir_size, the reservoir is reset.
73
+
74
+ time_decay_tau : float or None, default=30.0
75
+ Time decay parameter for reservoir weights. If None, no time decay is applied.
76
+ Weights decay exponentially with age: exp(-age / time_decay_tau).
77
+
78
+ missing_value_method : {"random_replicate", "interpolate"},
79
+ default="random_replicate"
80
+ Method for handling sequences of different lengths:
81
+ - "random_replicate": Randomly replicate shorter sequences
82
+ - "interpolate": Use linear interpolation to match lengths
83
+
84
+ sampling_mode : {"gaussian", "qmc"}, default="qmc"
85
+ Method for generating random projection directions:
86
+ - "gaussian": Standard Gaussian sampling
87
+ - "qmc": Quasi-Monte Carlo sampling using Sobol sequences
88
+
89
+ Notes
90
+ -----
91
+ - Reservoir sampling is enabled when num_new_candidates > 0
92
+ - Reservoir size = num_proj - num_new_candidates
93
+ - When use_ucv or use_lcv is True, variance reduction is applied using
94
+ control variates based on the difference between sample and population means
95
+ - The loss automatically handles sequences of different lengths using the
96
+ specified missing_value_method
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ num_proj: int = 64,
102
+ distance: typing.Literal["l1", "l2"] = "l1",
103
+ use_ucv: bool = False,
104
+ use_lcv: bool = False,
105
+ refresh_projections_every_n_steps: int = 1,
106
+ num_new_candidates: int = 16,
107
+ ess_alpha: float = 0.5,
108
+ time_decay_tau: float | None = 30.0,
109
+ missing_value_method: typing.Literal[
110
+ "random_replicate", "interpolate"
111
+ ] = "random_replicate",
112
+ sampling_mode: typing.Literal[
113
+ "gaussian",
114
+ "qmc",
115
+ ] = "qmc",
116
+ ):
117
+ super().__init__()
118
+
119
+ assert not (use_ucv and use_lcv), "use_ucv and use_lcv cannot both be True"
120
+
121
+ self.num_proj = num_proj
122
+ self.distance = distance
123
+ self.use_ucv = use_ucv
124
+ self.use_lcv = use_lcv
125
+
126
+ self.refresh_projections_every_n_steps = refresh_projections_every_n_steps
127
+ self.num_new_candidates = num_new_candidates # M
128
+ self.ess_alpha = ess_alpha
129
+ self.time_decay_tau = time_decay_tau
130
+ self.missing_value_method = missing_value_method
131
+
132
+ if num_new_candidates > 0 and self.refresh_projections_every_n_steps != 1:
133
+ # Print a warning that this is not recommended
134
+ print(
135
+ "WARNING: num_new_candidates > 0 (enabling reservoir sampling) and "
136
+ "refresh_projections_every_n_steps != 1 is not recommended"
137
+ )
138
+ assert (
139
+ num_new_candidates <= num_proj
140
+ ), "`num_new_candidates` must not exceed `num_proj`"
141
+
142
+ # internal state for reservoir sampling
143
+ self.restir_enabled = self.num_new_candidates > 0
144
+ self.reservoir_size = self.num_proj - self.num_new_candidates
145
+ self.register_buffer("_reservoir_filters", torch.empty(0))
146
+ self.register_buffer("_reservoir_weights", torch.empty(0))
147
+ self.register_buffer("_reservoir_steps", torch.empty(0, dtype=torch.long))
148
+ self.register_buffer("_reservoir_keys", torch.empty(0))
149
+ self.register_buffer("_cumulative_weights", torch.tensor(0.0))
150
+ self.register_buffer("_has_reservoir", torch.tensor(False, dtype=torch.bool))
151
+
152
+ self._cached_dirs: typing.Optional[torch.Tensor] = None
153
+ self.sampling_mode = sampling_mode
154
+ self.sobol_engine = None
155
+
156
+ def _gaussian_proposals(self, k: int, d: int, device: torch.device) -> torch.Tensor:
157
+ """Generate Gaussian random projection directions."""
158
+ w = torch.randn(k, d, device=device)
159
+ return w / (w.norm(dim=1, keepdim=True) + 1e-8) # unit length
160
+
161
+ def _qmc_proposals(self, k: int, d: int, device: torch.device) -> torch.Tensor:
162
+ """Generate quasi-Monte Carlo projection directions using Sobol sequences."""
163
+ vecs, self.sobol_engine = sobol_sphere(k, d, device, self.sobol_engine)
164
+ return vecs.view(k, d)
165
+
166
+ def _draw_dirs(self, k: int, d: int, device: torch.device) -> torch.Tensor:
167
+ """Draw projection directions using the specified sampling mode."""
168
+ if self.sampling_mode == "gaussian":
169
+ return self._gaussian_proposals(k, d, device)
170
+ if self.sampling_mode == "qmc":
171
+ return self._qmc_proposals(k, d, device)
172
+ raise ValueError("bad sampling_mode")
173
+
174
+ @staticmethod
175
+ def _duplicate_to_match(a: torch.Tensor, b: torch.Tensor, method: str):
176
+ """
177
+ Make two tensors have the same length by duplicating the shorter one.
178
+
179
+ Args
180
+ ----
181
+ a, b : (K, N₁) and (K, Nβ‚‚) tensors
182
+ method : "random_replicate" or "interpolate"
183
+
184
+ Returns
185
+ -------
186
+ a, b : Tensors with matching second dimension
187
+ """
188
+ if a.shape[1] == b.shape[1]:
189
+ return a, b
190
+ if a.shape[1] < b.shape[1]:
191
+ a, b = b, a # swap so that `a` is the larger
192
+
193
+ K, NA = a.shape
194
+ NB = b.shape[1]
195
+
196
+ # repeat / interpolate B until it matches A
197
+ if method == "random_replicate":
198
+ repeats = NA // NB
199
+ b = torch.cat([b] * repeats, dim=1)
200
+ if b.shape[1] < NA:
201
+ idx = torch.randint(0, NB, (NA - b.shape[1],), device=b.device)
202
+ b = torch.cat([b, b[:, idx]], dim=1)
203
+ else: # interpolate
204
+ b = F.interpolate(
205
+ b.unsqueeze(0), size=(NA,), mode="linear", align_corners=False
206
+ ).squeeze(0)
207
+ return a, b
208
+
209
+ def reset(self):
210
+ """Reset the reservoir sampling state."""
211
+ if self.restir_enabled:
212
+ self._reservoir_filters = torch.empty(0)
213
+ self._reservoir_weights = torch.empty(0)
214
+ self._cumulative_weights.data.fill_(0)
215
+ self._has_reservoir.fill_(False)
216
+ self._reservoir_steps = torch.empty(0, dtype=torch.long)
217
+ self._reservoir_keys = torch.empty(0)
218
+
219
+ def _wrs_multi(
220
+ self, filters: torch.Tensor, weights: torch.Tensor, step: int
221
+ ) -> torch.Tensor:
222
+ """
223
+ Weighted reservoir sampling that keeps exactly self.reservoir_size samples and
224
+ returns their indices inside the concatenated candidate set.
225
+
226
+ Args
227
+ ----
228
+ filters : (K+M, D) tensor of candidate directions
229
+ weights : (K+M,) tensor of importance weights
230
+ step : Current training step
231
+
232
+ Returns
233
+ -------
234
+ keep_idx : Indices of kept samples
235
+ keep_w : Normalized weights of kept samples
236
+ """
237
+ R = self.reservoir_size
238
+ device = weights.device
239
+
240
+ u = torch.rand_like(weights)
241
+ keys = u.pow(1.0 / weights.clamp_min(1e-9))
242
+
243
+ if not self._has_reservoir.item():
244
+ self._reservoir_filters = filters[:R]
245
+ self._reservoir_weights = weights[:R]
246
+ self._reservoir_keys = keys[:R]
247
+ self._reservoir_steps = torch.full(
248
+ (R,), step, dtype=torch.long, device=device
249
+ )
250
+ self._has_reservoir.fill_(True)
251
+
252
+ new_filters = filters[R:]
253
+ new_keys = keys[R:]
254
+ new_weights = weights[R:]
255
+ new_steps = torch.full(
256
+ (new_filters.size(0),), step, dtype=torch.long, device=device
257
+ )
258
+
259
+ all_filters = torch.cat([self._reservoir_filters, new_filters], 0)
260
+ all_keys = torch.cat([self._reservoir_keys, new_keys], 0)
261
+ all_weights = torch.cat([self._reservoir_weights, new_weights], 0)
262
+ all_steps = torch.cat([self._reservoir_steps, new_steps], 0)
263
+
264
+ topk_keys, topk_idx = torch.topk(all_keys, R, largest=True)
265
+
266
+ self._reservoir_filters = all_filters[topk_idx]
267
+ self._reservoir_weights = all_weights[topk_idx]
268
+ self._reservoir_keys = topk_keys
269
+ self._reservoir_steps = all_steps[topk_idx]
270
+
271
+ # indices w.r.t. current cand_dirs (old R first, then new M)
272
+ keep_idx = torch.cat(
273
+ [
274
+ torch.arange(R, device=device),
275
+ torch.arange(R, R + new_filters.size(0), device=device),
276
+ ]
277
+ )[topk_idx]
278
+ keep_w = self._reservoir_weights / self._reservoir_weights.sum().clamp_min(
279
+ 1e-12
280
+ )
281
+ return keep_idx, keep_w
282
+
283
+ def _apply_time_decay(self, step: int):
284
+ """
285
+ Apply exponential time decay to stored reservoir weights.
286
+
287
+ Args
288
+ ----
289
+ step : Current training step
290
+ """
291
+ if self.time_decay_tau is None or not self._has_reservoir.item():
292
+ return
293
+ age = (step - self._reservoir_steps).to(torch.float32)
294
+ decay = torch.exp(-age / self.time_decay_tau).to(self._reservoir_weights.dtype)
295
+ self._reservoir_weights.mul_(decay)
296
+ self._reservoir_keys.mul_(decay) # preserve ordering consistency
297
+
298
+ def forward(
299
+ self,
300
+ pred: Float[torch.Tensor, "B D N"],
301
+ gt: Float[torch.Tensor, "B D N"],
302
+ step: int,
303
+ ):
304
+ """
305
+ Compute the sliced Wasserstein distance between predicted and ground truth
306
+ sequences.
307
+
308
+ Args
309
+ ----
310
+ pred : (B, D, N) tensor of predicted sequences
311
+ gt : (B, D, N) tensor of ground truth sequences
312
+ step : Current training step for reservoir sampling
313
+
314
+ Returns
315
+ -------
316
+ loss : Scalar tensor containing the computed loss
317
+ """
318
+ B, D, N = pred.shape
319
+ K = self.num_proj
320
+ M = self.num_new_candidates
321
+ R = self.reservoir_size
322
+ device = pred.device
323
+ gt = gt.detach()
324
+
325
+ self._apply_time_decay(step)
326
+
327
+ # Get candidate directions
328
+ if step % self.refresh_projections_every_n_steps == 0:
329
+ new_dirs = self._draw_dirs(
330
+ M if self.restir_enabled and self._has_reservoir.item() else K,
331
+ D,
332
+ device,
333
+ )
334
+ self._cached_dirs = new_dirs
335
+ else:
336
+ new_dirs = self._cached_dirs
337
+
338
+ if self.restir_enabled and self._has_reservoir.item():
339
+ cand_dirs = torch.cat(
340
+ [self._reservoir_filters, new_dirs], dim=0
341
+ ) # [K+M, C,P,P]
342
+ else:
343
+ cand_dirs = new_dirs
344
+
345
+ # Project sequences
346
+ cand_pred = process_vector(pred, cand_dirs)
347
+ cand_gt = process_vector(gt, cand_dirs)
348
+
349
+ cand_pred, cand_gt = self._duplicate_to_match(
350
+ cand_pred, cand_gt, self.missing_value_method
351
+ )
352
+
353
+ cand_pred = cand_pred.sort(dim=1).values
354
+ cand_gt = cand_gt.sort(dim=1).values
355
+
356
+ # Select K directions (reservoir) & importance weights
357
+ if self.restir_enabled:
358
+ with torch.no_grad():
359
+ base = cand_pred - cand_gt
360
+ base = base.abs() if self.distance == "l1" else base.square()
361
+ ris_weights = base.mean(1) # (K+M)
362
+ keep_idx, keep_w = self._wrs_multi(cand_dirs, ris_weights, step)
363
+
364
+ w = keep_w
365
+ w_hat = keep_w
366
+
367
+ dirs = cand_dirs[keep_idx]
368
+ proj_pred = cand_pred[keep_idx]
369
+ proj_gt = cand_gt[keep_idx]
370
+ else:
371
+ dirs = cand_dirs
372
+ proj_pred = cand_pred
373
+ proj_gt = cand_gt
374
+ w = torch.full((dirs.shape[0],), 1.0 / K, device=device)
375
+
376
+ # Compute SWD
377
+ diff = proj_pred - proj_gt
378
+ diff = diff.abs() if self.distance == "l1" else diff.square()
379
+ per_slice = diff.mean(1) # (L,)
380
+
381
+ if self.use_ucv or self.use_lcv:
382
+ X_vecs = pred.permute(0, 2, 1).reshape(-1, D) # (BΒ·N, D)
383
+ Y_vecs = gt.permute(0, 2, 1).reshape(-1, D) # (BΒ·N, D)
384
+
385
+ m1 = X_vecs.mean(0) # (D,)
386
+ m2 = Y_vecs.mean(0)
387
+ diff_m = m1 - m2 # (D,)
388
+
389
+ theta = dirs # (L, D) already unit-norm
390
+
391
+ if self.use_ucv:
392
+ diff_X = X_vecs - m1
393
+ diff_Y = Y_vecs - m2
394
+
395
+ d = D
396
+ trSigX = diff_X.pow(2).mean()
397
+ trSigY = diff_Y.pow(2).mean()
398
+ G_bar = (diff_m @ diff_m) / d + (trSigX + trSigY)
399
+
400
+ delta2 = (theta @ diff_m) ** 2 # (L,)
401
+
402
+ proj_X = diff_X @ theta.t() # (BΒ·N, L)
403
+ proj_Y = diff_Y @ theta.t()
404
+ varX = proj_X.pow(2).mean(0) # (L,)
405
+ varY = proj_Y.pow(2).mean(0)
406
+ G_hat = delta2 + varX + varY
407
+ else: # LCV
408
+ d = D
409
+ G_bar = (diff_m @ diff_m) / d
410
+ G_hat = (theta @ diff_m) ** 2
411
+
412
+ diff_hat_G_mean_G = G_hat - G_bar
413
+
414
+ hat_A = (w * per_slice).sum()
415
+ var_G = (w * diff_hat_G_mean_G.pow(2)).sum()
416
+ cov_AG = (w * (per_slice - hat_A) * diff_hat_G_mean_G).sum()
417
+ hat_alpha = cov_AG / (var_G + 1e-12)
418
+ loss = hat_A - hat_alpha * (w * diff_hat_G_mean_G).sum()
419
+ else:
420
+ loss = (w * per_slice).sum()
421
+
422
+ # Reservoir update
423
+ if self.restir_enabled and self.ess_alpha > 0:
424
+ with torch.no_grad():
425
+ ess = (w_hat.sum().square()) / (w_hat.square().sum() + 1e-12)
426
+ ess = torch.nan_to_num(ess, nan=0.0, posinf=R, neginf=0.0).item()
427
+ if ess < self.ess_alpha * R:
428
+ print(f"ESS: {ess} is less than {self.ess_alpha * R}, resetting")
429
+ self.reset()
430
+
431
+ return loss
src/sw_sdthree_guidance.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation of the SW Guidance method with our enhanced SWD implementation
2
+ # See: https://github.com/alobashev/sw-guidance/ for the original implementation
3
+ #
4
+ # Alexander Lobashev, Maria Larchenko, Dmitry Guskov
5
+ # Color Conditional Generation with Sliced Wasserstein Guidance
6
+ # https://arxiv.org/abs/2503.19034
7
+
8
+
9
+ import gc
10
+ import os
11
+ from typing import Any, Callable, Dict, List, Literal, Optional, Union
12
+
13
+ import numpy as np
14
+ import PIL
15
+ import torch
16
+ from diffusers import (
17
+ FlowMatchEulerDiscreteScheduler,
18
+ StableDiffusion3Pipeline,
19
+ )
20
+ from diffusers.image_processor import PipelineImageInput
21
+ from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import (
22
+ XLA_AVAILABLE,
23
+ StableDiffusion3PipelineOutput,
24
+ calculate_shift,
25
+ retrieve_timesteps,
26
+ )
27
+
28
+ from src.loss.vector_swd import VectorSWDLoss
29
+ from src.utils.color_space import rgb_to_lab
30
+ from src.utils.image import from_torch, write_img
31
+
32
+ if XLA_AVAILABLE:
33
+ from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import xm
34
+
35
+
36
+ def _no_grad_noise(model, *args, **kw):
37
+ """Forward pass with grad disabled; result is returned detached."""
38
+ with torch.no_grad():
39
+ return model(*args, **kw, return_dict=False)[0].detach()
40
+
41
+
42
+ # ---------------- explicit pipeline forward call
43
+ class SWStableDiffusion3Pipeline(StableDiffusion3Pipeline):
44
+ swd: VectorSWDLoss = None
45
+
46
+ def setup_swd(
47
+ self,
48
+ num_projections: int = 64,
49
+ use_ucv: bool = False,
50
+ use_lcv: bool = False,
51
+ distance: Literal["l1", "l2"] = "l1",
52
+ num_new_candidates: int = 32,
53
+ subsampling_factor: int = 1,
54
+ sampling_mode: Literal["gaussian", "qmc"] = "qmc",
55
+ ):
56
+ self.swd = VectorSWDLoss(
57
+ num_proj=num_projections,
58
+ distance=distance,
59
+ use_ucv=use_ucv,
60
+ use_lcv=use_lcv,
61
+ num_new_candidates=num_new_candidates,
62
+ missing_value_method="interpolate",
63
+ ess_alpha=-1,
64
+ sampling_mode=sampling_mode,
65
+ ).to(self.device)
66
+ self.subsampling_factor = subsampling_factor
67
+
68
+ def do_sw_guidance(
69
+ self,
70
+ sw_steps,
71
+ sw_u_lr,
72
+ latents,
73
+ t,
74
+ prompt_embeds,
75
+ pooled_prompt_embeds,
76
+ pixels_ref,
77
+ cur_iter_step,
78
+ write_video_animation_path: Optional[str] = None,
79
+ ):
80
+ if sw_steps == 0:
81
+ return latents
82
+
83
+ if latents.shape[0] != prompt_embeds.shape[0]:
84
+ prompt_embeds = prompt_embeds[1].unsqueeze(0)
85
+ pooled_prompt_embeds = pooled_prompt_embeds[1].unsqueeze(0)
86
+
87
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
88
+ timestep = t.expand(latents.shape[0])
89
+
90
+ pixels_ref = (
91
+ rgb_to_lab(pixels_ref.unsqueeze(0).clamp(0, 1).permute(0, 3, 1, 2))
92
+ .permute(0, 2, 3, 1)
93
+ .contiguous()
94
+ )
95
+
96
+ csc_scaler = torch.tensor(
97
+ [100, 2 * 128, 2 * 128], dtype=torch.bfloat16, device=latents.device
98
+ ).view(1, 3, 1)
99
+ csc_bias = torch.tensor(
100
+ [0, 0.5, 0.5], dtype=torch.bfloat16, device=latents.device
101
+ ).view(1, 3, 1)
102
+
103
+ u = torch.nn.Parameter(
104
+ torch.zeros_like(latents, dtype=torch.bfloat16, device=latents.device)
105
+ )
106
+ optimizer = torch.optim.Adam([u], lr=sw_u_lr)
107
+
108
+ for tt in range(sw_steps):
109
+ optimizer.zero_grad()
110
+
111
+ x_hat_t = latents.detach() + u
112
+ noise_pred = _no_grad_noise(
113
+ self.transformer,
114
+ hidden_states=x_hat_t,
115
+ timestep=timestep,
116
+ encoder_hidden_states=prompt_embeds,
117
+ pooled_projections=pooled_prompt_embeds,
118
+ joint_attention_kwargs=self.joint_attention_kwargs,
119
+ )
120
+
121
+ # ------------ Compute x_0
122
+ sigma_t = self.scheduler.sigmas[
123
+ self.scheduler.index_for_timestep(t)
124
+ ] # scalar
125
+ while sigma_t.ndim < x_hat_t.ndim:
126
+ sigma_t = sigma_t.unsqueeze(-1)
127
+ sigma_t = sigma_t.to(x_hat_t.dtype).to(latents.device)
128
+
129
+ x_0 = x_hat_t - sigma_t * noise_pred
130
+
131
+ # ------------ Compute loss
132
+ img_unscaled = self.vae.decode(
133
+ (x_0 / self.vae.config.scaling_factor) + self.vae.config.shift_factor,
134
+ return_dict=False,
135
+ )[0]
136
+ image = (img_unscaled * 0.5 + 0.5).clamp(0, 1)
137
+ image_matched = (
138
+ rgb_to_lab(image.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous()
139
+ )
140
+ # reshape to (B, D, N) where D=3, N = H*W
141
+ pred_seq = image_matched.view(1, 3, -1) / csc_scaler + csc_bias
142
+ ref_seq = pixels_ref.view(1, 3, -1) / csc_scaler + csc_bias
143
+
144
+ # Apply subsampling if enabled
145
+ if self.subsampling_factor > 1:
146
+ pred_seq = pred_seq[..., :: self.subsampling_factor]
147
+ ref_seq = ref_seq[..., :: self.subsampling_factor]
148
+
149
+ loss = self.swd(pred=pred_seq, gt=ref_seq, step=tt)
150
+
151
+ loss.backward()
152
+ optimizer.step()
153
+
154
+ if write_video_animation_path is not None:
155
+ frame_idx = cur_iter_step * sw_steps + tt
156
+ write_img(
157
+ os.path.join(write_video_animation_path, f"{frame_idx:05d}.jpg"),
158
+ from_torch(img_unscaled.squeeze(0)),
159
+ )
160
+
161
+ latents = latents.detach() + u.detach()
162
+
163
+ gc.collect()
164
+ torch.cuda.empty_cache()
165
+ return latents
166
+
167
+ def __call__(
168
+ self,
169
+ sw_reference: PIL.Image = None,
170
+ sw_steps: int = 8,
171
+ sw_u_lr: float = 0.05 * 10**3,
172
+ num_guided_steps: int = None,
173
+ # -----------------------------------
174
+ prompt: Union[str, List[str]] = None,
175
+ prompt_2: Optional[Union[str, List[str]]] = None,
176
+ prompt_3: Optional[Union[str, List[str]]] = None,
177
+ height: Optional[int] = None,
178
+ width: Optional[int] = None,
179
+ num_inference_steps: int = 28,
180
+ sigmas: Optional[List[float]] = None,
181
+ guidance_scale: float = 7.0,
182
+ cfg_rescale_phi: float = 0.7,
183
+ negative_prompt: Optional[Union[str, List[str]]] = None,
184
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
185
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
186
+ num_images_per_prompt: Optional[int] = 1,
187
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
188
+ latents: Optional[torch.FloatTensor] = None,
189
+ prompt_embeds: Optional[torch.FloatTensor] = None,
190
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
191
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
192
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
193
+ ip_adapter_image: Optional[PipelineImageInput] = None,
194
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
195
+ output_type: Optional[str] = "pil",
196
+ return_dict: bool = True,
197
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
198
+ clip_skip: Optional[int] = None,
199
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
200
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
201
+ max_sequence_length: int = 256,
202
+ skip_guidance_layers: List[int] = None,
203
+ skip_layer_guidance_scale: float = 2.8,
204
+ skip_layer_guidance_stop: float = 0.2,
205
+ skip_layer_guidance_start: float = 0.01,
206
+ mu: Optional[float] = None,
207
+ write_video_animation_path: Optional[str] = None,
208
+ ):
209
+ assert self.swd is not None, "SWD not initialized"
210
+
211
+ height = height or self.default_sample_size * self.vae_scale_factor
212
+ width = width or self.default_sample_size * self.vae_scale_factor
213
+
214
+ # 1. Check inputs. Raise error if not correct
215
+ self.check_inputs(
216
+ prompt,
217
+ prompt_2,
218
+ prompt_3,
219
+ height,
220
+ width,
221
+ negative_prompt=negative_prompt,
222
+ negative_prompt_2=negative_prompt_2,
223
+ negative_prompt_3=negative_prompt_3,
224
+ prompt_embeds=prompt_embeds,
225
+ negative_prompt_embeds=negative_prompt_embeds,
226
+ pooled_prompt_embeds=pooled_prompt_embeds,
227
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
228
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
229
+ max_sequence_length=max_sequence_length,
230
+ )
231
+
232
+ self._guidance_scale = guidance_scale
233
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
234
+ self._clip_skip = clip_skip
235
+ self._joint_attention_kwargs = joint_attention_kwargs
236
+ self._interrupt = False
237
+
238
+ # 2. Define call parameters
239
+ if prompt is not None and isinstance(prompt, str):
240
+ batch_size = 1
241
+ elif prompt is not None and isinstance(prompt, list):
242
+ batch_size = len(prompt)
243
+ else:
244
+ batch_size = prompt_embeds.shape[0]
245
+
246
+ device = self._execution_device
247
+
248
+ lora_scale = (
249
+ self.joint_attention_kwargs.get("scale", None)
250
+ if self.joint_attention_kwargs is not None
251
+ else None
252
+ )
253
+ (
254
+ prompt_embeds,
255
+ negative_prompt_embeds,
256
+ pooled_prompt_embeds,
257
+ negative_pooled_prompt_embeds,
258
+ ) = self.encode_prompt(
259
+ prompt=prompt,
260
+ prompt_2=prompt_2,
261
+ prompt_3=prompt_3,
262
+ negative_prompt=negative_prompt,
263
+ negative_prompt_2=negative_prompt_2,
264
+ negative_prompt_3=negative_prompt_3,
265
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
266
+ prompt_embeds=prompt_embeds,
267
+ negative_prompt_embeds=negative_prompt_embeds,
268
+ pooled_prompt_embeds=pooled_prompt_embeds,
269
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
270
+ device=device,
271
+ clip_skip=self.clip_skip,
272
+ num_images_per_prompt=num_images_per_prompt,
273
+ max_sequence_length=max_sequence_length,
274
+ lora_scale=lora_scale,
275
+ )
276
+
277
+ if self.do_classifier_free_guidance:
278
+ if skip_guidance_layers is not None:
279
+ original_prompt_embeds = prompt_embeds
280
+ original_pooled_prompt_embeds = pooled_prompt_embeds
281
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
282
+ pooled_prompt_embeds = torch.cat(
283
+ [negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0
284
+ )
285
+
286
+ # 4. Prepare latent variables
287
+ num_channels_latents = self.transformer.config.in_channels
288
+ latents = self.prepare_latents(
289
+ batch_size * num_images_per_prompt,
290
+ num_channels_latents,
291
+ height,
292
+ width,
293
+ prompt_embeds.dtype,
294
+ device,
295
+ generator,
296
+ latents,
297
+ )
298
+
299
+ # 5. Prepare timesteps
300
+ scheduler_kwargs = {}
301
+ if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
302
+ _, _, height, width = latents.shape
303
+ image_seq_len = (height // self.transformer.config.patch_size) * (
304
+ width // self.transformer.config.patch_size
305
+ )
306
+ mu = calculate_shift(
307
+ image_seq_len,
308
+ self.scheduler.config.get("base_image_seq_len", 256),
309
+ self.scheduler.config.get("max_image_seq_len", 4096),
310
+ self.scheduler.config.get("base_shift", 0.5),
311
+ self.scheduler.config.get("max_shift", 1.16),
312
+ )
313
+ scheduler_kwargs["mu"] = mu
314
+ elif mu is not None:
315
+ scheduler_kwargs["mu"] = mu
316
+ timesteps, num_inference_steps = retrieve_timesteps(
317
+ self.scheduler,
318
+ num_inference_steps,
319
+ device,
320
+ sigmas=sigmas,
321
+ **scheduler_kwargs,
322
+ )
323
+ num_warmup_steps = max(
324
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
325
+ )
326
+ self._num_timesteps = len(timesteps)
327
+
328
+ # 6. Prepare image embeddings
329
+ if (
330
+ ip_adapter_image is not None and self.is_ip_adapter_active
331
+ ) or ip_adapter_image_embeds is not None:
332
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
333
+ ip_adapter_image,
334
+ ip_adapter_image_embeds,
335
+ device,
336
+ batch_size * num_images_per_prompt,
337
+ self.do_classifier_free_guidance,
338
+ )
339
+
340
+ if self.joint_attention_kwargs is None:
341
+ self._joint_attention_kwargs = {
342
+ "ip_adapter_image_embeds": ip_adapter_image_embeds
343
+ }
344
+ else:
345
+ self._joint_attention_kwargs.update(
346
+ ip_adapter_image_embeds=ip_adapter_image_embeds
347
+ )
348
+
349
+ if sw_reference is not None:
350
+ # Resize so the reference is maximal width or height of the output image
351
+
352
+ target_max_size = max(height, width)
353
+ reference_max_size = max(sw_reference.width, sw_reference.height)
354
+ scale_factor = target_max_size / reference_max_size
355
+
356
+ sw_reference = sw_reference.resize(
357
+ (
358
+ int(sw_reference.width * scale_factor),
359
+ int(sw_reference.height * scale_factor),
360
+ )
361
+ )
362
+ pixels_ref = (
363
+ torch.Tensor(np.array(sw_reference).astype(np.float32) / 255)
364
+ .permute(2, 0, 1)
365
+ .to(device)
366
+ .to(torch.bfloat16)
367
+ )
368
+
369
+ # 7. Denoising loop
370
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
371
+ for i, t in enumerate(timesteps):
372
+ if self.interrupt:
373
+ continue
374
+
375
+ # broadcast to batch dimension in a way that's compatible
376
+ # with ONNX/Core ML
377
+ timestep = t.expand(latents.shape[0])
378
+
379
+ # SW Guidance
380
+ if sw_reference is not None:
381
+ if num_guided_steps is None or i < num_guided_steps:
382
+ latents = self.do_sw_guidance(
383
+ sw_steps,
384
+ sw_u_lr,
385
+ latents,
386
+ t,
387
+ prompt_embeds,
388
+ pooled_prompt_embeds,
389
+ pixels_ref,
390
+ cur_iter_step=i,
391
+ write_video_animation_path=write_video_animation_path,
392
+ )
393
+ if i == num_guided_steps // 2:
394
+ self.swd.reset()
395
+
396
+ # expand the latents if we are doing classifier free guidance
397
+ latent_model_input = (
398
+ torch.cat([latents] * 2)
399
+ if self.do_classifier_free_guidance
400
+ else latents
401
+ )
402
+
403
+ with torch.no_grad():
404
+ timestep = t.expand(latent_model_input.shape[0])
405
+
406
+ noise_pred = self.transformer(
407
+ hidden_states=latent_model_input,
408
+ timestep=timestep,
409
+ encoder_hidden_states=prompt_embeds,
410
+ pooled_projections=pooled_prompt_embeds,
411
+ joint_attention_kwargs=self.joint_attention_kwargs,
412
+ return_dict=False,
413
+ )[0]
414
+
415
+ # perform guidance
416
+ if self.do_classifier_free_guidance:
417
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
418
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
419
+ noise_pred_text - noise_pred_uncond
420
+ )
421
+
422
+ should_skip_layers = (
423
+ True
424
+ if i > num_inference_steps * skip_layer_guidance_start
425
+ and i < num_inference_steps * skip_layer_guidance_stop
426
+ else False
427
+ )
428
+ if skip_guidance_layers is not None and should_skip_layers:
429
+ timestep = t.expand(latents.shape[0])
430
+ latent_model_input = latents
431
+ noise_pred_skip_layers = self.transformer(
432
+ hidden_states=latent_model_input,
433
+ timestep=timestep,
434
+ encoder_hidden_states=original_prompt_embeds,
435
+ pooled_projections=original_pooled_prompt_embeds,
436
+ joint_attention_kwargs=self.joint_attention_kwargs,
437
+ return_dict=False,
438
+ skip_layers=skip_guidance_layers,
439
+ )[0]
440
+ noise_pred = (
441
+ noise_pred
442
+ + (noise_pred_text - noise_pred_skip_layers)
443
+ * self._skip_layer_guidance_scale
444
+ )
445
+
446
+ # Based on Sec. 3.4 of Lin, Liu, Li, Yang -
447
+ # Common Diffusion Noise Schedules and Sample Steps are Flawed
448
+ # https://arxiv.org/abs/2305.08891
449
+ # While Flow matching is free of most issues, a high CFG scale
450
+ # can still cause over-exposure issues as discussed in the work.
451
+ if cfg_rescale_phi is not None and cfg_rescale_phi > 0:
452
+ # Οƒ_pos and Οƒ_cfg are per-sample (BΓ—1Γ—1Γ—1) stdevs
453
+ sigma_pos = noise_pred_text.std(dim=(1, 2, 3), keepdim=True)
454
+ sigma_cfg = noise_pred.std(dim=(1, 2, 3), keepdim=True)
455
+
456
+ # Linear blend between the raw ratio and 1,
457
+ # cf. Eq. (15–16) in the paper
458
+ factor = torch.lerp(
459
+ sigma_pos / (sigma_cfg + 1e-8), # avoid div-by-zero
460
+ torch.ones_like(sigma_cfg),
461
+ 1.0 - cfg_rescale_phi,
462
+ )
463
+ noise_pred = noise_pred * factor
464
+ else:
465
+ noise_pred = noise_pred
466
+
467
+ # compute the previous noisy sample x_t -> x_t-1
468
+ latents_dtype = latents.dtype
469
+ latents = self.scheduler.step(
470
+ noise_pred, t, latents, return_dict=False
471
+ )[0]
472
+
473
+ if latents.dtype != latents_dtype:
474
+ if torch.backends.mps.is_available():
475
+ # some platforms (eg. apple mps) misbehave due to a
476
+ # pytorch bug: https://github.com/pytorch/pytorch/pull/99272
477
+ latents = latents.to(latents_dtype)
478
+
479
+ if callback_on_step_end is not None:
480
+ callback_kwargs = {}
481
+ for k in callback_on_step_end_tensor_inputs:
482
+ callback_kwargs[k] = locals()[k]
483
+ callback_outputs = callback_on_step_end(
484
+ self, i, t, callback_kwargs
485
+ )
486
+
487
+ latents = callback_outputs.pop("latents", latents)
488
+ prompt_embeds = callback_outputs.pop(
489
+ "prompt_embeds", prompt_embeds
490
+ )
491
+ negative_prompt_embeds = callback_outputs.pop(
492
+ "negative_prompt_embeds", negative_prompt_embeds
493
+ )
494
+ negative_pooled_prompt_embeds = callback_outputs.pop(
495
+ "negative_pooled_prompt_embeds",
496
+ negative_pooled_prompt_embeds,
497
+ )
498
+
499
+ if write_video_animation_path is not None and i >= num_guided_steps:
500
+ with torch.no_grad():
501
+ image = self.vae.decode(
502
+ (latents / self.vae.config.scaling_factor)
503
+ + self.vae.config.shift_factor,
504
+ return_dict=False,
505
+ )[0]
506
+ cur_frame_idx = i * sw_steps
507
+ write_img(
508
+ os.path.join(
509
+ write_video_animation_path,
510
+ f"{cur_frame_idx:05d}.jpg",
511
+ ),
512
+ from_torch(image.squeeze(0)),
513
+ )
514
+
515
+ # call the callback, if provided
516
+ if i == len(timesteps) - 1 or (
517
+ (i + 1) > num_warmup_steps
518
+ and (i + 1) % self.scheduler.order == 0
519
+ ):
520
+ progress_bar.update()
521
+
522
+ if XLA_AVAILABLE:
523
+ xm.mark_step()
524
+
525
+ if output_type == "latent":
526
+ image = latents
527
+
528
+ else:
529
+ latents = (
530
+ latents / self.vae.config.scaling_factor
531
+ ) + self.vae.config.shift_factor
532
+
533
+ image = self.vae.decode(latents, return_dict=False)[0]
534
+ image = self.image_processor.postprocess(
535
+ image.detach(), output_type=output_type
536
+ )
537
+
538
+ # Offload all models
539
+ self.maybe_free_model_hooks()
540
+
541
+ if not return_dict:
542
+ return (image,)
543
+
544
+ return StableDiffusion3PipelineOutput(images=image)
545
+
546
+
547
+ def run(
548
+ prompt: str,
549
+ reference_image: PIL.Image.Image,
550
+ model_path: str,
551
+ num_inference_steps: int = 30,
552
+ num_guided_steps: int = 28,
553
+ guidance_scale: float = 5.0,
554
+ cfg_rescale_phi: float = 0.7,
555
+ sw_u_lr: float = 3e-3,
556
+ sw_steps: int = 8,
557
+ height: int = 768,
558
+ width: int = 768,
559
+ device: str = "cuda",
560
+ seed: Optional[int] = None,
561
+ # Add new SW-related parameters
562
+ num_projections: int = 64,
563
+ use_ucv: bool = False,
564
+ use_lcv: bool = False,
565
+ distance: Literal["l1", "l2"] = "l1",
566
+ num_new_candidates: int = 32,
567
+ subsampling_factor: int = 1,
568
+ sampling_mode: Literal["gaussian", "qmc"] = "gaussian",
569
+ pipe: Optional[SWStableDiffusion3Pipeline] = None,
570
+ compile: bool = False,
571
+ video_animation_path: Optional[str] = None,
572
+ ) -> PIL.Image.Image:
573
+ """
574
+ Generate an image using SW Guidance with a given prompt and reference image.
575
+
576
+ Args:
577
+ prompt (str): Text prompt to guide the generation
578
+ reference_image (PIL.Image.Image): Reference image to guide the generation
579
+ model_path (str): Path to the model weights
580
+ num_inference_steps (int): Number of denoising steps
581
+ num_guided_steps (int): Number of steps to apply SW guidance
582
+ guidance_scale (float): Scale for classifier-free guidance
583
+ cfg_rescale_phi (float): Rescale factor for classifier-free guidance
584
+ sw_u_lr (float): Learning rate for SW guidance
585
+ sw_steps (int): Number of steps to apply SW guidance
586
+ height (int): Output image height
587
+ width (int): Output image width
588
+ device (str): Device to run the model on
589
+ num_projections (int): Number of random projections for VectorSWDLoss
590
+ use_ucv (bool): Use UCV variant of VectorSWDLoss
591
+ use_lcv (bool): Use LCV variant of VectorSWDLoss
592
+ distance (str): Distance metric for VectorSWDLoss ("l1" or "l2")
593
+ refresh_projections_every_n_steps (int): How often to refresh projections
594
+ num_new_candidates (int): Number of new candidates for the reservoir
595
+ subsampling_factor (int): Factor to subsample points for SW computation.
596
+ Higher values reduce memory usage but may affect quality.
597
+ sampling_mode (str): Sampling mode for VectorSWDLoss.
598
+ pipe (SWStableDiffusion3Pipeline): Pipeline to use for generation.
599
+ If None, a new pipeline is created.
600
+ compile (bool): Whether to compile the pipeline.
601
+
602
+ Returns:
603
+ PIL.Image.Image: Generated image
604
+ """
605
+ # Normalize device to torch.device for robustness
606
+ device = torch.device(device) if not isinstance(device, torch.device) else device
607
+ if pipe is None:
608
+ pipe = create_pipeline(model_path, device, compile=compile)
609
+
610
+ pipe.setup_swd(
611
+ num_projections=num_projections,
612
+ use_ucv=use_ucv,
613
+ use_lcv=use_lcv,
614
+ distance=distance,
615
+ num_new_candidates=num_new_candidates,
616
+ subsampling_factor=subsampling_factor,
617
+ sampling_mode=sampling_mode,
618
+ )
619
+
620
+ if seed is not None:
621
+ print(f"Using seed: {seed}")
622
+ generator = torch.Generator(device=device).manual_seed(seed)
623
+ else:
624
+ generator = None
625
+
626
+ image = pipe(
627
+ prompt=prompt,
628
+ num_inference_steps=num_inference_steps,
629
+ num_guided_steps=num_guided_steps,
630
+ guidance_scale=guidance_scale,
631
+ cfg_rescale_phi=cfg_rescale_phi,
632
+ sw_u_lr=sw_u_lr,
633
+ sw_steps=sw_steps,
634
+ height=height,
635
+ width=width,
636
+ sw_reference=reference_image,
637
+ generator=generator,
638
+ write_video_animation_path=video_animation_path,
639
+ ).images[0]
640
+
641
+ return image
642
+
643
+
644
+ def create_pipeline(model_path, device: str = "cuda", compile: bool = False):
645
+ pipe = SWStableDiffusion3Pipeline.from_pretrained(
646
+ model_path,
647
+ torch_dtype=torch.bfloat16,
648
+ )
649
+ pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)
650
+ pipe.to(device)
651
+ if compile:
652
+ pipe.transformer = torch.compile(pipe.transformer)
653
+ pipe.vae.decoder = torch.compile(pipe.vae.decoder)
654
+ return pipe
src/utils/asc_cdl.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from jaxtyping import Float
6
+ from lxml import etree
7
+
8
+
9
+ def load_asc_cdl(cdl_path: str, device: torch.device = torch.device("cpu")) -> dict:
10
+ """
11
+ Loads ASC CDL parameters from an XML file.
12
+
13
+ Parameters:
14
+ cdl_path (str): Path to the ASC CDL XML file
15
+
16
+ Returns:
17
+ Dict:
18
+ slope, offset, power, and saturation values as torch tensors
19
+ """
20
+ try:
21
+ tree = etree.parse(cdl_path)
22
+ root = tree.getroot()
23
+ except Exception as e:
24
+ raise ValueError(f"Error loading ASC CDL from {cdl_path}: {e}")
25
+
26
+ # Extract SOP values
27
+ sop_node = root.find(".//SOPNode")
28
+ slope = torch.tensor(
29
+ [float(x) for x in sop_node.find("Slope").text.split()], device=device
30
+ )
31
+ offset = torch.tensor(
32
+ [float(x) for x in sop_node.find("Offset").text.split()], device=device
33
+ )
34
+ power = torch.tensor(
35
+ [float(x) for x in sop_node.find("Power").text.split()], device=device
36
+ )
37
+
38
+ # Extract Saturation value
39
+ sat_node = root.find(".//SatNode")
40
+ saturation = torch.tensor(float(sat_node.find("Saturation").text), device=device)
41
+
42
+ return {"slope": slope, "offset": offset, "power": power, "saturation": saturation}
43
+
44
+
45
+ def save_asc_cdl(cdl_dict: dict, cdl_path: Optional[str]):
46
+ """
47
+ Saves ASC CDL parameters to an XML file.
48
+
49
+ Parameters:
50
+ cdl_dict (dict): Dictionary containing slope, offset, power, and
51
+ saturation values
52
+ """
53
+ root = etree.Element("ASC_CDL")
54
+ sop_node = etree.SubElement(root, "SOPNode")
55
+ etree.SubElement(sop_node, "Slope").text = " ".join(
56
+ str(x) for x in cdl_dict["slope"].detach().cpu().numpy()
57
+ )
58
+ etree.SubElement(sop_node, "Offset").text = " ".join(
59
+ str(x) for x in cdl_dict["offset"].detach().cpu().numpy()
60
+ )
61
+ etree.SubElement(sop_node, "Power").text = " ".join(
62
+ str(x) for x in cdl_dict["power"].detach().cpu().numpy()
63
+ )
64
+ sat_node = etree.SubElement(root, "SatNode")
65
+ etree.SubElement(sat_node, "Saturation").text = str(
66
+ cdl_dict["saturation"].detach().cpu().numpy()
67
+ )
68
+
69
+ tree = etree.ElementTree(root)
70
+ if cdl_path is not None:
71
+ try:
72
+ tree.write(
73
+ cdl_path, pretty_print=True, xml_declaration=True, encoding="utf-8"
74
+ )
75
+ except Exception as e:
76
+ raise ValueError(f"Error saving ASC CDL to {cdl_path}: {e}")
77
+ else:
78
+ return etree.tostring(
79
+ root, pretty_print=True, xml_declaration=True, encoding="utf-8"
80
+ ).decode("utf-8")
81
+
82
+
83
+ def apply_sop(
84
+ img: Float[torch.Tensor, "*B C H W"],
85
+ slope: Float[torch.Tensor, "*B C"],
86
+ offset: Float[torch.Tensor, "*B C"],
87
+ power: Float[torch.Tensor, "*B C"],
88
+ clamp: bool = True,
89
+ ) -> Float[torch.Tensor, "*B C H W"]:
90
+ """
91
+ Applies Slope, Offset, and Power adjustments.
92
+
93
+ Parameters:
94
+ img (torch.Tensor): Input image tensor (*B, C, H, W)
95
+ slope (torch.Tensor): Slope per channel (*B, C)
96
+ offset (torch.Tensor): Offset per channel (*B, C)
97
+ power (torch.Tensor): Power per channel (*B, C)
98
+
99
+ Returns:
100
+ torch.Tensor: Image after SOP adjustments.
101
+ """
102
+ so = img * slope.unsqueeze(-1).unsqueeze(-1) + offset.unsqueeze(-1).unsqueeze(-1)
103
+ if clamp:
104
+ so = torch.clamp(so, min=0.0, max=1.0)
105
+ return torch.where(
106
+ so > 1e-7, torch.pow(so.clamp(min=1e-7), power.unsqueeze(-1).unsqueeze(-1)), so
107
+ )
108
+
109
+
110
+ def apply_saturation(
111
+ img: Float[torch.Tensor, "*B C H W"],
112
+ saturation: Float[torch.Tensor, "*B"],
113
+ ) -> Float[torch.Tensor, "*B C H W"]:
114
+ """
115
+ Applies saturation adjustment.
116
+
117
+ Parameters:
118
+ img (torch.Tensor): Image tensor (*B, C, H, W)
119
+ saturation (torch.Tensor): Saturation factor (*B)
120
+
121
+ Returns:
122
+ torch.Tensor: Image after saturation adjustment.
123
+ """
124
+ # Calculate luminance using Rec. 709 coefficients
125
+ lum = (
126
+ 0.2126 * img[..., 0, :, :]
127
+ + 0.7152 * img[..., 1, :, :]
128
+ + 0.0722 * img[..., 2, :, :]
129
+ )
130
+ lum = lum.unsqueeze(-3) # Add channel dimension
131
+ return lum + (img - lum) * saturation.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
132
+
133
+
134
+ def asc_cdl_forward(
135
+ img: Float[torch.Tensor, "*B C H W"],
136
+ slope: Float[torch.Tensor, "*B C"],
137
+ offset: Float[torch.Tensor, "*B C"],
138
+ power: Float[torch.Tensor, "*B C"],
139
+ saturation: Float[torch.Tensor, "*B"],
140
+ clamp: bool = True,
141
+ ) -> Float[torch.Tensor, "*B C H W"]:
142
+ """
143
+ Applies ASC CDL transformation in Fwd or FwdNoClamp mode.
144
+
145
+ Parameters:
146
+ img (torch.Tensor): Input image tensor (*B, C, H, W)
147
+ slope (torch.Tensor): Slope per channel (*B, C)
148
+ offset (torch.Tensor): Offset per channel (*B, C)
149
+ power (torch.Tensor): Power per channel (*B, C)
150
+ saturation (torch.Tensor): Saturation factor (*B)
151
+ clamp (bool): If True, clamps output to [0, 1] (Fwd mode).
152
+ If False, no clamping (FwdNoClamp mode).
153
+
154
+ Returns:
155
+ torch.Tensor: Transformed image tensor.
156
+ """
157
+ # Add warning if saturation, slope, power are below 0
158
+ if (saturation < 0).any():
159
+ warnings.warn("Saturation is below 0, this will result in a color shift.")
160
+ if (slope < 0).any():
161
+ warnings.warn("Slope is below 0, this will result in a color shift.")
162
+ if (power < 0).any():
163
+ warnings.warn("Power is below 0, this will result in a color shift.")
164
+
165
+ img_batch_dim = img.shape[:-3]
166
+ # Check if slope, offset, power, saturation have the same batch dimension
167
+ # If they do not have any batch dimensions, add a single batch dimensions
168
+ if slope.ndim == 1:
169
+ slope = slope.view(*[1] * len(img_batch_dim), *slope.shape)
170
+ if offset.ndim == 1:
171
+ offset = offset.view(*[1] * len(img_batch_dim), *offset.shape)
172
+ if power.ndim == 1:
173
+ power = power.view(*[1] * len(img_batch_dim), *power.shape)
174
+ if saturation.ndim == 0:
175
+ saturation = saturation.view(*[1] * len(img_batch_dim), *saturation.shape)
176
+
177
+ # Now check that the lengths are matching
178
+ assert slope.ndim == len(img_batch_dim) + 1
179
+ assert offset.ndim == len(img_batch_dim) + 1
180
+ assert power.ndim == len(img_batch_dim) + 1
181
+ assert saturation.ndim == len(img_batch_dim)
182
+
183
+ # Apply Slope, Offset, and Power adjustments
184
+ img = apply_sop(img, slope, offset, power, clamp=clamp)
185
+ # print("img after sop", img.min(), img.max())
186
+ # Apply Saturation adjustment
187
+ img = apply_saturation(img, saturation)
188
+ # print("img after saturation", img.min(), img.max())
189
+ # Clamp if in Fwd mode
190
+ if clamp:
191
+ img = torch.clamp(img, 0.0, 1.0)
192
+ return img
193
+
194
+
195
+ def inverse_saturation(
196
+ img: Float[torch.Tensor, "*B C H W"],
197
+ saturation: Float[torch.Tensor, "*B"],
198
+ ) -> Float[torch.Tensor, "*B C H W"]:
199
+ """
200
+ Reverts saturation adjustment.
201
+
202
+ Parameters:
203
+ img (torch.Tensor): Image tensor (*B, C, H, W)
204
+ saturation (torch.Tensor): Saturation factor (*B)
205
+
206
+ Returns:
207
+ torch.Tensor: Image after reversing saturation adjustment.
208
+ """
209
+ # Calculate luminance using Rec. 709 coefficients
210
+ lum = (
211
+ 0.2126 * img[..., 0, :, :]
212
+ + 0.7152 * img[..., 1, :, :]
213
+ + 0.0722 * img[..., 2, :, :]
214
+ )
215
+ lum = lum.unsqueeze(-3) # Add channel dimension
216
+ return lum + (img - lum) / saturation.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
217
+
218
+
219
+ def asc_cdl_reverse(
220
+ img: Float[torch.Tensor, "*B C H W"],
221
+ slope: Float[torch.Tensor, "*B C"],
222
+ offset: Float[torch.Tensor, "*B C"],
223
+ power: Float[torch.Tensor, "*B C"],
224
+ saturation: Float[torch.Tensor, "*B"],
225
+ clamp: bool = True,
226
+ ) -> Float[torch.Tensor, "*B C H W"]:
227
+ """
228
+ Applies reverse ASC CDL transformation.
229
+
230
+ Parameters:
231
+ img (torch.Tensor): Transformed image tensor (*B, C, H, W)
232
+ slope (torch.Tensor): Slope per channel (*B, C)
233
+ offset (torch.Tensor): Offset per channel (*B, C)
234
+ power (torch.Tensor): Power per channel (*B, C)
235
+ saturation (torch.Tensor): Saturation factor (*B)
236
+ clamp (bool): If True, clamps output to [0, 1].
237
+
238
+ Returns:
239
+ torch.Tensor: Recovered input image tensor.
240
+ """
241
+ # Add warning if saturation, slope, power are below 0
242
+ if (saturation < 0).any():
243
+ warnings.warn("Saturation is below 0, this will result in a color shift.")
244
+ if (slope < 0).any():
245
+ warnings.warn("Slope is below 0, this will result in a color shift.")
246
+ if (power < 0).any():
247
+ warnings.warn("Power is below 0, this will result in a color shift.")
248
+
249
+ img_batch_dim = img.shape[:-3]
250
+ # Check if slope, offset, power, saturation have the same batch dimension
251
+ # If they do not have any batch dimensions, add a single batch dimensions
252
+ if slope.ndim == 1:
253
+ slope = slope.view(*[1] * len(img_batch_dim), *slope.shape)
254
+ if offset.ndim == 1:
255
+ offset = offset.view(*[1] * len(img_batch_dim), *offset.shape)
256
+ if power.ndim == 1:
257
+ power = power.view(*[1] * len(img_batch_dim), *power.shape)
258
+ if saturation.ndim == 0:
259
+ saturation = saturation.view(*[1] * len(img_batch_dim), *saturation.shape)
260
+
261
+ # Now check that the lengths are matching
262
+ assert slope.ndim == len(img_batch_dim) + 1
263
+ assert offset.ndim == len(img_batch_dim) + 1
264
+ assert power.ndim == len(img_batch_dim) + 1
265
+ assert saturation.ndim == len(img_batch_dim)
266
+
267
+ # Inverse Saturation adjustment
268
+ img = inverse_saturation(img, saturation)
269
+ # Inverse SOP adjustments
270
+ if clamp:
271
+ img = torch.clamp(img, 0.0, 1.0)
272
+ img = torch.where(
273
+ img > 1e-7, torch.pow(img, 1 / power.unsqueeze(-1).unsqueeze(-1)), img
274
+ )
275
+ img = (img - offset.unsqueeze(-1).unsqueeze(-1)) / slope.unsqueeze(-1).unsqueeze(-1)
276
+ # Clamp if specified
277
+ if clamp:
278
+ img = torch.clamp(img, 0.0, 1.0)
279
+ return img
src/utils/color_space.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from jaxtyping import Float
3
+
4
+
5
+ def srgb_to_linear(x: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]:
6
+ switch_val = 0.04045
7
+ return torch.where(
8
+ torch.greater(x, switch_val),
9
+ ((x.clip(min=switch_val) + 0.055) / 1.055).pow(2.4),
10
+ x / 12.92,
11
+ )
12
+
13
+
14
+ def linear_to_srgb(x: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]:
15
+ switch_val = 0.0031308
16
+ return torch.where(
17
+ torch.greater(x, switch_val),
18
+ 1.055 * x.clip(min=switch_val).pow(1.0 / 2.4) - 0.055,
19
+ x * 12.92,
20
+ )
21
+
22
+
23
+ def rgb_to_lab(srgb: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]:
24
+ srgb_pixels = torch.reshape(srgb, [-1, 3])
25
+
26
+ linear_mask = srgb_pixels <= 0.04045
27
+ exponential_mask = srgb_pixels > 0.04045
28
+ rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (
29
+ ((srgb_pixels + 0.055) / 1.055) ** 2.4
30
+ ) * exponential_mask
31
+
32
+ rgb_to_xyz = (
33
+ torch.tensor(
34
+ [
35
+ # X Y Z
36
+ [0.412453, 0.212671, 0.019334], # R
37
+ [0.357580, 0.715160, 0.119193], # G
38
+ [0.180423, 0.072169, 0.950227], # B
39
+ ]
40
+ )
41
+ .to(srgb.dtype)
42
+ .to(srgb.device)
43
+ )
44
+
45
+ xyz_pixels = torch.mm(rgb_pixels, rgb_to_xyz)
46
+
47
+ xyz_normalized_pixels = torch.mul(
48
+ xyz_pixels,
49
+ torch.tensor([1 / 0.950456, 1.0, 1 / 1.088754]).to(srgb.dtype).to(srgb.device),
50
+ )
51
+
52
+ epsilon = 6.0 / 29.0
53
+ linear_mask = (xyz_normalized_pixels <= (epsilon**3)).to(srgb.dtype).to(srgb.device)
54
+
55
+ exponential_mask = (
56
+ (xyz_normalized_pixels > (epsilon**3)).to(srgb.dtype).to(srgb.device)
57
+ )
58
+
59
+ fxfyfz_pixels = (
60
+ xyz_normalized_pixels / (3 * epsilon**2) + 4.0 / 29.0
61
+ ) * linear_mask + (
62
+ (xyz_normalized_pixels + 0.000001) ** (1.0 / 3.0)
63
+ ) * exponential_mask
64
+
65
+ fxfyfz_to_lab = (
66
+ torch.tensor(
67
+ [
68
+ # l a b
69
+ [0.0, 500.0, 0.0], # fx
70
+ [116.0, -500.0, 200.0], # fy
71
+ [0.0, 0.0, -200.0], # fz
72
+ ]
73
+ )
74
+ .to(srgb.dtype)
75
+ .to(srgb.device)
76
+ )
77
+ lab_pixels = torch.mm(fxfyfz_pixels, fxfyfz_to_lab) + torch.tensor(
78
+ [-16.0, 0.0, 0.0]
79
+ ).to(srgb.dtype).to(srgb.device)
80
+ return torch.reshape(lab_pixels, srgb.shape)
src/utils/image.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from jaxtyping import Float
5
+
6
+
7
+ def read_img(path):
8
+ img = cv2.imread(str(path), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
9
+ if img.ndim == 3:
10
+ img = cv2.cvtColor(img[..., :3], cv2.COLOR_BGR2RGB)
11
+ elif img.ndim == 2:
12
+ img = img[..., np.newaxis]
13
+ dinfo = np.iinfo(img.dtype)
14
+ return (img.astype(np.float32) / dinfo.max) * 2 - 1
15
+
16
+
17
+ def write_img(path: str, data: np.ndarray):
18
+ data = np.clip(data * 0.5 + 0.5, 0, 1)
19
+ if data.ndim == 3 and data.shape[-1] == 3:
20
+ data = cv2.cvtColor(data, cv2.COLOR_RGB2BGR)
21
+ elif data.ndim == 2:
22
+ data = data[..., np.newaxis]
23
+
24
+ data = (data * 255).astype(np.uint8)
25
+ cv2.imwrite(path, data)
26
+
27
+
28
+ def to_torch(img: Float[np.ndarray, "H W C"]) -> Float[torch.Tensor, "C H W"]:
29
+ return torch.from_numpy(img).permute(2, 0, 1)
30
+
31
+
32
+ def from_torch(img: Float[torch.Tensor, "C H W"]) -> Float[np.ndarray, "H W C"]:
33
+ return img.permute(1, 2, 0).detach().cpu().float().numpy()
src/utils/math.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+
6
+
7
+ def _random_orthonormal_matrix(d: int, device: torch.device) -> torch.Tensor:
8
+ """Draw a random rotation matrix Q ∈ SO(d) (Haar) via QR-factorisation."""
9
+ a = torch.randn(d, d, device=device)
10
+ # QR gives orthonormal columns; ensure right-handed
11
+ q, r = torch.linalg.qr(a, mode="reduced")
12
+ # make determinant +1 (special orthogonal) – flip first column if needed
13
+ if torch.det(q) < 0:
14
+ q[:, 0] = -q[:, 0]
15
+ return q # (d,d)
16
+
17
+
18
+ def sobol_sphere(
19
+ n: int,
20
+ d: int,
21
+ device: torch.device,
22
+ sobol_engine: Optional[torch.quasirandom.SobolEngine] = None,
23
+ ) -> Union[torch.Tensor, torch.quasirandom.SobolEngine]:
24
+ """n unit vectors on S^{d-1} via scrambled Sobol + Gaussian + random rotation."""
25
+ if sobol_engine is None:
26
+ sob = torch.quasirandom.SobolEngine(dimension=d, scramble=True)
27
+ else:
28
+ sob = sobol_engine
29
+ # Draw in [0,1)^d then map β†’ 𝒩(0,1)
30
+ u01 = sob.draw(n).to(device)
31
+
32
+ eps = 1e-7
33
+ u01 = u01.clamp(min=eps, max=1.0 - eps) # avoid 0 and 1 exactly
34
+
35
+ z = torch.erfinv(2.0 * u01 - 1.0) * math.sqrt(2.0) # inverse-CDF of Normal
36
+ z = z / (z.norm(dim=1, keepdim=True) + 1e-8) # project to sphere
37
+ # Random global rotation (RQMC) to make estimator unbiased
38
+ Q = _random_orthonormal_matrix(d, device)
39
+ return z @ Q.T, sob # (n,d)