boatbomber commited on
Commit
3050f1b
·
0 Parent(s):

Initial release

Browse files
Files changed (48) hide show
  1. .gitattributes +36 -0
  2. .gitignore +176 -0
  3. LICENSE.md +176 -0
  4. README.md +447 -0
  5. ae.safetensors +3 -0
  6. assets/NisabaRelief-Logo.png +3 -0
  7. assets/example_diff_0.png +3 -0
  8. assets/example_diff_1.png +3 -0
  9. assets/example_diff_2.png +3 -0
  10. assets/example_diff_3.png +3 -0
  11. assets/example_diff_4.png +3 -0
  12. assets/example_input_0.png +3 -0
  13. assets/example_input_1.png +3 -0
  14. assets/example_input_2.png +3 -0
  15. assets/example_input_3.png +3 -0
  16. assets/example_input_4.png +3 -0
  17. assets/example_output_0.png +3 -0
  18. assets/example_output_1.png +3 -0
  19. assets/example_output_2.png +3 -0
  20. assets/example_output_3.png +3 -0
  21. assets/example_output_4.png +3 -0
  22. assets/example_truth_0.png +3 -0
  23. assets/example_truth_1.png +3 -0
  24. assets/example_truth_2.png +3 -0
  25. assets/example_truth_3.png +3 -0
  26. assets/example_truth_4.png +3 -0
  27. data/val_tablet_ids.json +90 -0
  28. dev_scripts/benchmark.py +149 -0
  29. dev_scripts/evaluation.py +162 -0
  30. dev_scripts/process_images.py +197 -0
  31. dev_scripts/util/load_val_dataset.py +24 -0
  32. dev_scripts/util/metrics.py +67 -0
  33. dev_scripts/util/psnr_hvsm.py +137 -0
  34. model.safetensors +3 -0
  35. nisaba_relief/__init__.py +7 -0
  36. nisaba_relief/constants.py +42 -0
  37. nisaba_relief/flux/__init__.py +0 -0
  38. nisaba_relief/flux/autoencoder.py +351 -0
  39. nisaba_relief/flux/layers.py +341 -0
  40. nisaba_relief/flux/model.py +147 -0
  41. nisaba_relief/flux/sampling.py +92 -0
  42. nisaba_relief/image_utils.py +153 -0
  43. nisaba_relief/model.py +474 -0
  44. nisaba_relief/py.typed +0 -0
  45. nisaba_relief/weights.py +23 -0
  46. prompt_embedding.safetensors +3 -0
  47. pyproject.toml +69 -0
  48. uv.lock +0 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/*.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/*
2
+ !data/val_tablet_ids.json
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ **/__marimo__/
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # UV
103
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ #uv.lock
107
+
108
+ # poetry
109
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
110
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
111
+ # commonly ignored for libraries.
112
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
113
+ #poetry.lock
114
+
115
+ # pdm
116
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
117
+ #pdm.lock
118
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
119
+ # in version control.
120
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
121
+ .pdm.toml
122
+ .pdm-python
123
+ .pdm-build/
124
+
125
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
126
+ __pypackages__/
127
+
128
+ # Celery stuff
129
+ celerybeat-schedule
130
+ celerybeat.pid
131
+
132
+ # SageMath parsed files
133
+ *.sage.py
134
+
135
+ # Environments
136
+ .env
137
+ .venv
138
+ env/
139
+ venv/
140
+ ENV/
141
+ env.bak/
142
+ venv.bak/
143
+
144
+ # Spyder project settings
145
+ .spyderproject
146
+ .spyproject
147
+
148
+ # Rope project settings
149
+ .ropeproject
150
+
151
+ # mkdocs documentation
152
+ /site
153
+
154
+ # mypy
155
+ .mypy_cache/
156
+ .dmypy.json
157
+ dmypy.json
158
+
159
+ # Pyre type checker
160
+ .pyre/
161
+
162
+ # pytype static type analyzer
163
+ .pytype/
164
+
165
+ # Cython debug symbols
166
+ cython_debug/
167
+
168
+ # PyCharm
169
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
170
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
171
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
172
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
173
+ #.idea/
174
+
175
+ # PyPI configuration file
176
+ .pypirc
LICENSE.md ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
README.md ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ pipeline_tag: image-to-image
4
+ base_model:
5
+ - black-forest-labs/FLUX.2-klein-base-4B
6
+ base_model_relation: finetune
7
+ datasets:
8
+ - boatbomber/CuneiformPhotosMSII
9
+ tags:
10
+ - image-to-image
11
+ - cuneiform
12
+ - geometry
13
+ - curvature
14
+ - multi-scale-integral-invariant
15
+ - msii
16
+ - Flux
17
+ ---
18
+
19
+ <div align="center">
20
+ <h1 align="center">
21
+ NisabaRelief
22
+ </h1>
23
+
24
+ <img src="./assets/NisabaRelief-Logo.png" width="600"/>
25
+ </div>
26
+
27
+
28
+ # NisabaRelief
29
+
30
+ NisabaRelief is a rectified flow transformer that converts ordinary photographs of cuneiform clay tablets into Multi-Scale Integral Invariant (MSII) curvature visualizations, without requiring 3D scanning hardware. Traditional MSII computation requires a high-resolution 3D scanner and GigaMesh postprocessing, averaging approximately 68 minutes per tablet. NisabaRelief processes a photograph in approximately 7 seconds.
31
+
32
+ Photographic images introduce a variety of noise sources: lighting direction, clay color, surface sheen, photography conditions, and surface staining. Any of these can cause wedge impressions to appear as shadows or shadows to appear as wedge impressions. MSII filtering discards this photometric variation, retaining only the geometric signal pressed into the clay. See [What is MSII?](#what-is-msii) for full technical details.
33
+
34
+ Built by fine-tuning [Flux.2 Klein Base 4B](https://huggingface.co/black-forest-labs/FLUX.2-klein-base-4B) on paired photo/MSII data generated from 3D scans in the [HeiCuBeDa](https://doi.org/10.11588/data/IE8CCN) corpus. Training data is made available here: [CuneiformPhotosMSII](https://huggingface.co/datasets/boatbomber/CuneiformPhotosMSII).
35
+
36
+ Named for Nisaba, the early Sumerian goddess of writing and scribes, NisabaRelief will serve as the preprocessing backbone of NabuOCR V2, a cuneiform OCR system currently in development.
37
+
38
+ ---
39
+
40
+ ## Contents
41
+
42
+ - [NisabaRelief](#nisabarelief)
43
+ - [Contents](#contents)
44
+ - [Example Output](#example-output)
45
+ - [Quickstart](#quickstart)
46
+ - [Installation](#installation)
47
+ - [Usage](#usage)
48
+ - [Hardware Requirements](#hardware-requirements)
49
+ - [Performance](#performance)
50
+ - [What is MSII?](#what-is-msii)
51
+ - [Intended Use \& Limitations](#intended-use--limitations)
52
+ - [Evaluation](#evaluation)
53
+ - [Step Sweep](#step-sweep)
54
+ - [Training Data](#training-data)
55
+ - [Training Pipeline](#training-pipeline)
56
+ - [Key Technical Decision: Text-Encoder-Free Training](#key-technical-decision-text-encoder-free-training)
57
+ - [Key Technical Decision: VAE BatchNorm Domain Calibration](#key-technical-decision-vae-batchnorm-domain-calibration)
58
+ - [Stage 1: Pretrain (Domain Initialization)](#stage-1-pretrain-domain-initialization)
59
+ - [Stage 2: Train (Image-to-Image Adaptation)](#stage-2-train-image-to-image-adaptation)
60
+ - [Augmentation Pipeline](#augmentation-pipeline)
61
+ - [Loss](#loss)
62
+ - [Stage 3: Rectify (Trajectory Straightening)](#stage-3-rectify-trajectory-straightening)
63
+ - [Acknowledgements \& Citations](#acknowledgements--citations)
64
+
65
+ ---
66
+
67
+ ## Example Output
68
+
69
+ <table>
70
+ <thead>
71
+ <tr>
72
+ <th align="center" width="25%">Input</th>
73
+ <th align="center" width="25%">Output</th>
74
+ <th align="center" width="25%">Ground Truth</th>
75
+ <th align="center" width="25%">Difference</th>
76
+ </tr>
77
+ </thead>
78
+ <tbody>
79
+
80
+ <tr>
81
+ <td align="center"><img src="./assets/example_input_0.png" width="200"/></td>
82
+ <td align="center"><img src="./assets/example_output_0.png" width="200"/></td>
83
+ <td align="center"><img src="./assets/example_truth_0.png" width="200"/></td>
84
+ <td align="center"><img src="./assets/example_diff_0.png" width="200"/></td>
85
+ </tr>
86
+ <tr>
87
+ <td colspan="4" align="center"><b>Dice: 0.9652</b> &nbsp;·&nbsp; RMSE: 0.0775 &nbsp;·&nbsp; MS-SSIM: 0.9295 &nbsp;·&nbsp; PSNR: 22.22 dB &nbsp;·&nbsp; PSNR-HVS-M: 17.77 dB &nbsp;·&nbsp; SRE: 58.34 dB</td>
88
+ </tr>
89
+
90
+ <tr>
91
+ <td align="center"><img src="./assets/example_input_1.png" width="200"/></td>
92
+ <td align="center"><img src="./assets/example_output_1.png" width="200"/></td>
93
+ <td align="center"><img src="./assets/example_truth_1.png" width="200"/></td>
94
+ <td align="center"><img src="./assets/example_diff_1.png" width="200"/></td>
95
+ </tr>
96
+ <tr>
97
+ <td colspan="4" align="center"><b>Dice: 0.9555</b> &nbsp;·&nbsp; RMSE: 0.0788 &nbsp;·&nbsp; MS-SSIM: 0.9219 &nbsp;·&nbsp; PSNR: 22.07 dB &nbsp;·&nbsp; PSNR-HVS-M: 17.80 dB &nbsp;·&nbsp; SRE: 57.89 dB</td>
98
+ </tr>
99
+
100
+ <tr>
101
+ <td align="center"><img src="./assets/example_input_2.png" width="200"/></td>
102
+ <td align="center"><img src="./assets/example_output_2.png" width="200"/></td>
103
+ <td align="center"><img src="./assets/example_truth_2.png" width="200"/></td>
104
+ <td align="center"><img src="./assets/example_diff_2.png" width="200"/></td>
105
+ </tr>
106
+ <tr>
107
+ <td colspan="4" align="center"><b>Dice: 0.9630</b> &nbsp;·&nbsp; RMSE: 0.1108 &nbsp;·&nbsp; MS-SSIM: 0.8513 &nbsp;·&nbsp; PSNR: 19.11 dB &nbsp;·&nbsp; PSNR-HVS-M: 14.65 dB &nbsp;·&nbsp; SRE: 59.60 dB</td>
108
+ </tr>
109
+
110
+ <tr>
111
+ <td align="center"><img src="./assets/example_input_3.png" width="200"/></td>
112
+ <td align="center"><img src="./assets/example_output_3.png" width="200"/></td>
113
+ <td align="center"><img src="./assets/example_truth_3.png" width="200"/></td>
114
+ <td align="center"><img src="./assets/example_diff_3.png" width="200"/></td>
115
+ </tr>
116
+ <tr>
117
+ <td colspan="4" align="center"><b>Dice: 0.9713</b> &nbsp;·&nbsp; RMSE: 0.1035 &nbsp;·&nbsp; MS-SSIM: 0.8748 &nbsp;·&nbsp; PSNR: 19.70 dB &nbsp;·&nbsp; PSNR-HVS-M: 15.33 dB &nbsp;·&nbsp; SRE: 59.41 dB</td>
118
+ </tr>
119
+
120
+ <tr>
121
+ <td align="center"><img src="./assets/example_input_4.png" width="200"/></td>
122
+ <td align="center"><img src="./assets/example_output_4.png" width="200"/></td>
123
+ <td align="center"><img src="./assets/example_truth_4.png" width="200"/></td>
124
+ <td align="center"><img src="./assets/example_diff_4.png" width="200"/></td>
125
+ </tr>
126
+ <tr>
127
+ <td colspan="4" align="center"><b>Dice: 0.9564</b> &nbsp;·&nbsp; RMSE: 0.1054 &nbsp;·&nbsp; MS-SSIM: 0.9325 &nbsp;·&nbsp; PSNR: 19.55 dB &nbsp;·&nbsp; PSNR-HVS-M: 15.18 dB &nbsp;·&nbsp; SRE: 57.36 dB</td>
128
+ </tr>
129
+
130
+ </tbody>
131
+ </table>
132
+
133
+ ---
134
+
135
+ ## Quickstart
136
+
137
+ ### Installation
138
+
139
+ **Prerequisites:**
140
+
141
+ - Python >= 3.10
142
+ - PyTorch with CUDA support. See https://pytorch.org/get-started/locally/.
143
+
144
+ ```bash
145
+ # Install PyTorch (CUDA 12.8 example)
146
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128
147
+
148
+ # Windows only: install Triton (included automatically on Linux)
149
+ pip install triton-windows
150
+ ```
151
+
152
+ **Install:**
153
+
154
+ ```bash
155
+ pip install nisaba-relief
156
+ ```
157
+
158
+ ### Usage
159
+
160
+ ```python
161
+ from nisaba_relief import NisabaRelief
162
+
163
+ model = NisabaRelief() # downloads weights from HF Hub automatically if needed
164
+ result = model.process("tablet.jpg")
165
+ result.save("tablet_msii.png")
166
+ ```
167
+
168
+ **Constructor parameters:**
169
+
170
+ | Parameter | Default | Description |
171
+ |---|---|---|
172
+ | `device` | `"cuda"` if available | Device for inference |
173
+ | `num_steps` | `2` | Denoising steps |
174
+ | `weights_dir` | `None` | Local weights directory; if `None`, downloads from HF Hub or uses HF cache. Expected dir contents: `model.safetensors`, `ae.safetensors`, `prompt_embedding.safetensors` |
175
+ | `batch_size` | `None` | Batch size for processing tiles during inference. `None` (default) auto-selects the largest batch that fits in available VRAM. Set an explicit integer to override. Higher values are faster but see note below. |
176
+ | `seed` | `None` | Optional random seed for reproducible noise generation; if `None`, randomized |
177
+ | `compile` | `True` | Use `torch.compile` for faster repeated inference. Requires Triton. Set to `False` if Triton is not installed or for one-off runs. |
178
+
179
+ > **Reproducibility note:** Results are pixel-exact across repeated runs with the same `batch_size` and `seed`. However, changing `batch_size` between runs (including letting `None` auto-select a different value as available VRAM changes) will produce outputs that differ by up to ~1-2 pixel values (mean < 0.25) due to GPU floating-point non-determinism: CUDA selects different kernel implementations for different matrix shapes, which changes the floating-point accumulation order in the transformer attention and linear layers. The visual difference is imperceptible. If exact cross-run reproducibility is required, set a constant `batch_size`.
180
+
181
+ **`process()` parameters:**
182
+
183
+ | Parameter | Default | Description |
184
+ |---|---|---|
185
+ | `image` | required | File path (str/Path) or PIL Image |
186
+ | `show_pbar` | `None` | Progress bar visibility. `None` = auto (shows when >= 2 batches); `True`/`False` = always show/hide |
187
+
188
+ **Returns:** Grayscale `PIL.Image.Image` containing the MSII visualization.
189
+
190
+ **Input requirements:**
191
+ - Any PIL-readable format (PNG, JPG, WEBP, ...)
192
+ - Minimum 64 px on the short side; maximum aspect ratio 8:1
193
+
194
+ **Large image support:**
195
+
196
+ The model's native tile size is 1024 px. For images where either side exceeds 1024 px, the model automatically applies a sliding-window tiling pass. Tiles are blended with raised-cosine overlap weights to avoid seams. Each tile is also conditioned on a 128 px thumbnail of the full image with a red rectangle marking the tile's position, so the model retains global context while processing local detail.
197
+
198
+ There is no practical upper limit on input resolution, though the model may perform unexpectedly if the 1024 px tile is only a small fraction of the total image area.
199
+
200
+ ---
201
+
202
+ ## Hardware Requirements
203
+
204
+ While CPU inference is technically supported, it is too slow for practical use. A GPU with at least 9GB VRAM is required, with 12GB+ being recommended for better batching.
205
+
206
+ The 9 GB figure is substantially lower than the ~18 GB a standard FLUX.2-klein-base-4B deployment would require because the Qwen3-4B text encoder is never loaded at runtime. The conditioning prompt is pre-computed once and shipped as a 7.8 MB embedding file alongside the model weights.
207
+
208
+ ---
209
+
210
+ ## Performance
211
+
212
+ Traditional pipelines require a high-resolution 3D scanner and GigaMesh postprocessing: across the HeiCuBeDa corpus, this averages approximately 68 minutes per tablet, totalling over 2,200 hours for the full collection. NisabaRelief processes a tablet photograph in approximately 7 seconds, roughly 600x faster, with no scanning equipment required.
213
+
214
+ On a 1064x2048px photo, an RTX 3090 performs as follows:
215
+
216
+ | Run | Time |
217
+ |---|---|
218
+ | *compile warmup* | 11.61s |
219
+ | 1 | 7.05s |
220
+ | 2 | 7.07s |
221
+ | 3 | 7.09s |
222
+ | **Mean** | **7.07 ± 0.02s** |
223
+
224
+ ---
225
+
226
+ ## What is MSII?
227
+
228
+ Multi-Scale Integral Invariant (MSII) filtering is a geometry-processing algorithm that computes a robust curvature measure at every point on a 3D surface mesh. At each vertex, a sphere of radius *r* is centered on the surface and the algorithm measures how much of the sphere's volume falls below the surface (the "interior" volume). On a perfectly flat surface the ratio is exactly one half. Concave regions (such as the channel cut by a wedge impression) admit more of the sphere below the surface, pushing the ratio above 0.5. Convex regions such as ridges or the rounded back of a tablet expose less interior volume, pulling the ratio below 0.5. The signed difference from the flat baseline maps directly to the sign and magnitude of mean curvature at that point.
229
+
230
+ The multi-scale component repeats this computation at several sphere radii simultaneously. Small radii resolve fine wedge tips and hairline details; large radii capture broader curvature trends such as the tablet's overall convexity. The per-vertex measurements across all radii form a compact feature vector, and the final scalar output conventionally displayed as a grayscale image is the maximum component of that feature vector, capturing the strongest curvature response across all scales into a single value per pixel.
231
+
232
+ By convention the scalar is displayed with its sign inverted relative to the mean curvature: concave regions (ratio > 0.5) map to darker pixel values and convex regions (ratio < 0.5) to lighter ones. This places the flat-surface baseline at mid-gray and renders wedge channels as dark strokes against a bright background, similar to ink on paper.
233
+
234
+ Because the result depends only on the 3D shape of the surface rather than on lighting, clay color, or photograph angle, wedge impressions appear as consistent dark strokes against a bright background. This makes the surface structure considerably more legible to machine-vision OCR systems than raw photographs.
235
+
236
+ ---
237
+
238
+ ## Intended Use & Limitations
239
+
240
+ Generating an MSII visualization of a tablet requires a high-resolution laser scanner and substantial per-vertex computation. The vast majority of cuneiform tablets do not have a 3D scan available, and the computational cost is difficult to scale across large corpora.
241
+
242
+ To reduce this barrier and increase the availability of readable images, this model is trained to predict the MSII visualization directly from photographs.
243
+
244
+ **Intended use:**
245
+ - Preprocessing step for cuneiform OCR (specifically NabuOCR V2)
246
+ - Visualizing cuneiform tablet geometry for research and digital humanities
247
+
248
+ **Limitations:**
249
+ - Trained exclusively using [HeiCuBeDa](https://doi.org/10.11588/data/IE8CCN) 3D-scan data; performance on tablet types or scribal traditions not well-represented in that corpus is unknown
250
+ - Outputs are MSII approximations inferred from 2D photographs, not computed from true 3D geometry. They are suitable for OCR preprocessing but are not a substitute for physical scanning
251
+ - Not a general-purpose MSII model; behavior on non-cuneiform inputs is undefined and out of distribution
252
+ - Designed for photographs following [CDLI photography guidelines](https://cdli.earth/docs/images-acquisition-and-processing): high-resolution fatcross layout on a black background. The model may underperform on low-resolution or visually cluttered inputs such as older black-and-white excavation photographs where the background blends into the tablet
253
+
254
+ ---
255
+
256
+ ## Evaluation
257
+
258
+ The model was evaluated on 704 held-out validation pairs, all tablets whose geometry was never seen during training (see [Training Data](#training-data)). Each validation image was processed through the model and the output compared against the ground-truth MSII visualization computed from the 3D scan. Ran with `seed=42` and `batch_size=4`.
259
+
260
+ | Metric | Value |
261
+ |------------|------------------|
262
+ | Dice | 0.9639 ± 0.0138 |
263
+ | RMSE | 0.0877 ± 0.0208 |
264
+ | MS-SSIM | 0.9026 ± 0.0308 |
265
+ | PSNR | 21.36 ± 1.91 dB |
266
+ | PSNR-HVS-M | 16.98 ± 1.89 dB |
267
+ | SRE | 59.57 ± 1.92 dB |
268
+
269
+ **Dice** (Binarized Dice Coefficient) thresholds both images to isolate wedge stroke regions, then measures overlap between predicted and ground-truth strokes on a 0-1 scale. This is the most task-relevant metric, as it directly measures whether the model correctly localizes wedge impressions for downstream OCR.
270
+
271
+ **RMSE** (Root Mean Squared Error) measures average pixel-level reconstruction error; lower is better.
272
+
273
+ **MS-SSIM** (Multi-Scale Structural Similarity Index) measures perceptual image similarity by comparing luminance, contrast, and local structure at multiple spatial scales simultaneously. Coarser scales capture global shape agreement; finer scales capture edge and texture detail. Scores range from 0 to 1, where 1 is a perfect match; higher is better.
274
+
275
+ **PSNR** (Peak Signal-to-Noise Ratio) expresses reconstruction fidelity in decibels relative to the maximum pixel value; higher is better.
276
+
277
+ **PSNR-HVS-M** (Peak Signal-to-Noise Ratio - Human Visual System and Masking) measures reconstruction fidelity in decibels relative to the maximum pixel value while taking into account Contrast Sensitivity Function (CSF) and between-coefficient contrast masking of DCT basis functions.
278
+
279
+ **SRE** (Signal-to-Reconstruction Error) ratio measures reconstruction fidelity in decibels based on signal energy vs. error energy; higher is better.
280
+
281
+ ### Step Sweep
282
+
283
+ A sweep of step counts was run on a subset of 175 validation samples and found that 2 steps is ideal for this model, adding one corrective step over the already solid single-step result. The rectified flow field is extremely straight (straightness_ratio=0.9989, path_length_ratio=1.0011, velocity_std=0.1565). For near-perfectly straight ODE trajectories, a single Euler step is theoretically near-exact, and each additional step accumulates small model prediction errors faster than it reduces discretization error. Where throughput is the primary concern, one step is acceptable. Ran with `seed=42` and `batch_size=4`.
284
+
285
+ | Metric | Steps=1 | Steps=2 | Steps=4 | Steps=8 |
286
+ |------------|------------------|----------------------|------------------|------------------|
287
+ | Dice | 0.9582 ± 0.0153 | **0.9634** ± 0.0139 | 0.9612 ± 0.0142 | 0.9580 ± 0.0148 |
288
+ | RMSE | 0.0909 ± 0.0209 | **0.0859** ± 0.0212 | 0.0900 ± 0.0203 | 0.0949 ± 0.0197 |
289
+ | MS-SSIM | 0.8987 ± 0.0326 | **0.9081** ± 0.0310 | 0.9039 ± 0.0314 | 0.8959 ± 0.0326 |
290
+ | PSNR | 21.03 ± 1.83 dB | **21.56** ± 1.97 dB | 21.11 ± 1.84 dB | 20.63 ± 1.72 dB |
291
+ | PSNR-HVS-M | 16.65 ± 1.80 dB | **17.19** ± 1.96 dB | 16.70 ± 1.83 dB | 16.18 ± 1.70 dB |
292
+ | SRE | 58.81 ± 1.81 dB | **59.07** ± 1.87 dB | 58.85 ± 1.87 dB | 58.61 ± 1.86 dB |
293
+
294
+ ---
295
+
296
+ ## Training Data
297
+
298
+ Training uses the [CuneiformPhotosMSII](https://huggingface.co/datasets/boatbomber/CuneiformPhotosMSII) dataset: 13,928 paired image pairs generated from 1,741 tablets sourced from the HeiCuBeDa (Heidelberg Cuneiform Benchmark Dataset), a professional research collection of 3D-scanned clay tablets. Each tablet was rendered multiple times in Blender at up to 4096 px, producing synthetic photographs alongside their corresponding MSII curvature visualizations.
299
+
300
+ Each render variant randomizes which faces of the tablet are shown, camera focal length (80-150 mm), tablet rotation (±5° Euler XYZ), lighting position/color/intensity, and background (fabric, grunge, stone, or none). This diversity encourages the model to generalize across realistic shooting conditions rather than overfitting to a specific lighting or composition style.
301
+
302
+ The dataset was split tablet-wise: 13,224 pairs (~95% of tablets) for training and 704 pairs (~5% of tablets) held out for validation. Because the split is by tablet identity, the model never sees a validation tablet's geometry during training.
303
+
304
+ ---
305
+
306
+ ## Training Pipeline
307
+
308
+ Training proceeded in three sequential stages: Pretrain, Train, and Rectify. Each stage builds directly on the weights from the previous one.
309
+
310
+ ### Key Technical Decision: Text-Encoder-Free Training
311
+
312
+ All three stages skip the Qwen3-4B text encoder entirely. Text embeddings are pre-computed once and cached to disk, reducing VRAM consumption from ~18 GB to ~9 GB without any loss in conditioning fidelity.
313
+
314
+ ### Key Technical Decision: VAE BatchNorm Domain Calibration
315
+
316
+ The FLUX.2 VAE contains a BatchNorm layer whose running statistics (`running_mean` and `running_var` across 128 channels: 32 latent channels × 2×2 patch size) were originally computed on diverse natural images. Applying this encoder to cuneiform tablets and MSII renderings introduces a latent-space distribution shift that manifests as screen-door dithering artifacts in decoded outputs.
317
+
318
+ To correct this, the BatchNorm statistics were recalibrated on the target domain before training began. 3,000 CDLI cuneiform tablet photographs and 2,000 synthetic MSII visualizations (5,000 images total) were encoded through the frozen VAE encoder; running mean and variance were accumulated across 19,301,093 spatial samples using float64 accumulators for numerical stability. Images from both domains were interleaved to ensure balanced sampling. The calibrated statistics are baked directly into the `ae.safetensors` weights shipped with this model.
319
+
320
+ ---
321
+
322
+ ### Stage 1: Pretrain (Domain Initialization)
323
+
324
+ The pretrain stage adapts the base FLUX.2 model to the cuneiform domain before any image-to-image translation is attempted. It runs standard text-to-image flow-matching training on two sources of real cuneiform imagery:
325
+
326
+ - ~60% CDLI archive photographs: real museum photos of tablets, paired with per-image text embeddings generated from CDLI metadata (period, material, object type, provenience, genre, language). Eight prompt templates were used and varied randomly.
327
+ - ~40% synthetic MSII renders: MSII visualization images from the training set, paired with MSII-specific text embeddings emphasizing curvature, surface topology, and wedge impression terminology.
328
+
329
+ Each image has its own unique cached embedding rather than a shared prompt, preventing the model from memorizing specimen identifiers and encouraging generalization.
330
+
331
+ | Hyperparameter | Value |
332
+ |---|---|
333
+ | Steps | 75,000 |
334
+ | Learning rate | 2e-4 (cosine decay, 1k warmup) |
335
+ | Effective batch size | 2 (batch 1, grad accum 2) |
336
+ | LoRA rank | 256 |
337
+ | LoRA init | PiSSA (8-iteration fast SVD) |
338
+ | Optimizer | 8-bit Adam |
339
+ | Precision | bfloat16 autocast |
340
+ | Timestep sampling | Logit-normal (mean=0, std=1) |
341
+ | Gradient clipping | 1.0 |
342
+
343
+ Images are resized to fit within 1 megapixel and rounded to 128-pixel multiples. Light augmentations are applied (horizontal flip, ±5° rotation, minor color jitter). Validation generates text-conditioned images across four aspect ratios every 1,000 steps.
344
+
345
+ ---
346
+
347
+ ### Stage 2: Train (Image-to-Image Adaptation)
348
+
349
+ The main training stage fine-tunes the pretrained weights for the target task: translating cuneiform tablet photographs into MSII visualizations. This stage introduces two significant changes over standard FLUX.2 fine-tuning.
350
+
351
+ **Tile and global context conditioning**
352
+
353
+ Rather than processing full images, the model trains on dynamic tile crops (128-1024 px, depending on image resolution) while simultaneously receiving a downscaled 128 px thumbnail of the full image with a red rectangle marking the tile's location, providing both local detail and global context.
354
+
355
+ **Paired crop with geometric consistency**
356
+
357
+ The same crop coordinates and geometric transforms (flip, rotation, perspective distortion) are applied to both the input photograph and the target MSII image, ensuring the model always receives spatially aligned pairs.
358
+
359
+ #### Augmentation Pipeline
360
+
361
+ Augmentations are split into two categories applied in sequence:
362
+
363
+ Geometric (applied identically to input and target):
364
+ - Horizontal flip (50%), vertical flip (40%), rotation ±8° (50%), perspective distortion strength 0.02 (30%)
365
+
366
+ Domain adaptation (applied to input only, to simulate real photographic variation):
367
+ - Perlin noise illumination (20%), vignette (40%), directional lighting gradient (50%), dust particles (50%), Gaussian noise (80%), gamma correction (50%), contrast adjustment (50%), brightness shift (50%), hue/saturation shift (40%), Gaussian blur (30%), grayscale conversion (3%)
368
+
369
+ Spatially-dependent effects (Perlin noise, vignette, gradient) use crop coordinates so the tile and its global thumbnail receive matching effects.
370
+
371
+ #### Loss
372
+
373
+ Flow-matching loss with Min-SNR-γ weighting (γ=5.0) to down-weight noisy high-timestep predictions, plus a multi-scale latent gradient loss weighted at 0.25. The gradient loss computes spatial gradient differences between predicted and target latents at four downsampling scales, encouraging sharp edge structure in outputs.
374
+
375
+ | Hyperparameter | Value |
376
+ |---|---|
377
+ | Steps | 150,000 |
378
+ | Learning rate | 3e-4 (cosine decay to 6e-6, 1k warmup) |
379
+ | Effective batch size | 8 (batch 1, grad accum 8) |
380
+ | LoRA rank | 256, alpha √rank, RSLoRA |
381
+ | LoRA init | PiSSA (8-iteration fast SVD) |
382
+ | EMA decay | 0.999 (used for validation and final save) |
383
+ | Optimizer | 8-bit Adam |
384
+ | Gradient clipping | 0.8 (with spike detection: skip if >2.5× EMA norm) |
385
+ | Precision | bfloat16 autocast |
386
+ | Gradient loss weight | 0.25 |
387
+ | Min-SNR-γ | 5.0 |
388
+ | Timestep sampling | Logit-normal (mean=0, std=1) |
389
+
390
+ Validation runs every 2,000 steps, generating 8 sample images with 8 denoising steps.
391
+
392
+ ---
393
+
394
+ ### Stage 3: Rectify (Trajectory Straightening)
395
+
396
+ The rectify stage implements [Rectified Flow](https://arxiv.org/abs/2209.03003) to reduce the number of inference steps required at runtime.
397
+
398
+ Standard flow-matching trains on random (noise, real target) pairs, producing curved ODE trajectories that require 25-50 denoising steps to traverse accurately. Rectified training instead pairs each noise sample with the output the fully-trained model generates from that noise, creating straight-line trajectories that can be traversed in 1-4 steps without quality loss.
399
+
400
+ Before training, a one-time preprocessing pass runs the trained model over the training set. Each image is cropped deterministically (seeded RNG, same tile-sizing logic as training), then fully denoised with the trained weights to produce a (noise, generated_output) coupled pair saved to disk. This eliminates VAE encoding from the training loop, reducing VRAM further.
401
+
402
+ The loss trains the model to predict the velocity between a coupled (noise, generated) pair at a random interpolated timestep. A pseudo-Huber loss replaces the MSE used in earlier stages, providing better gradient stability when predictions are far from target.
403
+
404
+ | Hyperparameter | Value |
405
+ |---|---|
406
+ | Steps | 50,000 |
407
+ | Learning rate | 3e-6 (cosine decay, 500 warmup) |
408
+ | Effective batch size | 4 (batch 1, grad accum 4) |
409
+ | LoRA rank | 256 |
410
+ | LoRA init | Loaded from Stage 2 weights (warm-start) |
411
+ | Loss | Pseudo-Huber (c=0.001) |
412
+ | Optimizer | 8-bit Adam |
413
+ | Gradient clipping | 1.0 |
414
+ | Precision | bfloat16 autocast |
415
+ | Timestep sampling | Logit-normal (mean=0, std=1) |
416
+
417
+ Validation runs every 2,000 steps using real validation images (not coupled pairs), generating outputs with only 2 denoising steps to directly measure few-step inference quality.
418
+
419
+ The result is usable MSII visualizations in 1-2 denoising steps, compared to the 25-50 steps standard flow-matching requires.
420
+
421
+ ---
422
+
423
+ ## Acknowledgements & Citations
424
+
425
+ **3D Scan Data (HeiCuBeDa)**
426
+
427
+ 3D scans used to generate the training dataset are from the Heidelberg Cuneiform Benchmark Dataset (HeiCuBeDa):
428
+
429
+ > Bogacz, B., Gertz, M., & Mara, H. (2015). *Character Proposals for Cuneiform Script Digitization*. Proceedings of the 15th International Conference on Frontiers in Handwriting Recognition (ICFHR). doi:[10.11588/data/IE8CCN](https://doi.org/10.11588/data/IE8CCN)
430
+
431
+ **Archive Photographs (CDLI)**
432
+
433
+ Real tablet photographs used in Stage 1 pretraining are sourced from the [Cuneiform Digital Library Initiative (CDLI)](https://cdli.mpiwg-berlin.mpg.de/).
434
+
435
+ **MSII Curvature (GigaMesh)**
436
+
437
+ MSII curvature values embedded in the HeiCuBeDa PLY files were computed using the [GigaMesh Software Framework](https://gigamesh.eu/).
438
+
439
+ **Rectified Flow**
440
+
441
+ Stage 3 (Rectify) implements the trajectory-straightening approach from:
442
+
443
+ > Liu, X., et al. (2022). *Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow*. arXiv:[2209.03003](https://arxiv.org/abs/2209.03003)
444
+
445
+ **Base Model (FLUX.2 Klein Base 4B)**
446
+
447
+ Fine-tuned from [FLUX.2-klein-base-4B](https://huggingface.co/black-forest-labs/FLUX.2-klein-base-4B) by Black Forest Labs.
ae.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:570cc44d0301b006a34b2604735cf296ef6083a95564b45042c1788eae246977
3
+ size 336211292
assets/NisabaRelief-Logo.png ADDED

Git LFS Details

  • SHA256: 62b1fa428e2dea3b963eae0bf5d58cb369e0b2674d6e951fe96206c7e7b9becf
  • Pointer size: 132 Bytes
  • Size of remote file: 2.16 MB
assets/example_diff_0.png ADDED

Git LFS Details

  • SHA256: 837769a8e9d4a223e9575476e22f809f0d0b4305937586eb36d8cf235999a3bd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.03 MB
assets/example_diff_1.png ADDED

Git LFS Details

  • SHA256: 6b155cd60651da23389b43cce096fe2ed6ec5e5c730ba69feec166e2728a402d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.91 MB
assets/example_diff_2.png ADDED

Git LFS Details

  • SHA256: 4f457cfedc797934ffb66c6878c342e26d2d1c5e89bf19bc3c3287136a971a7a
  • Pointer size: 131 Bytes
  • Size of remote file: 927 kB
assets/example_diff_3.png ADDED

Git LFS Details

  • SHA256: b0aa7d54ff0ec6483c0ceda87793563a1b3bf7049ab45cafae96680dc1e6fb42
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
assets/example_diff_4.png ADDED

Git LFS Details

  • SHA256: 2b154de211ec623532f835473aedcaf42f6a2969b53a35c79125ced14393c336
  • Pointer size: 131 Bytes
  • Size of remote file: 962 kB
assets/example_input_0.png ADDED

Git LFS Details

  • SHA256: e131fa5b64a19d113db548d9bf10181e1ffa9e5dce97185cf667ebd80113b228
  • Pointer size: 132 Bytes
  • Size of remote file: 6.2 MB
assets/example_input_1.png ADDED

Git LFS Details

  • SHA256: 3c0b2547da571247bd18452b8e956c850b47978ad8371b00ef0c6c6e1425c675
  • Pointer size: 132 Bytes
  • Size of remote file: 7.08 MB
assets/example_input_2.png ADDED

Git LFS Details

  • SHA256: 2b7dd866b8a9ff119b15b7471b67519bcf60af901255a18f6ff4d580d4181bbc
  • Pointer size: 132 Bytes
  • Size of remote file: 4.75 MB
assets/example_input_3.png ADDED

Git LFS Details

  • SHA256: 84843078d6da5c46f9afe2b9470dab73536ad62aaff33b9f3144b4cc7c61ee42
  • Pointer size: 132 Bytes
  • Size of remote file: 6.65 MB
assets/example_input_4.png ADDED

Git LFS Details

  • SHA256: f18ca16e4e4d18c7bbe6936bd80ba8d1d25ffebf90091c581ff2605778837a50
  • Pointer size: 132 Bytes
  • Size of remote file: 4.96 MB
assets/example_output_0.png ADDED

Git LFS Details

  • SHA256: 41b7420375095bfa79b929a45afe81a3076b5accf274971d6d6c154725e97ddc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.74 MB
assets/example_output_1.png ADDED

Git LFS Details

  • SHA256: 317e6785571d72fe2e4de4629af384e03b59812d22a6c82781b5166392199f0c
  • Pointer size: 132 Bytes
  • Size of remote file: 2.38 MB
assets/example_output_2.png ADDED

Git LFS Details

  • SHA256: 99f48ab700e9f5aad5dc797a3b3e2afbe9e42ed2d408fdf392b12d683a7cc3ce
  • Pointer size: 132 Bytes
  • Size of remote file: 1.61 MB
assets/example_output_3.png ADDED

Git LFS Details

  • SHA256: da55bebf52e374b84769f4845e0f0ac7d3f7ec22e50a1f5003f0e6fe05414748
  • Pointer size: 132 Bytes
  • Size of remote file: 2.07 MB
assets/example_output_4.png ADDED

Git LFS Details

  • SHA256: 83275746a8232b448331d6b6738f64140a5cd091afc74cdb12851a54514c9325
  • Pointer size: 132 Bytes
  • Size of remote file: 1.67 MB
assets/example_truth_0.png ADDED

Git LFS Details

  • SHA256: d6df5e794dddc87e61fe021d6d85d17f18f3e06f675c7907a4a3d44ff8bfd09a
  • Pointer size: 132 Bytes
  • Size of remote file: 3.39 MB
assets/example_truth_1.png ADDED

Git LFS Details

  • SHA256: f79ffd7d9b76de5f11c12c29abdab46f1248e6727c385668561c81880c8ed31d
  • Pointer size: 132 Bytes
  • Size of remote file: 4.81 MB
assets/example_truth_2.png ADDED

Git LFS Details

  • SHA256: 503dbab87683109c6df9848b1ccfa3b1de79503bb5e9ed7b64f87d4a5893030e
  • Pointer size: 132 Bytes
  • Size of remote file: 3.21 MB
assets/example_truth_3.png ADDED

Git LFS Details

  • SHA256: 5413190d0c3473a03f15d2d9d6b30d7c1ef769e68abae37a7819705b23a6edfd
  • Pointer size: 132 Bytes
  • Size of remote file: 4.15 MB
assets/example_truth_4.png ADDED

Git LFS Details

  • SHA256: debdccfc969a8329bd79ba3fc9439d768a2a5373185141d1de76e2b24a26a5e2
  • Pointer size: 132 Bytes
  • Size of remote file: 3.09 MB
data/val_tablet_ids.json ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "HS_1746",
3
+ "HS_1059",
4
+ "HS_1660",
5
+ "HS_2631",
6
+ "HS_2072",
7
+ "HS_890",
8
+ "HS_883",
9
+ "HS_0713",
10
+ "HS_919",
11
+ "HS_0459",
12
+ "HS_1327",
13
+ "HS_736",
14
+ "HS_1200",
15
+ "HS_294",
16
+ "HS_0205",
17
+ "HS_0362",
18
+ "HS_510",
19
+ "HS_1122",
20
+ "HS_2467",
21
+ "HS_1650",
22
+ "HS_2590",
23
+ "HS_2616",
24
+ "HS_1336",
25
+ "HS_2355",
26
+ "HS_0449",
27
+ "HS_1770",
28
+ "HS_0898",
29
+ "HS_2309",
30
+ "HS_2084",
31
+ "HS_566",
32
+ "HS_0199",
33
+ "HS_843",
34
+ "HS_1275",
35
+ "HS_2556",
36
+ "HS_1506",
37
+ "HS_1643",
38
+ "HS_0661",
39
+ "HS_1774",
40
+ "HS_0626",
41
+ "HS_933",
42
+ "HS_1485",
43
+ "HS_665",
44
+ "HS_1175",
45
+ "HS_1045",
46
+ "HS_901",
47
+ "HS_1494",
48
+ "HS_194a",
49
+ "HS_491",
50
+ "HS_1052",
51
+ "HS_841",
52
+ "HS_653",
53
+ "HS_0102",
54
+ "HS_848",
55
+ "HS_1304",
56
+ "HS_2503",
57
+ "HS_2061",
58
+ "HS_1186",
59
+ "HS_1944",
60
+ "HS_929",
61
+ "HS_501",
62
+ "HS_2673",
63
+ "HS_535",
64
+ "HS_1139",
65
+ "HS_2373",
66
+ "HS_0151",
67
+ "HS_2550",
68
+ "HS_2249",
69
+ "HS_1210",
70
+ "HS_1182",
71
+ "HS_0628",
72
+ "HS_0158b",
73
+ "HS_0164",
74
+ "HS_1949",
75
+ "HS_2511",
76
+ "HS_0570",
77
+ "HS_2337",
78
+ "HS_598",
79
+ "HS_435",
80
+ "HS_0717",
81
+ "HS_588",
82
+ "HS_1010",
83
+ "HS_1192",
84
+ "HS_1235",
85
+ "HS_1298",
86
+ "HS_600",
87
+ "HS_0147",
88
+ "HS_0749",
89
+ "HS_2641"
90
+ ]
dev_scripts/benchmark.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Benchmark script for NisabaRelief inference pipeline."""
2
+
3
+ import argparse
4
+ import statistics
5
+ import time
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+ from rich.console import Console
12
+ from rich.progress import (
13
+ BarColumn,
14
+ MofNCompleteColumn,
15
+ Progress,
16
+ TextColumn,
17
+ TimeElapsedColumn,
18
+ )
19
+ from rich.table import Table
20
+
21
+ from nisaba_relief import NisabaRelief
22
+ from util.load_val_dataset import load_val_dataset
23
+
24
+ BENCHMARK_DIR = Path(__file__).parent.parent / "data" / "benchmark"
25
+ BASELINE = BENCHMARK_DIR / "benchmark_baseline.png"
26
+ WARMUP_RUNS = 2
27
+ BENCH_RUNS = 3
28
+
29
+
30
+ def build_timing_table(timings: list[float], n_warmup: int) -> Table:
31
+ bench_timings = timings[n_warmup:]
32
+ mean = statistics.mean(bench_timings)
33
+ stdev = statistics.stdev(bench_timings) if len(bench_timings) > 1 else 0.0
34
+ table = Table(title="Inference Timings")
35
+ table.add_column("Run", justify="right")
36
+ table.add_column("Time", justify="right")
37
+ for i, t in enumerate(timings, 1):
38
+ label = f"[dim]{i} (warmup)[/dim]" if i <= n_warmup else str(i - n_warmup)
39
+ time_str = f"[dim]{t:.2f}s[/dim]" if i <= n_warmup else f"{t:.2f}s"
40
+ table.add_row(label, time_str)
41
+ table.add_section()
42
+ table.add_row("[bold]Mean[/bold]", f"[bold]{mean:.2f} ± {stdev:.2f}s[/bold]")
43
+ return table
44
+
45
+
46
+ def build_diff_table(flat: np.ndarray, max_diff: int) -> Table:
47
+ percentile_vals = np.percentile(flat, [50, 90, 95, 96, 97, 98, 99])
48
+ p98 = percentile_vals[5]
49
+ status = "PASS" if p98 <= 1 else "FAIL"
50
+ status_style = "green" if status == "PASS" else "red"
51
+ table = Table(
52
+ title=f"Pixel Diff vs Baseline — [{status_style}]{status}[/{status_style}]"
53
+ )
54
+ table.add_column("Stat", style="bold")
55
+ table.add_column("Value", justify="right")
56
+ table.add_row("Mean", f"{flat.mean():.4f}")
57
+ for label, val in zip(
58
+ ["p50", "p90", "p95", "p96", "p97", "p98", "p99"], percentile_vals
59
+ ):
60
+ table.add_row(label, f"{val:.0f}")
61
+ table.add_row("Max", str(max_diff))
62
+ return table
63
+
64
+
65
+ def main():
66
+ parser = argparse.ArgumentParser(
67
+ description="Benchmark NisabaRelief inference pipeline"
68
+ )
69
+ parser.add_argument(
70
+ "--weights-dir",
71
+ default=".",
72
+ metavar="PATH",
73
+ help="path to weights directory (default: .)",
74
+ )
75
+ parser.add_argument(
76
+ "--device",
77
+ default=None,
78
+ metavar="DEVICE",
79
+ help="device to run inference on, e.g. cuda, cpu (default: cuda if available, else cpu)",
80
+ )
81
+ args = parser.parse_args()
82
+
83
+ console = Console()
84
+ rows = load_val_dataset()
85
+ test_image = rows[0]["photo"]
86
+ max_dim = max(test_image.size)
87
+ if max_dim > 2048:
88
+ scale = 2048 / max_dim
89
+ new_size = (round(test_image.width * scale), round(test_image.height * scale))
90
+ test_image = test_image.resize(new_size, Image.LANCZOS)
91
+ console.print(f"Input size: [cyan]{test_image.width}x{test_image.height}[/cyan]")
92
+
93
+ model_kwargs = dict(seed=42, weights_dir=Path(args.weights_dir))
94
+ if args.device is not None:
95
+ model_kwargs["device"] = args.device
96
+ model = NisabaRelief(**model_kwargs)
97
+
98
+ timings = []
99
+ output = None
100
+ total_runs = WARMUP_RUNS + BENCH_RUNS
101
+ progress = Progress(
102
+ TextColumn("[progress.description]{task.description}"),
103
+ BarColumn(),
104
+ MofNCompleteColumn(),
105
+ TimeElapsedColumn(),
106
+ )
107
+ with progress:
108
+ task = progress.add_task("Benchmarking", total=total_runs)
109
+ for i in range(total_runs):
110
+ t0 = time.perf_counter()
111
+ result = model.process(test_image, show_pbar=False)
112
+ timings.append(time.perf_counter() - t0)
113
+ progress.advance(task)
114
+ if i == WARMUP_RUNS:
115
+ output = result
116
+
117
+ console.print(build_timing_table(timings, WARMUP_RUNS))
118
+
119
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
120
+ run_path = BENCHMARK_DIR / f"benchmark_{timestamp}.png"
121
+ run_path.parent.mkdir(parents=True, exist_ok=True)
122
+ output.save(run_path)
123
+ console.print(f"Run image saved to [cyan]{run_path}[/cyan]")
124
+
125
+ output_arr = np.array(output)
126
+
127
+ if not BASELINE.exists():
128
+ output.save(BASELINE)
129
+ console.print(f"Baseline saved to [cyan]{BASELINE}[/cyan]")
130
+ else:
131
+ baseline_arr = np.array(Image.open(BASELINE))
132
+ diff = np.abs(output_arr.astype(int) - baseline_arr.astype(int))
133
+ flat = diff.flatten()
134
+ max_diff = int(flat.max())
135
+ console.print(build_diff_table(flat, max_diff))
136
+
137
+ if max_diff > 0:
138
+ diff_img = Image.fromarray(
139
+ np.clip(diff * (255 // max_diff), 0, 255).astype("uint8")
140
+ )
141
+ diff_path = Path(f"benchmark_{timestamp}_diff.png")
142
+ diff_img.save(diff_path)
143
+ console.print(
144
+ f"Diff image saved to [cyan]{diff_path}[/cyan] (amplified {255 // max_diff}x)"
145
+ )
146
+
147
+
148
+ if __name__ == "__main__":
149
+ main()
dev_scripts/evaluation.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluate NisabaRelief on the validation set, optionally sweeping over step counts.
3
+
4
+ Usage:
5
+ python evaluation.py # full dataset, num_steps=2
6
+ python evaluation.py --sweep # subset, steps=[1,2,4,8]
7
+ """
8
+
9
+ import argparse
10
+ import time
11
+ from datetime import timedelta
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+ from PIL import Image
16
+ from rich.console import Console, Group
17
+ from rich.live import Live
18
+ from rich.progress import (
19
+ BarColumn,
20
+ MofNCompleteColumn,
21
+ Progress,
22
+ TextColumn,
23
+ TimeElapsedColumn,
24
+ )
25
+ from rich.table import Table
26
+
27
+ from nisaba_relief import NisabaRelief
28
+ from util.metrics import compute_metrics, METRIC_NAMES, LABELS
29
+ from util.load_val_dataset import load_val_dataset
30
+
31
+
32
+ SWEEP_STEPS = [1, 2, 4, 8]
33
+ DEFAULT_STEPS = 2
34
+ SWEEP_STRIDE = 4
35
+ SWEEP_MAX = 175
36
+ EVALS_DIR = Path(__file__).parent.parent / "data" / "evals"
37
+
38
+
39
+ def _eta(n_done: int, n_total: int, elapsed: float) -> str:
40
+ if n_done >= n_total > 0:
41
+ return "Done"
42
+ if n_done > 0:
43
+ return str(timedelta(seconds=int(elapsed / n_done * (n_total - n_done))))
44
+ return "?"
45
+
46
+
47
+ def build_table(
48
+ results: dict,
49
+ n_done: int = 0,
50
+ n_total: int = 0,
51
+ elapsed: float = 0.0,
52
+ ) -> Table:
53
+ eta = _eta(n_done, n_total, elapsed)
54
+ steps = list(results.keys())
55
+ table = Table(title=f"Results — ETA: {eta}")
56
+ table.add_column("Metric", style="bold")
57
+ for s in steps:
58
+ table.add_column(f"Steps={s}", justify="right")
59
+ for name in METRIC_NAMES:
60
+ cells = []
61
+ for s in steps:
62
+ arr = np.array(results[s][name])
63
+ if len(arr) == 0:
64
+ cells.append("—")
65
+ elif name in ("psnr", "psnr_hvsm", "sre"):
66
+ cells.append(f"{arr.mean():.2f} ± {arr.std():.2f} dB")
67
+ else:
68
+ cells.append(f"{arr.mean():.4f} ± {arr.std():.4f}")
69
+ table.add_row(LABELS[name], *cells)
70
+ return table
71
+
72
+
73
+ def load_grayscale(img: Image.Image) -> np.ndarray:
74
+ return np.array(img.convert("L"))
75
+
76
+
77
+ def main():
78
+ parser = argparse.ArgumentParser(description="Evaluate NisabaRelief model")
79
+ parser.add_argument(
80
+ "--weights-dir",
81
+ default=".",
82
+ metavar="PATH",
83
+ help="path to weights directory (default: .)",
84
+ )
85
+ parser.add_argument(
86
+ "--sweep",
87
+ action="store_true",
88
+ help="sweep over steps=[1,2,4,8] on a dataset subset",
89
+ )
90
+ args = parser.parse_args()
91
+
92
+ rows = load_val_dataset()
93
+ if args.sweep:
94
+ rows = rows.select(
95
+ range(0, min(len(rows), SWEEP_MAX * SWEEP_STRIDE), SWEEP_STRIDE)
96
+ )
97
+ steps_to_run = SWEEP_STEPS
98
+ else:
99
+ steps_to_run = [DEFAULT_STEPS]
100
+ results = {s: {m: [] for m in METRIC_NAMES} for s in steps_to_run}
101
+
102
+ model = NisabaRelief(seed=42, batch_size=4, weights_dir=Path(args.weights_dir))
103
+
104
+ progress = Progress(
105
+ TextColumn("[progress.description]{task.description}"),
106
+ BarColumn(),
107
+ MofNCompleteColumn(),
108
+ TimeElapsedColumn(),
109
+ TextColumn("[cyan]{task.fields[hs_number]}"),
110
+ )
111
+ task_desc = "Step Sweep" if args.sweep else "Evaluating"
112
+ task = progress.add_task(task_desc, total=len(rows), hs_number="")
113
+
114
+ start_time = time.monotonic()
115
+ with Live(
116
+ Group(progress, build_table(results)),
117
+ refresh_per_second=4,
118
+ transient=True,
119
+ ) as live:
120
+ for n_done, row in enumerate(rows):
121
+ progress.update(task, hs_number=row["hs_number"])
122
+ gt = load_grayscale(row["msii"])
123
+
124
+ for num_steps in steps_to_run:
125
+ model.num_steps = num_steps
126
+ save_name = f"{row['hs_number']}_photo_fullview_{int(row['variation']):02d}-step{num_steps}.png"
127
+ save_path = EVALS_DIR / save_name
128
+ save_path.parent.mkdir(parents=True, exist_ok=True)
129
+
130
+ if save_path.exists():
131
+ pred_img = Image.open(save_path)
132
+ else:
133
+ pred_img = model.process(row["photo"], show_pbar=False)
134
+ pred_img.save(save_path)
135
+
136
+ pred = load_grayscale(pred_img)
137
+ pred_img.close()
138
+
139
+ if pred.shape != gt.shape:
140
+ pred = np.array(
141
+ Image.fromarray(pred).resize(
142
+ (gt.shape[1], gt.shape[0]), Image.LANCZOS
143
+ )
144
+ )
145
+
146
+ m = compute_metrics(pred, gt)
147
+ for name, val in m.items():
148
+ results[num_steps][name].append(val)
149
+
150
+ elapsed = time.monotonic() - start_time
151
+ live.update(
152
+ Group(progress, build_table(results, n_done + 1, len(rows), elapsed))
153
+ )
154
+
155
+ progress.advance(task)
156
+
157
+ final_elapsed = time.monotonic() - start_time
158
+ Console().print(build_table(results, len(rows), len(rows), final_elapsed))
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
dev_scripts/process_images.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Process a directory of images through NisabaRelief and save as PNG."""
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ from PIL import Image
7
+ from rich.console import Console
8
+ from rich.progress import (
9
+ BarColumn,
10
+ MofNCompleteColumn,
11
+ Progress,
12
+ ProgressColumn,
13
+ SpinnerColumn,
14
+ Task,
15
+ TextColumn,
16
+ TimeElapsedColumn,
17
+ )
18
+ from rich.text import Text
19
+
20
+ from nisaba_relief import NisabaRelief
21
+ from nisaba_relief.constants import MAX_TILE, MIN_IMAGE_DIMENSION
22
+
23
+ Image.MAX_IMAGE_PIXELS = None
24
+
25
+ IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp", ".webp"}
26
+
27
+
28
+ class SimpleTimeRemainingColumn(ProgressColumn):
29
+ """Estimates remaining time from the average duration of the last 10 iterations.
30
+
31
+ Only recomputes when a new step completes so the display is stable.
32
+ """
33
+
34
+ def __init__(self, window: int = 10) -> None:
35
+ super().__init__()
36
+ self._last_completed: float = 0
37
+ self._last_elapsed: float = 0.0
38
+ self._durations: list[float] = []
39
+ self._window: int = window
40
+ self._cached: Text = Text("-:--:--", style="progress.remaining")
41
+
42
+ def render(self, task: Task) -> Text:
43
+ if task.completed <= self._last_completed:
44
+ return self._cached
45
+ elapsed = task.finished_time if task.finished else task.elapsed
46
+ if not elapsed or not task.completed:
47
+ self._last_completed = task.completed
48
+ self._cached = Text("-:--:--", style="progress.remaining")
49
+ return self._cached
50
+ step_duration = elapsed - self._last_elapsed
51
+ steps = task.completed - self._last_completed
52
+ if steps > 0 and self._last_completed > 0:
53
+ per_step = step_duration / steps
54
+ self._durations.append(per_step)
55
+ if len(self._durations) > self._window:
56
+ self._durations = self._durations[-self._window :]
57
+ self._last_completed = task.completed
58
+ self._last_elapsed = elapsed
59
+ if not self._durations:
60
+ self._cached = Text("-:--:--", style="progress.remaining")
61
+ return self._cached
62
+ avg = sum(self._durations) / len(self._durations)
63
+ remaining = task.total - task.completed
64
+ eta_seconds = avg * remaining
65
+ hours, rem = divmod(int(eta_seconds), 3600)
66
+ minutes, seconds = divmod(rem, 60)
67
+ if hours:
68
+ self._cached = Text(
69
+ f"{hours}:{minutes:02d}:{seconds:02d}", style="progress.remaining"
70
+ )
71
+ else:
72
+ self._cached = Text(f"{minutes}:{seconds:02d}", style="progress.remaining")
73
+ return self._cached
74
+
75
+
76
+ def main():
77
+ parser = argparse.ArgumentParser(
78
+ description="Process images through NisabaRelief and save as PNG."
79
+ )
80
+ parser.add_argument(
81
+ "--input-dir", type=Path, required=True, help="Source image directory"
82
+ )
83
+ parser.add_argument(
84
+ "--output-dir", type=Path, required=True, help="Destination directory (created if needed)"
85
+ )
86
+ parser.add_argument(
87
+ "--max-size", type=int, default=MAX_TILE * 5,
88
+ help="Downsample images larger than this before processing (default: %(default)s)",
89
+ )
90
+ parser.add_argument(
91
+ "--min-size", type=int, default=1536,
92
+ help="Skip images where max dimension < this (default: %(default)s)",
93
+ )
94
+ parser.add_argument("--seed", type=int, default=None, help="Reproducibility seed")
95
+ parser.add_argument("--weights-dir", type=Path, default=None, help="Local weights directory")
96
+ parser.add_argument("--batch-size", type=int, default=None, help="Tile batch size")
97
+ parser.add_argument("--num-steps", type=int, default=2, help="Solver steps (default: %(default)s)")
98
+ parser.add_argument("--device", default="cuda", help="Torch device (default: %(default)s)")
99
+ parser.add_argument(
100
+ "--overwrite", action="store_true", help="Re-process even if output file exists"
101
+ )
102
+ args = parser.parse_args()
103
+
104
+ console = Console()
105
+
106
+ input_dir: Path = args.input_dir
107
+ output_dir: Path = args.output_dir
108
+
109
+ if not input_dir.is_dir():
110
+ console.print(f"[red]Input directory not found:[/red] [cyan]{input_dir}[/cyan]")
111
+ return
112
+
113
+ input_images = sorted(
114
+ p for p in input_dir.iterdir() if p.suffix.lower() in IMAGE_EXTENSIONS
115
+ )
116
+ if not input_images:
117
+ console.print(f"[red]No images found in[/red] [cyan]{input_dir}[/cyan]")
118
+ return
119
+
120
+ output_dir.mkdir(parents=True, exist_ok=True)
121
+
122
+ to_process = []
123
+ skipped_existing = 0
124
+ skipped_small = 0
125
+ for src in input_images:
126
+ dst = output_dir / (src.stem + ".png")
127
+ if not args.overwrite and dst.exists():
128
+ skipped_existing += 1
129
+ continue
130
+ with Image.open(src) as img:
131
+ if max(img.size) < args.min_size or min(img.size) < MIN_IMAGE_DIMENSION:
132
+ skipped_small += 1
133
+ continue
134
+ to_process.append((src, dst))
135
+
136
+ if skipped_existing:
137
+ console.print(
138
+ f"[dim]Skipping {skipped_existing} already-processed image(s)[/dim]"
139
+ )
140
+ if skipped_small:
141
+ console.print(
142
+ f"[dim]Skipping {skipped_small} image(s) smaller than {args.min_size}px[/dim]"
143
+ )
144
+
145
+ if not to_process:
146
+ console.print("[green]All images already processed.[/green]")
147
+ return
148
+
149
+ console.print(
150
+ f"Processing [bold]{len(to_process)}[/bold] / {len(input_images)} images "
151
+ f"[dim]({input_dir} → {output_dir})[/dim]"
152
+ )
153
+
154
+ model_kwargs = dict(num_steps=args.num_steps, device=args.device)
155
+ if args.seed is not None:
156
+ model_kwargs["seed"] = args.seed
157
+ if args.weights_dir is not None:
158
+ model_kwargs["weights_dir"] = args.weights_dir
159
+ if args.batch_size is not None:
160
+ model_kwargs["batch_size"] = args.batch_size
161
+ model = NisabaRelief(**model_kwargs)
162
+
163
+ progress = Progress(
164
+ SpinnerColumn(),
165
+ TextColumn("[progress.description]{task.description}"),
166
+ BarColumn(),
167
+ MofNCompleteColumn(),
168
+ TimeElapsedColumn(),
169
+ TextColumn("eta"),
170
+ SimpleTimeRemainingColumn(),
171
+ )
172
+ with progress:
173
+ task = progress.add_task("Processing", total=len(to_process))
174
+ for src, dst in to_process:
175
+ progress.update(task, description=f"[cyan]{src.name}[/cyan]")
176
+ image = Image.open(src).convert("RGB")
177
+ original_size = image.size
178
+ if max(image.size) > args.max_size:
179
+ scale = args.max_size / max(image.size)
180
+ new_size = (
181
+ round(image.width * scale) // 16 * 16,
182
+ round(image.height * scale) // 16 * 16,
183
+ )
184
+ image = image.resize(new_size, Image.LANCZOS)
185
+ result = model.process(image, show_pbar=False)
186
+ if result.size != original_size:
187
+ result = result.resize(original_size, Image.LANCZOS)
188
+ result.save(dst)
189
+ progress.advance(task)
190
+
191
+ console.print(
192
+ f"[green]Done.[/green] {len(to_process)} image(s) saved to [cyan]{output_dir}[/cyan]"
193
+ )
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()
dev_scripts/util/load_val_dataset.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Load the validation set from the CuneiformPhotosMSII dataset.
3
+ """
4
+
5
+ from datasets import load_dataset, Dataset
6
+
7
+ from pathlib import Path
8
+ import json
9
+
10
+ VAL_IDS_PATH = Path(__file__).parent.parent.parent / "data" / "val_tablet_ids.json"
11
+ VAL_IDS = set(json.load(open(VAL_IDS_PATH)))
12
+
13
+
14
+ def load_val_dataset() -> Dataset:
15
+ ds = load_dataset("boatbomber/CuneiformPhotosMSII", split="train", num_proc=4)
16
+
17
+ # First pass: parquet column projection reads only the ID strings, skipping image bytes
18
+ indices = [
19
+ i
20
+ for i, row in enumerate(ds.select_columns(["hs_number"]))
21
+ if row["hs_number"] in VAL_IDS
22
+ ]
23
+
24
+ return ds.select(indices)
dev_scripts/util/metrics.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared metric computation for NisabaRelief evaluation scripts."""
2
+
3
+ from concurrent.futures import ThreadPoolExecutor
4
+
5
+ import numpy as np
6
+ from image_similarity_measures.quality_metrics import (
7
+ rmse,
8
+ psnr,
9
+ sre,
10
+ )
11
+ import torch
12
+ from pytorch_msssim import ms_ssim as _pt_msssim
13
+ from util.psnr_hvsm import psnr_hvsm
14
+
15
+ DICE_THRESHOLD = 130
16
+
17
+ METRIC_NAMES = [
18
+ "dice",
19
+ "rmse",
20
+ "msssim",
21
+ "psnr",
22
+ "psnr_hvsm",
23
+ "sre",
24
+ ]
25
+
26
+ LABELS = {
27
+ "dice": "**Dice**",
28
+ "rmse": "RMSE",
29
+ "msssim": "MS-SSIM",
30
+ "psnr": "PSNR",
31
+ "psnr_hvsm": "PSNR-HVS-M",
32
+ "sre": "SRE",
33
+ }
34
+
35
+
36
+ def _to_tensor(arr: np.ndarray) -> torch.Tensor:
37
+ return torch.from_numpy(arr).float().unsqueeze(0).unsqueeze(0)
38
+
39
+
40
+ def _msssim(gt: np.ndarray, pred: np.ndarray) -> float:
41
+ return _pt_msssim(
42
+ _to_tensor(gt), _to_tensor(pred), data_range=255, size_average=True
43
+ ).item()
44
+
45
+
46
+ def compute_metrics(pred: np.ndarray, gt: np.ndarray) -> dict[str, float]:
47
+ """Compute all metrics for a pair of equal-shape grayscale uint8 images."""
48
+ pred_3d = pred[:, :, np.newaxis]
49
+ gt_3d = gt[:, :, np.newaxis]
50
+
51
+ pred_bin = pred > DICE_THRESHOLD
52
+ gt_bin = gt > DICE_THRESHOLD
53
+ denom = pred_bin.sum() + gt_bin.sum()
54
+ dice = float(2 * np.logical_and(pred_bin, gt_bin).sum() / denom) if denom > 0 else 1.0
55
+
56
+ tasks = {
57
+ "rmse": lambda: rmse(gt_3d, pred_3d, max_p=255),
58
+ "psnr": lambda: psnr(gt_3d, pred_3d, max_p=255),
59
+ "msssim": lambda: _msssim(gt, pred),
60
+ "sre": lambda: sre(gt_3d, pred_3d),
61
+ "psnr_hvsm": lambda: psnr_hvsm(gt, pred)[0],
62
+ "dice": lambda: dice,
63
+ }
64
+
65
+ with ThreadPoolExecutor(max_workers=len(tasks)) as executor:
66
+ futures = {name: executor.submit(fn) for name, fn in tasks.items()}
67
+ return {name: future.result() for name, future in futures.items()}
dev_scripts/util/psnr_hvsm.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PSNR-HVS-M and PSNR-HVS metrics (Ponomarenko et al., 2006/2007).
2
+
3
+ Direct Python translation of the MATLAB reference implementation at
4
+ https://www.ponomarenko.info/psnrhvsm.m
5
+
6
+ Returns (p_hvs_m, p_hvs) as a tuple.
7
+ Uses CUDA if available, otherwise falls back to CPU.
8
+ """
9
+
10
+ import math
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ _N = 8
16
+
17
+
18
+ def _make_dct_matrix() -> torch.Tensor:
19
+ """8x8 orthonormal DCT-II matrix: D[0,n]=1/√N, D[k>0,n]=√(2/N)·cos(π·k·(2n+1)/(2N))."""
20
+ k = torch.arange(_N, dtype=torch.float64).unsqueeze(1)
21
+ n = torch.arange(_N, dtype=torch.float64).unsqueeze(0)
22
+ D = torch.cos(math.pi * k * (2 * n + 1) / (2 * _N))
23
+ D[0] = D[0] / math.sqrt(_N)
24
+ D[1:] = D[1:] * math.sqrt(2.0 / _N)
25
+ return D
26
+
27
+
28
+ _DCT8 = _make_dct_matrix() # (8, 8), CPU float64
29
+
30
+ _CSF = torch.tensor(
31
+ [
32
+ [1.608443, 2.339554, 2.573509, 1.608443, 1.072295, 0.643377, 0.504610, 0.421887],
33
+ [2.144591, 2.144591, 1.838221, 1.354478, 0.989811, 0.443708, 0.428918, 0.467911],
34
+ [1.838221, 1.979622, 1.608443, 1.072295, 0.643377, 0.451493, 0.372972, 0.459555],
35
+ [1.838221, 1.513829, 1.169777, 0.887417, 0.504610, 0.295806, 0.321689, 0.415082],
36
+ [1.429727, 1.169777, 0.695543, 0.459555, 0.378457, 0.236102, 0.249855, 0.334222],
37
+ [1.072295, 0.735288, 0.467911, 0.402111, 0.317717, 0.247453, 0.227744, 0.279729],
38
+ [0.525206, 0.402111, 0.329937, 0.295806, 0.249855, 0.212687, 0.214459, 0.254803],
39
+ [0.357432, 0.279729, 0.270896, 0.262603, 0.229778, 0.257351, 0.249855, 0.259950],
40
+ ],
41
+ dtype=torch.float64,
42
+ )
43
+ _MASKCOF = torch.tensor(
44
+ [
45
+ [0.390625, 0.826446, 1.000000, 0.390625, 0.173611, 0.062500, 0.038447, 0.026874],
46
+ [0.694444, 0.694444, 0.510204, 0.277008, 0.147929, 0.029727, 0.027778, 0.033058],
47
+ [0.510204, 0.591716, 0.390625, 0.173611, 0.062500, 0.030779, 0.021004, 0.031888],
48
+ [0.510204, 0.346021, 0.206612, 0.118906, 0.038447, 0.013212, 0.015625, 0.026015],
49
+ [0.308642, 0.206612, 0.073046, 0.031888, 0.021626, 0.008417, 0.009426, 0.016866],
50
+ [0.173611, 0.081633, 0.033058, 0.024414, 0.015242, 0.009246, 0.007831, 0.011815],
51
+ [0.041649, 0.024414, 0.016437, 0.013212, 0.009426, 0.006830, 0.006944, 0.009803],
52
+ [0.019290, 0.011815, 0.011080, 0.010412, 0.007972, 0.010000, 0.009426, 0.010203],
53
+ ],
54
+ dtype=torch.float64,
55
+ )
56
+
57
+ # True everywhere except the DC coefficient at (0, 0)
58
+ _AC_MASK = torch.ones((_N, _N), dtype=torch.bool)
59
+ _AC_MASK[0, 0] = False
60
+
61
+
62
+ def _vari_batch(blocks: torch.Tensor) -> torch.Tensor:
63
+ """Unbiased variance * N for a batch of blocks. (B, H, W) -> (B,)"""
64
+ flat = blocks.reshape(blocks.shape[0], -1)
65
+ return flat.var(dim=-1, correction=1) * flat.shape[-1]
66
+
67
+
68
+ def _maskeff_batch(blocks: torch.Tensor, dct_blocks: torch.Tensor) -> torch.Tensor:
69
+ """Perceptual masking strength for a batch of 8x8 blocks. Returns (B,)."""
70
+ dev = blocks.device
71
+ ac = _AC_MASK.to(dev)
72
+ mc = _MASKCOF.to(dev)
73
+
74
+ m = (dct_blocks[:, ac] ** 2 * mc[ac]).sum(dim=-1) # (B,)
75
+
76
+ pop = _vari_batch(blocks)
77
+ quad = (
78
+ _vari_batch(blocks[:, :4, :4])
79
+ + _vari_batch(blocks[:, :4, 4:])
80
+ + _vari_batch(blocks[:, 4:, :4])
81
+ + _vari_batch(blocks[:, 4:, 4:])
82
+ )
83
+ pop_ratio = torch.where(pop > 0, quad / pop, torch.zeros_like(pop))
84
+ return torch.sqrt(m * pop_ratio) / 32.0
85
+
86
+
87
+ def psnr_hvsm(img1: np.ndarray, img2: np.ndarray) -> tuple[float, float]:
88
+ """Return (PSNR-HVS-M, PSNR-HVS) for two uint8 grayscale arrays.
89
+
90
+ Direct translation of the MATLAB reference (Ponomarenko et al.).
91
+ Partial edge blocks are skipped (truncate to nearest multiple of 8).
92
+ """
93
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
+ D = _DCT8.to(device)
95
+ csf = _CSF.to(device)
96
+ maskcof = _MASKCOF.to(device)
97
+ ac_mask = _AC_MASK.to(device)
98
+
99
+ a = torch.from_numpy(img1.astype(np.float64)).to(device)
100
+ b = torch.from_numpy(img2.astype(np.float64)).to(device)
101
+
102
+ h, w = a.shape
103
+ h = (h // 8) * 8
104
+ w = (w // 8) * 8
105
+ a = a[:h, :w]
106
+ b = b[:h, :w]
107
+
108
+ num_blocks = (h // 8) * (w // 8)
109
+ if num_blocks == 0:
110
+ return 100000.0, 100000.0
111
+
112
+ # Extract all non-overlapping 8x8 blocks: (B, 8, 8)
113
+ ba = a.unfold(0, 8, 8).unfold(1, 8, 8).contiguous().reshape(-1, 8, 8)
114
+ bb = b.unfold(0, 8, 8).unfold(1, 8, 8).contiguous().reshape(-1, 8, 8)
115
+
116
+ # 2D DCT-II (ortho) via separable matrix product: D @ block @ D.T
117
+ da = D @ ba @ D.t()
118
+ db = D @ bb @ D.t()
119
+
120
+ mask = torch.maximum(_maskeff_batch(ba, da), _maskeff_batch(bb, db)) # (B,)
121
+
122
+ diff = torch.abs(da - db) # (B, 8, 8)
123
+
124
+ # PSNR-HVS: CSF-weighted squared error (no masking)
125
+ S2 = float(((diff * csf) ** 2).sum())
126
+
127
+ # PSNR-HVS-M: soft-threshold AC coefficients by local mask, keep DC as-is
128
+ thresh = mask[:, None, None] / maskcof[None, :, :]
129
+ u = torch.where(ac_mask[None, :, :], torch.clamp(diff - thresh, min=0.0), diff)
130
+ S1 = float(((u * csf) ** 2).sum())
131
+
132
+ denom = num_blocks * 64
133
+ S1 /= denom
134
+ S2 /= denom
135
+ p_hvs_m = 100000.0 if S1 == 0 else float(10.0 * np.log10(255.0**2 / S1))
136
+ p_hvs = 100000.0 if S2 == 0 else float(10.0 * np.log10(255.0**2 / S2))
137
+ return p_hvs_m, p_hvs
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b83231048fc47de658665425368daadce6791bfd95456397b8b595aa0e5d05d
3
+ size 7751105712
nisaba_relief/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """NisabaRelief: Transform cuneiform tablet photos into MSII relief visualizations."""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ from .model import NisabaRelief
6
+
7
+ __all__ = ["NisabaRelief"]
nisaba_relief/constants.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Named constants for NisabaRelief magic numbers."""
2
+
3
+ # Flux model processes images in 16×16 pixel patches
4
+ PATCH_SIZE = 16
5
+
6
+ # Tile size bounds: 12 patches (192px) to 64 patches (1024px)
7
+ MIN_TILE = PATCH_SIZE * 12 # 192
8
+ MAX_TILE = PATCH_SIZE * 64 # 1024
9
+
10
+ # Aims for ~4 tiles along the longest axis when computing tile size
11
+ TARGET_TILES_PER_SIDE = 4
12
+
13
+ # Overlap is 1/8 of the tile size, giving a smooth cosine blend region
14
+ TILE_OVERLAP_DIVISOR = 8
15
+
16
+ # Smallest accepted input side in pixels
17
+ MIN_IMAGE_DIMENSION = MIN_TILE * 2
18
+
19
+ # Maximum allowed aspect ratio (width:height or height:width)
20
+ MAX_ASPECT_RATIO = 8.0
21
+
22
+ # Maximum size (px) for the global context thumbnail
23
+ MAX_GLOBAL_CONTEXT_SIZE = 128
24
+
25
+ # Positional sequence ID for conditioning tokens (image being processed)
26
+ COND_SEQ_ID = 10
27
+
28
+ # Positional sequence ID for global context tokens (thumbnail overview)
29
+ GLOBAL_CTX_ID = 20
30
+
31
+ # Number of latent channels in the Flux model's latent space
32
+ LATENT_CHANNELS = 128
33
+
34
+ # Dynamic batch_size constants. Determined empirically on an RTX 3090.
35
+ MAX_BATCH_SIZE = 16
36
+ MIN_BATCH_SIZE = 1
37
+ VRAM_MB_PER_PIXEL = 0.0035
38
+ VRAM_FIXED_OVERHEAD_MB = 15.0
39
+ VRAM_HEADROOM_MB = 1024.0
40
+
41
+ # Divisor for AE decoder sub-batching (decoder needs more VRAM than denoiser)
42
+ DECODE_BATCH_SIZE_DIVISOR = 5
nisaba_relief/flux/__init__.py ADDED
File without changes
nisaba_relief/flux/autoencoder.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass, field
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from torch import Tensor, nn
8
+
9
+
10
+ @dataclass
11
+ class AutoEncoderParams:
12
+ resolution: int = 256
13
+ in_channels: int = 3
14
+ ch: int = 128
15
+ out_ch: int = 3
16
+ ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
17
+ num_res_blocks: int = 2
18
+ z_channels: int = 32
19
+
20
+
21
+ class AttnBlock(nn.Module):
22
+ def __init__(self, in_channels: int):
23
+ super().__init__()
24
+ self.in_channels = in_channels
25
+
26
+ self.norm = nn.GroupNorm(
27
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
28
+ )
29
+
30
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
31
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
32
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
+
35
+ def attention(self, h_: Tensor) -> Tensor:
36
+ h_ = self.norm(h_)
37
+ q = self.q(h_)
38
+ k = self.k(h_)
39
+ v = self.v(h_)
40
+
41
+ b, c, h, w = q.shape
42
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
43
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
44
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
45
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
46
+
47
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
48
+
49
+ def forward(self, x: Tensor) -> Tensor:
50
+ return x + self.proj_out(self.attention(x))
51
+
52
+
53
+ class ResnetBlock(nn.Module):
54
+ def __init__(self, in_channels: int, out_channels: int):
55
+ super().__init__()
56
+ self.in_channels = in_channels
57
+ out_channels = in_channels if out_channels is None else out_channels
58
+ self.out_channels = out_channels
59
+
60
+ self.norm1 = nn.GroupNorm(
61
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
62
+ )
63
+ self.conv1 = nn.Conv2d(
64
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
65
+ )
66
+ self.norm2 = nn.GroupNorm(
67
+ num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
68
+ )
69
+ self.conv2 = nn.Conv2d(
70
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
71
+ )
72
+ if self.in_channels != self.out_channels:
73
+ self.nin_shortcut = nn.Conv2d(
74
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
75
+ )
76
+
77
+ def forward(self, x: Tensor) -> Tensor:
78
+ h = x
79
+ h = self.norm1(h)
80
+ h = F.silu(h)
81
+ h = self.conv1(h)
82
+
83
+ h = self.norm2(h)
84
+ h = F.silu(h)
85
+ h = self.conv2(h)
86
+
87
+ if self.in_channels != self.out_channels:
88
+ x = self.nin_shortcut(x)
89
+
90
+ return x + h
91
+
92
+
93
+ class Downsample(nn.Module):
94
+ def __init__(self, in_channels: int):
95
+ super().__init__()
96
+ # no asymmetric padding in torch conv, must do it ourselves
97
+ self.conv = nn.Conv2d(
98
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
99
+ )
100
+
101
+ def forward(self, x: Tensor) -> Tensor:
102
+ pad = (0, 1, 0, 1)
103
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
104
+ x = self.conv(x)
105
+ return x
106
+
107
+
108
+ class Upsample(nn.Module):
109
+ def __init__(self, in_channels: int):
110
+ super().__init__()
111
+ self.conv = nn.Conv2d(
112
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
113
+ )
114
+
115
+ def forward(self, x: Tensor) -> Tensor:
116
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
117
+ x = self.conv(x)
118
+ return x
119
+
120
+
121
+ class Encoder(nn.Module):
122
+ def __init__(
123
+ self,
124
+ resolution: int,
125
+ in_channels: int,
126
+ ch: int,
127
+ ch_mult: list[int],
128
+ num_res_blocks: int,
129
+ z_channels: int,
130
+ ):
131
+ super().__init__()
132
+ self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1)
133
+ self.ch = ch
134
+ self.num_resolutions = len(ch_mult)
135
+ self.num_res_blocks = num_res_blocks
136
+ self.resolution = resolution
137
+ self.in_channels = in_channels
138
+ # downsampling
139
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
140
+
141
+ curr_res = resolution
142
+ in_ch_mult = (1,) + tuple(ch_mult)
143
+ self.in_ch_mult = in_ch_mult
144
+ self.down = nn.ModuleList()
145
+ block_in = self.ch
146
+ for i_level in range(self.num_resolutions):
147
+ block = nn.ModuleList()
148
+ attn = nn.ModuleList()
149
+ block_in = ch * in_ch_mult[i_level]
150
+ block_out = ch * ch_mult[i_level]
151
+ for _ in range(self.num_res_blocks):
152
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
153
+ block_in = block_out
154
+ down = nn.Module()
155
+ down.block = block
156
+ down.attn = attn
157
+ if i_level != self.num_resolutions - 1:
158
+ down.downsample = Downsample(block_in)
159
+ curr_res = curr_res // 2
160
+ self.down.append(down)
161
+
162
+ # middle
163
+ self.mid = nn.Module()
164
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
165
+ self.mid.attn_1 = AttnBlock(block_in)
166
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
167
+
168
+ # end
169
+ self.norm_out = nn.GroupNorm(
170
+ num_groups=32, num_channels=block_in, eps=1e-6, affine=True
171
+ )
172
+ self.conv_out = nn.Conv2d(
173
+ block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
174
+ )
175
+
176
+ def forward(self, x: Tensor) -> Tensor:
177
+ # downsampling
178
+ h = self.conv_in(x)
179
+ for i_level in range(self.num_resolutions):
180
+ for i_block in range(self.num_res_blocks):
181
+ h = self.down[i_level].block[i_block](h)
182
+ if len(self.down[i_level].attn) > 0:
183
+ h = self.down[i_level].attn[i_block](h)
184
+ if i_level != self.num_resolutions - 1:
185
+ h = self.down[i_level].downsample(h)
186
+
187
+ # middle
188
+ h = self.mid.block_1(h)
189
+ h = self.mid.attn_1(h)
190
+ h = self.mid.block_2(h)
191
+ # end
192
+ h = self.norm_out(h)
193
+ h = F.silu(h)
194
+ h = self.conv_out(h)
195
+ h = self.quant_conv(h)
196
+ return h
197
+
198
+
199
+ class Decoder(nn.Module):
200
+ def __init__(
201
+ self,
202
+ ch: int,
203
+ out_ch: int,
204
+ ch_mult: list[int],
205
+ num_res_blocks: int,
206
+ in_channels: int,
207
+ resolution: int,
208
+ z_channels: int,
209
+ ):
210
+ super().__init__()
211
+ self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
212
+ self.ch = ch
213
+ self.num_resolutions = len(ch_mult)
214
+ self.num_res_blocks = num_res_blocks
215
+ self.resolution = resolution
216
+ self.in_channels = in_channels
217
+ self.ffactor = 2 ** (self.num_resolutions - 1)
218
+
219
+ # compute in_ch_mult, block_in and curr_res at lowest res
220
+ block_in = ch * ch_mult[self.num_resolutions - 1]
221
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
222
+ self.z_shape = (1, z_channels, curr_res, curr_res)
223
+
224
+ # z to block_in
225
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
226
+
227
+ # middle
228
+ self.mid = nn.Module()
229
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
230
+ self.mid.attn_1 = AttnBlock(block_in)
231
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
232
+
233
+ # upsampling
234
+ self.up = nn.ModuleList()
235
+ for i_level in reversed(range(self.num_resolutions)):
236
+ block = nn.ModuleList()
237
+ attn = nn.ModuleList()
238
+ block_out = ch * ch_mult[i_level]
239
+ for _ in range(self.num_res_blocks + 1):
240
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
241
+ block_in = block_out
242
+ up = nn.Module()
243
+ up.block = block
244
+ up.attn = attn
245
+ if i_level != 0:
246
+ up.upsample = Upsample(block_in)
247
+ curr_res = curr_res * 2
248
+ self.up.insert(0, up) # prepend to get consistent order
249
+
250
+ # end
251
+ self.norm_out = nn.GroupNorm(
252
+ num_groups=32, num_channels=block_in, eps=1e-6, affine=True
253
+ )
254
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
255
+
256
+ def forward(self, z: Tensor) -> Tensor:
257
+ z = self.post_quant_conv(z)
258
+
259
+ # get dtype for proper tracing
260
+ upscale_dtype = next(self.up.parameters()).dtype
261
+
262
+ # z to block_in
263
+ h = self.conv_in(z)
264
+
265
+ # middle
266
+ h = self.mid.block_1(h)
267
+ h = self.mid.attn_1(h)
268
+ h = self.mid.block_2(h)
269
+
270
+ # cast to proper dtype
271
+ h = h.to(upscale_dtype)
272
+ # upsampling
273
+ for i_level in reversed(range(self.num_resolutions)):
274
+ for i_block in range(self.num_res_blocks + 1):
275
+ h = self.up[i_level].block[i_block](h)
276
+ if len(self.up[i_level].attn) > 0:
277
+ h = self.up[i_level].attn[i_block](h)
278
+ if i_level != 0:
279
+ h = self.up[i_level].upsample(h)
280
+
281
+ # end
282
+ h = self.norm_out(h)
283
+ h = F.silu(h)
284
+ h = self.conv_out(h)
285
+ return h
286
+
287
+
288
+ class AutoEncoder(nn.Module):
289
+ def __init__(self, params: AutoEncoderParams = AutoEncoderParams()):
290
+ super().__init__()
291
+ self.params = params
292
+ self.encoder = Encoder(
293
+ resolution=params.resolution,
294
+ in_channels=params.in_channels,
295
+ ch=params.ch,
296
+ ch_mult=params.ch_mult,
297
+ num_res_blocks=params.num_res_blocks,
298
+ z_channels=params.z_channels,
299
+ )
300
+ self.decoder = Decoder(
301
+ resolution=params.resolution,
302
+ in_channels=params.in_channels,
303
+ ch=params.ch,
304
+ out_ch=params.out_ch,
305
+ ch_mult=params.ch_mult,
306
+ num_res_blocks=params.num_res_blocks,
307
+ z_channels=params.z_channels,
308
+ )
309
+
310
+ self.bn_eps = 1e-4
311
+ self.bn_momentum = 0.1
312
+ self.ps = [2, 2]
313
+ self.bn = torch.nn.BatchNorm2d(
314
+ math.prod(self.ps) * params.z_channels,
315
+ eps=self.bn_eps,
316
+ momentum=self.bn_momentum,
317
+ affine=False,
318
+ track_running_stats=True,
319
+ )
320
+
321
+ def normalize(self, z: Tensor) -> Tensor:
322
+ return self.bn(z)
323
+
324
+ def inv_normalize(self, z: Tensor) -> Tensor:
325
+ s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps)
326
+ m = self.bn.running_mean.view(1, -1, 1, 1)
327
+ return z * s + m
328
+
329
+ def encode(self, x: Tensor) -> Tensor:
330
+ moments = self.encoder(x)
331
+ mean = torch.chunk(moments, 2, dim=1)[0]
332
+
333
+ z = rearrange(
334
+ mean,
335
+ "... c (i pi) (j pj) -> ... (c pi pj) i j",
336
+ pi=self.ps[0],
337
+ pj=self.ps[1],
338
+ )
339
+ z = self.normalize(z)
340
+ return z
341
+
342
+ def decode(self, z: Tensor) -> Tensor:
343
+ z = self.inv_normalize(z)
344
+ z = rearrange(
345
+ z,
346
+ "... (c pi pj) i j -> ... c (i pi) (j pj)",
347
+ pi=self.ps[0],
348
+ pj=self.ps[1],
349
+ )
350
+ dec = self.decoder(z.to(next(self.decoder.parameters()).dtype))
351
+ return dec
nisaba_relief/flux/layers.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Building-block nn.Module primitives and standalone functions for Flux2."""
2
+
3
+ import math
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import Tensor, nn
8
+
9
+
10
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
11
+ t = time_factor * t
12
+ half = dim // 2
13
+ freqs = torch.exp(
14
+ -math.log(max_period)
15
+ * torch.arange(start=0, end=half, device=t.device, dtype=torch.float32)
16
+ / half
17
+ )
18
+
19
+ args = t[:, None].float() * freqs[None]
20
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
21
+ if dim % 2:
22
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
23
+ if torch.is_floating_point(t):
24
+ embedding = embedding.to(t)
25
+ return embedding
26
+
27
+
28
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
29
+ q, k = apply_rope(q, k, pe)
30
+
31
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
32
+ x = rearrange(x, "B H L D -> B L (H D)")
33
+
34
+ return x
35
+
36
+
37
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
38
+ assert dim % 2 == 0
39
+ scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
40
+ omega = 1.0 / (theta**scale)
41
+ out = torch.einsum("...n,d->...nd", pos.float(), omega)
42
+ out = torch.stack(
43
+ [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)],
44
+ dim=-1,
45
+ )
46
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
47
+ return out
48
+
49
+
50
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
51
+ xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)
52
+ xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)
53
+ freqs_cis = freqs_cis.to(xq.dtype)
54
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
55
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
56
+ return xq_out.reshape(*xq.shape), xk_out.reshape(*xk.shape)
57
+
58
+
59
+ class SelfAttention(nn.Module):
60
+ def __init__(
61
+ self,
62
+ dim: int,
63
+ num_heads: int = 8,
64
+ ):
65
+ super().__init__()
66
+ self.num_heads = num_heads
67
+ head_dim = dim // num_heads
68
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
69
+
70
+ self.norm = QKNorm(head_dim)
71
+ self.proj = nn.Linear(dim, dim, bias=False)
72
+
73
+
74
+ class SiLUActivation(nn.Module):
75
+ def __init__(self):
76
+ super().__init__()
77
+ self.gate_fn = nn.SiLU()
78
+
79
+ def forward(self, x: Tensor) -> Tensor:
80
+ x1, x2 = x.chunk(2, dim=-1)
81
+ return self.gate_fn(x1) * x2
82
+
83
+
84
+ class Modulation(nn.Module):
85
+ def __init__(self, dim: int, double: bool, disable_bias: bool = False):
86
+ super().__init__()
87
+ self.is_double = double
88
+ self.multiplier = 6 if double else 3
89
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=not disable_bias)
90
+
91
+ def forward(self, vec: torch.Tensor):
92
+ out = self.lin(nn.functional.silu(vec))
93
+ if out.ndim == 2:
94
+ out = out[:, None, :]
95
+ out = out.chunk(self.multiplier, dim=-1)
96
+ return out[:3], out[3:] if self.is_double else None
97
+
98
+
99
+ class LastLayer(nn.Module):
100
+ def __init__(
101
+ self,
102
+ hidden_size: int,
103
+ out_channels: int,
104
+ ):
105
+ super().__init__()
106
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
107
+ self.linear = nn.Linear(hidden_size, out_channels, bias=False)
108
+ self.adaLN_modulation = nn.Sequential(
109
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=False)
110
+ )
111
+
112
+ def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
113
+ mod = self.adaLN_modulation(vec)
114
+ shift, scale = mod.chunk(2, dim=-1)
115
+ if shift.ndim == 2:
116
+ shift = shift[:, None, :]
117
+ scale = scale[:, None, :]
118
+ x = (1 + scale) * self.norm_final(x) + shift
119
+ x = self.linear(x)
120
+ return x
121
+
122
+
123
+ class SingleStreamBlock(nn.Module):
124
+ def __init__(
125
+ self,
126
+ hidden_size: int,
127
+ num_heads: int,
128
+ mlp_ratio: float = 4.0,
129
+ ):
130
+ super().__init__()
131
+
132
+ self.hidden_dim = hidden_size
133
+ self.num_heads = num_heads
134
+ head_dim = hidden_size // num_heads
135
+ self.scale = head_dim**-0.5
136
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
137
+ self.mlp_mult_factor = 2
138
+
139
+ self.linear1 = nn.Linear(
140
+ hidden_size,
141
+ hidden_size * 3 + self.mlp_hidden_dim * self.mlp_mult_factor,
142
+ bias=False,
143
+ )
144
+
145
+ self.linear2 = nn.Linear(
146
+ hidden_size + self.mlp_hidden_dim, hidden_size, bias=False
147
+ )
148
+
149
+ self.norm = QKNorm(head_dim)
150
+
151
+ self.hidden_size = hidden_size
152
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
153
+
154
+ self.mlp_act = SiLUActivation()
155
+
156
+ def forward(
157
+ self,
158
+ x: Tensor,
159
+ pe: Tensor,
160
+ mod: tuple[Tensor, Tensor],
161
+ ) -> Tensor:
162
+ mod_shift, mod_scale, mod_gate = mod
163
+ x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift
164
+
165
+ qkv, mlp = torch.split(
166
+ self.linear1(x_mod),
167
+ [3 * self.hidden_size, self.mlp_hidden_dim * self.mlp_mult_factor],
168
+ dim=-1,
169
+ )
170
+
171
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
172
+ q, k = self.norm(q, k, v)
173
+
174
+ attn = attention(q, k, v, pe)
175
+
176
+ # compute activation in mlp stream, cat again and run second linear layer
177
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
178
+ return x + mod_gate * output
179
+
180
+
181
+ class DoubleStreamBlock(nn.Module):
182
+ def __init__(
183
+ self,
184
+ hidden_size: int,
185
+ num_heads: int,
186
+ mlp_ratio: float,
187
+ ):
188
+ super().__init__()
189
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
190
+ self.num_heads = num_heads
191
+ assert hidden_size % num_heads == 0, (
192
+ f"{hidden_size=} must be divisible by {num_heads=}"
193
+ )
194
+
195
+ self.hidden_size = hidden_size
196
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
197
+ self.mlp_mult_factor = 2
198
+
199
+ self.img_attn = SelfAttention(
200
+ dim=hidden_size,
201
+ num_heads=num_heads,
202
+ )
203
+
204
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
205
+ self.img_mlp = nn.Sequential(
206
+ nn.Linear(hidden_size, mlp_hidden_dim * self.mlp_mult_factor, bias=False),
207
+ SiLUActivation(),
208
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
209
+ )
210
+
211
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
212
+ self.txt_attn = SelfAttention(
213
+ dim=hidden_size,
214
+ num_heads=num_heads,
215
+ )
216
+
217
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
218
+ self.txt_mlp = nn.Sequential(
219
+ nn.Linear(
220
+ hidden_size,
221
+ mlp_hidden_dim * self.mlp_mult_factor,
222
+ bias=False,
223
+ ),
224
+ SiLUActivation(),
225
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
226
+ )
227
+
228
+ def forward(
229
+ self,
230
+ img: Tensor,
231
+ txt: Tensor,
232
+ pe: Tensor,
233
+ pe_ctx: Tensor,
234
+ mod_img: tuple[Tensor, Tensor],
235
+ mod_txt: tuple[Tensor, Tensor],
236
+ ) -> tuple[Tensor, Tensor]:
237
+ img_mod1, img_mod2 = mod_img
238
+ txt_mod1, txt_mod2 = mod_txt
239
+
240
+ img_mod1_shift, img_mod1_scale, img_mod1_gate = img_mod1
241
+ img_mod2_shift, img_mod2_scale, img_mod2_gate = img_mod2
242
+ txt_mod1_shift, txt_mod1_scale, txt_mod1_gate = txt_mod1
243
+ txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = txt_mod2
244
+
245
+ # prepare image for attention
246
+ img_modulated = self.img_norm1(img)
247
+ img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift
248
+
249
+ img_qkv = self.img_attn.qkv(img_modulated)
250
+ img_q, img_k, img_v = rearrange(
251
+ img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
252
+ )
253
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
254
+
255
+ # prepare txt for attention
256
+ txt_modulated = self.txt_norm1(txt)
257
+ txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift
258
+
259
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
260
+ txt_q, txt_k, txt_v = rearrange(
261
+ txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
262
+ )
263
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
264
+
265
+ q = torch.cat((txt_q, img_q), dim=2)
266
+ k = torch.cat((txt_k, img_k), dim=2)
267
+ v = torch.cat((txt_v, img_v), dim=2)
268
+
269
+ pe = torch.cat((pe_ctx, pe), dim=2)
270
+ attn = attention(q, k, v, pe)
271
+ txt_attn, img_attn = (
272
+ attn[:, : txt_q.shape[2]],
273
+ attn[:, txt_q.shape[2] :],
274
+ )
275
+
276
+ # calculate the img blocks
277
+ img = img + img_mod1_gate * self.img_attn.proj(img_attn)
278
+ img = img + img_mod2_gate * self.img_mlp(
279
+ (1 + img_mod2_scale) * (self.img_norm2(img)) + img_mod2_shift
280
+ )
281
+
282
+ # calculate the txt blocks
283
+ txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn)
284
+ txt = txt + txt_mod2_gate * self.txt_mlp(
285
+ (1 + txt_mod2_scale) * (self.txt_norm2(txt)) + txt_mod2_shift
286
+ )
287
+ return img, txt
288
+
289
+
290
+ class MLPEmbedder(nn.Module):
291
+ def __init__(self, in_dim: int, hidden_dim: int, disable_bias: bool = False):
292
+ super().__init__()
293
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=not disable_bias)
294
+ self.silu = nn.SiLU()
295
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=not disable_bias)
296
+
297
+ def forward(self, x: Tensor) -> Tensor:
298
+ return self.out_layer(self.silu(self.in_layer(x)))
299
+
300
+
301
+ class EmbedND(nn.Module):
302
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
303
+ super().__init__()
304
+ self.dim = dim
305
+ self.theta = theta
306
+ self.axes_dim = axes_dim
307
+
308
+ def forward(self, ids: Tensor) -> Tensor:
309
+ emb = torch.cat(
310
+ [
311
+ rope(ids[..., i], self.axes_dim[i], self.theta)
312
+ for i in range(len(self.axes_dim))
313
+ ],
314
+ dim=-3,
315
+ )
316
+
317
+ return emb.unsqueeze(1)
318
+
319
+
320
+ class RMSNorm(torch.nn.Module):
321
+ def __init__(self, dim: int):
322
+ super().__init__()
323
+ self.scale = nn.Parameter(torch.ones(dim))
324
+
325
+ def forward(self, x: Tensor):
326
+ x_dtype = x.dtype
327
+ x = x.float()
328
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
329
+ return (x * rrms).to(dtype=x_dtype) * self.scale
330
+
331
+
332
+ class QKNorm(torch.nn.Module):
333
+ def __init__(self, dim: int):
334
+ super().__init__()
335
+ self.query_norm = RMSNorm(dim)
336
+ self.key_norm = RMSNorm(dim)
337
+
338
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
339
+ q = self.query_norm(q)
340
+ k = self.key_norm(k)
341
+ return q.to(v), k.to(v)
nisaba_relief/flux/model.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from .layers import (
7
+ DoubleStreamBlock,
8
+ EmbedND,
9
+ LastLayer,
10
+ MLPEmbedder,
11
+ Modulation,
12
+ SingleStreamBlock,
13
+ timestep_embedding,
14
+ )
15
+
16
+
17
+ @dataclass
18
+ class Klein4BParams:
19
+ in_channels: int = 128
20
+ context_in_dim: int = 7680
21
+ hidden_size: int = 3072
22
+ num_heads: int = 24
23
+ depth: int = 5
24
+ depth_single_blocks: int = 20
25
+ axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
26
+ theta: int = 2000
27
+ mlp_ratio: float = 3.0
28
+
29
+
30
+ class Flux2(nn.Module):
31
+ def __init__(self, params: Klein4BParams = Klein4BParams()):
32
+ super().__init__()
33
+
34
+ self.in_channels = params.in_channels
35
+ self.out_channels = params.in_channels
36
+ if params.hidden_size % params.num_heads != 0:
37
+ raise ValueError(
38
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
39
+ )
40
+ pe_dim = params.hidden_size // params.num_heads
41
+ if sum(params.axes_dim) != pe_dim:
42
+ raise ValueError(
43
+ f"Got {params.axes_dim} but expected positional dim {pe_dim}"
44
+ )
45
+ self.hidden_size = params.hidden_size
46
+ self.num_heads = params.num_heads
47
+ self.pe_embedder = EmbedND(
48
+ dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
49
+ )
50
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=False)
51
+ self.time_in = MLPEmbedder(
52
+ in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
53
+ )
54
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False)
55
+
56
+ self.double_blocks = nn.ModuleList(
57
+ [
58
+ DoubleStreamBlock(
59
+ self.hidden_size,
60
+ self.num_heads,
61
+ mlp_ratio=params.mlp_ratio,
62
+ )
63
+ for _ in range(params.depth)
64
+ ]
65
+ )
66
+
67
+ self.single_blocks = nn.ModuleList(
68
+ [
69
+ SingleStreamBlock(
70
+ self.hidden_size,
71
+ self.num_heads,
72
+ mlp_ratio=params.mlp_ratio,
73
+ )
74
+ for _ in range(params.depth_single_blocks)
75
+ ]
76
+ )
77
+
78
+ self.double_stream_modulation_img = Modulation(
79
+ self.hidden_size,
80
+ double=True,
81
+ disable_bias=True,
82
+ )
83
+ self.double_stream_modulation_txt = Modulation(
84
+ self.hidden_size,
85
+ double=True,
86
+ disable_bias=True,
87
+ )
88
+ self.single_stream_modulation = Modulation(
89
+ self.hidden_size, double=False, disable_bias=True
90
+ )
91
+
92
+ self.final_layer = LastLayer(
93
+ self.hidden_size,
94
+ self.out_channels,
95
+ )
96
+
97
+ def forward(
98
+ self,
99
+ x: Tensor,
100
+ x_ids: Tensor,
101
+ timesteps: Tensor,
102
+ ctx: Tensor,
103
+ ctx_ids: Tensor,
104
+ pe_x: Tensor | None = None,
105
+ pe_ctx: Tensor | None = None,
106
+ ) -> Tensor:
107
+ num_txt_tokens = ctx.shape[1]
108
+
109
+ timestep_emb = timestep_embedding(timesteps, 256)
110
+ vec = self.time_in(timestep_emb)
111
+
112
+ double_block_mod_img = self.double_stream_modulation_img(vec)
113
+ double_block_mod_txt = self.double_stream_modulation_txt(vec)
114
+ single_block_mod, _ = self.single_stream_modulation(vec)
115
+
116
+ img = self.img_in(x)
117
+ txt = self.txt_in(ctx)
118
+
119
+ if pe_x is None:
120
+ pe_x = self.pe_embedder(x_ids)
121
+ if pe_ctx is None:
122
+ pe_ctx = self.pe_embedder(ctx_ids)
123
+
124
+ for block in self.double_blocks:
125
+ img, txt = block(
126
+ img,
127
+ txt,
128
+ pe_x,
129
+ pe_ctx,
130
+ double_block_mod_img,
131
+ double_block_mod_txt,
132
+ )
133
+
134
+ img = torch.cat((txt, img), dim=1)
135
+ pe = torch.cat((pe_ctx, pe_x), dim=2)
136
+
137
+ for block in self.single_blocks:
138
+ img = block(
139
+ img,
140
+ pe,
141
+ single_block_mod,
142
+ )
143
+
144
+ img = img[:, num_txt_tokens:, ...]
145
+
146
+ img = self.final_layer(img, vec)
147
+ return img
nisaba_relief/flux/sampling.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor
6
+
7
+ from .model import Flux2
8
+
9
+
10
+ def prc_img_batch(x: Tensor) -> tuple[Tensor, Tensor]:
11
+ b, _, h, w = x.shape
12
+ x_ids = torch.cartesian_prod(
13
+ torch.arange(1),
14
+ torch.arange(h),
15
+ torch.arange(w),
16
+ torch.arange(1),
17
+ )
18
+ x_ids = x_ids.unsqueeze(0).expand(b, -1, -1)
19
+ x = rearrange(x, "b c h w -> b (h w) c")
20
+ return x, x_ids.to(x.device)
21
+
22
+
23
+ def generalized_time_snr_shift(t: Tensor, mu: float, sigma: float) -> Tensor:
24
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
25
+
26
+
27
+ def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
28
+ a1, b1 = 8.73809524e-05, 1.89833333
29
+ a2, b2 = 0.00016927, 0.45666666
30
+
31
+ if image_seq_len > 4300:
32
+ mu = a2 * image_seq_len + b2
33
+ return float(mu)
34
+
35
+ m_200 = a2 * image_seq_len + b2
36
+ m_10 = a1 * image_seq_len + b1
37
+
38
+ a = (m_200 - m_10) / 190.0
39
+ b = m_200 - 200.0 * a
40
+ mu = a * num_steps + b
41
+
42
+ return float(mu)
43
+
44
+
45
+ def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
46
+ mu = compute_empirical_mu(image_seq_len, num_steps)
47
+ timesteps = torch.linspace(1, 0, num_steps + 1)
48
+ timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
49
+ return timesteps.tolist()
50
+
51
+
52
+ def denoise(
53
+ model: Flux2,
54
+ img: Tensor,
55
+ img_ids: Tensor,
56
+ txt: Tensor,
57
+ txt_ids: Tensor,
58
+ timesteps: list[float],
59
+ img_cond_seq: Tensor | None = None,
60
+ img_cond_seq_ids: Tensor | None = None,
61
+ ) -> Tensor:
62
+ if img_cond_seq is not None:
63
+ assert img_cond_seq_ids is not None, (
64
+ "You need to provide either both or neither of the sequence conditioning"
65
+ )
66
+ combined_ids = torch.cat((img_ids, img_cond_seq_ids), dim=1)
67
+ else:
68
+ combined_ids = img_ids
69
+
70
+ # Pre-compute positional embeddings once (constant across all timesteps)
71
+ pe_x = model.pe_embedder(combined_ids)
72
+ pe_ctx = model.pe_embedder(txt_ids)
73
+
74
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
75
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
76
+ img_input = img
77
+ if img_cond_seq is not None:
78
+ img_input = torch.cat((img_input, img_cond_seq), dim=1)
79
+ pred = model(
80
+ x=img_input,
81
+ x_ids=combined_ids,
82
+ timesteps=t_vec,
83
+ ctx=txt,
84
+ ctx_ids=txt_ids,
85
+ pe_x=pe_x,
86
+ pe_ctx=pe_ctx,
87
+ )
88
+ pred = pred[:, : img.shape[1]]
89
+
90
+ img = img + (t_prev - t_curr) * pred
91
+
92
+ return img
nisaba_relief/image_utils.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pure image and tensor helper functions for NisabaRelief."""
2
+
3
+ import math
4
+
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+
10
+ from .constants import (
11
+ MAX_TILE,
12
+ MIN_TILE,
13
+ PATCH_SIZE,
14
+ TARGET_TILES_PER_SIDE,
15
+ TILE_OVERLAP_DIVISOR,
16
+ )
17
+
18
+ _to_tensor = transforms.ToTensor()
19
+
20
+
21
+ def round_to_patch(value: float) -> int:
22
+ """Round a pixel value to the nearest multiple of PATCH_SIZE (minimum PATCH_SIZE)."""
23
+ return max(PATCH_SIZE, PATCH_SIZE * round(value / PATCH_SIZE))
24
+
25
+
26
+ def ceil_to_patch(value: float) -> int:
27
+ """Ceil a pixel value to the next multiple of PATCH_SIZE (minimum PATCH_SIZE)."""
28
+ return max(PATCH_SIZE, PATCH_SIZE * math.ceil(value / PATCH_SIZE))
29
+
30
+
31
+ def compute_tile_size(max_side: int) -> int:
32
+ """Compute the optimal square tile side length for a given image maximum side."""
33
+ raw = ceil_to_patch(
34
+ max_side
35
+ * TILE_OVERLAP_DIVISOR
36
+ / (TARGET_TILES_PER_SIDE * (TILE_OVERLAP_DIVISOR - 1) + 1)
37
+ )
38
+ return max(min(raw, MAX_TILE), MIN_TILE)
39
+
40
+
41
+ def compute_tile_grid(
42
+ orig_w: int, orig_h: int, tile_size: int
43
+ ) -> tuple[int, int, int, int, int, int, int, int]:
44
+ """Compute tiled grid layout for an image.
45
+
46
+ Returns (n_cols, n_rows, padded_w, padded_h, pad_left, pad_top, overlap, stride).
47
+ """
48
+ overlap = tile_size // TILE_OVERLAP_DIVISOR
49
+ stride = tile_size - overlap
50
+ n_cols = max(1, math.ceil((orig_w - overlap) / stride))
51
+ n_rows = max(1, math.ceil((orig_h - overlap) / stride))
52
+ padded_w = tile_size + (n_cols - 1) * stride
53
+ padded_h = tile_size + (n_rows - 1) * stride
54
+ pad_left = (padded_w - orig_w) // 2
55
+ pad_top = (padded_h - orig_h) // 2
56
+ return n_cols, n_rows, padded_w, padded_h, pad_left, pad_top, overlap, stride
57
+
58
+
59
+ def image_to_tensor(image: Image.Image, device: str) -> torch.Tensor:
60
+ """Convert a PIL image to a normalised [-1, 1] float tensor on device."""
61
+ return (2 * _to_tensor(image) - 1).to(device)
62
+
63
+
64
+ def tensor_to_image(tensor: torch.Tensor) -> Image.Image:
65
+ """Convert a normalised [-1, 1] CHW tensor to a PIL RGB image."""
66
+ img = (tensor.clamp(-1, 1) + 1) / 2
67
+ img = img.permute(1, 2, 0).float().cpu().numpy()
68
+ return Image.fromarray((img * 255).astype("uint8"))
69
+
70
+
71
+ def pad_to_patch_multiple(image: Image.Image) -> Image.Image:
72
+ """Pad image width and height up to the next multiple of PATCH_SIZE."""
73
+ w, h = image.size
74
+ pad_w = (PATCH_SIZE - w % PATCH_SIZE) % PATCH_SIZE
75
+ pad_h = (PATCH_SIZE - h % PATCH_SIZE) % PATCH_SIZE
76
+ if pad_w == 0 and pad_h == 0:
77
+ return image
78
+ padded = Image.new("RGB", (w + pad_w, h + pad_h), (0, 0, 0))
79
+ padded.paste(image, (0, 0))
80
+ return padded
81
+
82
+
83
+ def postprocess(image: Image.Image, shadow_strength: float = 0.7) -> Image.Image:
84
+ """Apply adaptive gamma correction and convert to grayscale."""
85
+ arr = np.array(image, dtype=np.float32) / 255.0
86
+ gamma = 1.0 + shadow_strength * (1.0 - arr)
87
+ arr = np.power(arr, gamma)
88
+ return Image.fromarray((arr * 255).clip(0, 255).astype(np.uint8)).convert("L")
89
+
90
+
91
+ def draw_tile_indicator(
92
+ tensor: torch.Tensor,
93
+ full_w: int,
94
+ full_h: int,
95
+ tile_x: int,
96
+ tile_y: int,
97
+ tile_w: int,
98
+ tile_h: int,
99
+ line_width: int = 1,
100
+ ) -> torch.Tensor:
101
+ """Draw a red rectangle on a CHW tensor to mark the current tile position."""
102
+ C, H, W = tensor.shape
103
+ result = tensor.clone()
104
+
105
+ scale_x = W / full_w
106
+ scale_y = H / full_h
107
+
108
+ x1 = max(0, min(int(tile_x * scale_x), W - 1))
109
+ y1 = max(0, min(int(tile_y * scale_y), H - 1))
110
+ x2 = max(0, min(int((tile_x + tile_w) * scale_x), W))
111
+ y2 = max(0, min(int((tile_y + tile_h) * scale_y), H))
112
+
113
+ red = torch.tensor([1.0, -1.0, -1.0], device=tensor.device, dtype=tensor.dtype)
114
+
115
+ for dy in range(line_width):
116
+ if y1 + dy < H:
117
+ result[:, y1 + dy, x1:x2] = red.view(3, 1)
118
+ if 0 <= y2 - 1 - dy < H:
119
+ result[:, y2 - 1 - dy, x1:x2] = red.view(3, 1)
120
+
121
+ for dx in range(line_width):
122
+ if x1 + dx < W:
123
+ result[:, y1:y2, x1 + dx] = red.view(3, 1)
124
+ if 0 <= x2 - 1 - dx < W:
125
+ result[:, y1:y2, x2 - 1 - dx] = red.view(3, 1)
126
+
127
+ return result
128
+
129
+
130
+ def create_blend_weights(
131
+ tile_size: int,
132
+ overlap: int,
133
+ is_top: bool = False,
134
+ is_bottom: bool = False,
135
+ is_left: bool = False,
136
+ is_right: bool = False,
137
+ device: str = "cpu",
138
+ ) -> torch.Tensor:
139
+ """Create cosine blend weights for a tile, ramping down at non-edge overlaps."""
140
+ weights = torch.ones(tile_size, tile_size, device=device)
141
+
142
+ if overlap > 0:
143
+ ramp = 0.5 * (1 - torch.cos(torch.linspace(0, torch.pi, overlap, device=device)))
144
+ if not is_top:
145
+ weights[:overlap, :] *= ramp.view(-1, 1)
146
+ if not is_bottom:
147
+ weights[-overlap:, :] *= ramp.flip(0).view(-1, 1)
148
+ if not is_left:
149
+ weights[:, :overlap] *= ramp.view(1, -1)
150
+ if not is_right:
151
+ weights[:, -overlap:] *= ramp.flip(0).view(1, -1)
152
+
153
+ return weights
nisaba_relief/model.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NisabaRelief inference model.
3
+ Transforms cuneiform tablet images into MSII visualizations.
4
+ """
5
+
6
+ import contextlib
7
+ import logging
8
+ from os import PathLike
9
+ from pathlib import Path
10
+
11
+ import gc
12
+ import torch
13
+ from einops import rearrange
14
+ from PIL import Image
15
+ from tqdm.auto import tqdm
16
+ from safetensors.torch import load_file
17
+
18
+ from .constants import (
19
+ COND_SEQ_ID,
20
+ DECODE_BATCH_SIZE_DIVISOR,
21
+ GLOBAL_CTX_ID,
22
+ LATENT_CHANNELS,
23
+ MAX_ASPECT_RATIO,
24
+ MAX_GLOBAL_CONTEXT_SIZE,
25
+ MIN_IMAGE_DIMENSION,
26
+ VRAM_FIXED_OVERHEAD_MB,
27
+ VRAM_HEADROOM_MB,
28
+ VRAM_MB_PER_PIXEL,
29
+ MIN_BATCH_SIZE,
30
+ MAX_BATCH_SIZE,
31
+ )
32
+
33
+ from .image_utils import (
34
+ _to_tensor,
35
+ compute_tile_grid,
36
+ compute_tile_size,
37
+ create_blend_weights,
38
+ draw_tile_indicator,
39
+ image_to_tensor,
40
+ pad_to_patch_multiple,
41
+ postprocess,
42
+ round_to_patch,
43
+ tensor_to_image,
44
+ )
45
+ from .weights import WEIGHT_FILES, download_weights
46
+ from .flux.autoencoder import AutoEncoder
47
+ from .flux.model import Flux2
48
+ from .flux.sampling import (
49
+ denoise,
50
+ get_schedule,
51
+ prc_img_batch,
52
+ )
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ class NisabaRelief:
58
+ """Transform cuneiform tablet images into MSII relief visualizations.
59
+
60
+ Args:
61
+ device: Device to run inference on (default "cuda" if available).
62
+ num_steps: Number of denoising steps (default 2).
63
+ weights_dir: Optional local weights directory. If None, uses HuggingFace Hub (boatbomber/NisabaRelief).
64
+ batch_size: Batch size for processing tiles during inference.
65
+ None (default) = auto-select based on available GPU memory each call.
66
+ Set an explicit int to override.
67
+ seed: Optional random seed for reproducible noise generation (default None).
68
+ compile: Whether to use torch.compile for faster repeated inference (default True).
69
+ Requires Triton. Set to False if Triton is not installed or for one-off runs.
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
75
+ num_steps: int = 2,
76
+ weights_dir: PathLike | None = None,
77
+ batch_size: int | None = None,
78
+ seed: int | None = None,
79
+ compile: bool = True,
80
+ ):
81
+ if batch_size is not None and batch_size < 1:
82
+ raise ValueError(f"batch_size must be >= 1 or None, got {batch_size}")
83
+
84
+ self.num_steps = num_steps
85
+ self.device = device
86
+ self.batch_size = batch_size
87
+ self.seed = seed
88
+
89
+ if weights_dir is not None:
90
+ weights_dir = Path(weights_dir)
91
+ if not weights_dir.is_dir():
92
+ raise FileNotFoundError(f"weights_dir does not exist: {weights_dir}")
93
+
94
+ missing = [f for f in WEIGHT_FILES if not (weights_dir / f).exists()]
95
+ if missing:
96
+ raise FileNotFoundError(
97
+ f"Missing weight files in {weights_dir}: {missing}"
98
+ )
99
+
100
+ weight_paths = {f: str(weights_dir / f) for f in WEIGHT_FILES}
101
+ else:
102
+ logger.info("Downloading weights from HuggingFace Hub...")
103
+ weight_paths = download_weights()
104
+
105
+ # Load AutoEncoder
106
+ logger.debug("Loading AutoEncoder...")
107
+ with torch.device("meta"):
108
+ self.ae = AutoEncoder()
109
+ ae_weights = load_file(weight_paths["ae.safetensors"], device=device)
110
+ self.ae.load_state_dict(ae_weights, strict=True, assign=True)
111
+ self.ae.decoder = self.ae.decoder.to(self.dtype)
112
+ self.ae.eval()
113
+
114
+ # Load finetuned FLUX.2 model (merged weights)
115
+ logger.debug("Loading Transformer...")
116
+ with torch.device("meta"):
117
+ self.model = Flux2().to(self.dtype)
118
+ model_weights = load_file(weight_paths["model.safetensors"], device=device)
119
+ self.model.load_state_dict(model_weights, strict=True, assign=True)
120
+ self.model = self.model.to(device=device, dtype=self.dtype).eval()
121
+
122
+ # Load pre-computed text embedding
123
+ logger.debug("Loading text embedding...")
124
+ text_data = load_file(weight_paths["prompt_embedding.safetensors"], device=device)
125
+ self.prompt_embedding = text_data["prompt_embedding"].to(self.dtype)
126
+ self.ctx_ids = text_data["ctx_ids"]
127
+
128
+ if compile and self.device_type == "cuda":
129
+ try:
130
+ self.model = torch.compile(self.model)
131
+ self.ae = torch.compile(self.ae)
132
+ logger.debug(
133
+ "Model compile mode enabled. First run will be slow, but subsequent runs will be faster."
134
+ )
135
+ except Exception as e:
136
+ logger.error("Error compiling model: %s", e, exc_info=True)
137
+ logger.warning("Falling back to non-compiled model")
138
+
139
+ logger.info("NisabaRelief model loaded and ready")
140
+
141
+ @property
142
+ def device_type(self) -> str:
143
+ return self.device.split(":")[0]
144
+
145
+ @property
146
+ def dtype(self) -> torch.dtype:
147
+ if self.device_type == "cuda":
148
+ return torch.bfloat16
149
+ return torch.float32
150
+
151
+ def _pick_batch_size(self, tile_size: int) -> int:
152
+ """Estimate the largest safe batch size for a given tile size."""
153
+ if self.device_type != "cuda":
154
+ return MIN_BATCH_SIZE
155
+
156
+ gc.collect()
157
+ torch.cuda.empty_cache()
158
+
159
+ try:
160
+ device_idx = torch.device(self.device).index or 0
161
+ free_vram_mb = (
162
+ torch.cuda.get_device_properties(device_idx).total_memory
163
+ - torch.cuda.memory_allocated(device_idx)
164
+ ) / (1024**2)
165
+ available = free_vram_mb - VRAM_HEADROOM_MB
166
+ per_tile = VRAM_MB_PER_PIXEL * tile_size**2 + VRAM_FIXED_OVERHEAD_MB
167
+ batch = max(MIN_BATCH_SIZE, min(MAX_BATCH_SIZE, int(available / per_tile)))
168
+ logger.debug(
169
+ "Auto batch_size=%d (tile=%d, free=%.0f MB, per_tile=%.0f MB)",
170
+ batch,
171
+ tile_size,
172
+ free_vram_mb,
173
+ per_tile,
174
+ )
175
+ return batch
176
+ except Exception as e:
177
+ logger.error("Error picking batch size: %s", e, exc_info=True)
178
+ return MIN_BATCH_SIZE
179
+
180
+ def __repr__(self) -> str:
181
+ return (
182
+ f"NisabaRelief(device={self.device!r}, num_steps={self.num_steps}, "
183
+ f"batch_size={self.batch_size}, seed={self.seed!r})"
184
+ )
185
+
186
+ def process(
187
+ self,
188
+ image: PathLike | Image.Image,
189
+ show_pbar: bool | None = None,
190
+ ) -> Image.Image:
191
+ """Transform a cuneiform tablet image into MSII visualization.
192
+
193
+ Args:
194
+ image: Input image (path or PIL Image).
195
+ show_pbar: Whether to show a progress bar during tiled inference.
196
+ If None (default), shows the bar only when there are at least 2 batches to run.
197
+
198
+ Returns:
199
+ PIL Image (grayscale) with MSII visualization.
200
+ """
201
+ if isinstance(image, (str, PathLike)):
202
+ image = Image.open(image)
203
+ if image.mode != "RGB":
204
+ image = image.convert("RGB")
205
+
206
+ w, h = image.size
207
+ max_side = max(w, h)
208
+ min_side = min(w, h)
209
+ if min_side < MIN_IMAGE_DIMENSION:
210
+ raise ValueError(
211
+ f"Image too small: {min_side}px minimum side (need >= {MIN_IMAGE_DIMENSION}px)"
212
+ )
213
+ if max_side / min_side > MAX_ASPECT_RATIO:
214
+ raise ValueError(
215
+ f"Aspect ratio too extreme: {max_side / min_side:.1f}:1 (max {MAX_ASPECT_RATIO:.0f}:1)"
216
+ )
217
+
218
+ tile_size = compute_tile_size(max_side)
219
+ output_image = self._process_tiled(
220
+ image, tile_size=tile_size, show_pbar=show_pbar
221
+ )
222
+
223
+ return postprocess(output_image)
224
+
225
+ def _prepare_global_context_tensor(
226
+ self,
227
+ image: Image.Image,
228
+ max_size: int = MAX_GLOBAL_CONTEXT_SIZE,
229
+ ) -> torch.Tensor:
230
+ w, h = image.size
231
+ scale = min(max_size / w, max_size / h)
232
+ new_w = round_to_patch(w * scale)
233
+ new_h = round_to_patch(h * scale)
234
+
235
+ resized = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
236
+ return image_to_tensor(resized, self.device)
237
+
238
+ def _encode_global_context_batch(
239
+ self,
240
+ img_tensors: list[torch.Tensor],
241
+ ) -> tuple[torch.Tensor, torch.Tensor]:
242
+ batch = torch.stack(img_tensors)
243
+ with torch.inference_mode():
244
+ global_latent = self.ae.encode(batch)
245
+ global_tokens, global_ids = prc_img_batch(global_latent)
246
+ global_ids[..., 0] = GLOBAL_CTX_ID
247
+ return global_tokens.to(self.dtype), global_ids
248
+
249
+ def _process_tile_batch(
250
+ self,
251
+ tiles: list[Image.Image],
252
+ global_ctx_tokens: torch.Tensor,
253
+ global_ctx_ids: torch.Tensor,
254
+ tile_index_offset: int = 0,
255
+ ) -> list[Image.Image]:
256
+ b = len(tiles)
257
+ original_sizes = [tile.size for tile in tiles]
258
+
259
+ padded_tiles = [pad_to_patch_multiple(tile) for tile in tiles]
260
+ img_tensors = torch.stack(
261
+ [image_to_tensor(tile, self.device) for tile in padded_tiles]
262
+ )
263
+
264
+ with torch.inference_mode():
265
+ input_latent = self.ae.encode(img_tensors)
266
+
267
+ input_tokens, input_ids = prc_img_batch(input_latent)
268
+ input_ids_cond = input_ids.clone()
269
+ input_ids_cond[..., 0] = COND_SEQ_ID
270
+
271
+ cond_tokens = torch.cat([input_tokens, global_ctx_tokens], dim=1)
272
+ cond_ids = torch.cat([input_ids_cond, global_ctx_ids], dim=1)
273
+
274
+ latent_h = input_latent.shape[2]
275
+ latent_w = input_latent.shape[3]
276
+
277
+ if self.seed is None:
278
+ noise = torch.randn(
279
+ b,
280
+ LATENT_CHANNELS,
281
+ latent_h,
282
+ latent_w,
283
+ device=self.device,
284
+ dtype=self.dtype,
285
+ )
286
+ else:
287
+ noise_list = []
288
+ for i in range(b):
289
+ tile_seed = self.seed ^ (tile_index_offset + i)
290
+ generator = torch.Generator(device=self.device).manual_seed(tile_seed)
291
+ noise_list.append(
292
+ torch.randn(
293
+ LATENT_CHANNELS,
294
+ latent_h,
295
+ latent_w,
296
+ device=self.device,
297
+ dtype=self.dtype,
298
+ generator=generator,
299
+ )
300
+ )
301
+ noise = torch.stack(noise_list)
302
+
303
+ noise_tokens, _ = prc_img_batch(noise)
304
+ noise_ids = input_ids
305
+
306
+ seq_len = noise_tokens.shape[1]
307
+ timesteps = get_schedule(self.num_steps, seq_len)
308
+
309
+ ctx = self.prompt_embedding.unsqueeze(0).expand(b, -1, -1)
310
+ ctx_ids = self.ctx_ids.unsqueeze(0).expand(b, -1, -1)
311
+
312
+ autocast_ctx = (
313
+ torch.autocast(device_type=self.device_type, dtype=self.dtype)
314
+ if self.device_type == "cuda"
315
+ else contextlib.nullcontext()
316
+ )
317
+ with autocast_ctx:
318
+ output_tokens = denoise(
319
+ model=self.model,
320
+ img=noise_tokens.to(self.dtype),
321
+ img_ids=noise_ids,
322
+ txt=ctx.to(self.dtype),
323
+ txt_ids=ctx_ids,
324
+ timesteps=timesteps,
325
+ img_cond_seq=cond_tokens.to(self.dtype),
326
+ img_cond_seq_ids=cond_ids,
327
+ )
328
+
329
+ output_latent = rearrange(
330
+ output_tokens,
331
+ "b (h w) c -> b c h w",
332
+ h=latent_h,
333
+ w=latent_w,
334
+ )
335
+
336
+ # Free tensors from encode/denoise phases before AE decode to
337
+ # avoid CUDA memory fragmentation (the decoder needs large
338
+ # full-resolution float32 allocations that differ in shape from
339
+ # the transformer's cached blocks).
340
+ del img_tensors, input_latent, input_tokens, input_ids
341
+ del input_ids_cond, cond_tokens, cond_ids
342
+ del noise, noise_tokens, noise_ids
343
+ del output_tokens, ctx, ctx_ids
344
+ if self.device_type == "cuda":
345
+ torch.cuda.empty_cache()
346
+
347
+ # The AE decoder operates at full pixel resolution in float32,
348
+ # requiring much more VRAM per tile than the latent-space denoiser.
349
+ # Sub-batch to avoid overflowing into shared memory.
350
+ decode_bs = max(1, b // DECODE_BATCH_SIZE_DIVISOR)
351
+ if decode_bs >= b:
352
+ output_imgs = self.ae.decode(output_latent)
353
+ else:
354
+ chunks = []
355
+ for i in range(0, b, decode_bs):
356
+ chunks.append(self.ae.decode(output_latent[i : i + decode_bs]))
357
+ if self.device_type == "cuda":
358
+ torch.cuda.empty_cache()
359
+ output_imgs = torch.cat(chunks, dim=0)
360
+
361
+ results = []
362
+ for i, (orig_w, orig_h) in enumerate(original_sizes):
363
+ result = tensor_to_image(output_imgs[i])
364
+ if padded_tiles[i].size != (orig_w, orig_h):
365
+ result = result.crop((0, 0, orig_w, orig_h))
366
+ results.append(result)
367
+
368
+ return results
369
+
370
+ def _process_tiled(
371
+ self, image: Image.Image, tile_size: int, show_pbar: bool | None = None
372
+ ) -> Image.Image:
373
+ orig_w, orig_h = image.size
374
+ n_cols, n_rows, w, h, pad_left, pad_top, overlap, stride = compute_tile_grid(
375
+ orig_w, orig_h, tile_size
376
+ )
377
+
378
+ # Pad canvas so tiles land at exact stride positions with uniform overlap.
379
+ # Center the image so padding is distributed evenly on all sides.
380
+ padded = Image.new("RGB", (w, h), (0, 0, 0))
381
+ padded.paste(image, (pad_left, pad_top))
382
+ image = padded
383
+
384
+ global_base_tensor = self._prepare_global_context_tensor(image)
385
+
386
+ output = torch.zeros(3, h, w, device=self.device)
387
+ weights = torch.zeros(1, h, w, device=self.device)
388
+
389
+ tile_specs = [
390
+ (row, col, col * stride, row * stride)
391
+ for row in range(n_rows)
392
+ for col in range(n_cols)
393
+ ]
394
+
395
+ blend_cache: dict[tuple[bool, bool, bool, bool], torch.Tensor] = {}
396
+
397
+ batch_size = (
398
+ self.batch_size
399
+ if self.batch_size is not None
400
+ else self._pick_batch_size(tile_size)
401
+ )
402
+
403
+ if show_pbar is None:
404
+ show_pbar = len(tile_specs) >= 2 * batch_size
405
+
406
+ pbar = tqdm(
407
+ total=len(tile_specs),
408
+ desc=f"Processing {orig_w}x{orig_h} px image with {n_cols}x{n_rows} tiles ({tile_size} px each, {overlap} px overlap)",
409
+ unit="tile",
410
+ leave=False,
411
+ disable=not show_pbar,
412
+ )
413
+
414
+ for batch_start in range(0, len(tile_specs), batch_size):
415
+ batch_specs = tile_specs[batch_start : batch_start + batch_size]
416
+
417
+ pbar.clear()
418
+ logger.debug(
419
+ "Processing %d batched tiles: %s",
420
+ len(batch_specs),
421
+ " + ".join([f"({row},{col})" for row, col, _, _ in batch_specs]),
422
+ )
423
+ pbar.refresh()
424
+
425
+ ctx_tensors = [
426
+ draw_tile_indicator(global_base_tensor, w, h, x, y, tile_size, tile_size)
427
+ for (row, col, x, y) in batch_specs
428
+ ]
429
+ global_tokens, global_ids = self._encode_global_context_batch(ctx_tensors)
430
+
431
+ tiles = [
432
+ image.crop((x, y, x + tile_size, y + tile_size))
433
+ for (row, col, x, y) in batch_specs
434
+ ]
435
+
436
+ result_tiles = self._process_tile_batch(
437
+ tiles, global_tokens, global_ids, tile_index_offset=batch_start
438
+ )
439
+
440
+ for i, (row, col, x, y) in enumerate(batch_specs):
441
+ edge_key = (
442
+ row == 0,
443
+ row == n_rows - 1,
444
+ col == 0,
445
+ col == n_cols - 1,
446
+ )
447
+ if edge_key not in blend_cache:
448
+ blend_cache[edge_key] = create_blend_weights(
449
+ tile_size,
450
+ overlap,
451
+ is_top=edge_key[0],
452
+ is_bottom=edge_key[1],
453
+ is_left=edge_key[2],
454
+ is_right=edge_key[3],
455
+ device=self.device,
456
+ )
457
+ blend = blend_cache[edge_key]
458
+ result_tensor = _to_tensor(result_tiles[i]).to(self.device)
459
+ output[:, y : y + tile_size, x : x + tile_size] += result_tensor * blend
460
+ weights[:, y : y + tile_size, x : x + tile_size] += blend
461
+
462
+ if self.device_type == "cuda":
463
+ torch.cuda.empty_cache()
464
+
465
+ pbar.update(len(batch_specs))
466
+
467
+ pbar.close()
468
+
469
+ output = output / weights.clamp(min=1e-6)
470
+ output = output.permute(1, 2, 0).cpu().numpy()
471
+ output = (output * 255).clip(0, 255).astype("uint8")
472
+ return Image.fromarray(output).crop(
473
+ (pad_left, pad_top, pad_left + orig_w, pad_top + orig_h)
474
+ )
nisaba_relief/py.typed ADDED
File without changes
nisaba_relief/weights.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace Hub weight downloading for NisabaRelief."""
2
+
3
+ from huggingface_hub import hf_hub_download
4
+
5
+ HF_REPO_ID = "boatbomber/NisabaRelief"
6
+ WEIGHT_FILES = [
7
+ "ae.safetensors",
8
+ "model.safetensors",
9
+ "prompt_embedding.safetensors",
10
+ ]
11
+
12
+
13
+ def download_weights(repo_id: str = HF_REPO_ID) -> dict[str, str]:
14
+ """Download all weight files from HF Hub, returning {filename: local_path}."""
15
+ paths = {}
16
+ for filename in WEIGHT_FILES:
17
+ try:
18
+ paths[filename] = hf_hub_download(repo_id=repo_id, filename=filename)
19
+ except Exception as e:
20
+ raise RuntimeError(
21
+ f"Failed to download {filename} from {repo_id}: {e}"
22
+ ) from e
23
+ return paths
prompt_embedding.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc9b70751370039f6af10f5c803f9854354f7029f7d9521c6a4ee7c5ae28f999
3
+ size 7880872
pyproject.toml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [tool.hatch.build.targets.sdist]
6
+ exclude = ["*.safetensors", "assets/**", "data/**", "dev_scripts/**", "uv.lock"]
7
+
8
+ [tool.hatch.build.targets.wheel]
9
+ packages = ["nisaba_relief"]
10
+
11
+ [project]
12
+ name = "nisaba-relief"
13
+ version = "0.1.0"
14
+ description = "Transform cuneiform tablet photos into MSII relief visualizations"
15
+ readme = { file = "README.md", content-type = "text/markdown" }
16
+ license = "Apache-2.0"
17
+ requires-python = ">=3.10,<3.14"
18
+ authors = [{ name = "Zack Williams", email = "zack@boatbomber.com" }]
19
+ keywords = ["cuneiform", "msii", "relief", "ocr", "flux", "deep-learning"]
20
+ classifiers = [
21
+ "Development Status :: 4 - Beta",
22
+ "Intended Audience :: Science/Research",
23
+ "License :: OSI Approved :: Apache Software License",
24
+ "Programming Language :: Python :: 3",
25
+ "Programming Language :: Python :: 3.10",
26
+ "Programming Language :: Python :: 3.11",
27
+ "Programming Language :: Python :: 3.12",
28
+ "Programming Language :: Python :: 3.13",
29
+ "Topic :: Scientific/Engineering :: Image Processing",
30
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
31
+ ]
32
+ dependencies = [
33
+ "einops>=0.8.2",
34
+ "safetensors",
35
+ "numpy",
36
+ "pillow",
37
+ "huggingface-hub",
38
+ "tqdm",
39
+ ]
40
+
41
+ [project.urls]
42
+ Homepage = "https://huggingface.co/boatbomber/NisabaRelief"
43
+ Repository = "https://huggingface.co/boatbomber/NisabaRelief"
44
+ Issues = "https://huggingface.co/boatbomber/NisabaRelief/discussions"
45
+
46
+ [dependency-groups]
47
+ dev = [
48
+ "ruff>=0.15.4",
49
+ "scikit-image>=0.25.2",
50
+ "scipy>=1.15.3",
51
+ "image-similarity-measures[speedups]>=0.3.5",
52
+ "pytorch-msssim>=1.0.0",
53
+ "rich>=14.3.3",
54
+ "datasets>=4.6.1",
55
+ ]
56
+
57
+ [[tool.uv.index]]
58
+ name = "pytorch-cu128"
59
+ url = "https://download.pytorch.org/whl/cu128"
60
+ explicit = true
61
+
62
+ [tool.uv.sources]
63
+ torch = { index = "pytorch-cu128" }
64
+ torchvision = { index = "pytorch-cu128" }
65
+ triton = { index = "pytorch-cu128" }
66
+
67
+
68
+ [tool.ruff]
69
+ line-length = 90
uv.lock ADDED
The diff for this file is too large to render. See raw diff