Commit ·
3050f1b
0
Parent(s):
Initial release
Browse files- .gitattributes +36 -0
- .gitignore +176 -0
- LICENSE.md +176 -0
- README.md +447 -0
- ae.safetensors +3 -0
- assets/NisabaRelief-Logo.png +3 -0
- assets/example_diff_0.png +3 -0
- assets/example_diff_1.png +3 -0
- assets/example_diff_2.png +3 -0
- assets/example_diff_3.png +3 -0
- assets/example_diff_4.png +3 -0
- assets/example_input_0.png +3 -0
- assets/example_input_1.png +3 -0
- assets/example_input_2.png +3 -0
- assets/example_input_3.png +3 -0
- assets/example_input_4.png +3 -0
- assets/example_output_0.png +3 -0
- assets/example_output_1.png +3 -0
- assets/example_output_2.png +3 -0
- assets/example_output_3.png +3 -0
- assets/example_output_4.png +3 -0
- assets/example_truth_0.png +3 -0
- assets/example_truth_1.png +3 -0
- assets/example_truth_2.png +3 -0
- assets/example_truth_3.png +3 -0
- assets/example_truth_4.png +3 -0
- data/val_tablet_ids.json +90 -0
- dev_scripts/benchmark.py +149 -0
- dev_scripts/evaluation.py +162 -0
- dev_scripts/process_images.py +197 -0
- dev_scripts/util/load_val_dataset.py +24 -0
- dev_scripts/util/metrics.py +67 -0
- dev_scripts/util/psnr_hvsm.py +137 -0
- model.safetensors +3 -0
- nisaba_relief/__init__.py +7 -0
- nisaba_relief/constants.py +42 -0
- nisaba_relief/flux/__init__.py +0 -0
- nisaba_relief/flux/autoencoder.py +351 -0
- nisaba_relief/flux/layers.py +341 -0
- nisaba_relief/flux/model.py +147 -0
- nisaba_relief/flux/sampling.py +92 -0
- nisaba_relief/image_utils.py +153 -0
- nisaba_relief/model.py +474 -0
- nisaba_relief/py.typed +0 -0
- nisaba_relief/weights.py +23 -0
- prompt_embedding.safetensors +3 -0
- pyproject.toml +69 -0
- uv.lock +0 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data/*
|
| 2 |
+
!data/val_tablet_ids.json
|
| 3 |
+
|
| 4 |
+
# Byte-compiled / optimized / DLL files
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
*$py.class
|
| 8 |
+
|
| 9 |
+
# C extensions
|
| 10 |
+
*.so
|
| 11 |
+
|
| 12 |
+
# Distribution / packaging
|
| 13 |
+
.Python
|
| 14 |
+
build/
|
| 15 |
+
develop-eggs/
|
| 16 |
+
dist/
|
| 17 |
+
downloads/
|
| 18 |
+
eggs/
|
| 19 |
+
.eggs/
|
| 20 |
+
lib/
|
| 21 |
+
lib64/
|
| 22 |
+
parts/
|
| 23 |
+
sdist/
|
| 24 |
+
var/
|
| 25 |
+
wheels/
|
| 26 |
+
share/python-wheels/
|
| 27 |
+
*.egg-info/
|
| 28 |
+
.installed.cfg
|
| 29 |
+
*.egg
|
| 30 |
+
MANIFEST
|
| 31 |
+
|
| 32 |
+
# PyInstaller
|
| 33 |
+
# Usually these files are written by a python script from a template
|
| 34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 35 |
+
*.manifest
|
| 36 |
+
*.spec
|
| 37 |
+
|
| 38 |
+
# Installer logs
|
| 39 |
+
pip-log.txt
|
| 40 |
+
pip-delete-this-directory.txt
|
| 41 |
+
|
| 42 |
+
# Unit test / coverage reports
|
| 43 |
+
htmlcov/
|
| 44 |
+
.tox/
|
| 45 |
+
.nox/
|
| 46 |
+
.coverage
|
| 47 |
+
.coverage.*
|
| 48 |
+
.cache
|
| 49 |
+
nosetests.xml
|
| 50 |
+
coverage.xml
|
| 51 |
+
*.cover
|
| 52 |
+
*.py,cover
|
| 53 |
+
.hypothesis/
|
| 54 |
+
.pytest_cache/
|
| 55 |
+
cover/
|
| 56 |
+
|
| 57 |
+
# Translations
|
| 58 |
+
*.mo
|
| 59 |
+
*.pot
|
| 60 |
+
|
| 61 |
+
# Django stuff:
|
| 62 |
+
*.log
|
| 63 |
+
local_settings.py
|
| 64 |
+
db.sqlite3
|
| 65 |
+
db.sqlite3-journal
|
| 66 |
+
|
| 67 |
+
# Flask stuff:
|
| 68 |
+
instance/
|
| 69 |
+
.webassets-cache
|
| 70 |
+
|
| 71 |
+
# Scrapy stuff:
|
| 72 |
+
.scrapy
|
| 73 |
+
|
| 74 |
+
# Sphinx documentation
|
| 75 |
+
docs/_build/
|
| 76 |
+
|
| 77 |
+
# PyBuilder
|
| 78 |
+
.pybuilder/
|
| 79 |
+
target/
|
| 80 |
+
|
| 81 |
+
# Jupyter Notebook
|
| 82 |
+
.ipynb_checkpoints
|
| 83 |
+
|
| 84 |
+
**/__marimo__/
|
| 85 |
+
|
| 86 |
+
# IPython
|
| 87 |
+
profile_default/
|
| 88 |
+
ipython_config.py
|
| 89 |
+
|
| 90 |
+
# pyenv
|
| 91 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 92 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 93 |
+
# .python-version
|
| 94 |
+
|
| 95 |
+
# pipenv
|
| 96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 99 |
+
# install all needed dependencies.
|
| 100 |
+
#Pipfile.lock
|
| 101 |
+
|
| 102 |
+
# UV
|
| 103 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 104 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 105 |
+
# commonly ignored for libraries.
|
| 106 |
+
#uv.lock
|
| 107 |
+
|
| 108 |
+
# poetry
|
| 109 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 110 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 111 |
+
# commonly ignored for libraries.
|
| 112 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 113 |
+
#poetry.lock
|
| 114 |
+
|
| 115 |
+
# pdm
|
| 116 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 117 |
+
#pdm.lock
|
| 118 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 119 |
+
# in version control.
|
| 120 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 121 |
+
.pdm.toml
|
| 122 |
+
.pdm-python
|
| 123 |
+
.pdm-build/
|
| 124 |
+
|
| 125 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 126 |
+
__pypackages__/
|
| 127 |
+
|
| 128 |
+
# Celery stuff
|
| 129 |
+
celerybeat-schedule
|
| 130 |
+
celerybeat.pid
|
| 131 |
+
|
| 132 |
+
# SageMath parsed files
|
| 133 |
+
*.sage.py
|
| 134 |
+
|
| 135 |
+
# Environments
|
| 136 |
+
.env
|
| 137 |
+
.venv
|
| 138 |
+
env/
|
| 139 |
+
venv/
|
| 140 |
+
ENV/
|
| 141 |
+
env.bak/
|
| 142 |
+
venv.bak/
|
| 143 |
+
|
| 144 |
+
# Spyder project settings
|
| 145 |
+
.spyderproject
|
| 146 |
+
.spyproject
|
| 147 |
+
|
| 148 |
+
# Rope project settings
|
| 149 |
+
.ropeproject
|
| 150 |
+
|
| 151 |
+
# mkdocs documentation
|
| 152 |
+
/site
|
| 153 |
+
|
| 154 |
+
# mypy
|
| 155 |
+
.mypy_cache/
|
| 156 |
+
.dmypy.json
|
| 157 |
+
dmypy.json
|
| 158 |
+
|
| 159 |
+
# Pyre type checker
|
| 160 |
+
.pyre/
|
| 161 |
+
|
| 162 |
+
# pytype static type analyzer
|
| 163 |
+
.pytype/
|
| 164 |
+
|
| 165 |
+
# Cython debug symbols
|
| 166 |
+
cython_debug/
|
| 167 |
+
|
| 168 |
+
# PyCharm
|
| 169 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 170 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 171 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 172 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 173 |
+
#.idea/
|
| 174 |
+
|
| 175 |
+
# PyPI configuration file
|
| 176 |
+
.pypirc
|
LICENSE.md
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
README.md
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
pipeline_tag: image-to-image
|
| 4 |
+
base_model:
|
| 5 |
+
- black-forest-labs/FLUX.2-klein-base-4B
|
| 6 |
+
base_model_relation: finetune
|
| 7 |
+
datasets:
|
| 8 |
+
- boatbomber/CuneiformPhotosMSII
|
| 9 |
+
tags:
|
| 10 |
+
- image-to-image
|
| 11 |
+
- cuneiform
|
| 12 |
+
- geometry
|
| 13 |
+
- curvature
|
| 14 |
+
- multi-scale-integral-invariant
|
| 15 |
+
- msii
|
| 16 |
+
- Flux
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
<div align="center">
|
| 20 |
+
<h1 align="center">
|
| 21 |
+
NisabaRelief
|
| 22 |
+
</h1>
|
| 23 |
+
|
| 24 |
+
<img src="./assets/NisabaRelief-Logo.png" width="600"/>
|
| 25 |
+
</div>
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# NisabaRelief
|
| 29 |
+
|
| 30 |
+
NisabaRelief is a rectified flow transformer that converts ordinary photographs of cuneiform clay tablets into Multi-Scale Integral Invariant (MSII) curvature visualizations, without requiring 3D scanning hardware. Traditional MSII computation requires a high-resolution 3D scanner and GigaMesh postprocessing, averaging approximately 68 minutes per tablet. NisabaRelief processes a photograph in approximately 7 seconds.
|
| 31 |
+
|
| 32 |
+
Photographic images introduce a variety of noise sources: lighting direction, clay color, surface sheen, photography conditions, and surface staining. Any of these can cause wedge impressions to appear as shadows or shadows to appear as wedge impressions. MSII filtering discards this photometric variation, retaining only the geometric signal pressed into the clay. See [What is MSII?](#what-is-msii) for full technical details.
|
| 33 |
+
|
| 34 |
+
Built by fine-tuning [Flux.2 Klein Base 4B](https://huggingface.co/black-forest-labs/FLUX.2-klein-base-4B) on paired photo/MSII data generated from 3D scans in the [HeiCuBeDa](https://doi.org/10.11588/data/IE8CCN) corpus. Training data is made available here: [CuneiformPhotosMSII](https://huggingface.co/datasets/boatbomber/CuneiformPhotosMSII).
|
| 35 |
+
|
| 36 |
+
Named for Nisaba, the early Sumerian goddess of writing and scribes, NisabaRelief will serve as the preprocessing backbone of NabuOCR V2, a cuneiform OCR system currently in development.
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+
## Contents
|
| 41 |
+
|
| 42 |
+
- [NisabaRelief](#nisabarelief)
|
| 43 |
+
- [Contents](#contents)
|
| 44 |
+
- [Example Output](#example-output)
|
| 45 |
+
- [Quickstart](#quickstart)
|
| 46 |
+
- [Installation](#installation)
|
| 47 |
+
- [Usage](#usage)
|
| 48 |
+
- [Hardware Requirements](#hardware-requirements)
|
| 49 |
+
- [Performance](#performance)
|
| 50 |
+
- [What is MSII?](#what-is-msii)
|
| 51 |
+
- [Intended Use \& Limitations](#intended-use--limitations)
|
| 52 |
+
- [Evaluation](#evaluation)
|
| 53 |
+
- [Step Sweep](#step-sweep)
|
| 54 |
+
- [Training Data](#training-data)
|
| 55 |
+
- [Training Pipeline](#training-pipeline)
|
| 56 |
+
- [Key Technical Decision: Text-Encoder-Free Training](#key-technical-decision-text-encoder-free-training)
|
| 57 |
+
- [Key Technical Decision: VAE BatchNorm Domain Calibration](#key-technical-decision-vae-batchnorm-domain-calibration)
|
| 58 |
+
- [Stage 1: Pretrain (Domain Initialization)](#stage-1-pretrain-domain-initialization)
|
| 59 |
+
- [Stage 2: Train (Image-to-Image Adaptation)](#stage-2-train-image-to-image-adaptation)
|
| 60 |
+
- [Augmentation Pipeline](#augmentation-pipeline)
|
| 61 |
+
- [Loss](#loss)
|
| 62 |
+
- [Stage 3: Rectify (Trajectory Straightening)](#stage-3-rectify-trajectory-straightening)
|
| 63 |
+
- [Acknowledgements \& Citations](#acknowledgements--citations)
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
## Example Output
|
| 68 |
+
|
| 69 |
+
<table>
|
| 70 |
+
<thead>
|
| 71 |
+
<tr>
|
| 72 |
+
<th align="center" width="25%">Input</th>
|
| 73 |
+
<th align="center" width="25%">Output</th>
|
| 74 |
+
<th align="center" width="25%">Ground Truth</th>
|
| 75 |
+
<th align="center" width="25%">Difference</th>
|
| 76 |
+
</tr>
|
| 77 |
+
</thead>
|
| 78 |
+
<tbody>
|
| 79 |
+
|
| 80 |
+
<tr>
|
| 81 |
+
<td align="center"><img src="./assets/example_input_0.png" width="200"/></td>
|
| 82 |
+
<td align="center"><img src="./assets/example_output_0.png" width="200"/></td>
|
| 83 |
+
<td align="center"><img src="./assets/example_truth_0.png" width="200"/></td>
|
| 84 |
+
<td align="center"><img src="./assets/example_diff_0.png" width="200"/></td>
|
| 85 |
+
</tr>
|
| 86 |
+
<tr>
|
| 87 |
+
<td colspan="4" align="center"><b>Dice: 0.9652</b> · RMSE: 0.0775 · MS-SSIM: 0.9295 · PSNR: 22.22 dB · PSNR-HVS-M: 17.77 dB · SRE: 58.34 dB</td>
|
| 88 |
+
</tr>
|
| 89 |
+
|
| 90 |
+
<tr>
|
| 91 |
+
<td align="center"><img src="./assets/example_input_1.png" width="200"/></td>
|
| 92 |
+
<td align="center"><img src="./assets/example_output_1.png" width="200"/></td>
|
| 93 |
+
<td align="center"><img src="./assets/example_truth_1.png" width="200"/></td>
|
| 94 |
+
<td align="center"><img src="./assets/example_diff_1.png" width="200"/></td>
|
| 95 |
+
</tr>
|
| 96 |
+
<tr>
|
| 97 |
+
<td colspan="4" align="center"><b>Dice: 0.9555</b> · RMSE: 0.0788 · MS-SSIM: 0.9219 · PSNR: 22.07 dB · PSNR-HVS-M: 17.80 dB · SRE: 57.89 dB</td>
|
| 98 |
+
</tr>
|
| 99 |
+
|
| 100 |
+
<tr>
|
| 101 |
+
<td align="center"><img src="./assets/example_input_2.png" width="200"/></td>
|
| 102 |
+
<td align="center"><img src="./assets/example_output_2.png" width="200"/></td>
|
| 103 |
+
<td align="center"><img src="./assets/example_truth_2.png" width="200"/></td>
|
| 104 |
+
<td align="center"><img src="./assets/example_diff_2.png" width="200"/></td>
|
| 105 |
+
</tr>
|
| 106 |
+
<tr>
|
| 107 |
+
<td colspan="4" align="center"><b>Dice: 0.9630</b> · RMSE: 0.1108 · MS-SSIM: 0.8513 · PSNR: 19.11 dB · PSNR-HVS-M: 14.65 dB · SRE: 59.60 dB</td>
|
| 108 |
+
</tr>
|
| 109 |
+
|
| 110 |
+
<tr>
|
| 111 |
+
<td align="center"><img src="./assets/example_input_3.png" width="200"/></td>
|
| 112 |
+
<td align="center"><img src="./assets/example_output_3.png" width="200"/></td>
|
| 113 |
+
<td align="center"><img src="./assets/example_truth_3.png" width="200"/></td>
|
| 114 |
+
<td align="center"><img src="./assets/example_diff_3.png" width="200"/></td>
|
| 115 |
+
</tr>
|
| 116 |
+
<tr>
|
| 117 |
+
<td colspan="4" align="center"><b>Dice: 0.9713</b> · RMSE: 0.1035 · MS-SSIM: 0.8748 · PSNR: 19.70 dB · PSNR-HVS-M: 15.33 dB · SRE: 59.41 dB</td>
|
| 118 |
+
</tr>
|
| 119 |
+
|
| 120 |
+
<tr>
|
| 121 |
+
<td align="center"><img src="./assets/example_input_4.png" width="200"/></td>
|
| 122 |
+
<td align="center"><img src="./assets/example_output_4.png" width="200"/></td>
|
| 123 |
+
<td align="center"><img src="./assets/example_truth_4.png" width="200"/></td>
|
| 124 |
+
<td align="center"><img src="./assets/example_diff_4.png" width="200"/></td>
|
| 125 |
+
</tr>
|
| 126 |
+
<tr>
|
| 127 |
+
<td colspan="4" align="center"><b>Dice: 0.9564</b> · RMSE: 0.1054 · MS-SSIM: 0.9325 · PSNR: 19.55 dB · PSNR-HVS-M: 15.18 dB · SRE: 57.36 dB</td>
|
| 128 |
+
</tr>
|
| 129 |
+
|
| 130 |
+
</tbody>
|
| 131 |
+
</table>
|
| 132 |
+
|
| 133 |
+
---
|
| 134 |
+
|
| 135 |
+
## Quickstart
|
| 136 |
+
|
| 137 |
+
### Installation
|
| 138 |
+
|
| 139 |
+
**Prerequisites:**
|
| 140 |
+
|
| 141 |
+
- Python >= 3.10
|
| 142 |
+
- PyTorch with CUDA support. See https://pytorch.org/get-started/locally/.
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
# Install PyTorch (CUDA 12.8 example)
|
| 146 |
+
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128
|
| 147 |
+
|
| 148 |
+
# Windows only: install Triton (included automatically on Linux)
|
| 149 |
+
pip install triton-windows
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
**Install:**
|
| 153 |
+
|
| 154 |
+
```bash
|
| 155 |
+
pip install nisaba-relief
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
### Usage
|
| 159 |
+
|
| 160 |
+
```python
|
| 161 |
+
from nisaba_relief import NisabaRelief
|
| 162 |
+
|
| 163 |
+
model = NisabaRelief() # downloads weights from HF Hub automatically if needed
|
| 164 |
+
result = model.process("tablet.jpg")
|
| 165 |
+
result.save("tablet_msii.png")
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
**Constructor parameters:**
|
| 169 |
+
|
| 170 |
+
| Parameter | Default | Description |
|
| 171 |
+
|---|---|---|
|
| 172 |
+
| `device` | `"cuda"` if available | Device for inference |
|
| 173 |
+
| `num_steps` | `2` | Denoising steps |
|
| 174 |
+
| `weights_dir` | `None` | Local weights directory; if `None`, downloads from HF Hub or uses HF cache. Expected dir contents: `model.safetensors`, `ae.safetensors`, `prompt_embedding.safetensors` |
|
| 175 |
+
| `batch_size` | `None` | Batch size for processing tiles during inference. `None` (default) auto-selects the largest batch that fits in available VRAM. Set an explicit integer to override. Higher values are faster but see note below. |
|
| 176 |
+
| `seed` | `None` | Optional random seed for reproducible noise generation; if `None`, randomized |
|
| 177 |
+
| `compile` | `True` | Use `torch.compile` for faster repeated inference. Requires Triton. Set to `False` if Triton is not installed or for one-off runs. |
|
| 178 |
+
|
| 179 |
+
> **Reproducibility note:** Results are pixel-exact across repeated runs with the same `batch_size` and `seed`. However, changing `batch_size` between runs (including letting `None` auto-select a different value as available VRAM changes) will produce outputs that differ by up to ~1-2 pixel values (mean < 0.25) due to GPU floating-point non-determinism: CUDA selects different kernel implementations for different matrix shapes, which changes the floating-point accumulation order in the transformer attention and linear layers. The visual difference is imperceptible. If exact cross-run reproducibility is required, set a constant `batch_size`.
|
| 180 |
+
|
| 181 |
+
**`process()` parameters:**
|
| 182 |
+
|
| 183 |
+
| Parameter | Default | Description |
|
| 184 |
+
|---|---|---|
|
| 185 |
+
| `image` | required | File path (str/Path) or PIL Image |
|
| 186 |
+
| `show_pbar` | `None` | Progress bar visibility. `None` = auto (shows when >= 2 batches); `True`/`False` = always show/hide |
|
| 187 |
+
|
| 188 |
+
**Returns:** Grayscale `PIL.Image.Image` containing the MSII visualization.
|
| 189 |
+
|
| 190 |
+
**Input requirements:**
|
| 191 |
+
- Any PIL-readable format (PNG, JPG, WEBP, ...)
|
| 192 |
+
- Minimum 64 px on the short side; maximum aspect ratio 8:1
|
| 193 |
+
|
| 194 |
+
**Large image support:**
|
| 195 |
+
|
| 196 |
+
The model's native tile size is 1024 px. For images where either side exceeds 1024 px, the model automatically applies a sliding-window tiling pass. Tiles are blended with raised-cosine overlap weights to avoid seams. Each tile is also conditioned on a 128 px thumbnail of the full image with a red rectangle marking the tile's position, so the model retains global context while processing local detail.
|
| 197 |
+
|
| 198 |
+
There is no practical upper limit on input resolution, though the model may perform unexpectedly if the 1024 px tile is only a small fraction of the total image area.
|
| 199 |
+
|
| 200 |
+
---
|
| 201 |
+
|
| 202 |
+
## Hardware Requirements
|
| 203 |
+
|
| 204 |
+
While CPU inference is technically supported, it is too slow for practical use. A GPU with at least 9GB VRAM is required, with 12GB+ being recommended for better batching.
|
| 205 |
+
|
| 206 |
+
The 9 GB figure is substantially lower than the ~18 GB a standard FLUX.2-klein-base-4B deployment would require because the Qwen3-4B text encoder is never loaded at runtime. The conditioning prompt is pre-computed once and shipped as a 7.8 MB embedding file alongside the model weights.
|
| 207 |
+
|
| 208 |
+
---
|
| 209 |
+
|
| 210 |
+
## Performance
|
| 211 |
+
|
| 212 |
+
Traditional pipelines require a high-resolution 3D scanner and GigaMesh postprocessing: across the HeiCuBeDa corpus, this averages approximately 68 minutes per tablet, totalling over 2,200 hours for the full collection. NisabaRelief processes a tablet photograph in approximately 7 seconds, roughly 600x faster, with no scanning equipment required.
|
| 213 |
+
|
| 214 |
+
On a 1064x2048px photo, an RTX 3090 performs as follows:
|
| 215 |
+
|
| 216 |
+
| Run | Time |
|
| 217 |
+
|---|---|
|
| 218 |
+
| *compile warmup* | 11.61s |
|
| 219 |
+
| 1 | 7.05s |
|
| 220 |
+
| 2 | 7.07s |
|
| 221 |
+
| 3 | 7.09s |
|
| 222 |
+
| **Mean** | **7.07 ± 0.02s** |
|
| 223 |
+
|
| 224 |
+
---
|
| 225 |
+
|
| 226 |
+
## What is MSII?
|
| 227 |
+
|
| 228 |
+
Multi-Scale Integral Invariant (MSII) filtering is a geometry-processing algorithm that computes a robust curvature measure at every point on a 3D surface mesh. At each vertex, a sphere of radius *r* is centered on the surface and the algorithm measures how much of the sphere's volume falls below the surface (the "interior" volume). On a perfectly flat surface the ratio is exactly one half. Concave regions (such as the channel cut by a wedge impression) admit more of the sphere below the surface, pushing the ratio above 0.5. Convex regions such as ridges or the rounded back of a tablet expose less interior volume, pulling the ratio below 0.5. The signed difference from the flat baseline maps directly to the sign and magnitude of mean curvature at that point.
|
| 229 |
+
|
| 230 |
+
The multi-scale component repeats this computation at several sphere radii simultaneously. Small radii resolve fine wedge tips and hairline details; large radii capture broader curvature trends such as the tablet's overall convexity. The per-vertex measurements across all radii form a compact feature vector, and the final scalar output conventionally displayed as a grayscale image is the maximum component of that feature vector, capturing the strongest curvature response across all scales into a single value per pixel.
|
| 231 |
+
|
| 232 |
+
By convention the scalar is displayed with its sign inverted relative to the mean curvature: concave regions (ratio > 0.5) map to darker pixel values and convex regions (ratio < 0.5) to lighter ones. This places the flat-surface baseline at mid-gray and renders wedge channels as dark strokes against a bright background, similar to ink on paper.
|
| 233 |
+
|
| 234 |
+
Because the result depends only on the 3D shape of the surface rather than on lighting, clay color, or photograph angle, wedge impressions appear as consistent dark strokes against a bright background. This makes the surface structure considerably more legible to machine-vision OCR systems than raw photographs.
|
| 235 |
+
|
| 236 |
+
---
|
| 237 |
+
|
| 238 |
+
## Intended Use & Limitations
|
| 239 |
+
|
| 240 |
+
Generating an MSII visualization of a tablet requires a high-resolution laser scanner and substantial per-vertex computation. The vast majority of cuneiform tablets do not have a 3D scan available, and the computational cost is difficult to scale across large corpora.
|
| 241 |
+
|
| 242 |
+
To reduce this barrier and increase the availability of readable images, this model is trained to predict the MSII visualization directly from photographs.
|
| 243 |
+
|
| 244 |
+
**Intended use:**
|
| 245 |
+
- Preprocessing step for cuneiform OCR (specifically NabuOCR V2)
|
| 246 |
+
- Visualizing cuneiform tablet geometry for research and digital humanities
|
| 247 |
+
|
| 248 |
+
**Limitations:**
|
| 249 |
+
- Trained exclusively using [HeiCuBeDa](https://doi.org/10.11588/data/IE8CCN) 3D-scan data; performance on tablet types or scribal traditions not well-represented in that corpus is unknown
|
| 250 |
+
- Outputs are MSII approximations inferred from 2D photographs, not computed from true 3D geometry. They are suitable for OCR preprocessing but are not a substitute for physical scanning
|
| 251 |
+
- Not a general-purpose MSII model; behavior on non-cuneiform inputs is undefined and out of distribution
|
| 252 |
+
- Designed for photographs following [CDLI photography guidelines](https://cdli.earth/docs/images-acquisition-and-processing): high-resolution fatcross layout on a black background. The model may underperform on low-resolution or visually cluttered inputs such as older black-and-white excavation photographs where the background blends into the tablet
|
| 253 |
+
|
| 254 |
+
---
|
| 255 |
+
|
| 256 |
+
## Evaluation
|
| 257 |
+
|
| 258 |
+
The model was evaluated on 704 held-out validation pairs, all tablets whose geometry was never seen during training (see [Training Data](#training-data)). Each validation image was processed through the model and the output compared against the ground-truth MSII visualization computed from the 3D scan. Ran with `seed=42` and `batch_size=4`.
|
| 259 |
+
|
| 260 |
+
| Metric | Value |
|
| 261 |
+
|------------|------------------|
|
| 262 |
+
| Dice | 0.9639 ± 0.0138 |
|
| 263 |
+
| RMSE | 0.0877 ± 0.0208 |
|
| 264 |
+
| MS-SSIM | 0.9026 ± 0.0308 |
|
| 265 |
+
| PSNR | 21.36 ± 1.91 dB |
|
| 266 |
+
| PSNR-HVS-M | 16.98 ± 1.89 dB |
|
| 267 |
+
| SRE | 59.57 ± 1.92 dB |
|
| 268 |
+
|
| 269 |
+
**Dice** (Binarized Dice Coefficient) thresholds both images to isolate wedge stroke regions, then measures overlap between predicted and ground-truth strokes on a 0-1 scale. This is the most task-relevant metric, as it directly measures whether the model correctly localizes wedge impressions for downstream OCR.
|
| 270 |
+
|
| 271 |
+
**RMSE** (Root Mean Squared Error) measures average pixel-level reconstruction error; lower is better.
|
| 272 |
+
|
| 273 |
+
**MS-SSIM** (Multi-Scale Structural Similarity Index) measures perceptual image similarity by comparing luminance, contrast, and local structure at multiple spatial scales simultaneously. Coarser scales capture global shape agreement; finer scales capture edge and texture detail. Scores range from 0 to 1, where 1 is a perfect match; higher is better.
|
| 274 |
+
|
| 275 |
+
**PSNR** (Peak Signal-to-Noise Ratio) expresses reconstruction fidelity in decibels relative to the maximum pixel value; higher is better.
|
| 276 |
+
|
| 277 |
+
**PSNR-HVS-M** (Peak Signal-to-Noise Ratio - Human Visual System and Masking) measures reconstruction fidelity in decibels relative to the maximum pixel value while taking into account Contrast Sensitivity Function (CSF) and between-coefficient contrast masking of DCT basis functions.
|
| 278 |
+
|
| 279 |
+
**SRE** (Signal-to-Reconstruction Error) ratio measures reconstruction fidelity in decibels based on signal energy vs. error energy; higher is better.
|
| 280 |
+
|
| 281 |
+
### Step Sweep
|
| 282 |
+
|
| 283 |
+
A sweep of step counts was run on a subset of 175 validation samples and found that 2 steps is ideal for this model, adding one corrective step over the already solid single-step result. The rectified flow field is extremely straight (straightness_ratio=0.9989, path_length_ratio=1.0011, velocity_std=0.1565). For near-perfectly straight ODE trajectories, a single Euler step is theoretically near-exact, and each additional step accumulates small model prediction errors faster than it reduces discretization error. Where throughput is the primary concern, one step is acceptable. Ran with `seed=42` and `batch_size=4`.
|
| 284 |
+
|
| 285 |
+
| Metric | Steps=1 | Steps=2 | Steps=4 | Steps=8 |
|
| 286 |
+
|------------|------------------|----------------------|------------------|------------------|
|
| 287 |
+
| Dice | 0.9582 ± 0.0153 | **0.9634** ± 0.0139 | 0.9612 ± 0.0142 | 0.9580 ± 0.0148 |
|
| 288 |
+
| RMSE | 0.0909 ± 0.0209 | **0.0859** ± 0.0212 | 0.0900 ± 0.0203 | 0.0949 ± 0.0197 |
|
| 289 |
+
| MS-SSIM | 0.8987 ± 0.0326 | **0.9081** ± 0.0310 | 0.9039 ± 0.0314 | 0.8959 ± 0.0326 |
|
| 290 |
+
| PSNR | 21.03 ± 1.83 dB | **21.56** ± 1.97 dB | 21.11 ± 1.84 dB | 20.63 ± 1.72 dB |
|
| 291 |
+
| PSNR-HVS-M | 16.65 ± 1.80 dB | **17.19** ± 1.96 dB | 16.70 ± 1.83 dB | 16.18 ± 1.70 dB |
|
| 292 |
+
| SRE | 58.81 ± 1.81 dB | **59.07** ± 1.87 dB | 58.85 ± 1.87 dB | 58.61 ± 1.86 dB |
|
| 293 |
+
|
| 294 |
+
---
|
| 295 |
+
|
| 296 |
+
## Training Data
|
| 297 |
+
|
| 298 |
+
Training uses the [CuneiformPhotosMSII](https://huggingface.co/datasets/boatbomber/CuneiformPhotosMSII) dataset: 13,928 paired image pairs generated from 1,741 tablets sourced from the HeiCuBeDa (Heidelberg Cuneiform Benchmark Dataset), a professional research collection of 3D-scanned clay tablets. Each tablet was rendered multiple times in Blender at up to 4096 px, producing synthetic photographs alongside their corresponding MSII curvature visualizations.
|
| 299 |
+
|
| 300 |
+
Each render variant randomizes which faces of the tablet are shown, camera focal length (80-150 mm), tablet rotation (±5° Euler XYZ), lighting position/color/intensity, and background (fabric, grunge, stone, or none). This diversity encourages the model to generalize across realistic shooting conditions rather than overfitting to a specific lighting or composition style.
|
| 301 |
+
|
| 302 |
+
The dataset was split tablet-wise: 13,224 pairs (~95% of tablets) for training and 704 pairs (~5% of tablets) held out for validation. Because the split is by tablet identity, the model never sees a validation tablet's geometry during training.
|
| 303 |
+
|
| 304 |
+
---
|
| 305 |
+
|
| 306 |
+
## Training Pipeline
|
| 307 |
+
|
| 308 |
+
Training proceeded in three sequential stages: Pretrain, Train, and Rectify. Each stage builds directly on the weights from the previous one.
|
| 309 |
+
|
| 310 |
+
### Key Technical Decision: Text-Encoder-Free Training
|
| 311 |
+
|
| 312 |
+
All three stages skip the Qwen3-4B text encoder entirely. Text embeddings are pre-computed once and cached to disk, reducing VRAM consumption from ~18 GB to ~9 GB without any loss in conditioning fidelity.
|
| 313 |
+
|
| 314 |
+
### Key Technical Decision: VAE BatchNorm Domain Calibration
|
| 315 |
+
|
| 316 |
+
The FLUX.2 VAE contains a BatchNorm layer whose running statistics (`running_mean` and `running_var` across 128 channels: 32 latent channels × 2×2 patch size) were originally computed on diverse natural images. Applying this encoder to cuneiform tablets and MSII renderings introduces a latent-space distribution shift that manifests as screen-door dithering artifacts in decoded outputs.
|
| 317 |
+
|
| 318 |
+
To correct this, the BatchNorm statistics were recalibrated on the target domain before training began. 3,000 CDLI cuneiform tablet photographs and 2,000 synthetic MSII visualizations (5,000 images total) were encoded through the frozen VAE encoder; running mean and variance were accumulated across 19,301,093 spatial samples using float64 accumulators for numerical stability. Images from both domains were interleaved to ensure balanced sampling. The calibrated statistics are baked directly into the `ae.safetensors` weights shipped with this model.
|
| 319 |
+
|
| 320 |
+
---
|
| 321 |
+
|
| 322 |
+
### Stage 1: Pretrain (Domain Initialization)
|
| 323 |
+
|
| 324 |
+
The pretrain stage adapts the base FLUX.2 model to the cuneiform domain before any image-to-image translation is attempted. It runs standard text-to-image flow-matching training on two sources of real cuneiform imagery:
|
| 325 |
+
|
| 326 |
+
- ~60% CDLI archive photographs: real museum photos of tablets, paired with per-image text embeddings generated from CDLI metadata (period, material, object type, provenience, genre, language). Eight prompt templates were used and varied randomly.
|
| 327 |
+
- ~40% synthetic MSII renders: MSII visualization images from the training set, paired with MSII-specific text embeddings emphasizing curvature, surface topology, and wedge impression terminology.
|
| 328 |
+
|
| 329 |
+
Each image has its own unique cached embedding rather than a shared prompt, preventing the model from memorizing specimen identifiers and encouraging generalization.
|
| 330 |
+
|
| 331 |
+
| Hyperparameter | Value |
|
| 332 |
+
|---|---|
|
| 333 |
+
| Steps | 75,000 |
|
| 334 |
+
| Learning rate | 2e-4 (cosine decay, 1k warmup) |
|
| 335 |
+
| Effective batch size | 2 (batch 1, grad accum 2) |
|
| 336 |
+
| LoRA rank | 256 |
|
| 337 |
+
| LoRA init | PiSSA (8-iteration fast SVD) |
|
| 338 |
+
| Optimizer | 8-bit Adam |
|
| 339 |
+
| Precision | bfloat16 autocast |
|
| 340 |
+
| Timestep sampling | Logit-normal (mean=0, std=1) |
|
| 341 |
+
| Gradient clipping | 1.0 |
|
| 342 |
+
|
| 343 |
+
Images are resized to fit within 1 megapixel and rounded to 128-pixel multiples. Light augmentations are applied (horizontal flip, ±5° rotation, minor color jitter). Validation generates text-conditioned images across four aspect ratios every 1,000 steps.
|
| 344 |
+
|
| 345 |
+
---
|
| 346 |
+
|
| 347 |
+
### Stage 2: Train (Image-to-Image Adaptation)
|
| 348 |
+
|
| 349 |
+
The main training stage fine-tunes the pretrained weights for the target task: translating cuneiform tablet photographs into MSII visualizations. This stage introduces two significant changes over standard FLUX.2 fine-tuning.
|
| 350 |
+
|
| 351 |
+
**Tile and global context conditioning**
|
| 352 |
+
|
| 353 |
+
Rather than processing full images, the model trains on dynamic tile crops (128-1024 px, depending on image resolution) while simultaneously receiving a downscaled 128 px thumbnail of the full image with a red rectangle marking the tile's location, providing both local detail and global context.
|
| 354 |
+
|
| 355 |
+
**Paired crop with geometric consistency**
|
| 356 |
+
|
| 357 |
+
The same crop coordinates and geometric transforms (flip, rotation, perspective distortion) are applied to both the input photograph and the target MSII image, ensuring the model always receives spatially aligned pairs.
|
| 358 |
+
|
| 359 |
+
#### Augmentation Pipeline
|
| 360 |
+
|
| 361 |
+
Augmentations are split into two categories applied in sequence:
|
| 362 |
+
|
| 363 |
+
Geometric (applied identically to input and target):
|
| 364 |
+
- Horizontal flip (50%), vertical flip (40%), rotation ±8° (50%), perspective distortion strength 0.02 (30%)
|
| 365 |
+
|
| 366 |
+
Domain adaptation (applied to input only, to simulate real photographic variation):
|
| 367 |
+
- Perlin noise illumination (20%), vignette (40%), directional lighting gradient (50%), dust particles (50%), Gaussian noise (80%), gamma correction (50%), contrast adjustment (50%), brightness shift (50%), hue/saturation shift (40%), Gaussian blur (30%), grayscale conversion (3%)
|
| 368 |
+
|
| 369 |
+
Spatially-dependent effects (Perlin noise, vignette, gradient) use crop coordinates so the tile and its global thumbnail receive matching effects.
|
| 370 |
+
|
| 371 |
+
#### Loss
|
| 372 |
+
|
| 373 |
+
Flow-matching loss with Min-SNR-γ weighting (γ=5.0) to down-weight noisy high-timestep predictions, plus a multi-scale latent gradient loss weighted at 0.25. The gradient loss computes spatial gradient differences between predicted and target latents at four downsampling scales, encouraging sharp edge structure in outputs.
|
| 374 |
+
|
| 375 |
+
| Hyperparameter | Value |
|
| 376 |
+
|---|---|
|
| 377 |
+
| Steps | 150,000 |
|
| 378 |
+
| Learning rate | 3e-4 (cosine decay to 6e-6, 1k warmup) |
|
| 379 |
+
| Effective batch size | 8 (batch 1, grad accum 8) |
|
| 380 |
+
| LoRA rank | 256, alpha √rank, RSLoRA |
|
| 381 |
+
| LoRA init | PiSSA (8-iteration fast SVD) |
|
| 382 |
+
| EMA decay | 0.999 (used for validation and final save) |
|
| 383 |
+
| Optimizer | 8-bit Adam |
|
| 384 |
+
| Gradient clipping | 0.8 (with spike detection: skip if >2.5× EMA norm) |
|
| 385 |
+
| Precision | bfloat16 autocast |
|
| 386 |
+
| Gradient loss weight | 0.25 |
|
| 387 |
+
| Min-SNR-γ | 5.0 |
|
| 388 |
+
| Timestep sampling | Logit-normal (mean=0, std=1) |
|
| 389 |
+
|
| 390 |
+
Validation runs every 2,000 steps, generating 8 sample images with 8 denoising steps.
|
| 391 |
+
|
| 392 |
+
---
|
| 393 |
+
|
| 394 |
+
### Stage 3: Rectify (Trajectory Straightening)
|
| 395 |
+
|
| 396 |
+
The rectify stage implements [Rectified Flow](https://arxiv.org/abs/2209.03003) to reduce the number of inference steps required at runtime.
|
| 397 |
+
|
| 398 |
+
Standard flow-matching trains on random (noise, real target) pairs, producing curved ODE trajectories that require 25-50 denoising steps to traverse accurately. Rectified training instead pairs each noise sample with the output the fully-trained model generates from that noise, creating straight-line trajectories that can be traversed in 1-4 steps without quality loss.
|
| 399 |
+
|
| 400 |
+
Before training, a one-time preprocessing pass runs the trained model over the training set. Each image is cropped deterministically (seeded RNG, same tile-sizing logic as training), then fully denoised with the trained weights to produce a (noise, generated_output) coupled pair saved to disk. This eliminates VAE encoding from the training loop, reducing VRAM further.
|
| 401 |
+
|
| 402 |
+
The loss trains the model to predict the velocity between a coupled (noise, generated) pair at a random interpolated timestep. A pseudo-Huber loss replaces the MSE used in earlier stages, providing better gradient stability when predictions are far from target.
|
| 403 |
+
|
| 404 |
+
| Hyperparameter | Value |
|
| 405 |
+
|---|---|
|
| 406 |
+
| Steps | 50,000 |
|
| 407 |
+
| Learning rate | 3e-6 (cosine decay, 500 warmup) |
|
| 408 |
+
| Effective batch size | 4 (batch 1, grad accum 4) |
|
| 409 |
+
| LoRA rank | 256 |
|
| 410 |
+
| LoRA init | Loaded from Stage 2 weights (warm-start) |
|
| 411 |
+
| Loss | Pseudo-Huber (c=0.001) |
|
| 412 |
+
| Optimizer | 8-bit Adam |
|
| 413 |
+
| Gradient clipping | 1.0 |
|
| 414 |
+
| Precision | bfloat16 autocast |
|
| 415 |
+
| Timestep sampling | Logit-normal (mean=0, std=1) |
|
| 416 |
+
|
| 417 |
+
Validation runs every 2,000 steps using real validation images (not coupled pairs), generating outputs with only 2 denoising steps to directly measure few-step inference quality.
|
| 418 |
+
|
| 419 |
+
The result is usable MSII visualizations in 1-2 denoising steps, compared to the 25-50 steps standard flow-matching requires.
|
| 420 |
+
|
| 421 |
+
---
|
| 422 |
+
|
| 423 |
+
## Acknowledgements & Citations
|
| 424 |
+
|
| 425 |
+
**3D Scan Data (HeiCuBeDa)**
|
| 426 |
+
|
| 427 |
+
3D scans used to generate the training dataset are from the Heidelberg Cuneiform Benchmark Dataset (HeiCuBeDa):
|
| 428 |
+
|
| 429 |
+
> Bogacz, B., Gertz, M., & Mara, H. (2015). *Character Proposals for Cuneiform Script Digitization*. Proceedings of the 15th International Conference on Frontiers in Handwriting Recognition (ICFHR). doi:[10.11588/data/IE8CCN](https://doi.org/10.11588/data/IE8CCN)
|
| 430 |
+
|
| 431 |
+
**Archive Photographs (CDLI)**
|
| 432 |
+
|
| 433 |
+
Real tablet photographs used in Stage 1 pretraining are sourced from the [Cuneiform Digital Library Initiative (CDLI)](https://cdli.mpiwg-berlin.mpg.de/).
|
| 434 |
+
|
| 435 |
+
**MSII Curvature (GigaMesh)**
|
| 436 |
+
|
| 437 |
+
MSII curvature values embedded in the HeiCuBeDa PLY files were computed using the [GigaMesh Software Framework](https://gigamesh.eu/).
|
| 438 |
+
|
| 439 |
+
**Rectified Flow**
|
| 440 |
+
|
| 441 |
+
Stage 3 (Rectify) implements the trajectory-straightening approach from:
|
| 442 |
+
|
| 443 |
+
> Liu, X., et al. (2022). *Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow*. arXiv:[2209.03003](https://arxiv.org/abs/2209.03003)
|
| 444 |
+
|
| 445 |
+
**Base Model (FLUX.2 Klein Base 4B)**
|
| 446 |
+
|
| 447 |
+
Fine-tuned from [FLUX.2-klein-base-4B](https://huggingface.co/black-forest-labs/FLUX.2-klein-base-4B) by Black Forest Labs.
|
ae.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:570cc44d0301b006a34b2604735cf296ef6083a95564b45042c1788eae246977
|
| 3 |
+
size 336211292
|
assets/NisabaRelief-Logo.png
ADDED
|
Git LFS Details
|
assets/example_diff_0.png
ADDED
|
Git LFS Details
|
assets/example_diff_1.png
ADDED
|
Git LFS Details
|
assets/example_diff_2.png
ADDED
|
Git LFS Details
|
assets/example_diff_3.png
ADDED
|
Git LFS Details
|
assets/example_diff_4.png
ADDED
|
Git LFS Details
|
assets/example_input_0.png
ADDED
|
Git LFS Details
|
assets/example_input_1.png
ADDED
|
Git LFS Details
|
assets/example_input_2.png
ADDED
|
Git LFS Details
|
assets/example_input_3.png
ADDED
|
Git LFS Details
|
assets/example_input_4.png
ADDED
|
Git LFS Details
|
assets/example_output_0.png
ADDED
|
Git LFS Details
|
assets/example_output_1.png
ADDED
|
Git LFS Details
|
assets/example_output_2.png
ADDED
|
Git LFS Details
|
assets/example_output_3.png
ADDED
|
Git LFS Details
|
assets/example_output_4.png
ADDED
|
Git LFS Details
|
assets/example_truth_0.png
ADDED
|
Git LFS Details
|
assets/example_truth_1.png
ADDED
|
Git LFS Details
|
assets/example_truth_2.png
ADDED
|
Git LFS Details
|
assets/example_truth_3.png
ADDED
|
Git LFS Details
|
assets/example_truth_4.png
ADDED
|
Git LFS Details
|
data/val_tablet_ids.json
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
"HS_1746",
|
| 3 |
+
"HS_1059",
|
| 4 |
+
"HS_1660",
|
| 5 |
+
"HS_2631",
|
| 6 |
+
"HS_2072",
|
| 7 |
+
"HS_890",
|
| 8 |
+
"HS_883",
|
| 9 |
+
"HS_0713",
|
| 10 |
+
"HS_919",
|
| 11 |
+
"HS_0459",
|
| 12 |
+
"HS_1327",
|
| 13 |
+
"HS_736",
|
| 14 |
+
"HS_1200",
|
| 15 |
+
"HS_294",
|
| 16 |
+
"HS_0205",
|
| 17 |
+
"HS_0362",
|
| 18 |
+
"HS_510",
|
| 19 |
+
"HS_1122",
|
| 20 |
+
"HS_2467",
|
| 21 |
+
"HS_1650",
|
| 22 |
+
"HS_2590",
|
| 23 |
+
"HS_2616",
|
| 24 |
+
"HS_1336",
|
| 25 |
+
"HS_2355",
|
| 26 |
+
"HS_0449",
|
| 27 |
+
"HS_1770",
|
| 28 |
+
"HS_0898",
|
| 29 |
+
"HS_2309",
|
| 30 |
+
"HS_2084",
|
| 31 |
+
"HS_566",
|
| 32 |
+
"HS_0199",
|
| 33 |
+
"HS_843",
|
| 34 |
+
"HS_1275",
|
| 35 |
+
"HS_2556",
|
| 36 |
+
"HS_1506",
|
| 37 |
+
"HS_1643",
|
| 38 |
+
"HS_0661",
|
| 39 |
+
"HS_1774",
|
| 40 |
+
"HS_0626",
|
| 41 |
+
"HS_933",
|
| 42 |
+
"HS_1485",
|
| 43 |
+
"HS_665",
|
| 44 |
+
"HS_1175",
|
| 45 |
+
"HS_1045",
|
| 46 |
+
"HS_901",
|
| 47 |
+
"HS_1494",
|
| 48 |
+
"HS_194a",
|
| 49 |
+
"HS_491",
|
| 50 |
+
"HS_1052",
|
| 51 |
+
"HS_841",
|
| 52 |
+
"HS_653",
|
| 53 |
+
"HS_0102",
|
| 54 |
+
"HS_848",
|
| 55 |
+
"HS_1304",
|
| 56 |
+
"HS_2503",
|
| 57 |
+
"HS_2061",
|
| 58 |
+
"HS_1186",
|
| 59 |
+
"HS_1944",
|
| 60 |
+
"HS_929",
|
| 61 |
+
"HS_501",
|
| 62 |
+
"HS_2673",
|
| 63 |
+
"HS_535",
|
| 64 |
+
"HS_1139",
|
| 65 |
+
"HS_2373",
|
| 66 |
+
"HS_0151",
|
| 67 |
+
"HS_2550",
|
| 68 |
+
"HS_2249",
|
| 69 |
+
"HS_1210",
|
| 70 |
+
"HS_1182",
|
| 71 |
+
"HS_0628",
|
| 72 |
+
"HS_0158b",
|
| 73 |
+
"HS_0164",
|
| 74 |
+
"HS_1949",
|
| 75 |
+
"HS_2511",
|
| 76 |
+
"HS_0570",
|
| 77 |
+
"HS_2337",
|
| 78 |
+
"HS_598",
|
| 79 |
+
"HS_435",
|
| 80 |
+
"HS_0717",
|
| 81 |
+
"HS_588",
|
| 82 |
+
"HS_1010",
|
| 83 |
+
"HS_1192",
|
| 84 |
+
"HS_1235",
|
| 85 |
+
"HS_1298",
|
| 86 |
+
"HS_600",
|
| 87 |
+
"HS_0147",
|
| 88 |
+
"HS_0749",
|
| 89 |
+
"HS_2641"
|
| 90 |
+
]
|
dev_scripts/benchmark.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Benchmark script for NisabaRelief inference pipeline."""
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import statistics
|
| 5 |
+
import time
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from rich.console import Console
|
| 12 |
+
from rich.progress import (
|
| 13 |
+
BarColumn,
|
| 14 |
+
MofNCompleteColumn,
|
| 15 |
+
Progress,
|
| 16 |
+
TextColumn,
|
| 17 |
+
TimeElapsedColumn,
|
| 18 |
+
)
|
| 19 |
+
from rich.table import Table
|
| 20 |
+
|
| 21 |
+
from nisaba_relief import NisabaRelief
|
| 22 |
+
from util.load_val_dataset import load_val_dataset
|
| 23 |
+
|
| 24 |
+
BENCHMARK_DIR = Path(__file__).parent.parent / "data" / "benchmark"
|
| 25 |
+
BASELINE = BENCHMARK_DIR / "benchmark_baseline.png"
|
| 26 |
+
WARMUP_RUNS = 2
|
| 27 |
+
BENCH_RUNS = 3
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def build_timing_table(timings: list[float], n_warmup: int) -> Table:
|
| 31 |
+
bench_timings = timings[n_warmup:]
|
| 32 |
+
mean = statistics.mean(bench_timings)
|
| 33 |
+
stdev = statistics.stdev(bench_timings) if len(bench_timings) > 1 else 0.0
|
| 34 |
+
table = Table(title="Inference Timings")
|
| 35 |
+
table.add_column("Run", justify="right")
|
| 36 |
+
table.add_column("Time", justify="right")
|
| 37 |
+
for i, t in enumerate(timings, 1):
|
| 38 |
+
label = f"[dim]{i} (warmup)[/dim]" if i <= n_warmup else str(i - n_warmup)
|
| 39 |
+
time_str = f"[dim]{t:.2f}s[/dim]" if i <= n_warmup else f"{t:.2f}s"
|
| 40 |
+
table.add_row(label, time_str)
|
| 41 |
+
table.add_section()
|
| 42 |
+
table.add_row("[bold]Mean[/bold]", f"[bold]{mean:.2f} ± {stdev:.2f}s[/bold]")
|
| 43 |
+
return table
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def build_diff_table(flat: np.ndarray, max_diff: int) -> Table:
|
| 47 |
+
percentile_vals = np.percentile(flat, [50, 90, 95, 96, 97, 98, 99])
|
| 48 |
+
p98 = percentile_vals[5]
|
| 49 |
+
status = "PASS" if p98 <= 1 else "FAIL"
|
| 50 |
+
status_style = "green" if status == "PASS" else "red"
|
| 51 |
+
table = Table(
|
| 52 |
+
title=f"Pixel Diff vs Baseline — [{status_style}]{status}[/{status_style}]"
|
| 53 |
+
)
|
| 54 |
+
table.add_column("Stat", style="bold")
|
| 55 |
+
table.add_column("Value", justify="right")
|
| 56 |
+
table.add_row("Mean", f"{flat.mean():.4f}")
|
| 57 |
+
for label, val in zip(
|
| 58 |
+
["p50", "p90", "p95", "p96", "p97", "p98", "p99"], percentile_vals
|
| 59 |
+
):
|
| 60 |
+
table.add_row(label, f"{val:.0f}")
|
| 61 |
+
table.add_row("Max", str(max_diff))
|
| 62 |
+
return table
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def main():
|
| 66 |
+
parser = argparse.ArgumentParser(
|
| 67 |
+
description="Benchmark NisabaRelief inference pipeline"
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--weights-dir",
|
| 71 |
+
default=".",
|
| 72 |
+
metavar="PATH",
|
| 73 |
+
help="path to weights directory (default: .)",
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--device",
|
| 77 |
+
default=None,
|
| 78 |
+
metavar="DEVICE",
|
| 79 |
+
help="device to run inference on, e.g. cuda, cpu (default: cuda if available, else cpu)",
|
| 80 |
+
)
|
| 81 |
+
args = parser.parse_args()
|
| 82 |
+
|
| 83 |
+
console = Console()
|
| 84 |
+
rows = load_val_dataset()
|
| 85 |
+
test_image = rows[0]["photo"]
|
| 86 |
+
max_dim = max(test_image.size)
|
| 87 |
+
if max_dim > 2048:
|
| 88 |
+
scale = 2048 / max_dim
|
| 89 |
+
new_size = (round(test_image.width * scale), round(test_image.height * scale))
|
| 90 |
+
test_image = test_image.resize(new_size, Image.LANCZOS)
|
| 91 |
+
console.print(f"Input size: [cyan]{test_image.width}x{test_image.height}[/cyan]")
|
| 92 |
+
|
| 93 |
+
model_kwargs = dict(seed=42, weights_dir=Path(args.weights_dir))
|
| 94 |
+
if args.device is not None:
|
| 95 |
+
model_kwargs["device"] = args.device
|
| 96 |
+
model = NisabaRelief(**model_kwargs)
|
| 97 |
+
|
| 98 |
+
timings = []
|
| 99 |
+
output = None
|
| 100 |
+
total_runs = WARMUP_RUNS + BENCH_RUNS
|
| 101 |
+
progress = Progress(
|
| 102 |
+
TextColumn("[progress.description]{task.description}"),
|
| 103 |
+
BarColumn(),
|
| 104 |
+
MofNCompleteColumn(),
|
| 105 |
+
TimeElapsedColumn(),
|
| 106 |
+
)
|
| 107 |
+
with progress:
|
| 108 |
+
task = progress.add_task("Benchmarking", total=total_runs)
|
| 109 |
+
for i in range(total_runs):
|
| 110 |
+
t0 = time.perf_counter()
|
| 111 |
+
result = model.process(test_image, show_pbar=False)
|
| 112 |
+
timings.append(time.perf_counter() - t0)
|
| 113 |
+
progress.advance(task)
|
| 114 |
+
if i == WARMUP_RUNS:
|
| 115 |
+
output = result
|
| 116 |
+
|
| 117 |
+
console.print(build_timing_table(timings, WARMUP_RUNS))
|
| 118 |
+
|
| 119 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 120 |
+
run_path = BENCHMARK_DIR / f"benchmark_{timestamp}.png"
|
| 121 |
+
run_path.parent.mkdir(parents=True, exist_ok=True)
|
| 122 |
+
output.save(run_path)
|
| 123 |
+
console.print(f"Run image saved to [cyan]{run_path}[/cyan]")
|
| 124 |
+
|
| 125 |
+
output_arr = np.array(output)
|
| 126 |
+
|
| 127 |
+
if not BASELINE.exists():
|
| 128 |
+
output.save(BASELINE)
|
| 129 |
+
console.print(f"Baseline saved to [cyan]{BASELINE}[/cyan]")
|
| 130 |
+
else:
|
| 131 |
+
baseline_arr = np.array(Image.open(BASELINE))
|
| 132 |
+
diff = np.abs(output_arr.astype(int) - baseline_arr.astype(int))
|
| 133 |
+
flat = diff.flatten()
|
| 134 |
+
max_diff = int(flat.max())
|
| 135 |
+
console.print(build_diff_table(flat, max_diff))
|
| 136 |
+
|
| 137 |
+
if max_diff > 0:
|
| 138 |
+
diff_img = Image.fromarray(
|
| 139 |
+
np.clip(diff * (255 // max_diff), 0, 255).astype("uint8")
|
| 140 |
+
)
|
| 141 |
+
diff_path = Path(f"benchmark_{timestamp}_diff.png")
|
| 142 |
+
diff_img.save(diff_path)
|
| 143 |
+
console.print(
|
| 144 |
+
f"Diff image saved to [cyan]{diff_path}[/cyan] (amplified {255 // max_diff}x)"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
main()
|
dev_scripts/evaluation.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluate NisabaRelief on the validation set, optionally sweeping over step counts.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python evaluation.py # full dataset, num_steps=2
|
| 6 |
+
python evaluation.py --sweep # subset, steps=[1,2,4,8]
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import time
|
| 11 |
+
from datetime import timedelta
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from rich.console import Console, Group
|
| 17 |
+
from rich.live import Live
|
| 18 |
+
from rich.progress import (
|
| 19 |
+
BarColumn,
|
| 20 |
+
MofNCompleteColumn,
|
| 21 |
+
Progress,
|
| 22 |
+
TextColumn,
|
| 23 |
+
TimeElapsedColumn,
|
| 24 |
+
)
|
| 25 |
+
from rich.table import Table
|
| 26 |
+
|
| 27 |
+
from nisaba_relief import NisabaRelief
|
| 28 |
+
from util.metrics import compute_metrics, METRIC_NAMES, LABELS
|
| 29 |
+
from util.load_val_dataset import load_val_dataset
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
SWEEP_STEPS = [1, 2, 4, 8]
|
| 33 |
+
DEFAULT_STEPS = 2
|
| 34 |
+
SWEEP_STRIDE = 4
|
| 35 |
+
SWEEP_MAX = 175
|
| 36 |
+
EVALS_DIR = Path(__file__).parent.parent / "data" / "evals"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _eta(n_done: int, n_total: int, elapsed: float) -> str:
|
| 40 |
+
if n_done >= n_total > 0:
|
| 41 |
+
return "Done"
|
| 42 |
+
if n_done > 0:
|
| 43 |
+
return str(timedelta(seconds=int(elapsed / n_done * (n_total - n_done))))
|
| 44 |
+
return "?"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def build_table(
|
| 48 |
+
results: dict,
|
| 49 |
+
n_done: int = 0,
|
| 50 |
+
n_total: int = 0,
|
| 51 |
+
elapsed: float = 0.0,
|
| 52 |
+
) -> Table:
|
| 53 |
+
eta = _eta(n_done, n_total, elapsed)
|
| 54 |
+
steps = list(results.keys())
|
| 55 |
+
table = Table(title=f"Results — ETA: {eta}")
|
| 56 |
+
table.add_column("Metric", style="bold")
|
| 57 |
+
for s in steps:
|
| 58 |
+
table.add_column(f"Steps={s}", justify="right")
|
| 59 |
+
for name in METRIC_NAMES:
|
| 60 |
+
cells = []
|
| 61 |
+
for s in steps:
|
| 62 |
+
arr = np.array(results[s][name])
|
| 63 |
+
if len(arr) == 0:
|
| 64 |
+
cells.append("—")
|
| 65 |
+
elif name in ("psnr", "psnr_hvsm", "sre"):
|
| 66 |
+
cells.append(f"{arr.mean():.2f} ± {arr.std():.2f} dB")
|
| 67 |
+
else:
|
| 68 |
+
cells.append(f"{arr.mean():.4f} ± {arr.std():.4f}")
|
| 69 |
+
table.add_row(LABELS[name], *cells)
|
| 70 |
+
return table
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def load_grayscale(img: Image.Image) -> np.ndarray:
|
| 74 |
+
return np.array(img.convert("L"))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def main():
|
| 78 |
+
parser = argparse.ArgumentParser(description="Evaluate NisabaRelief model")
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--weights-dir",
|
| 81 |
+
default=".",
|
| 82 |
+
metavar="PATH",
|
| 83 |
+
help="path to weights directory (default: .)",
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--sweep",
|
| 87 |
+
action="store_true",
|
| 88 |
+
help="sweep over steps=[1,2,4,8] on a dataset subset",
|
| 89 |
+
)
|
| 90 |
+
args = parser.parse_args()
|
| 91 |
+
|
| 92 |
+
rows = load_val_dataset()
|
| 93 |
+
if args.sweep:
|
| 94 |
+
rows = rows.select(
|
| 95 |
+
range(0, min(len(rows), SWEEP_MAX * SWEEP_STRIDE), SWEEP_STRIDE)
|
| 96 |
+
)
|
| 97 |
+
steps_to_run = SWEEP_STEPS
|
| 98 |
+
else:
|
| 99 |
+
steps_to_run = [DEFAULT_STEPS]
|
| 100 |
+
results = {s: {m: [] for m in METRIC_NAMES} for s in steps_to_run}
|
| 101 |
+
|
| 102 |
+
model = NisabaRelief(seed=42, batch_size=4, weights_dir=Path(args.weights_dir))
|
| 103 |
+
|
| 104 |
+
progress = Progress(
|
| 105 |
+
TextColumn("[progress.description]{task.description}"),
|
| 106 |
+
BarColumn(),
|
| 107 |
+
MofNCompleteColumn(),
|
| 108 |
+
TimeElapsedColumn(),
|
| 109 |
+
TextColumn("[cyan]{task.fields[hs_number]}"),
|
| 110 |
+
)
|
| 111 |
+
task_desc = "Step Sweep" if args.sweep else "Evaluating"
|
| 112 |
+
task = progress.add_task(task_desc, total=len(rows), hs_number="")
|
| 113 |
+
|
| 114 |
+
start_time = time.monotonic()
|
| 115 |
+
with Live(
|
| 116 |
+
Group(progress, build_table(results)),
|
| 117 |
+
refresh_per_second=4,
|
| 118 |
+
transient=True,
|
| 119 |
+
) as live:
|
| 120 |
+
for n_done, row in enumerate(rows):
|
| 121 |
+
progress.update(task, hs_number=row["hs_number"])
|
| 122 |
+
gt = load_grayscale(row["msii"])
|
| 123 |
+
|
| 124 |
+
for num_steps in steps_to_run:
|
| 125 |
+
model.num_steps = num_steps
|
| 126 |
+
save_name = f"{row['hs_number']}_photo_fullview_{int(row['variation']):02d}-step{num_steps}.png"
|
| 127 |
+
save_path = EVALS_DIR / save_name
|
| 128 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 129 |
+
|
| 130 |
+
if save_path.exists():
|
| 131 |
+
pred_img = Image.open(save_path)
|
| 132 |
+
else:
|
| 133 |
+
pred_img = model.process(row["photo"], show_pbar=False)
|
| 134 |
+
pred_img.save(save_path)
|
| 135 |
+
|
| 136 |
+
pred = load_grayscale(pred_img)
|
| 137 |
+
pred_img.close()
|
| 138 |
+
|
| 139 |
+
if pred.shape != gt.shape:
|
| 140 |
+
pred = np.array(
|
| 141 |
+
Image.fromarray(pred).resize(
|
| 142 |
+
(gt.shape[1], gt.shape[0]), Image.LANCZOS
|
| 143 |
+
)
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
m = compute_metrics(pred, gt)
|
| 147 |
+
for name, val in m.items():
|
| 148 |
+
results[num_steps][name].append(val)
|
| 149 |
+
|
| 150 |
+
elapsed = time.monotonic() - start_time
|
| 151 |
+
live.update(
|
| 152 |
+
Group(progress, build_table(results, n_done + 1, len(rows), elapsed))
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
progress.advance(task)
|
| 156 |
+
|
| 157 |
+
final_elapsed = time.monotonic() - start_time
|
| 158 |
+
Console().print(build_table(results, len(rows), len(rows), final_elapsed))
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
main()
|
dev_scripts/process_images.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Process a directory of images through NisabaRelief and save as PNG."""
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from rich.console import Console
|
| 8 |
+
from rich.progress import (
|
| 9 |
+
BarColumn,
|
| 10 |
+
MofNCompleteColumn,
|
| 11 |
+
Progress,
|
| 12 |
+
ProgressColumn,
|
| 13 |
+
SpinnerColumn,
|
| 14 |
+
Task,
|
| 15 |
+
TextColumn,
|
| 16 |
+
TimeElapsedColumn,
|
| 17 |
+
)
|
| 18 |
+
from rich.text import Text
|
| 19 |
+
|
| 20 |
+
from nisaba_relief import NisabaRelief
|
| 21 |
+
from nisaba_relief.constants import MAX_TILE, MIN_IMAGE_DIMENSION
|
| 22 |
+
|
| 23 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 24 |
+
|
| 25 |
+
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp", ".webp"}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SimpleTimeRemainingColumn(ProgressColumn):
|
| 29 |
+
"""Estimates remaining time from the average duration of the last 10 iterations.
|
| 30 |
+
|
| 31 |
+
Only recomputes when a new step completes so the display is stable.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, window: int = 10) -> None:
|
| 35 |
+
super().__init__()
|
| 36 |
+
self._last_completed: float = 0
|
| 37 |
+
self._last_elapsed: float = 0.0
|
| 38 |
+
self._durations: list[float] = []
|
| 39 |
+
self._window: int = window
|
| 40 |
+
self._cached: Text = Text("-:--:--", style="progress.remaining")
|
| 41 |
+
|
| 42 |
+
def render(self, task: Task) -> Text:
|
| 43 |
+
if task.completed <= self._last_completed:
|
| 44 |
+
return self._cached
|
| 45 |
+
elapsed = task.finished_time if task.finished else task.elapsed
|
| 46 |
+
if not elapsed or not task.completed:
|
| 47 |
+
self._last_completed = task.completed
|
| 48 |
+
self._cached = Text("-:--:--", style="progress.remaining")
|
| 49 |
+
return self._cached
|
| 50 |
+
step_duration = elapsed - self._last_elapsed
|
| 51 |
+
steps = task.completed - self._last_completed
|
| 52 |
+
if steps > 0 and self._last_completed > 0:
|
| 53 |
+
per_step = step_duration / steps
|
| 54 |
+
self._durations.append(per_step)
|
| 55 |
+
if len(self._durations) > self._window:
|
| 56 |
+
self._durations = self._durations[-self._window :]
|
| 57 |
+
self._last_completed = task.completed
|
| 58 |
+
self._last_elapsed = elapsed
|
| 59 |
+
if not self._durations:
|
| 60 |
+
self._cached = Text("-:--:--", style="progress.remaining")
|
| 61 |
+
return self._cached
|
| 62 |
+
avg = sum(self._durations) / len(self._durations)
|
| 63 |
+
remaining = task.total - task.completed
|
| 64 |
+
eta_seconds = avg * remaining
|
| 65 |
+
hours, rem = divmod(int(eta_seconds), 3600)
|
| 66 |
+
minutes, seconds = divmod(rem, 60)
|
| 67 |
+
if hours:
|
| 68 |
+
self._cached = Text(
|
| 69 |
+
f"{hours}:{minutes:02d}:{seconds:02d}", style="progress.remaining"
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
self._cached = Text(f"{minutes}:{seconds:02d}", style="progress.remaining")
|
| 73 |
+
return self._cached
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def main():
|
| 77 |
+
parser = argparse.ArgumentParser(
|
| 78 |
+
description="Process images through NisabaRelief and save as PNG."
|
| 79 |
+
)
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--input-dir", type=Path, required=True, help="Source image directory"
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--output-dir", type=Path, required=True, help="Destination directory (created if needed)"
|
| 85 |
+
)
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--max-size", type=int, default=MAX_TILE * 5,
|
| 88 |
+
help="Downsample images larger than this before processing (default: %(default)s)",
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--min-size", type=int, default=1536,
|
| 92 |
+
help="Skip images where max dimension < this (default: %(default)s)",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument("--seed", type=int, default=None, help="Reproducibility seed")
|
| 95 |
+
parser.add_argument("--weights-dir", type=Path, default=None, help="Local weights directory")
|
| 96 |
+
parser.add_argument("--batch-size", type=int, default=None, help="Tile batch size")
|
| 97 |
+
parser.add_argument("--num-steps", type=int, default=2, help="Solver steps (default: %(default)s)")
|
| 98 |
+
parser.add_argument("--device", default="cuda", help="Torch device (default: %(default)s)")
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--overwrite", action="store_true", help="Re-process even if output file exists"
|
| 101 |
+
)
|
| 102 |
+
args = parser.parse_args()
|
| 103 |
+
|
| 104 |
+
console = Console()
|
| 105 |
+
|
| 106 |
+
input_dir: Path = args.input_dir
|
| 107 |
+
output_dir: Path = args.output_dir
|
| 108 |
+
|
| 109 |
+
if not input_dir.is_dir():
|
| 110 |
+
console.print(f"[red]Input directory not found:[/red] [cyan]{input_dir}[/cyan]")
|
| 111 |
+
return
|
| 112 |
+
|
| 113 |
+
input_images = sorted(
|
| 114 |
+
p for p in input_dir.iterdir() if p.suffix.lower() in IMAGE_EXTENSIONS
|
| 115 |
+
)
|
| 116 |
+
if not input_images:
|
| 117 |
+
console.print(f"[red]No images found in[/red] [cyan]{input_dir}[/cyan]")
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 121 |
+
|
| 122 |
+
to_process = []
|
| 123 |
+
skipped_existing = 0
|
| 124 |
+
skipped_small = 0
|
| 125 |
+
for src in input_images:
|
| 126 |
+
dst = output_dir / (src.stem + ".png")
|
| 127 |
+
if not args.overwrite and dst.exists():
|
| 128 |
+
skipped_existing += 1
|
| 129 |
+
continue
|
| 130 |
+
with Image.open(src) as img:
|
| 131 |
+
if max(img.size) < args.min_size or min(img.size) < MIN_IMAGE_DIMENSION:
|
| 132 |
+
skipped_small += 1
|
| 133 |
+
continue
|
| 134 |
+
to_process.append((src, dst))
|
| 135 |
+
|
| 136 |
+
if skipped_existing:
|
| 137 |
+
console.print(
|
| 138 |
+
f"[dim]Skipping {skipped_existing} already-processed image(s)[/dim]"
|
| 139 |
+
)
|
| 140 |
+
if skipped_small:
|
| 141 |
+
console.print(
|
| 142 |
+
f"[dim]Skipping {skipped_small} image(s) smaller than {args.min_size}px[/dim]"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if not to_process:
|
| 146 |
+
console.print("[green]All images already processed.[/green]")
|
| 147 |
+
return
|
| 148 |
+
|
| 149 |
+
console.print(
|
| 150 |
+
f"Processing [bold]{len(to_process)}[/bold] / {len(input_images)} images "
|
| 151 |
+
f"[dim]({input_dir} → {output_dir})[/dim]"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
model_kwargs = dict(num_steps=args.num_steps, device=args.device)
|
| 155 |
+
if args.seed is not None:
|
| 156 |
+
model_kwargs["seed"] = args.seed
|
| 157 |
+
if args.weights_dir is not None:
|
| 158 |
+
model_kwargs["weights_dir"] = args.weights_dir
|
| 159 |
+
if args.batch_size is not None:
|
| 160 |
+
model_kwargs["batch_size"] = args.batch_size
|
| 161 |
+
model = NisabaRelief(**model_kwargs)
|
| 162 |
+
|
| 163 |
+
progress = Progress(
|
| 164 |
+
SpinnerColumn(),
|
| 165 |
+
TextColumn("[progress.description]{task.description}"),
|
| 166 |
+
BarColumn(),
|
| 167 |
+
MofNCompleteColumn(),
|
| 168 |
+
TimeElapsedColumn(),
|
| 169 |
+
TextColumn("eta"),
|
| 170 |
+
SimpleTimeRemainingColumn(),
|
| 171 |
+
)
|
| 172 |
+
with progress:
|
| 173 |
+
task = progress.add_task("Processing", total=len(to_process))
|
| 174 |
+
for src, dst in to_process:
|
| 175 |
+
progress.update(task, description=f"[cyan]{src.name}[/cyan]")
|
| 176 |
+
image = Image.open(src).convert("RGB")
|
| 177 |
+
original_size = image.size
|
| 178 |
+
if max(image.size) > args.max_size:
|
| 179 |
+
scale = args.max_size / max(image.size)
|
| 180 |
+
new_size = (
|
| 181 |
+
round(image.width * scale) // 16 * 16,
|
| 182 |
+
round(image.height * scale) // 16 * 16,
|
| 183 |
+
)
|
| 184 |
+
image = image.resize(new_size, Image.LANCZOS)
|
| 185 |
+
result = model.process(image, show_pbar=False)
|
| 186 |
+
if result.size != original_size:
|
| 187 |
+
result = result.resize(original_size, Image.LANCZOS)
|
| 188 |
+
result.save(dst)
|
| 189 |
+
progress.advance(task)
|
| 190 |
+
|
| 191 |
+
console.print(
|
| 192 |
+
f"[green]Done.[/green] {len(to_process)} image(s) saved to [cyan]{output_dir}[/cyan]"
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
main()
|
dev_scripts/util/load_val_dataset.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Load the validation set from the CuneiformPhotosMSII dataset.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from datasets import load_dataset, Dataset
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
VAL_IDS_PATH = Path(__file__).parent.parent.parent / "data" / "val_tablet_ids.json"
|
| 11 |
+
VAL_IDS = set(json.load(open(VAL_IDS_PATH)))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_val_dataset() -> Dataset:
|
| 15 |
+
ds = load_dataset("boatbomber/CuneiformPhotosMSII", split="train", num_proc=4)
|
| 16 |
+
|
| 17 |
+
# First pass: parquet column projection reads only the ID strings, skipping image bytes
|
| 18 |
+
indices = [
|
| 19 |
+
i
|
| 20 |
+
for i, row in enumerate(ds.select_columns(["hs_number"]))
|
| 21 |
+
if row["hs_number"] in VAL_IDS
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
return ds.select(indices)
|
dev_scripts/util/metrics.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared metric computation for NisabaRelief evaluation scripts."""
|
| 2 |
+
|
| 3 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from image_similarity_measures.quality_metrics import (
|
| 7 |
+
rmse,
|
| 8 |
+
psnr,
|
| 9 |
+
sre,
|
| 10 |
+
)
|
| 11 |
+
import torch
|
| 12 |
+
from pytorch_msssim import ms_ssim as _pt_msssim
|
| 13 |
+
from util.psnr_hvsm import psnr_hvsm
|
| 14 |
+
|
| 15 |
+
DICE_THRESHOLD = 130
|
| 16 |
+
|
| 17 |
+
METRIC_NAMES = [
|
| 18 |
+
"dice",
|
| 19 |
+
"rmse",
|
| 20 |
+
"msssim",
|
| 21 |
+
"psnr",
|
| 22 |
+
"psnr_hvsm",
|
| 23 |
+
"sre",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
LABELS = {
|
| 27 |
+
"dice": "**Dice**",
|
| 28 |
+
"rmse": "RMSE",
|
| 29 |
+
"msssim": "MS-SSIM",
|
| 30 |
+
"psnr": "PSNR",
|
| 31 |
+
"psnr_hvsm": "PSNR-HVS-M",
|
| 32 |
+
"sre": "SRE",
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _to_tensor(arr: np.ndarray) -> torch.Tensor:
|
| 37 |
+
return torch.from_numpy(arr).float().unsqueeze(0).unsqueeze(0)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _msssim(gt: np.ndarray, pred: np.ndarray) -> float:
|
| 41 |
+
return _pt_msssim(
|
| 42 |
+
_to_tensor(gt), _to_tensor(pred), data_range=255, size_average=True
|
| 43 |
+
).item()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def compute_metrics(pred: np.ndarray, gt: np.ndarray) -> dict[str, float]:
|
| 47 |
+
"""Compute all metrics for a pair of equal-shape grayscale uint8 images."""
|
| 48 |
+
pred_3d = pred[:, :, np.newaxis]
|
| 49 |
+
gt_3d = gt[:, :, np.newaxis]
|
| 50 |
+
|
| 51 |
+
pred_bin = pred > DICE_THRESHOLD
|
| 52 |
+
gt_bin = gt > DICE_THRESHOLD
|
| 53 |
+
denom = pred_bin.sum() + gt_bin.sum()
|
| 54 |
+
dice = float(2 * np.logical_and(pred_bin, gt_bin).sum() / denom) if denom > 0 else 1.0
|
| 55 |
+
|
| 56 |
+
tasks = {
|
| 57 |
+
"rmse": lambda: rmse(gt_3d, pred_3d, max_p=255),
|
| 58 |
+
"psnr": lambda: psnr(gt_3d, pred_3d, max_p=255),
|
| 59 |
+
"msssim": lambda: _msssim(gt, pred),
|
| 60 |
+
"sre": lambda: sre(gt_3d, pred_3d),
|
| 61 |
+
"psnr_hvsm": lambda: psnr_hvsm(gt, pred)[0],
|
| 62 |
+
"dice": lambda: dice,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
with ThreadPoolExecutor(max_workers=len(tasks)) as executor:
|
| 66 |
+
futures = {name: executor.submit(fn) for name, fn in tasks.items()}
|
| 67 |
+
return {name: future.result() for name, future in futures.items()}
|
dev_scripts/util/psnr_hvsm.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PSNR-HVS-M and PSNR-HVS metrics (Ponomarenko et al., 2006/2007).
|
| 2 |
+
|
| 3 |
+
Direct Python translation of the MATLAB reference implementation at
|
| 4 |
+
https://www.ponomarenko.info/psnrhvsm.m
|
| 5 |
+
|
| 6 |
+
Returns (p_hvs_m, p_hvs) as a tuple.
|
| 7 |
+
Uses CUDA if available, otherwise falls back to CPU.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
_N = 8
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _make_dct_matrix() -> torch.Tensor:
|
| 19 |
+
"""8x8 orthonormal DCT-II matrix: D[0,n]=1/√N, D[k>0,n]=√(2/N)·cos(π·k·(2n+1)/(2N))."""
|
| 20 |
+
k = torch.arange(_N, dtype=torch.float64).unsqueeze(1)
|
| 21 |
+
n = torch.arange(_N, dtype=torch.float64).unsqueeze(0)
|
| 22 |
+
D = torch.cos(math.pi * k * (2 * n + 1) / (2 * _N))
|
| 23 |
+
D[0] = D[0] / math.sqrt(_N)
|
| 24 |
+
D[1:] = D[1:] * math.sqrt(2.0 / _N)
|
| 25 |
+
return D
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
_DCT8 = _make_dct_matrix() # (8, 8), CPU float64
|
| 29 |
+
|
| 30 |
+
_CSF = torch.tensor(
|
| 31 |
+
[
|
| 32 |
+
[1.608443, 2.339554, 2.573509, 1.608443, 1.072295, 0.643377, 0.504610, 0.421887],
|
| 33 |
+
[2.144591, 2.144591, 1.838221, 1.354478, 0.989811, 0.443708, 0.428918, 0.467911],
|
| 34 |
+
[1.838221, 1.979622, 1.608443, 1.072295, 0.643377, 0.451493, 0.372972, 0.459555],
|
| 35 |
+
[1.838221, 1.513829, 1.169777, 0.887417, 0.504610, 0.295806, 0.321689, 0.415082],
|
| 36 |
+
[1.429727, 1.169777, 0.695543, 0.459555, 0.378457, 0.236102, 0.249855, 0.334222],
|
| 37 |
+
[1.072295, 0.735288, 0.467911, 0.402111, 0.317717, 0.247453, 0.227744, 0.279729],
|
| 38 |
+
[0.525206, 0.402111, 0.329937, 0.295806, 0.249855, 0.212687, 0.214459, 0.254803],
|
| 39 |
+
[0.357432, 0.279729, 0.270896, 0.262603, 0.229778, 0.257351, 0.249855, 0.259950],
|
| 40 |
+
],
|
| 41 |
+
dtype=torch.float64,
|
| 42 |
+
)
|
| 43 |
+
_MASKCOF = torch.tensor(
|
| 44 |
+
[
|
| 45 |
+
[0.390625, 0.826446, 1.000000, 0.390625, 0.173611, 0.062500, 0.038447, 0.026874],
|
| 46 |
+
[0.694444, 0.694444, 0.510204, 0.277008, 0.147929, 0.029727, 0.027778, 0.033058],
|
| 47 |
+
[0.510204, 0.591716, 0.390625, 0.173611, 0.062500, 0.030779, 0.021004, 0.031888],
|
| 48 |
+
[0.510204, 0.346021, 0.206612, 0.118906, 0.038447, 0.013212, 0.015625, 0.026015],
|
| 49 |
+
[0.308642, 0.206612, 0.073046, 0.031888, 0.021626, 0.008417, 0.009426, 0.016866],
|
| 50 |
+
[0.173611, 0.081633, 0.033058, 0.024414, 0.015242, 0.009246, 0.007831, 0.011815],
|
| 51 |
+
[0.041649, 0.024414, 0.016437, 0.013212, 0.009426, 0.006830, 0.006944, 0.009803],
|
| 52 |
+
[0.019290, 0.011815, 0.011080, 0.010412, 0.007972, 0.010000, 0.009426, 0.010203],
|
| 53 |
+
],
|
| 54 |
+
dtype=torch.float64,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# True everywhere except the DC coefficient at (0, 0)
|
| 58 |
+
_AC_MASK = torch.ones((_N, _N), dtype=torch.bool)
|
| 59 |
+
_AC_MASK[0, 0] = False
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _vari_batch(blocks: torch.Tensor) -> torch.Tensor:
|
| 63 |
+
"""Unbiased variance * N for a batch of blocks. (B, H, W) -> (B,)"""
|
| 64 |
+
flat = blocks.reshape(blocks.shape[0], -1)
|
| 65 |
+
return flat.var(dim=-1, correction=1) * flat.shape[-1]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _maskeff_batch(blocks: torch.Tensor, dct_blocks: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
"""Perceptual masking strength for a batch of 8x8 blocks. Returns (B,)."""
|
| 70 |
+
dev = blocks.device
|
| 71 |
+
ac = _AC_MASK.to(dev)
|
| 72 |
+
mc = _MASKCOF.to(dev)
|
| 73 |
+
|
| 74 |
+
m = (dct_blocks[:, ac] ** 2 * mc[ac]).sum(dim=-1) # (B,)
|
| 75 |
+
|
| 76 |
+
pop = _vari_batch(blocks)
|
| 77 |
+
quad = (
|
| 78 |
+
_vari_batch(blocks[:, :4, :4])
|
| 79 |
+
+ _vari_batch(blocks[:, :4, 4:])
|
| 80 |
+
+ _vari_batch(blocks[:, 4:, :4])
|
| 81 |
+
+ _vari_batch(blocks[:, 4:, 4:])
|
| 82 |
+
)
|
| 83 |
+
pop_ratio = torch.where(pop > 0, quad / pop, torch.zeros_like(pop))
|
| 84 |
+
return torch.sqrt(m * pop_ratio) / 32.0
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def psnr_hvsm(img1: np.ndarray, img2: np.ndarray) -> tuple[float, float]:
|
| 88 |
+
"""Return (PSNR-HVS-M, PSNR-HVS) for two uint8 grayscale arrays.
|
| 89 |
+
|
| 90 |
+
Direct translation of the MATLAB reference (Ponomarenko et al.).
|
| 91 |
+
Partial edge blocks are skipped (truncate to nearest multiple of 8).
|
| 92 |
+
"""
|
| 93 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 94 |
+
D = _DCT8.to(device)
|
| 95 |
+
csf = _CSF.to(device)
|
| 96 |
+
maskcof = _MASKCOF.to(device)
|
| 97 |
+
ac_mask = _AC_MASK.to(device)
|
| 98 |
+
|
| 99 |
+
a = torch.from_numpy(img1.astype(np.float64)).to(device)
|
| 100 |
+
b = torch.from_numpy(img2.astype(np.float64)).to(device)
|
| 101 |
+
|
| 102 |
+
h, w = a.shape
|
| 103 |
+
h = (h // 8) * 8
|
| 104 |
+
w = (w // 8) * 8
|
| 105 |
+
a = a[:h, :w]
|
| 106 |
+
b = b[:h, :w]
|
| 107 |
+
|
| 108 |
+
num_blocks = (h // 8) * (w // 8)
|
| 109 |
+
if num_blocks == 0:
|
| 110 |
+
return 100000.0, 100000.0
|
| 111 |
+
|
| 112 |
+
# Extract all non-overlapping 8x8 blocks: (B, 8, 8)
|
| 113 |
+
ba = a.unfold(0, 8, 8).unfold(1, 8, 8).contiguous().reshape(-1, 8, 8)
|
| 114 |
+
bb = b.unfold(0, 8, 8).unfold(1, 8, 8).contiguous().reshape(-1, 8, 8)
|
| 115 |
+
|
| 116 |
+
# 2D DCT-II (ortho) via separable matrix product: D @ block @ D.T
|
| 117 |
+
da = D @ ba @ D.t()
|
| 118 |
+
db = D @ bb @ D.t()
|
| 119 |
+
|
| 120 |
+
mask = torch.maximum(_maskeff_batch(ba, da), _maskeff_batch(bb, db)) # (B,)
|
| 121 |
+
|
| 122 |
+
diff = torch.abs(da - db) # (B, 8, 8)
|
| 123 |
+
|
| 124 |
+
# PSNR-HVS: CSF-weighted squared error (no masking)
|
| 125 |
+
S2 = float(((diff * csf) ** 2).sum())
|
| 126 |
+
|
| 127 |
+
# PSNR-HVS-M: soft-threshold AC coefficients by local mask, keep DC as-is
|
| 128 |
+
thresh = mask[:, None, None] / maskcof[None, :, :]
|
| 129 |
+
u = torch.where(ac_mask[None, :, :], torch.clamp(diff - thresh, min=0.0), diff)
|
| 130 |
+
S1 = float(((u * csf) ** 2).sum())
|
| 131 |
+
|
| 132 |
+
denom = num_blocks * 64
|
| 133 |
+
S1 /= denom
|
| 134 |
+
S2 /= denom
|
| 135 |
+
p_hvs_m = 100000.0 if S1 == 0 else float(10.0 * np.log10(255.0**2 / S1))
|
| 136 |
+
p_hvs = 100000.0 if S2 == 0 else float(10.0 * np.log10(255.0**2 / S2))
|
| 137 |
+
return p_hvs_m, p_hvs
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2b83231048fc47de658665425368daadce6791bfd95456397b8b595aa0e5d05d
|
| 3 |
+
size 7751105712
|
nisaba_relief/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""NisabaRelief: Transform cuneiform tablet photos into MSII relief visualizations."""
|
| 2 |
+
|
| 3 |
+
__version__ = "0.1.0"
|
| 4 |
+
|
| 5 |
+
from .model import NisabaRelief
|
| 6 |
+
|
| 7 |
+
__all__ = ["NisabaRelief"]
|
nisaba_relief/constants.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Named constants for NisabaRelief magic numbers."""
|
| 2 |
+
|
| 3 |
+
# Flux model processes images in 16×16 pixel patches
|
| 4 |
+
PATCH_SIZE = 16
|
| 5 |
+
|
| 6 |
+
# Tile size bounds: 12 patches (192px) to 64 patches (1024px)
|
| 7 |
+
MIN_TILE = PATCH_SIZE * 12 # 192
|
| 8 |
+
MAX_TILE = PATCH_SIZE * 64 # 1024
|
| 9 |
+
|
| 10 |
+
# Aims for ~4 tiles along the longest axis when computing tile size
|
| 11 |
+
TARGET_TILES_PER_SIDE = 4
|
| 12 |
+
|
| 13 |
+
# Overlap is 1/8 of the tile size, giving a smooth cosine blend region
|
| 14 |
+
TILE_OVERLAP_DIVISOR = 8
|
| 15 |
+
|
| 16 |
+
# Smallest accepted input side in pixels
|
| 17 |
+
MIN_IMAGE_DIMENSION = MIN_TILE * 2
|
| 18 |
+
|
| 19 |
+
# Maximum allowed aspect ratio (width:height or height:width)
|
| 20 |
+
MAX_ASPECT_RATIO = 8.0
|
| 21 |
+
|
| 22 |
+
# Maximum size (px) for the global context thumbnail
|
| 23 |
+
MAX_GLOBAL_CONTEXT_SIZE = 128
|
| 24 |
+
|
| 25 |
+
# Positional sequence ID for conditioning tokens (image being processed)
|
| 26 |
+
COND_SEQ_ID = 10
|
| 27 |
+
|
| 28 |
+
# Positional sequence ID for global context tokens (thumbnail overview)
|
| 29 |
+
GLOBAL_CTX_ID = 20
|
| 30 |
+
|
| 31 |
+
# Number of latent channels in the Flux model's latent space
|
| 32 |
+
LATENT_CHANNELS = 128
|
| 33 |
+
|
| 34 |
+
# Dynamic batch_size constants. Determined empirically on an RTX 3090.
|
| 35 |
+
MAX_BATCH_SIZE = 16
|
| 36 |
+
MIN_BATCH_SIZE = 1
|
| 37 |
+
VRAM_MB_PER_PIXEL = 0.0035
|
| 38 |
+
VRAM_FIXED_OVERHEAD_MB = 15.0
|
| 39 |
+
VRAM_HEADROOM_MB = 1024.0
|
| 40 |
+
|
| 41 |
+
# Divisor for AE decoder sub-batching (decoder needs more VRAM than denoiser)
|
| 42 |
+
DECODE_BATCH_SIZE_DIVISOR = 5
|
nisaba_relief/flux/__init__.py
ADDED
|
File without changes
|
nisaba_relief/flux/autoencoder.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from torch import Tensor, nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class AutoEncoderParams:
|
| 12 |
+
resolution: int = 256
|
| 13 |
+
in_channels: int = 3
|
| 14 |
+
ch: int = 128
|
| 15 |
+
out_ch: int = 3
|
| 16 |
+
ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
|
| 17 |
+
num_res_blocks: int = 2
|
| 18 |
+
z_channels: int = 32
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AttnBlock(nn.Module):
|
| 22 |
+
def __init__(self, in_channels: int):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.in_channels = in_channels
|
| 25 |
+
|
| 26 |
+
self.norm = nn.GroupNorm(
|
| 27 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 31 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 32 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 33 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 34 |
+
|
| 35 |
+
def attention(self, h_: Tensor) -> Tensor:
|
| 36 |
+
h_ = self.norm(h_)
|
| 37 |
+
q = self.q(h_)
|
| 38 |
+
k = self.k(h_)
|
| 39 |
+
v = self.v(h_)
|
| 40 |
+
|
| 41 |
+
b, c, h, w = q.shape
|
| 42 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
| 43 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
| 44 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
| 45 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
| 46 |
+
|
| 47 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
| 48 |
+
|
| 49 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 50 |
+
return x + self.proj_out(self.attention(x))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class ResnetBlock(nn.Module):
|
| 54 |
+
def __init__(self, in_channels: int, out_channels: int):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.in_channels = in_channels
|
| 57 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 58 |
+
self.out_channels = out_channels
|
| 59 |
+
|
| 60 |
+
self.norm1 = nn.GroupNorm(
|
| 61 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
| 62 |
+
)
|
| 63 |
+
self.conv1 = nn.Conv2d(
|
| 64 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 65 |
+
)
|
| 66 |
+
self.norm2 = nn.GroupNorm(
|
| 67 |
+
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
|
| 68 |
+
)
|
| 69 |
+
self.conv2 = nn.Conv2d(
|
| 70 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 71 |
+
)
|
| 72 |
+
if self.in_channels != self.out_channels:
|
| 73 |
+
self.nin_shortcut = nn.Conv2d(
|
| 74 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 78 |
+
h = x
|
| 79 |
+
h = self.norm1(h)
|
| 80 |
+
h = F.silu(h)
|
| 81 |
+
h = self.conv1(h)
|
| 82 |
+
|
| 83 |
+
h = self.norm2(h)
|
| 84 |
+
h = F.silu(h)
|
| 85 |
+
h = self.conv2(h)
|
| 86 |
+
|
| 87 |
+
if self.in_channels != self.out_channels:
|
| 88 |
+
x = self.nin_shortcut(x)
|
| 89 |
+
|
| 90 |
+
return x + h
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class Downsample(nn.Module):
|
| 94 |
+
def __init__(self, in_channels: int):
|
| 95 |
+
super().__init__()
|
| 96 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 97 |
+
self.conv = nn.Conv2d(
|
| 98 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 102 |
+
pad = (0, 1, 0, 1)
|
| 103 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
| 104 |
+
x = self.conv(x)
|
| 105 |
+
return x
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class Upsample(nn.Module):
|
| 109 |
+
def __init__(self, in_channels: int):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.conv = nn.Conv2d(
|
| 112 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 116 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 117 |
+
x = self.conv(x)
|
| 118 |
+
return x
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class Encoder(nn.Module):
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
resolution: int,
|
| 125 |
+
in_channels: int,
|
| 126 |
+
ch: int,
|
| 127 |
+
ch_mult: list[int],
|
| 128 |
+
num_res_blocks: int,
|
| 129 |
+
z_channels: int,
|
| 130 |
+
):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1)
|
| 133 |
+
self.ch = ch
|
| 134 |
+
self.num_resolutions = len(ch_mult)
|
| 135 |
+
self.num_res_blocks = num_res_blocks
|
| 136 |
+
self.resolution = resolution
|
| 137 |
+
self.in_channels = in_channels
|
| 138 |
+
# downsampling
|
| 139 |
+
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
| 140 |
+
|
| 141 |
+
curr_res = resolution
|
| 142 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 143 |
+
self.in_ch_mult = in_ch_mult
|
| 144 |
+
self.down = nn.ModuleList()
|
| 145 |
+
block_in = self.ch
|
| 146 |
+
for i_level in range(self.num_resolutions):
|
| 147 |
+
block = nn.ModuleList()
|
| 148 |
+
attn = nn.ModuleList()
|
| 149 |
+
block_in = ch * in_ch_mult[i_level]
|
| 150 |
+
block_out = ch * ch_mult[i_level]
|
| 151 |
+
for _ in range(self.num_res_blocks):
|
| 152 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 153 |
+
block_in = block_out
|
| 154 |
+
down = nn.Module()
|
| 155 |
+
down.block = block
|
| 156 |
+
down.attn = attn
|
| 157 |
+
if i_level != self.num_resolutions - 1:
|
| 158 |
+
down.downsample = Downsample(block_in)
|
| 159 |
+
curr_res = curr_res // 2
|
| 160 |
+
self.down.append(down)
|
| 161 |
+
|
| 162 |
+
# middle
|
| 163 |
+
self.mid = nn.Module()
|
| 164 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 165 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 166 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 167 |
+
|
| 168 |
+
# end
|
| 169 |
+
self.norm_out = nn.GroupNorm(
|
| 170 |
+
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
| 171 |
+
)
|
| 172 |
+
self.conv_out = nn.Conv2d(
|
| 173 |
+
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 177 |
+
# downsampling
|
| 178 |
+
h = self.conv_in(x)
|
| 179 |
+
for i_level in range(self.num_resolutions):
|
| 180 |
+
for i_block in range(self.num_res_blocks):
|
| 181 |
+
h = self.down[i_level].block[i_block](h)
|
| 182 |
+
if len(self.down[i_level].attn) > 0:
|
| 183 |
+
h = self.down[i_level].attn[i_block](h)
|
| 184 |
+
if i_level != self.num_resolutions - 1:
|
| 185 |
+
h = self.down[i_level].downsample(h)
|
| 186 |
+
|
| 187 |
+
# middle
|
| 188 |
+
h = self.mid.block_1(h)
|
| 189 |
+
h = self.mid.attn_1(h)
|
| 190 |
+
h = self.mid.block_2(h)
|
| 191 |
+
# end
|
| 192 |
+
h = self.norm_out(h)
|
| 193 |
+
h = F.silu(h)
|
| 194 |
+
h = self.conv_out(h)
|
| 195 |
+
h = self.quant_conv(h)
|
| 196 |
+
return h
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class Decoder(nn.Module):
|
| 200 |
+
def __init__(
|
| 201 |
+
self,
|
| 202 |
+
ch: int,
|
| 203 |
+
out_ch: int,
|
| 204 |
+
ch_mult: list[int],
|
| 205 |
+
num_res_blocks: int,
|
| 206 |
+
in_channels: int,
|
| 207 |
+
resolution: int,
|
| 208 |
+
z_channels: int,
|
| 209 |
+
):
|
| 210 |
+
super().__init__()
|
| 211 |
+
self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
|
| 212 |
+
self.ch = ch
|
| 213 |
+
self.num_resolutions = len(ch_mult)
|
| 214 |
+
self.num_res_blocks = num_res_blocks
|
| 215 |
+
self.resolution = resolution
|
| 216 |
+
self.in_channels = in_channels
|
| 217 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
| 218 |
+
|
| 219 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 220 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 221 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
| 222 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 223 |
+
|
| 224 |
+
# z to block_in
|
| 225 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 226 |
+
|
| 227 |
+
# middle
|
| 228 |
+
self.mid = nn.Module()
|
| 229 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 230 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 231 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 232 |
+
|
| 233 |
+
# upsampling
|
| 234 |
+
self.up = nn.ModuleList()
|
| 235 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 236 |
+
block = nn.ModuleList()
|
| 237 |
+
attn = nn.ModuleList()
|
| 238 |
+
block_out = ch * ch_mult[i_level]
|
| 239 |
+
for _ in range(self.num_res_blocks + 1):
|
| 240 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 241 |
+
block_in = block_out
|
| 242 |
+
up = nn.Module()
|
| 243 |
+
up.block = block
|
| 244 |
+
up.attn = attn
|
| 245 |
+
if i_level != 0:
|
| 246 |
+
up.upsample = Upsample(block_in)
|
| 247 |
+
curr_res = curr_res * 2
|
| 248 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 249 |
+
|
| 250 |
+
# end
|
| 251 |
+
self.norm_out = nn.GroupNorm(
|
| 252 |
+
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
| 253 |
+
)
|
| 254 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 255 |
+
|
| 256 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 257 |
+
z = self.post_quant_conv(z)
|
| 258 |
+
|
| 259 |
+
# get dtype for proper tracing
|
| 260 |
+
upscale_dtype = next(self.up.parameters()).dtype
|
| 261 |
+
|
| 262 |
+
# z to block_in
|
| 263 |
+
h = self.conv_in(z)
|
| 264 |
+
|
| 265 |
+
# middle
|
| 266 |
+
h = self.mid.block_1(h)
|
| 267 |
+
h = self.mid.attn_1(h)
|
| 268 |
+
h = self.mid.block_2(h)
|
| 269 |
+
|
| 270 |
+
# cast to proper dtype
|
| 271 |
+
h = h.to(upscale_dtype)
|
| 272 |
+
# upsampling
|
| 273 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 274 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 275 |
+
h = self.up[i_level].block[i_block](h)
|
| 276 |
+
if len(self.up[i_level].attn) > 0:
|
| 277 |
+
h = self.up[i_level].attn[i_block](h)
|
| 278 |
+
if i_level != 0:
|
| 279 |
+
h = self.up[i_level].upsample(h)
|
| 280 |
+
|
| 281 |
+
# end
|
| 282 |
+
h = self.norm_out(h)
|
| 283 |
+
h = F.silu(h)
|
| 284 |
+
h = self.conv_out(h)
|
| 285 |
+
return h
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class AutoEncoder(nn.Module):
|
| 289 |
+
def __init__(self, params: AutoEncoderParams = AutoEncoderParams()):
|
| 290 |
+
super().__init__()
|
| 291 |
+
self.params = params
|
| 292 |
+
self.encoder = Encoder(
|
| 293 |
+
resolution=params.resolution,
|
| 294 |
+
in_channels=params.in_channels,
|
| 295 |
+
ch=params.ch,
|
| 296 |
+
ch_mult=params.ch_mult,
|
| 297 |
+
num_res_blocks=params.num_res_blocks,
|
| 298 |
+
z_channels=params.z_channels,
|
| 299 |
+
)
|
| 300 |
+
self.decoder = Decoder(
|
| 301 |
+
resolution=params.resolution,
|
| 302 |
+
in_channels=params.in_channels,
|
| 303 |
+
ch=params.ch,
|
| 304 |
+
out_ch=params.out_ch,
|
| 305 |
+
ch_mult=params.ch_mult,
|
| 306 |
+
num_res_blocks=params.num_res_blocks,
|
| 307 |
+
z_channels=params.z_channels,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
self.bn_eps = 1e-4
|
| 311 |
+
self.bn_momentum = 0.1
|
| 312 |
+
self.ps = [2, 2]
|
| 313 |
+
self.bn = torch.nn.BatchNorm2d(
|
| 314 |
+
math.prod(self.ps) * params.z_channels,
|
| 315 |
+
eps=self.bn_eps,
|
| 316 |
+
momentum=self.bn_momentum,
|
| 317 |
+
affine=False,
|
| 318 |
+
track_running_stats=True,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
def normalize(self, z: Tensor) -> Tensor:
|
| 322 |
+
return self.bn(z)
|
| 323 |
+
|
| 324 |
+
def inv_normalize(self, z: Tensor) -> Tensor:
|
| 325 |
+
s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps)
|
| 326 |
+
m = self.bn.running_mean.view(1, -1, 1, 1)
|
| 327 |
+
return z * s + m
|
| 328 |
+
|
| 329 |
+
def encode(self, x: Tensor) -> Tensor:
|
| 330 |
+
moments = self.encoder(x)
|
| 331 |
+
mean = torch.chunk(moments, 2, dim=1)[0]
|
| 332 |
+
|
| 333 |
+
z = rearrange(
|
| 334 |
+
mean,
|
| 335 |
+
"... c (i pi) (j pj) -> ... (c pi pj) i j",
|
| 336 |
+
pi=self.ps[0],
|
| 337 |
+
pj=self.ps[1],
|
| 338 |
+
)
|
| 339 |
+
z = self.normalize(z)
|
| 340 |
+
return z
|
| 341 |
+
|
| 342 |
+
def decode(self, z: Tensor) -> Tensor:
|
| 343 |
+
z = self.inv_normalize(z)
|
| 344 |
+
z = rearrange(
|
| 345 |
+
z,
|
| 346 |
+
"... (c pi pj) i j -> ... c (i pi) (j pj)",
|
| 347 |
+
pi=self.ps[0],
|
| 348 |
+
pj=self.ps[1],
|
| 349 |
+
)
|
| 350 |
+
dec = self.decoder(z.to(next(self.decoder.parameters()).dtype))
|
| 351 |
+
return dec
|
nisaba_relief/flux/layers.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Building-block nn.Module primitives and standalone functions for Flux2."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from torch import Tensor, nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
| 11 |
+
t = time_factor * t
|
| 12 |
+
half = dim // 2
|
| 13 |
+
freqs = torch.exp(
|
| 14 |
+
-math.log(max_period)
|
| 15 |
+
* torch.arange(start=0, end=half, device=t.device, dtype=torch.float32)
|
| 16 |
+
/ half
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
args = t[:, None].float() * freqs[None]
|
| 20 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 21 |
+
if dim % 2:
|
| 22 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 23 |
+
if torch.is_floating_point(t):
|
| 24 |
+
embedding = embedding.to(t)
|
| 25 |
+
return embedding
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
| 29 |
+
q, k = apply_rope(q, k, pe)
|
| 30 |
+
|
| 31 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 32 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
| 33 |
+
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
| 38 |
+
assert dim % 2 == 0
|
| 39 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
|
| 40 |
+
omega = 1.0 / (theta**scale)
|
| 41 |
+
out = torch.einsum("...n,d->...nd", pos.float(), omega)
|
| 42 |
+
out = torch.stack(
|
| 43 |
+
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)],
|
| 44 |
+
dim=-1,
|
| 45 |
+
)
|
| 46 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
| 47 |
+
return out
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
| 51 |
+
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)
|
| 52 |
+
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)
|
| 53 |
+
freqs_cis = freqs_cis.to(xq.dtype)
|
| 54 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
| 55 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
| 56 |
+
return xq_out.reshape(*xq.shape), xk_out.reshape(*xk.shape)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class SelfAttention(nn.Module):
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
dim: int,
|
| 63 |
+
num_heads: int = 8,
|
| 64 |
+
):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.num_heads = num_heads
|
| 67 |
+
head_dim = dim // num_heads
|
| 68 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
| 69 |
+
|
| 70 |
+
self.norm = QKNorm(head_dim)
|
| 71 |
+
self.proj = nn.Linear(dim, dim, bias=False)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class SiLUActivation(nn.Module):
|
| 75 |
+
def __init__(self):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.gate_fn = nn.SiLU()
|
| 78 |
+
|
| 79 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 80 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 81 |
+
return self.gate_fn(x1) * x2
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class Modulation(nn.Module):
|
| 85 |
+
def __init__(self, dim: int, double: bool, disable_bias: bool = False):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.is_double = double
|
| 88 |
+
self.multiplier = 6 if double else 3
|
| 89 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=not disable_bias)
|
| 90 |
+
|
| 91 |
+
def forward(self, vec: torch.Tensor):
|
| 92 |
+
out = self.lin(nn.functional.silu(vec))
|
| 93 |
+
if out.ndim == 2:
|
| 94 |
+
out = out[:, None, :]
|
| 95 |
+
out = out.chunk(self.multiplier, dim=-1)
|
| 96 |
+
return out[:3], out[3:] if self.is_double else None
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class LastLayer(nn.Module):
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
hidden_size: int,
|
| 103 |
+
out_channels: int,
|
| 104 |
+
):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 107 |
+
self.linear = nn.Linear(hidden_size, out_channels, bias=False)
|
| 108 |
+
self.adaLN_modulation = nn.Sequential(
|
| 109 |
+
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=False)
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
|
| 113 |
+
mod = self.adaLN_modulation(vec)
|
| 114 |
+
shift, scale = mod.chunk(2, dim=-1)
|
| 115 |
+
if shift.ndim == 2:
|
| 116 |
+
shift = shift[:, None, :]
|
| 117 |
+
scale = scale[:, None, :]
|
| 118 |
+
x = (1 + scale) * self.norm_final(x) + shift
|
| 119 |
+
x = self.linear(x)
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class SingleStreamBlock(nn.Module):
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
hidden_size: int,
|
| 127 |
+
num_heads: int,
|
| 128 |
+
mlp_ratio: float = 4.0,
|
| 129 |
+
):
|
| 130 |
+
super().__init__()
|
| 131 |
+
|
| 132 |
+
self.hidden_dim = hidden_size
|
| 133 |
+
self.num_heads = num_heads
|
| 134 |
+
head_dim = hidden_size // num_heads
|
| 135 |
+
self.scale = head_dim**-0.5
|
| 136 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 137 |
+
self.mlp_mult_factor = 2
|
| 138 |
+
|
| 139 |
+
self.linear1 = nn.Linear(
|
| 140 |
+
hidden_size,
|
| 141 |
+
hidden_size * 3 + self.mlp_hidden_dim * self.mlp_mult_factor,
|
| 142 |
+
bias=False,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
self.linear2 = nn.Linear(
|
| 146 |
+
hidden_size + self.mlp_hidden_dim, hidden_size, bias=False
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
self.norm = QKNorm(head_dim)
|
| 150 |
+
|
| 151 |
+
self.hidden_size = hidden_size
|
| 152 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 153 |
+
|
| 154 |
+
self.mlp_act = SiLUActivation()
|
| 155 |
+
|
| 156 |
+
def forward(
|
| 157 |
+
self,
|
| 158 |
+
x: Tensor,
|
| 159 |
+
pe: Tensor,
|
| 160 |
+
mod: tuple[Tensor, Tensor],
|
| 161 |
+
) -> Tensor:
|
| 162 |
+
mod_shift, mod_scale, mod_gate = mod
|
| 163 |
+
x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift
|
| 164 |
+
|
| 165 |
+
qkv, mlp = torch.split(
|
| 166 |
+
self.linear1(x_mod),
|
| 167 |
+
[3 * self.hidden_size, self.mlp_hidden_dim * self.mlp_mult_factor],
|
| 168 |
+
dim=-1,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 172 |
+
q, k = self.norm(q, k, v)
|
| 173 |
+
|
| 174 |
+
attn = attention(q, k, v, pe)
|
| 175 |
+
|
| 176 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
| 177 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
| 178 |
+
return x + mod_gate * output
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class DoubleStreamBlock(nn.Module):
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
hidden_size: int,
|
| 185 |
+
num_heads: int,
|
| 186 |
+
mlp_ratio: float,
|
| 187 |
+
):
|
| 188 |
+
super().__init__()
|
| 189 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 190 |
+
self.num_heads = num_heads
|
| 191 |
+
assert hidden_size % num_heads == 0, (
|
| 192 |
+
f"{hidden_size=} must be divisible by {num_heads=}"
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
self.hidden_size = hidden_size
|
| 196 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 197 |
+
self.mlp_mult_factor = 2
|
| 198 |
+
|
| 199 |
+
self.img_attn = SelfAttention(
|
| 200 |
+
dim=hidden_size,
|
| 201 |
+
num_heads=num_heads,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 205 |
+
self.img_mlp = nn.Sequential(
|
| 206 |
+
nn.Linear(hidden_size, mlp_hidden_dim * self.mlp_mult_factor, bias=False),
|
| 207 |
+
SiLUActivation(),
|
| 208 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 212 |
+
self.txt_attn = SelfAttention(
|
| 213 |
+
dim=hidden_size,
|
| 214 |
+
num_heads=num_heads,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 218 |
+
self.txt_mlp = nn.Sequential(
|
| 219 |
+
nn.Linear(
|
| 220 |
+
hidden_size,
|
| 221 |
+
mlp_hidden_dim * self.mlp_mult_factor,
|
| 222 |
+
bias=False,
|
| 223 |
+
),
|
| 224 |
+
SiLUActivation(),
|
| 225 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def forward(
|
| 229 |
+
self,
|
| 230 |
+
img: Tensor,
|
| 231 |
+
txt: Tensor,
|
| 232 |
+
pe: Tensor,
|
| 233 |
+
pe_ctx: Tensor,
|
| 234 |
+
mod_img: tuple[Tensor, Tensor],
|
| 235 |
+
mod_txt: tuple[Tensor, Tensor],
|
| 236 |
+
) -> tuple[Tensor, Tensor]:
|
| 237 |
+
img_mod1, img_mod2 = mod_img
|
| 238 |
+
txt_mod1, txt_mod2 = mod_txt
|
| 239 |
+
|
| 240 |
+
img_mod1_shift, img_mod1_scale, img_mod1_gate = img_mod1
|
| 241 |
+
img_mod2_shift, img_mod2_scale, img_mod2_gate = img_mod2
|
| 242 |
+
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate = txt_mod1
|
| 243 |
+
txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = txt_mod2
|
| 244 |
+
|
| 245 |
+
# prepare image for attention
|
| 246 |
+
img_modulated = self.img_norm1(img)
|
| 247 |
+
img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift
|
| 248 |
+
|
| 249 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
| 250 |
+
img_q, img_k, img_v = rearrange(
|
| 251 |
+
img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
| 252 |
+
)
|
| 253 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
| 254 |
+
|
| 255 |
+
# prepare txt for attention
|
| 256 |
+
txt_modulated = self.txt_norm1(txt)
|
| 257 |
+
txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift
|
| 258 |
+
|
| 259 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
| 260 |
+
txt_q, txt_k, txt_v = rearrange(
|
| 261 |
+
txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
| 262 |
+
)
|
| 263 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
| 264 |
+
|
| 265 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
| 266 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
| 267 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
| 268 |
+
|
| 269 |
+
pe = torch.cat((pe_ctx, pe), dim=2)
|
| 270 |
+
attn = attention(q, k, v, pe)
|
| 271 |
+
txt_attn, img_attn = (
|
| 272 |
+
attn[:, : txt_q.shape[2]],
|
| 273 |
+
attn[:, txt_q.shape[2] :],
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# calculate the img blocks
|
| 277 |
+
img = img + img_mod1_gate * self.img_attn.proj(img_attn)
|
| 278 |
+
img = img + img_mod2_gate * self.img_mlp(
|
| 279 |
+
(1 + img_mod2_scale) * (self.img_norm2(img)) + img_mod2_shift
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# calculate the txt blocks
|
| 283 |
+
txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn)
|
| 284 |
+
txt = txt + txt_mod2_gate * self.txt_mlp(
|
| 285 |
+
(1 + txt_mod2_scale) * (self.txt_norm2(txt)) + txt_mod2_shift
|
| 286 |
+
)
|
| 287 |
+
return img, txt
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class MLPEmbedder(nn.Module):
|
| 291 |
+
def __init__(self, in_dim: int, hidden_dim: int, disable_bias: bool = False):
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=not disable_bias)
|
| 294 |
+
self.silu = nn.SiLU()
|
| 295 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=not disable_bias)
|
| 296 |
+
|
| 297 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 298 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class EmbedND(nn.Module):
|
| 302 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
| 303 |
+
super().__init__()
|
| 304 |
+
self.dim = dim
|
| 305 |
+
self.theta = theta
|
| 306 |
+
self.axes_dim = axes_dim
|
| 307 |
+
|
| 308 |
+
def forward(self, ids: Tensor) -> Tensor:
|
| 309 |
+
emb = torch.cat(
|
| 310 |
+
[
|
| 311 |
+
rope(ids[..., i], self.axes_dim[i], self.theta)
|
| 312 |
+
for i in range(len(self.axes_dim))
|
| 313 |
+
],
|
| 314 |
+
dim=-3,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
return emb.unsqueeze(1)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class RMSNorm(torch.nn.Module):
|
| 321 |
+
def __init__(self, dim: int):
|
| 322 |
+
super().__init__()
|
| 323 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
| 324 |
+
|
| 325 |
+
def forward(self, x: Tensor):
|
| 326 |
+
x_dtype = x.dtype
|
| 327 |
+
x = x.float()
|
| 328 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
| 329 |
+
return (x * rrms).to(dtype=x_dtype) * self.scale
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class QKNorm(torch.nn.Module):
|
| 333 |
+
def __init__(self, dim: int):
|
| 334 |
+
super().__init__()
|
| 335 |
+
self.query_norm = RMSNorm(dim)
|
| 336 |
+
self.key_norm = RMSNorm(dim)
|
| 337 |
+
|
| 338 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
| 339 |
+
q = self.query_norm(q)
|
| 340 |
+
k = self.key_norm(k)
|
| 341 |
+
return q.to(v), k.to(v)
|
nisaba_relief/flux/model.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor, nn
|
| 5 |
+
|
| 6 |
+
from .layers import (
|
| 7 |
+
DoubleStreamBlock,
|
| 8 |
+
EmbedND,
|
| 9 |
+
LastLayer,
|
| 10 |
+
MLPEmbedder,
|
| 11 |
+
Modulation,
|
| 12 |
+
SingleStreamBlock,
|
| 13 |
+
timestep_embedding,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class Klein4BParams:
|
| 19 |
+
in_channels: int = 128
|
| 20 |
+
context_in_dim: int = 7680
|
| 21 |
+
hidden_size: int = 3072
|
| 22 |
+
num_heads: int = 24
|
| 23 |
+
depth: int = 5
|
| 24 |
+
depth_single_blocks: int = 20
|
| 25 |
+
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
|
| 26 |
+
theta: int = 2000
|
| 27 |
+
mlp_ratio: float = 3.0
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Flux2(nn.Module):
|
| 31 |
+
def __init__(self, params: Klein4BParams = Klein4BParams()):
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
self.in_channels = params.in_channels
|
| 35 |
+
self.out_channels = params.in_channels
|
| 36 |
+
if params.hidden_size % params.num_heads != 0:
|
| 37 |
+
raise ValueError(
|
| 38 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
| 39 |
+
)
|
| 40 |
+
pe_dim = params.hidden_size // params.num_heads
|
| 41 |
+
if sum(params.axes_dim) != pe_dim:
|
| 42 |
+
raise ValueError(
|
| 43 |
+
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
|
| 44 |
+
)
|
| 45 |
+
self.hidden_size = params.hidden_size
|
| 46 |
+
self.num_heads = params.num_heads
|
| 47 |
+
self.pe_embedder = EmbedND(
|
| 48 |
+
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
|
| 49 |
+
)
|
| 50 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=False)
|
| 51 |
+
self.time_in = MLPEmbedder(
|
| 52 |
+
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
|
| 53 |
+
)
|
| 54 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False)
|
| 55 |
+
|
| 56 |
+
self.double_blocks = nn.ModuleList(
|
| 57 |
+
[
|
| 58 |
+
DoubleStreamBlock(
|
| 59 |
+
self.hidden_size,
|
| 60 |
+
self.num_heads,
|
| 61 |
+
mlp_ratio=params.mlp_ratio,
|
| 62 |
+
)
|
| 63 |
+
for _ in range(params.depth)
|
| 64 |
+
]
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.single_blocks = nn.ModuleList(
|
| 68 |
+
[
|
| 69 |
+
SingleStreamBlock(
|
| 70 |
+
self.hidden_size,
|
| 71 |
+
self.num_heads,
|
| 72 |
+
mlp_ratio=params.mlp_ratio,
|
| 73 |
+
)
|
| 74 |
+
for _ in range(params.depth_single_blocks)
|
| 75 |
+
]
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
self.double_stream_modulation_img = Modulation(
|
| 79 |
+
self.hidden_size,
|
| 80 |
+
double=True,
|
| 81 |
+
disable_bias=True,
|
| 82 |
+
)
|
| 83 |
+
self.double_stream_modulation_txt = Modulation(
|
| 84 |
+
self.hidden_size,
|
| 85 |
+
double=True,
|
| 86 |
+
disable_bias=True,
|
| 87 |
+
)
|
| 88 |
+
self.single_stream_modulation = Modulation(
|
| 89 |
+
self.hidden_size, double=False, disable_bias=True
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.final_layer = LastLayer(
|
| 93 |
+
self.hidden_size,
|
| 94 |
+
self.out_channels,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def forward(
|
| 98 |
+
self,
|
| 99 |
+
x: Tensor,
|
| 100 |
+
x_ids: Tensor,
|
| 101 |
+
timesteps: Tensor,
|
| 102 |
+
ctx: Tensor,
|
| 103 |
+
ctx_ids: Tensor,
|
| 104 |
+
pe_x: Tensor | None = None,
|
| 105 |
+
pe_ctx: Tensor | None = None,
|
| 106 |
+
) -> Tensor:
|
| 107 |
+
num_txt_tokens = ctx.shape[1]
|
| 108 |
+
|
| 109 |
+
timestep_emb = timestep_embedding(timesteps, 256)
|
| 110 |
+
vec = self.time_in(timestep_emb)
|
| 111 |
+
|
| 112 |
+
double_block_mod_img = self.double_stream_modulation_img(vec)
|
| 113 |
+
double_block_mod_txt = self.double_stream_modulation_txt(vec)
|
| 114 |
+
single_block_mod, _ = self.single_stream_modulation(vec)
|
| 115 |
+
|
| 116 |
+
img = self.img_in(x)
|
| 117 |
+
txt = self.txt_in(ctx)
|
| 118 |
+
|
| 119 |
+
if pe_x is None:
|
| 120 |
+
pe_x = self.pe_embedder(x_ids)
|
| 121 |
+
if pe_ctx is None:
|
| 122 |
+
pe_ctx = self.pe_embedder(ctx_ids)
|
| 123 |
+
|
| 124 |
+
for block in self.double_blocks:
|
| 125 |
+
img, txt = block(
|
| 126 |
+
img,
|
| 127 |
+
txt,
|
| 128 |
+
pe_x,
|
| 129 |
+
pe_ctx,
|
| 130 |
+
double_block_mod_img,
|
| 131 |
+
double_block_mod_txt,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
img = torch.cat((txt, img), dim=1)
|
| 135 |
+
pe = torch.cat((pe_ctx, pe_x), dim=2)
|
| 136 |
+
|
| 137 |
+
for block in self.single_blocks:
|
| 138 |
+
img = block(
|
| 139 |
+
img,
|
| 140 |
+
pe,
|
| 141 |
+
single_block_mod,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
img = img[:, num_txt_tokens:, ...]
|
| 145 |
+
|
| 146 |
+
img = self.final_layer(img, vec)
|
| 147 |
+
return img
|
nisaba_relief/flux/sampling.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
from .model import Flux2
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def prc_img_batch(x: Tensor) -> tuple[Tensor, Tensor]:
|
| 11 |
+
b, _, h, w = x.shape
|
| 12 |
+
x_ids = torch.cartesian_prod(
|
| 13 |
+
torch.arange(1),
|
| 14 |
+
torch.arange(h),
|
| 15 |
+
torch.arange(w),
|
| 16 |
+
torch.arange(1),
|
| 17 |
+
)
|
| 18 |
+
x_ids = x_ids.unsqueeze(0).expand(b, -1, -1)
|
| 19 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
| 20 |
+
return x, x_ids.to(x.device)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def generalized_time_snr_shift(t: Tensor, mu: float, sigma: float) -> Tensor:
|
| 24 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
| 28 |
+
a1, b1 = 8.73809524e-05, 1.89833333
|
| 29 |
+
a2, b2 = 0.00016927, 0.45666666
|
| 30 |
+
|
| 31 |
+
if image_seq_len > 4300:
|
| 32 |
+
mu = a2 * image_seq_len + b2
|
| 33 |
+
return float(mu)
|
| 34 |
+
|
| 35 |
+
m_200 = a2 * image_seq_len + b2
|
| 36 |
+
m_10 = a1 * image_seq_len + b1
|
| 37 |
+
|
| 38 |
+
a = (m_200 - m_10) / 190.0
|
| 39 |
+
b = m_200 - 200.0 * a
|
| 40 |
+
mu = a * num_steps + b
|
| 41 |
+
|
| 42 |
+
return float(mu)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
|
| 46 |
+
mu = compute_empirical_mu(image_seq_len, num_steps)
|
| 47 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
| 48 |
+
timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
|
| 49 |
+
return timesteps.tolist()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def denoise(
|
| 53 |
+
model: Flux2,
|
| 54 |
+
img: Tensor,
|
| 55 |
+
img_ids: Tensor,
|
| 56 |
+
txt: Tensor,
|
| 57 |
+
txt_ids: Tensor,
|
| 58 |
+
timesteps: list[float],
|
| 59 |
+
img_cond_seq: Tensor | None = None,
|
| 60 |
+
img_cond_seq_ids: Tensor | None = None,
|
| 61 |
+
) -> Tensor:
|
| 62 |
+
if img_cond_seq is not None:
|
| 63 |
+
assert img_cond_seq_ids is not None, (
|
| 64 |
+
"You need to provide either both or neither of the sequence conditioning"
|
| 65 |
+
)
|
| 66 |
+
combined_ids = torch.cat((img_ids, img_cond_seq_ids), dim=1)
|
| 67 |
+
else:
|
| 68 |
+
combined_ids = img_ids
|
| 69 |
+
|
| 70 |
+
# Pre-compute positional embeddings once (constant across all timesteps)
|
| 71 |
+
pe_x = model.pe_embedder(combined_ids)
|
| 72 |
+
pe_ctx = model.pe_embedder(txt_ids)
|
| 73 |
+
|
| 74 |
+
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
| 75 |
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
| 76 |
+
img_input = img
|
| 77 |
+
if img_cond_seq is not None:
|
| 78 |
+
img_input = torch.cat((img_input, img_cond_seq), dim=1)
|
| 79 |
+
pred = model(
|
| 80 |
+
x=img_input,
|
| 81 |
+
x_ids=combined_ids,
|
| 82 |
+
timesteps=t_vec,
|
| 83 |
+
ctx=txt,
|
| 84 |
+
ctx_ids=txt_ids,
|
| 85 |
+
pe_x=pe_x,
|
| 86 |
+
pe_ctx=pe_ctx,
|
| 87 |
+
)
|
| 88 |
+
pred = pred[:, : img.shape[1]]
|
| 89 |
+
|
| 90 |
+
img = img + (t_prev - t_curr) * pred
|
| 91 |
+
|
| 92 |
+
return img
|
nisaba_relief/image_utils.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pure image and tensor helper functions for NisabaRelief."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
|
| 10 |
+
from .constants import (
|
| 11 |
+
MAX_TILE,
|
| 12 |
+
MIN_TILE,
|
| 13 |
+
PATCH_SIZE,
|
| 14 |
+
TARGET_TILES_PER_SIDE,
|
| 15 |
+
TILE_OVERLAP_DIVISOR,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
_to_tensor = transforms.ToTensor()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def round_to_patch(value: float) -> int:
|
| 22 |
+
"""Round a pixel value to the nearest multiple of PATCH_SIZE (minimum PATCH_SIZE)."""
|
| 23 |
+
return max(PATCH_SIZE, PATCH_SIZE * round(value / PATCH_SIZE))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def ceil_to_patch(value: float) -> int:
|
| 27 |
+
"""Ceil a pixel value to the next multiple of PATCH_SIZE (minimum PATCH_SIZE)."""
|
| 28 |
+
return max(PATCH_SIZE, PATCH_SIZE * math.ceil(value / PATCH_SIZE))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def compute_tile_size(max_side: int) -> int:
|
| 32 |
+
"""Compute the optimal square tile side length for a given image maximum side."""
|
| 33 |
+
raw = ceil_to_patch(
|
| 34 |
+
max_side
|
| 35 |
+
* TILE_OVERLAP_DIVISOR
|
| 36 |
+
/ (TARGET_TILES_PER_SIDE * (TILE_OVERLAP_DIVISOR - 1) + 1)
|
| 37 |
+
)
|
| 38 |
+
return max(min(raw, MAX_TILE), MIN_TILE)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def compute_tile_grid(
|
| 42 |
+
orig_w: int, orig_h: int, tile_size: int
|
| 43 |
+
) -> tuple[int, int, int, int, int, int, int, int]:
|
| 44 |
+
"""Compute tiled grid layout for an image.
|
| 45 |
+
|
| 46 |
+
Returns (n_cols, n_rows, padded_w, padded_h, pad_left, pad_top, overlap, stride).
|
| 47 |
+
"""
|
| 48 |
+
overlap = tile_size // TILE_OVERLAP_DIVISOR
|
| 49 |
+
stride = tile_size - overlap
|
| 50 |
+
n_cols = max(1, math.ceil((orig_w - overlap) / stride))
|
| 51 |
+
n_rows = max(1, math.ceil((orig_h - overlap) / stride))
|
| 52 |
+
padded_w = tile_size + (n_cols - 1) * stride
|
| 53 |
+
padded_h = tile_size + (n_rows - 1) * stride
|
| 54 |
+
pad_left = (padded_w - orig_w) // 2
|
| 55 |
+
pad_top = (padded_h - orig_h) // 2
|
| 56 |
+
return n_cols, n_rows, padded_w, padded_h, pad_left, pad_top, overlap, stride
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def image_to_tensor(image: Image.Image, device: str) -> torch.Tensor:
|
| 60 |
+
"""Convert a PIL image to a normalised [-1, 1] float tensor on device."""
|
| 61 |
+
return (2 * _to_tensor(image) - 1).to(device)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def tensor_to_image(tensor: torch.Tensor) -> Image.Image:
|
| 65 |
+
"""Convert a normalised [-1, 1] CHW tensor to a PIL RGB image."""
|
| 66 |
+
img = (tensor.clamp(-1, 1) + 1) / 2
|
| 67 |
+
img = img.permute(1, 2, 0).float().cpu().numpy()
|
| 68 |
+
return Image.fromarray((img * 255).astype("uint8"))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def pad_to_patch_multiple(image: Image.Image) -> Image.Image:
|
| 72 |
+
"""Pad image width and height up to the next multiple of PATCH_SIZE."""
|
| 73 |
+
w, h = image.size
|
| 74 |
+
pad_w = (PATCH_SIZE - w % PATCH_SIZE) % PATCH_SIZE
|
| 75 |
+
pad_h = (PATCH_SIZE - h % PATCH_SIZE) % PATCH_SIZE
|
| 76 |
+
if pad_w == 0 and pad_h == 0:
|
| 77 |
+
return image
|
| 78 |
+
padded = Image.new("RGB", (w + pad_w, h + pad_h), (0, 0, 0))
|
| 79 |
+
padded.paste(image, (0, 0))
|
| 80 |
+
return padded
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def postprocess(image: Image.Image, shadow_strength: float = 0.7) -> Image.Image:
|
| 84 |
+
"""Apply adaptive gamma correction and convert to grayscale."""
|
| 85 |
+
arr = np.array(image, dtype=np.float32) / 255.0
|
| 86 |
+
gamma = 1.0 + shadow_strength * (1.0 - arr)
|
| 87 |
+
arr = np.power(arr, gamma)
|
| 88 |
+
return Image.fromarray((arr * 255).clip(0, 255).astype(np.uint8)).convert("L")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def draw_tile_indicator(
|
| 92 |
+
tensor: torch.Tensor,
|
| 93 |
+
full_w: int,
|
| 94 |
+
full_h: int,
|
| 95 |
+
tile_x: int,
|
| 96 |
+
tile_y: int,
|
| 97 |
+
tile_w: int,
|
| 98 |
+
tile_h: int,
|
| 99 |
+
line_width: int = 1,
|
| 100 |
+
) -> torch.Tensor:
|
| 101 |
+
"""Draw a red rectangle on a CHW tensor to mark the current tile position."""
|
| 102 |
+
C, H, W = tensor.shape
|
| 103 |
+
result = tensor.clone()
|
| 104 |
+
|
| 105 |
+
scale_x = W / full_w
|
| 106 |
+
scale_y = H / full_h
|
| 107 |
+
|
| 108 |
+
x1 = max(0, min(int(tile_x * scale_x), W - 1))
|
| 109 |
+
y1 = max(0, min(int(tile_y * scale_y), H - 1))
|
| 110 |
+
x2 = max(0, min(int((tile_x + tile_w) * scale_x), W))
|
| 111 |
+
y2 = max(0, min(int((tile_y + tile_h) * scale_y), H))
|
| 112 |
+
|
| 113 |
+
red = torch.tensor([1.0, -1.0, -1.0], device=tensor.device, dtype=tensor.dtype)
|
| 114 |
+
|
| 115 |
+
for dy in range(line_width):
|
| 116 |
+
if y1 + dy < H:
|
| 117 |
+
result[:, y1 + dy, x1:x2] = red.view(3, 1)
|
| 118 |
+
if 0 <= y2 - 1 - dy < H:
|
| 119 |
+
result[:, y2 - 1 - dy, x1:x2] = red.view(3, 1)
|
| 120 |
+
|
| 121 |
+
for dx in range(line_width):
|
| 122 |
+
if x1 + dx < W:
|
| 123 |
+
result[:, y1:y2, x1 + dx] = red.view(3, 1)
|
| 124 |
+
if 0 <= x2 - 1 - dx < W:
|
| 125 |
+
result[:, y1:y2, x2 - 1 - dx] = red.view(3, 1)
|
| 126 |
+
|
| 127 |
+
return result
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def create_blend_weights(
|
| 131 |
+
tile_size: int,
|
| 132 |
+
overlap: int,
|
| 133 |
+
is_top: bool = False,
|
| 134 |
+
is_bottom: bool = False,
|
| 135 |
+
is_left: bool = False,
|
| 136 |
+
is_right: bool = False,
|
| 137 |
+
device: str = "cpu",
|
| 138 |
+
) -> torch.Tensor:
|
| 139 |
+
"""Create cosine blend weights for a tile, ramping down at non-edge overlaps."""
|
| 140 |
+
weights = torch.ones(tile_size, tile_size, device=device)
|
| 141 |
+
|
| 142 |
+
if overlap > 0:
|
| 143 |
+
ramp = 0.5 * (1 - torch.cos(torch.linspace(0, torch.pi, overlap, device=device)))
|
| 144 |
+
if not is_top:
|
| 145 |
+
weights[:overlap, :] *= ramp.view(-1, 1)
|
| 146 |
+
if not is_bottom:
|
| 147 |
+
weights[-overlap:, :] *= ramp.flip(0).view(-1, 1)
|
| 148 |
+
if not is_left:
|
| 149 |
+
weights[:, :overlap] *= ramp.view(1, -1)
|
| 150 |
+
if not is_right:
|
| 151 |
+
weights[:, -overlap:] *= ramp.flip(0).view(1, -1)
|
| 152 |
+
|
| 153 |
+
return weights
|
nisaba_relief/model.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NisabaRelief inference model.
|
| 3 |
+
Transforms cuneiform tablet images into MSII visualizations.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import contextlib
|
| 7 |
+
import logging
|
| 8 |
+
from os import PathLike
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import gc
|
| 12 |
+
import torch
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from tqdm.auto import tqdm
|
| 16 |
+
from safetensors.torch import load_file
|
| 17 |
+
|
| 18 |
+
from .constants import (
|
| 19 |
+
COND_SEQ_ID,
|
| 20 |
+
DECODE_BATCH_SIZE_DIVISOR,
|
| 21 |
+
GLOBAL_CTX_ID,
|
| 22 |
+
LATENT_CHANNELS,
|
| 23 |
+
MAX_ASPECT_RATIO,
|
| 24 |
+
MAX_GLOBAL_CONTEXT_SIZE,
|
| 25 |
+
MIN_IMAGE_DIMENSION,
|
| 26 |
+
VRAM_FIXED_OVERHEAD_MB,
|
| 27 |
+
VRAM_HEADROOM_MB,
|
| 28 |
+
VRAM_MB_PER_PIXEL,
|
| 29 |
+
MIN_BATCH_SIZE,
|
| 30 |
+
MAX_BATCH_SIZE,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
from .image_utils import (
|
| 34 |
+
_to_tensor,
|
| 35 |
+
compute_tile_grid,
|
| 36 |
+
compute_tile_size,
|
| 37 |
+
create_blend_weights,
|
| 38 |
+
draw_tile_indicator,
|
| 39 |
+
image_to_tensor,
|
| 40 |
+
pad_to_patch_multiple,
|
| 41 |
+
postprocess,
|
| 42 |
+
round_to_patch,
|
| 43 |
+
tensor_to_image,
|
| 44 |
+
)
|
| 45 |
+
from .weights import WEIGHT_FILES, download_weights
|
| 46 |
+
from .flux.autoencoder import AutoEncoder
|
| 47 |
+
from .flux.model import Flux2
|
| 48 |
+
from .flux.sampling import (
|
| 49 |
+
denoise,
|
| 50 |
+
get_schedule,
|
| 51 |
+
prc_img_batch,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
logger = logging.getLogger(__name__)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class NisabaRelief:
|
| 58 |
+
"""Transform cuneiform tablet images into MSII relief visualizations.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
device: Device to run inference on (default "cuda" if available).
|
| 62 |
+
num_steps: Number of denoising steps (default 2).
|
| 63 |
+
weights_dir: Optional local weights directory. If None, uses HuggingFace Hub (boatbomber/NisabaRelief).
|
| 64 |
+
batch_size: Batch size for processing tiles during inference.
|
| 65 |
+
None (default) = auto-select based on available GPU memory each call.
|
| 66 |
+
Set an explicit int to override.
|
| 67 |
+
seed: Optional random seed for reproducible noise generation (default None).
|
| 68 |
+
compile: Whether to use torch.compile for faster repeated inference (default True).
|
| 69 |
+
Requires Triton. Set to False if Triton is not installed or for one-off runs.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
| 75 |
+
num_steps: int = 2,
|
| 76 |
+
weights_dir: PathLike | None = None,
|
| 77 |
+
batch_size: int | None = None,
|
| 78 |
+
seed: int | None = None,
|
| 79 |
+
compile: bool = True,
|
| 80 |
+
):
|
| 81 |
+
if batch_size is not None and batch_size < 1:
|
| 82 |
+
raise ValueError(f"batch_size must be >= 1 or None, got {batch_size}")
|
| 83 |
+
|
| 84 |
+
self.num_steps = num_steps
|
| 85 |
+
self.device = device
|
| 86 |
+
self.batch_size = batch_size
|
| 87 |
+
self.seed = seed
|
| 88 |
+
|
| 89 |
+
if weights_dir is not None:
|
| 90 |
+
weights_dir = Path(weights_dir)
|
| 91 |
+
if not weights_dir.is_dir():
|
| 92 |
+
raise FileNotFoundError(f"weights_dir does not exist: {weights_dir}")
|
| 93 |
+
|
| 94 |
+
missing = [f for f in WEIGHT_FILES if not (weights_dir / f).exists()]
|
| 95 |
+
if missing:
|
| 96 |
+
raise FileNotFoundError(
|
| 97 |
+
f"Missing weight files in {weights_dir}: {missing}"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
weight_paths = {f: str(weights_dir / f) for f in WEIGHT_FILES}
|
| 101 |
+
else:
|
| 102 |
+
logger.info("Downloading weights from HuggingFace Hub...")
|
| 103 |
+
weight_paths = download_weights()
|
| 104 |
+
|
| 105 |
+
# Load AutoEncoder
|
| 106 |
+
logger.debug("Loading AutoEncoder...")
|
| 107 |
+
with torch.device("meta"):
|
| 108 |
+
self.ae = AutoEncoder()
|
| 109 |
+
ae_weights = load_file(weight_paths["ae.safetensors"], device=device)
|
| 110 |
+
self.ae.load_state_dict(ae_weights, strict=True, assign=True)
|
| 111 |
+
self.ae.decoder = self.ae.decoder.to(self.dtype)
|
| 112 |
+
self.ae.eval()
|
| 113 |
+
|
| 114 |
+
# Load finetuned FLUX.2 model (merged weights)
|
| 115 |
+
logger.debug("Loading Transformer...")
|
| 116 |
+
with torch.device("meta"):
|
| 117 |
+
self.model = Flux2().to(self.dtype)
|
| 118 |
+
model_weights = load_file(weight_paths["model.safetensors"], device=device)
|
| 119 |
+
self.model.load_state_dict(model_weights, strict=True, assign=True)
|
| 120 |
+
self.model = self.model.to(device=device, dtype=self.dtype).eval()
|
| 121 |
+
|
| 122 |
+
# Load pre-computed text embedding
|
| 123 |
+
logger.debug("Loading text embedding...")
|
| 124 |
+
text_data = load_file(weight_paths["prompt_embedding.safetensors"], device=device)
|
| 125 |
+
self.prompt_embedding = text_data["prompt_embedding"].to(self.dtype)
|
| 126 |
+
self.ctx_ids = text_data["ctx_ids"]
|
| 127 |
+
|
| 128 |
+
if compile and self.device_type == "cuda":
|
| 129 |
+
try:
|
| 130 |
+
self.model = torch.compile(self.model)
|
| 131 |
+
self.ae = torch.compile(self.ae)
|
| 132 |
+
logger.debug(
|
| 133 |
+
"Model compile mode enabled. First run will be slow, but subsequent runs will be faster."
|
| 134 |
+
)
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error("Error compiling model: %s", e, exc_info=True)
|
| 137 |
+
logger.warning("Falling back to non-compiled model")
|
| 138 |
+
|
| 139 |
+
logger.info("NisabaRelief model loaded and ready")
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
def device_type(self) -> str:
|
| 143 |
+
return self.device.split(":")[0]
|
| 144 |
+
|
| 145 |
+
@property
|
| 146 |
+
def dtype(self) -> torch.dtype:
|
| 147 |
+
if self.device_type == "cuda":
|
| 148 |
+
return torch.bfloat16
|
| 149 |
+
return torch.float32
|
| 150 |
+
|
| 151 |
+
def _pick_batch_size(self, tile_size: int) -> int:
|
| 152 |
+
"""Estimate the largest safe batch size for a given tile size."""
|
| 153 |
+
if self.device_type != "cuda":
|
| 154 |
+
return MIN_BATCH_SIZE
|
| 155 |
+
|
| 156 |
+
gc.collect()
|
| 157 |
+
torch.cuda.empty_cache()
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
device_idx = torch.device(self.device).index or 0
|
| 161 |
+
free_vram_mb = (
|
| 162 |
+
torch.cuda.get_device_properties(device_idx).total_memory
|
| 163 |
+
- torch.cuda.memory_allocated(device_idx)
|
| 164 |
+
) / (1024**2)
|
| 165 |
+
available = free_vram_mb - VRAM_HEADROOM_MB
|
| 166 |
+
per_tile = VRAM_MB_PER_PIXEL * tile_size**2 + VRAM_FIXED_OVERHEAD_MB
|
| 167 |
+
batch = max(MIN_BATCH_SIZE, min(MAX_BATCH_SIZE, int(available / per_tile)))
|
| 168 |
+
logger.debug(
|
| 169 |
+
"Auto batch_size=%d (tile=%d, free=%.0f MB, per_tile=%.0f MB)",
|
| 170 |
+
batch,
|
| 171 |
+
tile_size,
|
| 172 |
+
free_vram_mb,
|
| 173 |
+
per_tile,
|
| 174 |
+
)
|
| 175 |
+
return batch
|
| 176 |
+
except Exception as e:
|
| 177 |
+
logger.error("Error picking batch size: %s", e, exc_info=True)
|
| 178 |
+
return MIN_BATCH_SIZE
|
| 179 |
+
|
| 180 |
+
def __repr__(self) -> str:
|
| 181 |
+
return (
|
| 182 |
+
f"NisabaRelief(device={self.device!r}, num_steps={self.num_steps}, "
|
| 183 |
+
f"batch_size={self.batch_size}, seed={self.seed!r})"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def process(
|
| 187 |
+
self,
|
| 188 |
+
image: PathLike | Image.Image,
|
| 189 |
+
show_pbar: bool | None = None,
|
| 190 |
+
) -> Image.Image:
|
| 191 |
+
"""Transform a cuneiform tablet image into MSII visualization.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
image: Input image (path or PIL Image).
|
| 195 |
+
show_pbar: Whether to show a progress bar during tiled inference.
|
| 196 |
+
If None (default), shows the bar only when there are at least 2 batches to run.
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
PIL Image (grayscale) with MSII visualization.
|
| 200 |
+
"""
|
| 201 |
+
if isinstance(image, (str, PathLike)):
|
| 202 |
+
image = Image.open(image)
|
| 203 |
+
if image.mode != "RGB":
|
| 204 |
+
image = image.convert("RGB")
|
| 205 |
+
|
| 206 |
+
w, h = image.size
|
| 207 |
+
max_side = max(w, h)
|
| 208 |
+
min_side = min(w, h)
|
| 209 |
+
if min_side < MIN_IMAGE_DIMENSION:
|
| 210 |
+
raise ValueError(
|
| 211 |
+
f"Image too small: {min_side}px minimum side (need >= {MIN_IMAGE_DIMENSION}px)"
|
| 212 |
+
)
|
| 213 |
+
if max_side / min_side > MAX_ASPECT_RATIO:
|
| 214 |
+
raise ValueError(
|
| 215 |
+
f"Aspect ratio too extreme: {max_side / min_side:.1f}:1 (max {MAX_ASPECT_RATIO:.0f}:1)"
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
tile_size = compute_tile_size(max_side)
|
| 219 |
+
output_image = self._process_tiled(
|
| 220 |
+
image, tile_size=tile_size, show_pbar=show_pbar
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
return postprocess(output_image)
|
| 224 |
+
|
| 225 |
+
def _prepare_global_context_tensor(
|
| 226 |
+
self,
|
| 227 |
+
image: Image.Image,
|
| 228 |
+
max_size: int = MAX_GLOBAL_CONTEXT_SIZE,
|
| 229 |
+
) -> torch.Tensor:
|
| 230 |
+
w, h = image.size
|
| 231 |
+
scale = min(max_size / w, max_size / h)
|
| 232 |
+
new_w = round_to_patch(w * scale)
|
| 233 |
+
new_h = round_to_patch(h * scale)
|
| 234 |
+
|
| 235 |
+
resized = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
| 236 |
+
return image_to_tensor(resized, self.device)
|
| 237 |
+
|
| 238 |
+
def _encode_global_context_batch(
|
| 239 |
+
self,
|
| 240 |
+
img_tensors: list[torch.Tensor],
|
| 241 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 242 |
+
batch = torch.stack(img_tensors)
|
| 243 |
+
with torch.inference_mode():
|
| 244 |
+
global_latent = self.ae.encode(batch)
|
| 245 |
+
global_tokens, global_ids = prc_img_batch(global_latent)
|
| 246 |
+
global_ids[..., 0] = GLOBAL_CTX_ID
|
| 247 |
+
return global_tokens.to(self.dtype), global_ids
|
| 248 |
+
|
| 249 |
+
def _process_tile_batch(
|
| 250 |
+
self,
|
| 251 |
+
tiles: list[Image.Image],
|
| 252 |
+
global_ctx_tokens: torch.Tensor,
|
| 253 |
+
global_ctx_ids: torch.Tensor,
|
| 254 |
+
tile_index_offset: int = 0,
|
| 255 |
+
) -> list[Image.Image]:
|
| 256 |
+
b = len(tiles)
|
| 257 |
+
original_sizes = [tile.size for tile in tiles]
|
| 258 |
+
|
| 259 |
+
padded_tiles = [pad_to_patch_multiple(tile) for tile in tiles]
|
| 260 |
+
img_tensors = torch.stack(
|
| 261 |
+
[image_to_tensor(tile, self.device) for tile in padded_tiles]
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
with torch.inference_mode():
|
| 265 |
+
input_latent = self.ae.encode(img_tensors)
|
| 266 |
+
|
| 267 |
+
input_tokens, input_ids = prc_img_batch(input_latent)
|
| 268 |
+
input_ids_cond = input_ids.clone()
|
| 269 |
+
input_ids_cond[..., 0] = COND_SEQ_ID
|
| 270 |
+
|
| 271 |
+
cond_tokens = torch.cat([input_tokens, global_ctx_tokens], dim=1)
|
| 272 |
+
cond_ids = torch.cat([input_ids_cond, global_ctx_ids], dim=1)
|
| 273 |
+
|
| 274 |
+
latent_h = input_latent.shape[2]
|
| 275 |
+
latent_w = input_latent.shape[3]
|
| 276 |
+
|
| 277 |
+
if self.seed is None:
|
| 278 |
+
noise = torch.randn(
|
| 279 |
+
b,
|
| 280 |
+
LATENT_CHANNELS,
|
| 281 |
+
latent_h,
|
| 282 |
+
latent_w,
|
| 283 |
+
device=self.device,
|
| 284 |
+
dtype=self.dtype,
|
| 285 |
+
)
|
| 286 |
+
else:
|
| 287 |
+
noise_list = []
|
| 288 |
+
for i in range(b):
|
| 289 |
+
tile_seed = self.seed ^ (tile_index_offset + i)
|
| 290 |
+
generator = torch.Generator(device=self.device).manual_seed(tile_seed)
|
| 291 |
+
noise_list.append(
|
| 292 |
+
torch.randn(
|
| 293 |
+
LATENT_CHANNELS,
|
| 294 |
+
latent_h,
|
| 295 |
+
latent_w,
|
| 296 |
+
device=self.device,
|
| 297 |
+
dtype=self.dtype,
|
| 298 |
+
generator=generator,
|
| 299 |
+
)
|
| 300 |
+
)
|
| 301 |
+
noise = torch.stack(noise_list)
|
| 302 |
+
|
| 303 |
+
noise_tokens, _ = prc_img_batch(noise)
|
| 304 |
+
noise_ids = input_ids
|
| 305 |
+
|
| 306 |
+
seq_len = noise_tokens.shape[1]
|
| 307 |
+
timesteps = get_schedule(self.num_steps, seq_len)
|
| 308 |
+
|
| 309 |
+
ctx = self.prompt_embedding.unsqueeze(0).expand(b, -1, -1)
|
| 310 |
+
ctx_ids = self.ctx_ids.unsqueeze(0).expand(b, -1, -1)
|
| 311 |
+
|
| 312 |
+
autocast_ctx = (
|
| 313 |
+
torch.autocast(device_type=self.device_type, dtype=self.dtype)
|
| 314 |
+
if self.device_type == "cuda"
|
| 315 |
+
else contextlib.nullcontext()
|
| 316 |
+
)
|
| 317 |
+
with autocast_ctx:
|
| 318 |
+
output_tokens = denoise(
|
| 319 |
+
model=self.model,
|
| 320 |
+
img=noise_tokens.to(self.dtype),
|
| 321 |
+
img_ids=noise_ids,
|
| 322 |
+
txt=ctx.to(self.dtype),
|
| 323 |
+
txt_ids=ctx_ids,
|
| 324 |
+
timesteps=timesteps,
|
| 325 |
+
img_cond_seq=cond_tokens.to(self.dtype),
|
| 326 |
+
img_cond_seq_ids=cond_ids,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
output_latent = rearrange(
|
| 330 |
+
output_tokens,
|
| 331 |
+
"b (h w) c -> b c h w",
|
| 332 |
+
h=latent_h,
|
| 333 |
+
w=latent_w,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# Free tensors from encode/denoise phases before AE decode to
|
| 337 |
+
# avoid CUDA memory fragmentation (the decoder needs large
|
| 338 |
+
# full-resolution float32 allocations that differ in shape from
|
| 339 |
+
# the transformer's cached blocks).
|
| 340 |
+
del img_tensors, input_latent, input_tokens, input_ids
|
| 341 |
+
del input_ids_cond, cond_tokens, cond_ids
|
| 342 |
+
del noise, noise_tokens, noise_ids
|
| 343 |
+
del output_tokens, ctx, ctx_ids
|
| 344 |
+
if self.device_type == "cuda":
|
| 345 |
+
torch.cuda.empty_cache()
|
| 346 |
+
|
| 347 |
+
# The AE decoder operates at full pixel resolution in float32,
|
| 348 |
+
# requiring much more VRAM per tile than the latent-space denoiser.
|
| 349 |
+
# Sub-batch to avoid overflowing into shared memory.
|
| 350 |
+
decode_bs = max(1, b // DECODE_BATCH_SIZE_DIVISOR)
|
| 351 |
+
if decode_bs >= b:
|
| 352 |
+
output_imgs = self.ae.decode(output_latent)
|
| 353 |
+
else:
|
| 354 |
+
chunks = []
|
| 355 |
+
for i in range(0, b, decode_bs):
|
| 356 |
+
chunks.append(self.ae.decode(output_latent[i : i + decode_bs]))
|
| 357 |
+
if self.device_type == "cuda":
|
| 358 |
+
torch.cuda.empty_cache()
|
| 359 |
+
output_imgs = torch.cat(chunks, dim=0)
|
| 360 |
+
|
| 361 |
+
results = []
|
| 362 |
+
for i, (orig_w, orig_h) in enumerate(original_sizes):
|
| 363 |
+
result = tensor_to_image(output_imgs[i])
|
| 364 |
+
if padded_tiles[i].size != (orig_w, orig_h):
|
| 365 |
+
result = result.crop((0, 0, orig_w, orig_h))
|
| 366 |
+
results.append(result)
|
| 367 |
+
|
| 368 |
+
return results
|
| 369 |
+
|
| 370 |
+
def _process_tiled(
|
| 371 |
+
self, image: Image.Image, tile_size: int, show_pbar: bool | None = None
|
| 372 |
+
) -> Image.Image:
|
| 373 |
+
orig_w, orig_h = image.size
|
| 374 |
+
n_cols, n_rows, w, h, pad_left, pad_top, overlap, stride = compute_tile_grid(
|
| 375 |
+
orig_w, orig_h, tile_size
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# Pad canvas so tiles land at exact stride positions with uniform overlap.
|
| 379 |
+
# Center the image so padding is distributed evenly on all sides.
|
| 380 |
+
padded = Image.new("RGB", (w, h), (0, 0, 0))
|
| 381 |
+
padded.paste(image, (pad_left, pad_top))
|
| 382 |
+
image = padded
|
| 383 |
+
|
| 384 |
+
global_base_tensor = self._prepare_global_context_tensor(image)
|
| 385 |
+
|
| 386 |
+
output = torch.zeros(3, h, w, device=self.device)
|
| 387 |
+
weights = torch.zeros(1, h, w, device=self.device)
|
| 388 |
+
|
| 389 |
+
tile_specs = [
|
| 390 |
+
(row, col, col * stride, row * stride)
|
| 391 |
+
for row in range(n_rows)
|
| 392 |
+
for col in range(n_cols)
|
| 393 |
+
]
|
| 394 |
+
|
| 395 |
+
blend_cache: dict[tuple[bool, bool, bool, bool], torch.Tensor] = {}
|
| 396 |
+
|
| 397 |
+
batch_size = (
|
| 398 |
+
self.batch_size
|
| 399 |
+
if self.batch_size is not None
|
| 400 |
+
else self._pick_batch_size(tile_size)
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
if show_pbar is None:
|
| 404 |
+
show_pbar = len(tile_specs) >= 2 * batch_size
|
| 405 |
+
|
| 406 |
+
pbar = tqdm(
|
| 407 |
+
total=len(tile_specs),
|
| 408 |
+
desc=f"Processing {orig_w}x{orig_h} px image with {n_cols}x{n_rows} tiles ({tile_size} px each, {overlap} px overlap)",
|
| 409 |
+
unit="tile",
|
| 410 |
+
leave=False,
|
| 411 |
+
disable=not show_pbar,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
for batch_start in range(0, len(tile_specs), batch_size):
|
| 415 |
+
batch_specs = tile_specs[batch_start : batch_start + batch_size]
|
| 416 |
+
|
| 417 |
+
pbar.clear()
|
| 418 |
+
logger.debug(
|
| 419 |
+
"Processing %d batched tiles: %s",
|
| 420 |
+
len(batch_specs),
|
| 421 |
+
" + ".join([f"({row},{col})" for row, col, _, _ in batch_specs]),
|
| 422 |
+
)
|
| 423 |
+
pbar.refresh()
|
| 424 |
+
|
| 425 |
+
ctx_tensors = [
|
| 426 |
+
draw_tile_indicator(global_base_tensor, w, h, x, y, tile_size, tile_size)
|
| 427 |
+
for (row, col, x, y) in batch_specs
|
| 428 |
+
]
|
| 429 |
+
global_tokens, global_ids = self._encode_global_context_batch(ctx_tensors)
|
| 430 |
+
|
| 431 |
+
tiles = [
|
| 432 |
+
image.crop((x, y, x + tile_size, y + tile_size))
|
| 433 |
+
for (row, col, x, y) in batch_specs
|
| 434 |
+
]
|
| 435 |
+
|
| 436 |
+
result_tiles = self._process_tile_batch(
|
| 437 |
+
tiles, global_tokens, global_ids, tile_index_offset=batch_start
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
for i, (row, col, x, y) in enumerate(batch_specs):
|
| 441 |
+
edge_key = (
|
| 442 |
+
row == 0,
|
| 443 |
+
row == n_rows - 1,
|
| 444 |
+
col == 0,
|
| 445 |
+
col == n_cols - 1,
|
| 446 |
+
)
|
| 447 |
+
if edge_key not in blend_cache:
|
| 448 |
+
blend_cache[edge_key] = create_blend_weights(
|
| 449 |
+
tile_size,
|
| 450 |
+
overlap,
|
| 451 |
+
is_top=edge_key[0],
|
| 452 |
+
is_bottom=edge_key[1],
|
| 453 |
+
is_left=edge_key[2],
|
| 454 |
+
is_right=edge_key[3],
|
| 455 |
+
device=self.device,
|
| 456 |
+
)
|
| 457 |
+
blend = blend_cache[edge_key]
|
| 458 |
+
result_tensor = _to_tensor(result_tiles[i]).to(self.device)
|
| 459 |
+
output[:, y : y + tile_size, x : x + tile_size] += result_tensor * blend
|
| 460 |
+
weights[:, y : y + tile_size, x : x + tile_size] += blend
|
| 461 |
+
|
| 462 |
+
if self.device_type == "cuda":
|
| 463 |
+
torch.cuda.empty_cache()
|
| 464 |
+
|
| 465 |
+
pbar.update(len(batch_specs))
|
| 466 |
+
|
| 467 |
+
pbar.close()
|
| 468 |
+
|
| 469 |
+
output = output / weights.clamp(min=1e-6)
|
| 470 |
+
output = output.permute(1, 2, 0).cpu().numpy()
|
| 471 |
+
output = (output * 255).clip(0, 255).astype("uint8")
|
| 472 |
+
return Image.fromarray(output).crop(
|
| 473 |
+
(pad_left, pad_top, pad_left + orig_w, pad_top + orig_h)
|
| 474 |
+
)
|
nisaba_relief/py.typed
ADDED
|
File without changes
|
nisaba_relief/weights.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HuggingFace Hub weight downloading for NisabaRelief."""
|
| 2 |
+
|
| 3 |
+
from huggingface_hub import hf_hub_download
|
| 4 |
+
|
| 5 |
+
HF_REPO_ID = "boatbomber/NisabaRelief"
|
| 6 |
+
WEIGHT_FILES = [
|
| 7 |
+
"ae.safetensors",
|
| 8 |
+
"model.safetensors",
|
| 9 |
+
"prompt_embedding.safetensors",
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def download_weights(repo_id: str = HF_REPO_ID) -> dict[str, str]:
|
| 14 |
+
"""Download all weight files from HF Hub, returning {filename: local_path}."""
|
| 15 |
+
paths = {}
|
| 16 |
+
for filename in WEIGHT_FILES:
|
| 17 |
+
try:
|
| 18 |
+
paths[filename] = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 19 |
+
except Exception as e:
|
| 20 |
+
raise RuntimeError(
|
| 21 |
+
f"Failed to download {filename} from {repo_id}: {e}"
|
| 22 |
+
) from e
|
| 23 |
+
return paths
|
prompt_embedding.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bc9b70751370039f6af10f5c803f9854354f7029f7d9521c6a4ee7c5ae28f999
|
| 3 |
+
size 7880872
|
pyproject.toml
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["hatchling"]
|
| 3 |
+
build-backend = "hatchling.build"
|
| 4 |
+
|
| 5 |
+
[tool.hatch.build.targets.sdist]
|
| 6 |
+
exclude = ["*.safetensors", "assets/**", "data/**", "dev_scripts/**", "uv.lock"]
|
| 7 |
+
|
| 8 |
+
[tool.hatch.build.targets.wheel]
|
| 9 |
+
packages = ["nisaba_relief"]
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "nisaba-relief"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "Transform cuneiform tablet photos into MSII relief visualizations"
|
| 15 |
+
readme = { file = "README.md", content-type = "text/markdown" }
|
| 16 |
+
license = "Apache-2.0"
|
| 17 |
+
requires-python = ">=3.10,<3.14"
|
| 18 |
+
authors = [{ name = "Zack Williams", email = "zack@boatbomber.com" }]
|
| 19 |
+
keywords = ["cuneiform", "msii", "relief", "ocr", "flux", "deep-learning"]
|
| 20 |
+
classifiers = [
|
| 21 |
+
"Development Status :: 4 - Beta",
|
| 22 |
+
"Intended Audience :: Science/Research",
|
| 23 |
+
"License :: OSI Approved :: Apache Software License",
|
| 24 |
+
"Programming Language :: Python :: 3",
|
| 25 |
+
"Programming Language :: Python :: 3.10",
|
| 26 |
+
"Programming Language :: Python :: 3.11",
|
| 27 |
+
"Programming Language :: Python :: 3.12",
|
| 28 |
+
"Programming Language :: Python :: 3.13",
|
| 29 |
+
"Topic :: Scientific/Engineering :: Image Processing",
|
| 30 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 31 |
+
]
|
| 32 |
+
dependencies = [
|
| 33 |
+
"einops>=0.8.2",
|
| 34 |
+
"safetensors",
|
| 35 |
+
"numpy",
|
| 36 |
+
"pillow",
|
| 37 |
+
"huggingface-hub",
|
| 38 |
+
"tqdm",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
[project.urls]
|
| 42 |
+
Homepage = "https://huggingface.co/boatbomber/NisabaRelief"
|
| 43 |
+
Repository = "https://huggingface.co/boatbomber/NisabaRelief"
|
| 44 |
+
Issues = "https://huggingface.co/boatbomber/NisabaRelief/discussions"
|
| 45 |
+
|
| 46 |
+
[dependency-groups]
|
| 47 |
+
dev = [
|
| 48 |
+
"ruff>=0.15.4",
|
| 49 |
+
"scikit-image>=0.25.2",
|
| 50 |
+
"scipy>=1.15.3",
|
| 51 |
+
"image-similarity-measures[speedups]>=0.3.5",
|
| 52 |
+
"pytorch-msssim>=1.0.0",
|
| 53 |
+
"rich>=14.3.3",
|
| 54 |
+
"datasets>=4.6.1",
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
[[tool.uv.index]]
|
| 58 |
+
name = "pytorch-cu128"
|
| 59 |
+
url = "https://download.pytorch.org/whl/cu128"
|
| 60 |
+
explicit = true
|
| 61 |
+
|
| 62 |
+
[tool.uv.sources]
|
| 63 |
+
torch = { index = "pytorch-cu128" }
|
| 64 |
+
torchvision = { index = "pytorch-cu128" }
|
| 65 |
+
triton = { index = "pytorch-cu128" }
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
[tool.ruff]
|
| 69 |
+
line-length = 90
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|