toshas commited on
Commit
d283d84
·
1 Parent(s): 72cbe32

initial commit

Browse files
LICENSE.txt ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Windowseat Reflection Removal
3
- emoji: 👁
4
- colorFrom: gray
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.0.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
+ title: WindowSeat Reflection Removal
3
+ emoji: 🪟
4
+ colorFrom: purple
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.29.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system("pip freeze")
4
+ import spaces
5
+ import tempfile
6
+ import shutil
7
+
8
+ import gradio as gr
9
+ import torch as torch
10
+ from gradio_dualvision import DualVisionApp
11
+ from huggingface_hub import login
12
+ from PIL import Image
13
+ from windowseat_inference import load_network, run_inference
14
+
15
+ uri_base = "Qwen/Qwen-Image-Edit-2509"
16
+ uri_lora = "huawei-bayerlab/windowseat-reflection-removal-v1-0"
17
+
18
+ if "HF_TOKEN_LOGIN" in os.environ:
19
+ login(token=os.environ["HF_TOKEN_LOGIN"])
20
+
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
23
+
24
+ vae, transformer, embeds_dict, processing_resolution = load_network(uri_base, uri_lora, device)
25
+
26
+ # # As of transformers==4.57.1 , xformers is not supported in QwenImageTransformer2DModel
27
+ # try:
28
+ # transformer.enable_xformers_memory_efficient_attention()
29
+ # print("xformers enabled")
30
+ # except:
31
+ # print("xformers not enabled")
32
+
33
+
34
+ class WindowSeatApp(DualVisionApp):
35
+ DEFAULT_SEED = 2025
36
+
37
+ def make_header(self):
38
+ gr.Markdown(
39
+ """
40
+ ## WindowSeat Reflection Removal
41
+ """
42
+ )
43
+ with gr.Row(elem_classes="remove-elements"):
44
+ gr.Markdown(
45
+ f"""
46
+ <p align="center">
47
+ <a title="Website" href="https://hf.co/spaces/huawei-bayerlab/windowseat-reflection-removal-web" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
48
+ <img src="https://img.shields.io/badge/%E2%99%A5%20Project%20-Website-blue">
49
+ </a>
50
+ <a title="arXiv" href="https://arxiv.org/abs/2512.05000" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
51
+ <img src="https://img.shields.io/badge/%F0%9F%93%84%20arXiv%20-Paper-AF3436">
52
+ </a>
53
+ <a title="Github" href="https://github.com/huawei-bayerlab/windowseat-reflection-removal" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
54
+ <img src="https://img.shields.io/github/stars/huawei-bayerlab/windowseat-reflection-removal?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
55
+ </a>
56
+ <a title="Model weights" href="https://hf.co/huawei-bayerlab/windowseat-reflection-removal-v1-0" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
57
+ <img src="https://img.shields.io/badge/%F0%9F%A4%97%20WindowSeat%20Model%20-Weights-yellow" alt="imagedepth">
58
+ </a>
59
+ <a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
60
+ <img src="https://shields.io/twitter/follow/:?label=Subscribe%20for%20updates!" alt="social">
61
+ </a>
62
+ </p>
63
+ <p align="center" style="margin-top: 0px;">
64
+ Upload a photo or pick an example below to remove reflections, wait for the result, then explore it with the slider.
65
+ If a quota limit appears, duplicate the space to continue.
66
+ </p>
67
+ """
68
+ )
69
+
70
+ def build_user_components(self):
71
+ return {}
72
+
73
+ def process(self, image_in: Image.Image, **kwargs):
74
+ input_temp_dir = tempfile.mkdtemp()
75
+ output_temp_dir = tempfile.mkdtemp()
76
+
77
+ try:
78
+ input_image_path = os.path.join(input_temp_dir, "image.png")
79
+ image_in.save(input_image_path)
80
+ run_inference(
81
+ vae,
82
+ transformer,
83
+ embeds_dict,
84
+ processing_resolution,
85
+ input_temp_dir,
86
+ output_temp_dir,
87
+ use_short_edge_tile=True,
88
+ save_comparison=False,
89
+ save_alternating=False,
90
+ )
91
+ output_image_path = os.path.join(output_temp_dir, "image_windowseat_output.png")
92
+ result_image = Image.open(output_image_path)
93
+ result_image.load()
94
+
95
+ out_modalities = {
96
+ "Result": result_image,
97
+ }
98
+
99
+ out_settings = {}
100
+
101
+ return out_modalities, out_settings
102
+
103
+ finally:
104
+ if os.path.exists(input_temp_dir):
105
+ shutil.rmtree(input_temp_dir)
106
+ if os.path.exists(output_temp_dir):
107
+ shutil.rmtree(output_temp_dir)
108
+
109
+
110
+ with WindowSeatApp(
111
+ title="WindowSeat Reflection Removal",
112
+ examples_path="example_images",
113
+ examples_per_page=12,
114
+ right_selector_visible=False,
115
+ advanced_settings_visible=False,
116
+ squeeze_canvas=True,
117
+ spaces_zero_gpu_enabled=True,
118
+ ) as demo:
119
+ demo.queue(
120
+ api_open=False,
121
+ ).launch(
122
+ server_name="0.0.0.0",
123
+ server_port=7860,
124
+ ssr_mode=False,
125
+ )
example_images/0_bakery.jpg ADDED

Git LFS Details

  • SHA256: 69619874eef8986d8138255d71ea88009b284c99ef235d1cbd779e3c172232c6
  • Pointer size: 130 Bytes
  • Size of remote file: 69 kB
example_images/0_cafe.jpg ADDED

Git LFS Details

  • SHA256: 98e890212b1c3e792842b3800358fbead1d15b94454857bfaa2680be5e5e770b
  • Pointer size: 131 Bytes
  • Size of remote file: 165 kB
example_images/0_car_wheel.jpg ADDED

Git LFS Details

  • SHA256: eab333cbd083cc4cc43552ae8a94c3abbdfe1d5a9f324b748453b1dcb6beb617
  • Pointer size: 131 Bytes
  • Size of remote file: 104 kB
example_images/0_cats.png ADDED

Git LFS Details

  • SHA256: b4f530d181c3a2e7ee856b58ebff7b0bcff4c84a2266016cce1fc714ddaba57c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.71 MB
example_images/0_dog.jpg ADDED

Git LFS Details

  • SHA256: d176019ae3e15012efcf22293d6ed5b4ad6ec314523acb1bcdd3e9f6b679f7cf
  • Pointer size: 130 Bytes
  • Size of remote file: 95.1 kB
example_images/0_entrance.jpg ADDED

Git LFS Details

  • SHA256: 0f3ff3846fda0d9731a6cc3e9fe96c03cda7d9067f9b60cb47d7ef740d38a653
  • Pointer size: 131 Bytes
  • Size of remote file: 169 kB
example_images/0_misty_train.jpg ADDED

Git LFS Details

  • SHA256: d3ef62338ccf42776aa5a3b6412fabef4380e2b8e44cea4a6ef9f6fe5433c94c
  • Pointer size: 130 Bytes
  • Size of remote file: 97.9 kB
example_images/0_museum.jpg ADDED

Git LFS Details

  • SHA256: 8096a520cf18314b3532bbc7ddfc95742791c01645c7825e3e616582eff0d8d4
  • Pointer size: 130 Bytes
  • Size of remote file: 69.6 kB
example_images/0_park_cart.jpg ADDED

Git LFS Details

  • SHA256: 069fa2d7d89f311f24f131eafbda5316b650716c831a69838bb687a0c1fde771
  • Pointer size: 132 Bytes
  • Size of remote file: 3.23 MB
example_images/0_pharaoh.jpg ADDED

Git LFS Details

  • SHA256: 920f0a030a742668b98a97dcc6abd140e1beb6360a11a9c287b40c39a414e514
  • Pointer size: 130 Bytes
  • Size of remote file: 87.1 kB
example_images/0_phone_booth.jpg ADDED

Git LFS Details

  • SHA256: c43c6265905cf9ea3bb8be7f3ed7670fd4532cbaf50ef2c5c2b25913e6db990d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
example_images/0_store_front.jpg ADDED

Git LFS Details

  • SHA256: fb657590961e92a3ad6fcda26024f7a649ef6c6ee1527d4b1afee24bfe28c5ff
  • Pointer size: 130 Bytes
  • Size of remote file: 54.1 kB
example_images/0_uniqlo.jpg ADDED

Git LFS Details

  • SHA256: eaee88a3387cba03dfac9b35a43f4c48156f553112be400bdacd03c4073a6661
  • Pointer size: 131 Bytes
  • Size of remote file: 125 kB
example_images/0_wolf.jpg ADDED

Git LFS Details

  • SHA256: f90ea67fb0da7c58f9ee88cd5caf6326e6758a30a50432a5c1758ce141c39418
  • Pointer size: 130 Bytes
  • Size of remote file: 54.4 kB
example_images/0_zoo.jpg ADDED

Git LFS Details

  • SHA256: c68c29774f0414daa22f76074d9276b889509db46ffe114ecd43cadcba3b24df
  • Pointer size: 130 Bytes
  • Size of remote file: 92.7 kB
example_images/1_window_airplane.png ADDED

Git LFS Details

  • SHA256: 9edda131ec22d7ad6d587f310009d55c62be9b514678da17fb07596660adf7c9
  • Pointer size: 131 Bytes
  • Size of remote file: 605 kB
example_images/1_window_airport.jpg ADDED

Git LFS Details

  • SHA256: 7b6b9c651bf7fb59c86aa3d46e0b0cb68ad2ad6739674cb42f6fd4189a476baa
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
example_images/2_postcards_008.png ADDED

Git LFS Details

  • SHA256: 15388ad5e182ecced2b025f5b7e1069b6712b5ab1576a9fb1a4e86725ed085f2
  • Pointer size: 131 Bytes
  • Size of remote file: 327 kB
example_images/2_postcards_050.png ADDED

Git LFS Details

  • SHA256: 95da1c6345b5b7461595972bad7b48dbdfe90a03112d2a98a6a2a0b05c2f8d59
  • Pointer size: 131 Bytes
  • Size of remote file: 288 kB
example_images/2_real_110.jpg ADDED

Git LFS Details

  • SHA256: 991989f0369307af5633d3c3e370c2ea204d0214b1287a1fc78a64c686eef2c7
  • Pointer size: 130 Bytes
  • Size of remote file: 38 kB
example_images/2_wild_026.jpg ADDED

Git LFS Details

  • SHA256: d1f4865dd3104931aa0ae2f7ba4346dd2b29614cf5633c39b903fbaf345c0ae6
  • Pointer size: 130 Bytes
  • Size of remote file: 25.3 kB
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers>=0.33.0
2
+ # gradio-dualvision @ git+https://github.com/toshas/gradio-dualvision.git@gradio-5.29.0
3
+ gradio_dualvision @ git+https://github.com/toshas/gradio-dualvision.git@59d338e9
4
+ accelerate
5
+ bitsandbytes
6
+ huggingface_hub
7
+ imageio
8
+ imageio-ffmpeg
9
+ peft
10
+ Pillow
11
+ safetensors
12
+ scipy
13
+ torch
14
+ torchvision
15
+ tqdm
16
+ transformers
17
+ # --only-binary=xformers
18
+ # xformers
windowseat_inference.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import json
4
+ import math
5
+ import os
6
+ import sys
7
+ import warnings
8
+
9
+ import imageio.v2 as imageio
10
+ import numpy as np
11
+ import safetensors
12
+ import torch
13
+ import torchvision
14
+ from diffusers import (
15
+ AutoencoderKLQwenImage,
16
+ BitsAndBytesConfig,
17
+ QwenImageEditPipeline,
18
+ QwenImageTransformer2DModel,
19
+ )
20
+ from huggingface_hub import hf_hub_download
21
+ from peft import LoraConfig
22
+ from PIL import Image
23
+ from torch.utils.data import DataLoader, Dataset
24
+ from tqdm import tqdm
25
+
26
+ SUPPORTED_MODEL_URIS = [
27
+ "Qwen/Qwen-Image-Edit-2509",
28
+ ]
29
+ LORA_MODEL_URI = "toshas/WindowSeat-Qwen-Image-Edit-2509"
30
+
31
+
32
+ def fetch_state_dict(
33
+ pretrained_model_name_or_path_or_dict: str,
34
+ weight_name: str,
35
+ use_safetensors: bool = True,
36
+ subfolder: str | None = None,
37
+ ):
38
+ file_path = hf_hub_download(pretrained_model_name_or_path_or_dict, weight_name, subfolder=subfolder)
39
+ if use_safetensors:
40
+ state_dict = safetensors.torch.load_file(file_path)
41
+ else:
42
+ state_dict = torch.load(file_path, weights_only=True)
43
+ return state_dict
44
+
45
+
46
+ def load_qwen_vae(uri: str, device: torch.device):
47
+ vae = AutoencoderKLQwenImage.from_pretrained(
48
+ uri,
49
+ subfolder="vae",
50
+ torch_dtype=torch.bfloat16,
51
+ device_map=device,
52
+ low_cpu_mem_usage=True,
53
+ use_safetensors=True,
54
+ )
55
+ vae.to(device, dtype=torch.bfloat16)
56
+ return vae
57
+
58
+
59
+ def load_qwen_transformer(uri: str, device: torch.device):
60
+ nf4 = BitsAndBytesConfig(
61
+ load_in_4bit=True,
62
+ bnb_4bit_quant_type="nf4",
63
+ bnb_4bit_compute_dtype=torch.bfloat16,
64
+ llm_int8_skip_modules=["transformer_blocks.0.img_mod"],
65
+ )
66
+
67
+ transformer = QwenImageTransformer2DModel.from_pretrained(
68
+ uri,
69
+ subfolder="transformer",
70
+ torch_dtype=torch.bfloat16,
71
+ quantization_config=nf4,
72
+ device_map=device,
73
+ )
74
+
75
+ return transformer
76
+
77
+
78
+ def load_lora_into_transformer(uri: str, transformer: QwenImageTransformer2DModel):
79
+ lora_config = LoraConfig.from_pretrained(uri, subfolder="transformer_lora")
80
+ transformer.add_adapter(lora_config)
81
+ state_dict = fetch_state_dict(uri, "pytorch_lora_weights.safetensors", subfolder="transformer_lora")
82
+ missing, unexpected = transformer.load_state_dict(state_dict, strict=False)
83
+ if len(unexpected) > 0:
84
+ raise ValueError(f"Unexpected keys in transformer state dict: {unexpected}")
85
+ return transformer
86
+
87
+
88
+ def load_embeds_dict(uri: str):
89
+ embeds_dict = fetch_state_dict(uri, "state_dict.safetensors", subfolder="text_embeddings")
90
+ return embeds_dict
91
+
92
+
93
+ def load_network(uri_base: str, uri_lora: str, device: torch.device):
94
+ config_file = hf_hub_download(uri_lora, "model_index.json")
95
+ with open(config_file, "r") as f:
96
+ config_dict = json.load(f)
97
+ base_model_uri = config_dict["base_model"]
98
+ processing_resolution = config_dict["processing_resolution"]
99
+ if base_model_uri not in SUPPORTED_MODEL_URIS:
100
+ raise ValueError(f"Unsupported base model URI: {base_model_uri}")
101
+
102
+ vae = load_qwen_vae(uri_base, device)
103
+ transformer = load_qwen_transformer(uri_base, device)
104
+ load_lora_into_transformer(uri_lora, transformer)
105
+ embeds_dict = load_embeds_dict(uri_lora)
106
+ return vae, transformer, embeds_dict, processing_resolution
107
+
108
+
109
+ def encode(image: torch.Tensor, vae: AutoencoderKLQwenImage) -> torch.Tensor:
110
+ image = image.to(device=vae.device, dtype=vae.dtype)
111
+ out = vae.encode(image.unsqueeze(2)).latent_dist.sample()
112
+ latents_mean = torch.tensor(vae.config.latents_mean, device=out.device, dtype=out.dtype)
113
+ latents_mean = latents_mean.view(1, vae.config.z_dim, 1, 1, 1)
114
+ latents_std_inv = 1.0 / torch.tensor(vae.config.latents_std, device=out.device, dtype=out.dtype)
115
+ latents_std_inv = latents_std_inv.view(1, vae.config.z_dim, 1, 1, 1)
116
+ out = (out - latents_mean) * latents_std_inv
117
+ return out
118
+
119
+
120
+ def decode(latents: torch.Tensor, vae: AutoencoderKLQwenImage) -> torch.Tensor:
121
+ latents_mean = torch.tensor(vae.config.latents_mean, device=latents.device, dtype=latents.dtype)
122
+ latents_mean = latents_mean.view(1, vae.config.z_dim, 1, 1, 1)
123
+ latents_std_inv = (1.0 / torch.tensor(vae.config.latents_std, device=latents.device, dtype=latents.dtype))
124
+ latents_std_inv = latents_std_inv.view(1, vae.config.z_dim, 1, 1, 1)
125
+ latents = latents / latents_std_inv + latents_mean
126
+ out = vae.decode(latents)
127
+ out = out.sample[:, :, 0]
128
+ return out
129
+
130
+
131
+ def _match_batch(t: torch.Tensor, B: int) -> torch.Tensor:
132
+ if t.size(0) == B:
133
+ return t
134
+ if t.size(0) == 1 and B > 1:
135
+ return t.expand(B, *t.shape[1:])
136
+ if t.size(0) > B:
137
+ return t[:B]
138
+ reps = (B + t.size(0) - 1) // t.size(0)
139
+ return t.repeat((reps,) + (1,) * (t.ndim - 1))[:B]
140
+
141
+
142
+ def flow_step(
143
+ model_input: torch.Tensor,
144
+ transformer: QwenImageTransformer2DModel,
145
+ vae: AutoencoderKLQwenImage,
146
+ embeds_dict: dict[str, torch.Tensor],
147
+ ) -> torch.Tensor:
148
+ prompt_embeds = embeds_dict["prompt_embeds"] # [N_ctx, L, D]
149
+ prompt_mask = embeds_dict["prompt_mask"] # [N_ctx, L]
150
+
151
+ if prompt_mask.dtype != torch.bool:
152
+ prompt_mask = prompt_mask > 0
153
+
154
+ # Accept [B, C, 1, H, W] or [B, C, H, W]
155
+ if model_input.ndim == 5 and model_input.shape[2] == 1:
156
+ model_input_4d = model_input[:, :, 0] # [B, C, H, W]
157
+ elif model_input.ndim == 4:
158
+ model_input_4d = model_input
159
+ else:
160
+ raise ValueError(f"Unexpected lat_encoding shape: {model_input.shape}")
161
+
162
+ B, C, H, W = model_input_4d.shape
163
+ device = next(transformer.parameters()).device
164
+
165
+ prompt_embeds = _match_batch(prompt_embeds, B).to(
166
+ device=device, dtype=torch.bfloat16, non_blocking=True
167
+ ) # [B, L, D]
168
+
169
+ prompt_mask = _match_batch(prompt_mask, B).to(
170
+ device=device, dtype=torch.bool, non_blocking=True
171
+ ) # [B, L]
172
+
173
+ num_channels_latents = C
174
+ packed_model_input = QwenImageEditPipeline._pack_latents(
175
+ model_input_4d,
176
+ batch_size=B,
177
+ num_channels_latents=num_channels_latents,
178
+ height=H,
179
+ width=W,
180
+ ) # [B, N_patches, C * 4], where N_patches = (H // 2) * (W // 2)
181
+ packed_model_input = packed_model_input.to(torch.bfloat16)
182
+
183
+ t_const = 499
184
+ timestep = torch.full(
185
+ (B,),
186
+ float(t_const),
187
+ device=device,
188
+ dtype=torch.bfloat16,
189
+ )
190
+ timestep = timestep / 1000.0
191
+
192
+ h_img = H // 2
193
+ w_img = W // 2
194
+
195
+ img_shapes = [[(1, h_img, w_img)]] * B
196
+ txt_seq_lens = prompt_mask.sum(dim=1).tolist() if prompt_mask is not None else None
197
+
198
+ if getattr(transformer, "attention_kwargs", None) is None:
199
+ attention_kwargs = {}
200
+ else:
201
+ attention_kwargs = transformer.attention_kwargs
202
+
203
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
204
+ model_pred = transformer(
205
+ hidden_states=packed_model_input, # [B, N_patches, C*4]
206
+ timestep=timestep, # [B], float / 1000
207
+ encoder_hidden_states=prompt_embeds, # [B, L, D]
208
+ encoder_hidden_states_mask=prompt_mask, # [B, L]
209
+ img_shapes=img_shapes, # single stream per batch
210
+ txt_seq_lens=txt_seq_lens,
211
+ guidance=None,
212
+ attention_kwargs=attention_kwargs,
213
+ return_dict=False,
214
+ )[0] # [B, N_patches, C*4]
215
+
216
+ temperal_downsample = vae.config.get("temperal_downsample", None)
217
+ if temperal_downsample is not None:
218
+ vae_scale_factor = 2 ** len(temperal_downsample)
219
+ else:
220
+ vae_scale_factor = 8
221
+
222
+ model_pred = QwenImageEditPipeline._unpack_latents(
223
+ model_pred,
224
+ height=H * vae_scale_factor, # H, W here are latent H,W from encode
225
+ width=W * vae_scale_factor,
226
+ vae_scale_factor=vae_scale_factor,
227
+ ) # [B, C, 1, H_lat, W_lat]
228
+
229
+ latent_output = model_input.to(vae.dtype) - model_pred.to(vae.dtype)
230
+
231
+ return latent_output
232
+
233
+
234
+ def _supports_color() -> bool:
235
+ return sys.stdout.isatty()
236
+
237
+
238
+ def _style(text: str, *, color: str | None = None, bold: bool = False) -> str:
239
+ if not _supports_color():
240
+ return text
241
+
242
+ codes = []
243
+ if bold:
244
+ codes.append("1")
245
+ if color == "red":
246
+ codes.append("31")
247
+ elif color == "green":
248
+ codes.append("32")
249
+ elif color == "yellow":
250
+ codes.append("33")
251
+ elif color == "blue":
252
+ codes.append("34")
253
+ elif color == "magenta":
254
+ codes.append("35")
255
+ elif color == "cyan":
256
+ codes.append("36")
257
+
258
+ if not codes:
259
+ return text
260
+ return f"\033[{';'.join(codes)}m{text}\033[0m"
261
+
262
+
263
+ def print_banner(title: str):
264
+ title = f" {title} "
265
+ bar = "═" * len(title)
266
+ print(_style(f"╔{bar}╗", color="cyan", bold=True))
267
+ print(_style(f"║{title}║", color="cyan", bold=True))
268
+ print(_style(f"╚{bar}╝", color="cyan", bold=True))
269
+
270
+
271
+ def print_step(step: str, msg: str):
272
+ prefix = _style(f"[{step}] ", color="yellow", bold=True)
273
+ print(prefix + msg)
274
+
275
+
276
+ def print_ok(msg: str):
277
+ print(_style("✔ ", color="green", bold=True) + msg)
278
+
279
+
280
+ def print_info(msg: str):
281
+ print(_style("ℹ ", color="blue", bold=True) + msg)
282
+
283
+
284
+ def print_error(msg: str):
285
+ print(_style("✖ ", color="red", bold=True) + msg)
286
+
287
+
288
+ def print_final_success(output_dir: str):
289
+ print_ok("Inference finished successfully!")
290
+ print_info("Predictions have been written to:")
291
+ print(" " + _style(output_dir, color="cyan", bold=True))
292
+ print(_style("Thank you for trying out WindowSeat! 🪟", color="green"))
293
+
294
+
295
+ def _required_side_for_axis(size: int, nmax: int, min_overlap: int) -> int:
296
+ """Smallest tile side T (1D) so that #tiles <= nmax with overlap >= min_overlap."""
297
+ nmax = max(1, int(nmax))
298
+ if nmax == 1:
299
+ return size
300
+ return math.ceil((size + (nmax - 1) * min_overlap) / nmax)
301
+
302
+
303
+ def _starts(size: int, T: int, min_overlap: int):
304
+ """Uniform stepping with stride = T - min_overlap; last tile flush with edge."""
305
+ if size <= T:
306
+ return [0]
307
+ stride = max(1, T - min_overlap)
308
+ xs = list(range(0, size - T + 1, stride))
309
+ last = size - T
310
+ if xs[-1] != last:
311
+ xs.append(last)
312
+ # monotonic dedupe
313
+ out = []
314
+ for v in xs:
315
+ if not out or v > out[-1]:
316
+ out.append(v)
317
+ return out
318
+
319
+
320
+ class TilingDataset(Dataset):
321
+ def __init__(
322
+ self,
323
+ transform_graph,
324
+ input_folder,
325
+ tiling_w=768,
326
+ tiling_h=768,
327
+ processing_resolution=768,
328
+ max_num_tiles_w=4,
329
+ max_num_tiles_h=4,
330
+ min_overlap_w=64,
331
+ min_overlap_h=64,
332
+ use_short_edge_tile=False,
333
+ **kwargs,
334
+ ) -> None:
335
+ super().__init__()
336
+ self.transform_graph = transform_graph
337
+ self.kwargs = kwargs
338
+ self.disp_name = kwargs.get("disp_name", "tiling_dataset")
339
+
340
+ img_paths = sorted(
341
+ os.path.join(input_folder, f)
342
+ for f in os.listdir(input_folder)
343
+ if os.path.isfile(os.path.join(input_folder, f))
344
+ )
345
+
346
+ self.filenames = []
347
+
348
+ Nw, Nh = int(max_num_tiles_w), int(max_num_tiles_h)
349
+ ow, oh = int(min_overlap_w), int(min_overlap_h)
350
+
351
+ for i, p in enumerate(img_paths):
352
+ with Image.open(p) as im:
353
+ W, H = im.size
354
+
355
+ # Choose preferred tile size for this image
356
+ if use_short_edge_tile:
357
+ short_edge = min(W, H)
358
+ short_edge = max(short_edge, processing_resolution)
359
+ tiling_w_i = short_edge
360
+ tiling_h_i = short_edge
361
+ else:
362
+ tiling_w_i = tiling_w
363
+ tiling_h_i = tiling_h
364
+
365
+ # Optional upscaling if image is smaller than desired tile
366
+ if W < tiling_w_i or H < tiling_h_i:
367
+ min_side = min(W, H)
368
+ scale_ratio = tiling_w_i / min_side
369
+ W = round(scale_ratio * W)
370
+ H = round(scale_ratio * H)
371
+
372
+ pref_side = max(int(tiling_w_i), int(tiling_h_i))
373
+
374
+ # Feasible square-side interval [T_low, T_high]
375
+ T_low = max(
376
+ _required_side_for_axis(W, Nw, ow),
377
+ _required_side_for_axis(H, Nh, oh),
378
+ ow + 1,
379
+ oh + 1,
380
+ )
381
+ T_high = min(W, H)
382
+
383
+ if T_low > T_high:
384
+ msg = (
385
+ f"Infeasible square constraints for {os.path.basename(p)}: "
386
+ f"need T >= {T_low}, but max square inside is {T_high}. "
387
+ f"Relax max_num_tiles_w/h or overlaps, allow non-square tiles, or pad."
388
+ )
389
+ raise ValueError(msg)
390
+ else:
391
+ T = max(T_low, min(pref_side, T_high))
392
+ Tw = Th = T
393
+
394
+ # Build starts with axis-specific tile sizes
395
+ xs = _starts(W, Tw, ow)
396
+ ys = _starts(H, Th, oh)
397
+
398
+ for y0 in ys:
399
+ for x0 in xs:
400
+ x1, y1 = x0 + Tw, y0 + Th
401
+ self.filenames.append([str(p), (x0, y0, x1, y1), False])
402
+
403
+ if self.filenames:
404
+ self.filenames[-1][-1] = True
405
+
406
+ def __len__(self):
407
+ return len(self.filenames)
408
+
409
+ def __getitem__(self, index):
410
+ sample = {}
411
+ sample["line"] = self.filenames[index]
412
+ sample["idx"] = index
413
+ self.transform_graph(sample)
414
+ return sample
415
+
416
+
417
+ def read_scalars(sample):
418
+ scalar_dict = {"tile_info": 1, "is_last_tile": 2}
419
+ for name, col in scalar_dict.items():
420
+ sample[name] = sample["line"][col]
421
+
422
+
423
+ def load_rgb_data(rgb_path, key_prefix="input"):
424
+ rgb = read_rgb_file(rgb_path)
425
+ rgb_norm = rgb / 255.0 * 2.0 - 1.0
426
+ outputs = {
427
+ f"{key_prefix}_int": torch.from_numpy(rgb).int(),
428
+ f"{key_prefix}_norm": torch.from_numpy(rgb_norm),
429
+ }
430
+ return outputs
431
+
432
+
433
+ def read_rgb_file(rgb_path) -> np.ndarray:
434
+ img = Image.open(rgb_path).convert("RGB")
435
+ arr = np.array(img, dtype=np.uint8) # [H, W, 3]
436
+ return arr.transpose(2, 0, 1) # [3, H, W]
437
+
438
+
439
+ def read_rgb_image(sample):
440
+ column = 0
441
+ name = "input"
442
+
443
+ img_path = sample["line"][column]
444
+ img = load_rgb_data(img_path, name)
445
+ sample.update(img)
446
+ sample.setdefault("meta", {})
447
+ sample["meta"]["orig_res"] = [
448
+ sample[name + "_norm"].shape[-2],
449
+ sample[name + "_norm"].shape[-1],
450
+ ]
451
+
452
+
453
+ def _lanczos_resize_chw(x, out_hw):
454
+ H_out, W_out = map(int, out_hw)
455
+
456
+ is_torch = isinstance(x, torch.Tensor)
457
+ if is_torch:
458
+ dev = x.device
459
+ arr = x.detach().cpu().numpy()
460
+ else:
461
+ arr = x
462
+
463
+ assert isinstance(arr, np.ndarray) and arr.ndim == 3, "expect CHW"
464
+ chw = arr.astype(np.float32, copy=False)
465
+ C, _, _ = chw.shape
466
+
467
+ out_chw = np.empty((C, H_out, W_out), dtype=np.float32)
468
+ for c in range(C):
469
+ ch = chw[c]
470
+ img = Image.fromarray(ch).convert("F")
471
+ img = img.resize((W_out, H_out), resample=Image.LANCZOS)
472
+ out_chw[c] = np.asarray(img, dtype=np.float32)
473
+
474
+ if is_torch:
475
+ return torch.from_numpy(out_chw).to(dev)
476
+ return out_chw
477
+
478
+
479
+ def reshape(sample, height, width):
480
+ Ht, Wt = height, width
481
+ for k, v in list(sample.items()):
482
+ if not (torch.is_tensor(v) and v.ndim >= 2) or "orig" in k:
483
+ continue
484
+ x = v.to(torch.float32)
485
+ x = _lanczos_resize_chw(x, (Ht, Wt))
486
+ if v.dtype == torch.bool:
487
+ x = x > 0.5
488
+ elif not torch.is_floating_point(v):
489
+ x = x.round().to(v.dtype)
490
+ sample[k] = x
491
+
492
+ return sample
493
+
494
+
495
+ def tile(sample, processing_resolution: int):
496
+ x0, y0, x1, y1 = map(int, sample["tile_info"])
497
+ processing_width = x1 - x0
498
+ processing_height = y1 - y0
499
+
500
+ # Reshape input while keeping aspect ratio
501
+ H, W = sample["input_norm"].shape[-2:]
502
+ if W < processing_width or H < processing_height:
503
+ min_side = min(W, H)
504
+ scale_ratio = processing_width / min_side
505
+ W = round(scale_ratio * W)
506
+ H = round(scale_ratio * H)
507
+
508
+ reshape(sample, height=H, width=W)
509
+ sample["input_int"] = sample["input_int"][:, y0:y1, x0:x1]
510
+ sample["input_norm"] = sample["input_norm"][:, y0:y1, x0:x1]
511
+ reshape(sample, height=processing_resolution, width=processing_resolution)
512
+
513
+
514
+ @torch.no_grad()
515
+ def validate_single_dataset(
516
+ vae: AutoencoderKLQwenImage,
517
+ transformer: QwenImageTransformer2DModel,
518
+ embeds_dict: dict[str, torch.Tensor],
519
+ data_loader: DataLoader,
520
+ save_to_dir: str = None,
521
+ save_comparison: bool = True,
522
+ save_alternating: bool = True,
523
+ ):
524
+ preds = []
525
+
526
+ for i, batch in enumerate(
527
+ tqdm(data_loader, desc=f"Reflection Removal Progress"),
528
+ start=1,
529
+ ):
530
+ batch["out"] = {}
531
+ with torch.no_grad():
532
+ latents = encode(batch["input_norm"], vae)
533
+ latents = flow_step(latents, transformer, vae, embeds_dict)
534
+ batch["out"]["pixel_pred"] = decode(latents, vae)
535
+
536
+ for b in range(len(batch["idx"])):
537
+ preds.append(
538
+ {
539
+ "file": batch["line"][0][b],
540
+
541
+ # [x0, y0, x1, y1] tuple for the tile
542
+ "tile_info": [batch["tile_info"][i][b] for i in range(4)],
543
+
544
+ # Shape 1, 3, H, W, torch tensor in range -1 to 1
545
+ "pred": batch["out"]["pixel_pred"][b].to("cpu"),
546
+ }
547
+ )
548
+
549
+ if batch["is_last_tile"][b]:
550
+ # Stitch predictions together
551
+ W = max(int(t["tile_info"][2]) for t in preds)
552
+ H = max(int(t["tile_info"][3]) for t in preds)
553
+
554
+ acc = torch.zeros(3, H, W, dtype=torch.float32)
555
+ wsum = torch.zeros(H, W, dtype=torch.float32)
556
+
557
+ for t in preds:
558
+ tile_info = [t["tile_info"][i] for i in range(4)]
559
+ x0, y0, x1, y1 = map(int, tile_info)
560
+ tile = t["pred"].squeeze(0).float() # [3, h, w], [-1,1]
561
+
562
+ h, w = tile.shape[-2:]
563
+ tH, tW = (y1 - y0), (x1 - x0)
564
+ if (h != tH) or (w != tW):
565
+ tile = _lanczos_resize_chw(tile, (tH, tW))
566
+ h, w = tH, tW
567
+
568
+ # triangular window for the tile
569
+ # fmt: off
570
+ wx = 1 - (2 * torch.arange(w, dtype=torch.float32) / (max(w - 1, 1)) - 1).abs()
571
+ wy = 1 - (2 * torch.arange(h, dtype=torch.float32) / (max(h - 1, 1)) - 1).abs()
572
+ # fmt: on
573
+ w2 = (wy[:, None] * wx[None, :]).clamp_min(1e-3)
574
+ acc[:, y0:y1, x0:x1] += tile * w2
575
+ wsum[y0:y1, x0:x1] += w2
576
+ stitched = (acc / wsum.clamp_min(1e-6)).unsqueeze(0) # [1,3,H,W], [-1,1]
577
+
578
+ # Lanczos resize to gt_orig shape
579
+ orig_H, orig_W = (
580
+ batch["meta"]["orig_res"][0][b].item(),
581
+ batch["meta"]["orig_res"][1][b].item(),
582
+ )
583
+
584
+ x = stitched.squeeze(0)
585
+ x01 = ((x + 1.0) / 2.0).clamp(0.0, 1.0)
586
+ device = x01.device
587
+
588
+ pil = torchvision.transforms.functional.to_pil_image(x01.cpu())
589
+ pil_resized = pil.resize((orig_W, orig_H), resample=Image.LANCZOS)
590
+ pred_ts = torchvision.transforms.functional.to_tensor(pil_resized).to(device) # [3,H,W], [0,1]
591
+ pred = pred_ts.cpu().numpy()
592
+ preds = []
593
+ else:
594
+ continue
595
+
596
+ pred_ts = torch.from_numpy(pred).to(device) # [3,H,W]
597
+ scene_path = batch["line"][0][b]
598
+ scene_name = scene_path.split("/")[-1][:-4]
599
+
600
+ # Load original input image (CHW, uint8 in [0,255])
601
+ input_chw = read_rgb_file(scene_path)
602
+ input_hwc = (
603
+ np.transpose(input_chw, (1, 2, 0)).astype(np.float32) / 255.0
604
+ ) # [H,W,3], [0,1]
605
+
606
+ pred_hwc = np.transpose(pred, (1, 2, 0))
607
+ if input_hwc.shape[:2] != pred_hwc.shape[:2]:
608
+ pil_pred = Image.fromarray(
609
+ (pred_hwc.clip(0, 1) * 255).round().astype(np.uint8)
610
+ )
611
+ H_in, W_in = input_hwc.shape[:2]
612
+ pil_pred = pil_pred.resize((W_in, H_in), resample=Image.LANCZOS)
613
+ pred_hwc = (np.array(pil_pred, dtype=np.uint8) / 255.0).clip(0, 1)
614
+
615
+ visualize(
616
+ file_prefix=scene_name,
617
+ input_hwc=input_hwc,
618
+ pred_hwc=pred_hwc,
619
+ output_dir=save_to_dir,
620
+ save_comparison=save_comparison,
621
+ save_alternating=save_alternating,
622
+ )
623
+
624
+ return
625
+
626
+
627
+ def save_prediction_only(
628
+ file_prefix: str,
629
+ pred_uint8: np.ndarray,
630
+ output_dir: str,
631
+ ) -> None:
632
+ imageio.imwrite(
633
+ os.path.join(output_dir, f"{file_prefix}_windowseat_output.png"),
634
+ pred_uint8,
635
+ plugin="pillow",
636
+ )
637
+
638
+
639
+ def save_comparison_image(
640
+ file_prefix: str,
641
+ pred_uint8: np.ndarray,
642
+ input_uint8: np.ndarray,
643
+ output_dir: str,
644
+ margin_width: int = 10,
645
+ ) -> None:
646
+ H_in, W_in, _ = input_uint8.shape
647
+ if pred_uint8.shape[:2] != (H_in, W_in):
648
+ pil_pred = Image.fromarray(pred_uint8)
649
+ pil_pred = pil_pred.resize((W_in, H_in), resample=Image.LANCZOS)
650
+ pred_uint8 = np.asarray(pil_pred, dtype=np.uint8)
651
+
652
+ margin = np.ones((H_in, margin_width, 3), dtype=np.uint8) * 255
653
+ comparison = np.concatenate([input_uint8, margin, pred_uint8], axis=1)
654
+
655
+ imageio.imwrite(
656
+ os.path.join(output_dir, f"{file_prefix}_windowseat_side_by_side.png"),
657
+ comparison,
658
+ plugin="pillow",
659
+ )
660
+
661
+
662
+ def save_alternating_video(
663
+ file_prefix: str,
664
+ input_uint8: np.ndarray,
665
+ pred_uint8: np.ndarray,
666
+ output_dir: str,
667
+ fps: float = 1.0,
668
+ total_frames: int = 20,
669
+ ) -> None:
670
+ video_path = os.path.join(output_dir, f"{file_prefix}_windowseat_alternating.mp4")
671
+
672
+ H, W = input_uint8.shape[:2]
673
+ pad_h = (0, H % 2)
674
+ pad_w = (0, W % 2)
675
+ if pad_h[1] or pad_w[1]:
676
+ input_uint8 = np.pad(input_uint8, (pad_h, pad_w, (0, 0)), mode="edge")
677
+ pred_uint8 = np.pad(pred_uint8, (pad_h, pad_w, (0, 0)), mode="edge")
678
+
679
+ with imageio.get_writer(
680
+ video_path, fps=fps, macro_block_size=1, ffmpeg_params=["-loglevel", "quiet"]
681
+ ) as writer:
682
+ for i in range(total_frames):
683
+ frame = input_uint8 if i % 2 == 0 else pred_uint8
684
+ writer.append_data(frame)
685
+
686
+
687
+ def visualize(
688
+ file_prefix: str,
689
+ input_hwc: np.ndarray,
690
+ pred_hwc: np.ndarray,
691
+ output_dir: str,
692
+ save_comparison: bool = True,
693
+ save_alternating: bool = True,
694
+ ) -> None:
695
+ pred_hwc = pred_hwc.clip(0, 1)
696
+ pred_uint8 = (pred_hwc * 255).round().astype(np.uint8)
697
+ input_hwc = np.asarray(input_hwc, dtype=np.float32)
698
+ if input_hwc.max() > 1.0:
699
+ input_hwc = input_hwc / 255.0
700
+ input_uint8 = (input_hwc.clip(0, 1) * 255).round().astype(np.uint8)
701
+
702
+ save_prediction_only(
703
+ file_prefix=file_prefix,
704
+ pred_uint8=pred_uint8,
705
+ output_dir=output_dir,
706
+ )
707
+
708
+ if save_comparison:
709
+ save_comparison_image(
710
+ file_prefix=file_prefix,
711
+ pred_uint8=pred_uint8,
712
+ input_uint8=input_uint8,
713
+ output_dir=output_dir,
714
+ )
715
+
716
+ if save_alternating:
717
+ save_alternating_video(
718
+ file_prefix=file_prefix,
719
+ input_uint8=input_uint8,
720
+ pred_uint8=pred_uint8,
721
+ output_dir=output_dir,
722
+ )
723
+
724
+
725
+ def data_transform(sample, processing_resolution=None):
726
+ read_scalars(sample)
727
+ read_rgb_image(sample)
728
+ tile(sample, processing_resolution)
729
+
730
+
731
+ def run_inference(
732
+ vae: AutoencoderKLQwenImage,
733
+ transformer: QwenImageTransformer2DModel,
734
+ embeds_dict: dict[str, torch.Tensor],
735
+ processing_resolution: int,
736
+ image_dir: str,
737
+ output_dir: str,
738
+ use_short_edge_tile=True,
739
+ save_comparison=True,
740
+ save_alternating=True,
741
+ ):
742
+ dataset = TilingDataset(
743
+ transform_graph=functools.partial(data_transform, processing_resolution=processing_resolution),
744
+ input_folder=image_dir,
745
+ gt_folder=image_dir,
746
+ use_short_edge_tile=use_short_edge_tile,
747
+ tiling_w=processing_resolution,
748
+ tiling_h=processing_resolution,
749
+ processing_resolution=processing_resolution,
750
+ )
751
+
752
+ data_loader = DataLoader(
753
+ dataset=dataset,
754
+ batch_size=2,
755
+ shuffle=False,
756
+ num_workers=0,
757
+ )
758
+
759
+ os.makedirs(output_dir, exist_ok=True)
760
+
761
+ validate_single_dataset(
762
+ vae,
763
+ transformer,
764
+ embeds_dict,
765
+ data_loader=data_loader,
766
+ save_to_dir=output_dir,
767
+ save_comparison=save_comparison,
768
+ save_alternating=save_alternating,
769
+ )
770
+
771
+
772
+ def parse_args():
773
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
774
+ IMAGE_DIR = os.path.join(SCRIPT_DIR, "example_images")
775
+ OUTPUT_DIR = os.path.join(SCRIPT_DIR, "outputs")
776
+
777
+ parser = argparse.ArgumentParser(
778
+ description="WindowSeat: reflection removal inference"
779
+ )
780
+ parser.add_argument(
781
+ "--input-dir",
782
+ type=str,
783
+ default=IMAGE_DIR,
784
+ help="Directory with input images (default: %(default)s)",
785
+ )
786
+ parser.add_argument(
787
+ "--output-dir",
788
+ type=str,
789
+ default=OUTPUT_DIR,
790
+ help="Directory to write predictions (default: %(default)s)",
791
+ )
792
+ parser.add_argument(
793
+ "--uri-base",
794
+ type=str,
795
+ default=SUPPORTED_MODEL_URIS[0],
796
+ help="URI of the base model (default: %(default)s)",
797
+ )
798
+ parser.add_argument(
799
+ "--uri-lora",
800
+ type=str,
801
+ default=LORA_MODEL_URI,
802
+ help="URI of the LoRA model (default: %(default)s)",
803
+ )
804
+ parser.add_argument(
805
+ "--more-tiles",
806
+ action="store_true",
807
+ help="Use more tiles for processing.",
808
+ )
809
+ parser.add_argument(
810
+ "--no-save-comparison",
811
+ dest="save_comparison",
812
+ action="store_false",
813
+ help="Do NOT save comparison image between input and prediction.",
814
+ )
815
+ parser.add_argument(
816
+ "--no-save-alternating",
817
+ dest="save_alternating",
818
+ action="store_false",
819
+ help="Do NOT save alternating video.",
820
+ )
821
+ parser.add_argument(
822
+ "--device",
823
+ type=str,
824
+ default="cuda",
825
+ help="Device used for inference.",
826
+ )
827
+ return parser.parse_args()
828
+
829
+
830
+ def main():
831
+ args = parse_args()
832
+ image_dir = args.input_dir
833
+ output_dir = args.output_dir
834
+ uri_base = args.uri_base
835
+ uri_lora = args.uri_lora
836
+ use_short_edge_tile = not args.more_tiles
837
+ save_comparison = args.save_comparison
838
+ save_alternating = args.save_alternating
839
+ device = torch.device(args.device)
840
+ if device != torch.device("cuda"):
841
+ warnings.warn(
842
+ f"WindowSeat inference was only tested with 'cuda'. "
843
+ f"Device {device} is not officially supported and may be slow or fail."
844
+ )
845
+
846
+ if not os.path.isdir(image_dir):
847
+ print_error(f"Input image directory does not exist: {image_dir}")
848
+ sys.exit(1)
849
+
850
+ os.makedirs(output_dir, exist_ok=True)
851
+
852
+ print_banner("WindowSeat: Reflection Removal")
853
+ print_step("1/2", "Loading network components:")
854
+ print_info(f"Base: {uri_base}")
855
+ print_info(f"WindowSeat: {uri_lora}")
856
+
857
+ try:
858
+ vae, transformer, embeds_dict, processing_resolution = load_network(uri_base, uri_lora, device)
859
+ except Exception as e:
860
+ print_error(f"Failed to load network: {e}")
861
+ raise
862
+
863
+ print_step("2/2", f"Running reflection removal inference on: {image_dir}")
864
+ run_inference(
865
+ vae, transformer, embeds_dict, processing_resolution, image_dir, output_dir, use_short_edge_tile, save_comparison, save_alternating
866
+ )
867
+ print_final_success(output_dir)
868
+
869
+
870
+ if __name__ == "__main__":
871
+ main()