Julien Blanchon commited on
Commit
d62394f
·
0 Parent(s):

Deploy optimized Image-GS with dynamic dependencies

Browse files

- Pre-built gsplat wheel stored in blanchon/image-gs-models-utils
- Models automatically downloaded from HF models repository
- Dynamic installation of gsplat wheel at runtime
- Optimized Docker build without CUDA compilation
- Clean repository without binary files

.dockerignore ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Git
2
+ .git
3
+ .gitignore
4
+
5
+ # Python
6
+ __pycache__
7
+ *.pyc
8
+ *.pyo
9
+ *.pyd
10
+ .Python
11
+ *.so
12
+ .pytest_cache
13
+ .coverage
14
+
15
+ # Virtual environments
16
+ .venv
17
+ .env
18
+ venv/
19
+ env/
20
+
21
+ # IDE
22
+ .vscode
23
+ .idea
24
+ *.swp
25
+ *.swo
26
+
27
+ # OS
28
+ .DS_Store
29
+ Thumbs.db
30
+
31
+ # Project specific
32
+ results/
33
+ temp_*
34
+ *.log
35
+
36
+ # Documentation
37
+ docs/
38
+ *.md
39
+ !README.md
40
+
41
+ # Assets (if large)
42
+ assets/images/
43
+ assets/fonts/
44
+
45
+ # GSplat documentation (not needed for runtime)
46
+ gsplat/src/gsplat/cuda/csrc/third_party/glm/doc/
Dockerfile ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use NVIDIA CUDA image that matches PyTorch's CUDA 12.4 compilation
2
+ FROM nvidia/cuda:12.4.0-devel-ubuntu22.04
3
+
4
+ # Install Python 3.10 and dependencies with cache mounts
5
+ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
6
+ --mount=type=cache,target=/var/lib/apt,sharing=locked \
7
+ apt-get update && apt-get install -y \
8
+ python3.10 \
9
+ python3.10-venv \
10
+ python3.10-dev \
11
+ python3-pip \
12
+ git \
13
+ build-essential \
14
+ curl \
15
+ ninja-build \
16
+ wget
17
+
18
+ # Create symlinks for python
19
+ RUN ln -sf /usr/bin/python3.10 /usr/bin/python3 && \
20
+ ln -sf /usr/bin/python3.10 /usr/bin/python
21
+
22
+ # Set CUDA environment variables for runtime
23
+ ENV CUDA_HOME=/usr/local/cuda \
24
+ PATH=/usr/local/cuda/bin:$PATH \
25
+ LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
26
+
27
+ # Install uv globally
28
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
29
+
30
+ # Set up user with ID 1000 (required for HF Spaces)
31
+ RUN useradd -m -u 1000 user
32
+
33
+ # Switch to user and set working directory
34
+ USER user
35
+ WORKDIR /home/user/app
36
+
37
+ # Set environment variables
38
+ ENV HOME=/home/user \
39
+ PATH=/home/user/.local/bin:$PATH \
40
+ PYTHONUNBUFFERED=1 \
41
+ GRADIO_SERVER_NAME=0.0.0.0 \
42
+ GRADIO_SERVER_PORT=7860 \
43
+ UV_CACHE_DIR=/home/user/.cache/uv
44
+
45
+ # Copy dependency files first for better caching
46
+ COPY --chown=user pyproject.toml uv.lock ./
47
+
48
+ # Copy the pre-built wheels directory
49
+ COPY --chown=user wheels/ ./wheels/
50
+
51
+ # Install dependencies with uv (using pre-built wheel - much faster!)
52
+ RUN --mount=type=cache,target=/tmp/uv-cache,sharing=locked,uid=1000,gid=1000 \
53
+ UV_CACHE_DIR=/tmp/uv-cache uv sync --frozen --no-dev
54
+
55
+ # Copy the rest of the application
56
+ COPY --chown=user . .
57
+
58
+ # Expose port 7860 (default for HF Spaces)
59
+ EXPOSE 7860
60
+
61
+ # Launch the Gradio app
62
+ CMD ["uv", "run", "python", "gradio_app.py"]
README.md ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Image GS
3
+ emoji: 💻
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: false
9
+ ---
10
+
11
+ <div align="center">
12
+
13
+ <h1>Image-GS: Content-Adaptive Image Representation via 2D Gaussians</h1>
14
+
15
+ [**Yunxiang Zhang**](https://yunxiangzhang.github.io/)<sup>1\*</sup>,
16
+ [**Bingxuan Li**](https://bingxuan-li.github.io/)<sup>1\*</sup>,
17
+ [**Alexandr Kuznetsov**](https://alexku.me/)<sup>3&dagger;</sup>,
18
+ [**Akshay Jindal**](https://www.akshayjindal.com/)<sup>2</sup>,
19
+ [**Stavros Diolatzis**](https://www.sdiolatz.info/)<sup>2</sup>,
20
+ [**Kenneth Chen**](https://kenchen10.github.io/)<sup>1</sup>,
21
+ [**Anton Sochenov**](https://www.intel.com/content/www/us/en/developer/articles/community/gpu-researchers-anton-sochenov.html)<sup>2</sup>,
22
+ [**Anton Kaplanyan**](http://kaplanyan.com/)<sup>2</sup>,
23
+ [**Qi Sun**](https://qisun.me/)<sup>1</sup>
24
+
25
+ \* Equal contribution &emsp; &dagger; Work done while at Intel
26
+
27
+ <sup>1</sup>
28
+ <a href="https://www.immersivecomputinglab.org/research/"><img width="30%" src="assets/images/NYU-logo.png" style="vertical-align: top;" alt="NYU logo"></a>
29
+ &emsp;
30
+ <sup>2</sup>
31
+ <a href="https://www.intel.com/content/www/us/en/developer/topic-technology/graphics-research/overview.html"><img width="22%" src="assets/images/Intel-logo.png" style="vertical-align: top;" alt="Intel logo"></a>
32
+ &emsp;
33
+ <sup>3</sup>
34
+ <a href="https://www.amd.com/en.html"><img width="33%" src="assets/images/AMD-logo.png" style="vertical-align: top;" alt="AMD logo"></a>
35
+
36
+ <a href="https://arxiv.org/abs/2407.01866"><img src="https://img.shields.io/badge/arXiv-2407.01866-red" alt="arXiv"></a>
37
+ <a href="https://www.immersivecomputinglab.org/publication/image-gs-content-adaptive-image-representation-via-2d-gaussians/"><img src="https://img.shields.io/badge/project page-ImageGS-blue" alt="project page"></a>
38
+ <a href="https://github.com/NYU-ICL/image-gs"><img src="https://visitor-badge.laobi.icu/badge?page_id=NYU-ICL.image-gs&left_color=green&right_color=red" alt="visitors"></a>
39
+
40
+ </div>
41
+
42
+ <div style="width: 90%; margin: 0 auto;">
43
+ Neural image representations have emerged as a promising approach for encoding and rendering visual data. Combined with learning-based workflows, they demonstrate impressive trade-offs between visual fidelity and memory footprint. Existing methods in this domain, however, often rely on fixed data structures that suboptimally allocate memory or compute-intensive implicit models, hindering their practicality for real-time graphics applications.
44
+
45
+ Inspired by recent advancements in radiance field rendering, we introduce Image-GS, a content-adaptive image representation based on 2D Gaussians. Leveraging a custom differentiable renderer, Image-GS reconstructs images by adaptively allocating and progressively optimizing a group of anisotropic, colored 2D Gaussians. It achieves a favorable balance between visual fidelity and memory efficiency across a variety of stylized images frequently seen in graphics workflows, especially for those showing non-uniformly distributed features and in low-bitrate regimes. Moreover, it supports hardware-friendly rapid random access for real-time usage, requiring only 0.3K MACs to decode a pixel. Through error-guided progressive optimization, Image-GS naturally constructs a smooth level-of-detail hierarchy. We demonstrate its versatility with several applications, including texture compression, semantics-aware compression, and joint image compression and restoration.
46
+
47
+ <img src="assets/images/teaser.jpg" width="100%" />
48
+ <sub>
49
+ Figure 1: Image-GS reconstructs an image by adaptively allocating and progressively optimizing a set of colored 2D Gaussians. It achieves favorable rate-distortion trade-offs, hardware-friendly random access, and flexible quality control through a smooth level-of-detail stack. (a) visualizes the optimized spatial distribution of Gaussians (20% randomly sampled for clarity). (b) Image-GS’s explicit content-adaptive design effectively captures non-uniformly distributed image features and better preserves fine details under constrained memory budgets. In the inset error maps, brighter colors indicate larger errors.
50
+ </sub>
51
+ </div>
52
+
53
+ ## Setup
54
+ 1. Create a dedicated Python environment and install the dependencies
55
+ ```bash
56
+ git clone https://github.com/NYU-ICL/image-gs.git
57
+ cd image-gs
58
+ conda env create -f environment.yml
59
+ conda activate image-gs
60
+ pip install git+https://github.com/rahul-goel/fused-ssim/ --no-build-isolation
61
+ cd gsplat
62
+ pip install -e ".[dev]"
63
+ cd ..
64
+ ```
65
+ 2. Download the image and texture datasets from [OneDrive](https://1drv.ms/u/c/3a8968df8a027819/EeshjZJlMtdCmvvmESiN2pABM71EDaoLYmEwuOvecg0tAA?e=GybqBv) and organize the folder structure as follows
66
+ ```
67
+ image-gs
68
+ └── media
69
+ ├── images
70
+ └── textures
71
+ ```
72
+ 3. (Optional) To run saliency-guided Gaussian position initialization, download the pre-trained [EML-Net](https://github.com/SenJia/EML-NET-Saliency) models ([res_imagenet.pth](https://drive.google.com/open?id=1-a494canr9qWKLdm-DUDMgbGwtlAJz71), [res_places.pth](https://drive.google.com/open?id=18nRz0JSRICLqnLQtAvq01azZAsH0SEzS), [res_decoder.pth](https://drive.google.com/open?id=1vwrkz3eX-AMtXQE08oivGMwS4lKB74sH)) and place them under the `models/emlnet/` folder
73
+ ```
74
+ image-gs
75
+ └── models
76
+ └── emlnet
77
+ ├── res_decoder.pth
78
+ ├── res_imagenet.pth
79
+ └── res_places.pth
80
+ ```
81
+
82
+ ## Quick Start
83
+
84
+ #### Image Compression
85
+ - Optimize an Image-GS representation for an input image `anime-1_2k.png` using `10000` Gaussians with half-precision parameters
86
+ ```bash
87
+ python main.py --input_path="images/anime-1_2k.png" --exp_name="test/anime-1_2k" --num_gaussians=10000 --quantize
88
+ ```
89
+ - Render the corresponding optimized Image-GS representation at a new resolution with height `4000` (aspect ratio is maintained)
90
+ ```bash
91
+ python main.py --input_path="images/anime-1_2k.png" --exp_name="test/anime-1_2k" --num_gaussians=10000 --quantize --eval --render_height=4000
92
+ ```
93
+
94
+ #### Texture Stack Compression
95
+ - Optimize an Image-GS representation for an input texture stack `alarm-clock_2k` using `30000` Gaussians with half-precision parameters
96
+ ```bash
97
+ python main.py --input_path="textures/alarm-clock_2k" --exp_name="test/alarm-clock_2k" --num_gaussians=30000 --quantize
98
+ ```
99
+ - Render the corresponding optimized Image-GS representation at a new resolution with height `3000` (aspect ratio is maintained)
100
+ ```bash
101
+ python main.py --input_path="textures/alarm-clock_2k" --exp_name="test/alarm-clock_2k" --num_gaussians=30000 --quantize --eval --render_height=3000
102
+ ```
103
+
104
+ #### Control bit precision of Gaussian parameters
105
+ - Optimize an Image-GS representation for an input image `anime-1_2k.png` using `10000` Gaussians with 12-bit-precision parameters
106
+ ```bash
107
+ python main.py --input_path="images/anime-1_2k.png" --exp_name="test/anime-1_2k" --num_gaussians=10000 --quantize --pos_bits=12 --scale_bits 12 --rot_bits 12 --feat_bits 12
108
+ ```
109
+
110
+ #### Switch to saliency-guided Gaussian position initialization
111
+ - Optimize an Image-GS representation for an input image `anime-1_2k.png` using `10000` Gaussians with half-precision parameters and saliency-guided initialization
112
+ ```bash
113
+ python main.py --input_path="images/anime-1_2k.png" --exp_name="test/anime-1_2k" --num_gaussians=10000 --quantize --init_mode="saliency"
114
+ ```
115
+
116
+ ## Gradio Web Interface
117
+
118
+ We provide a user-friendly web interface built with Gradio for easy experimentation and training visualization.
119
+
120
+ ### Setup for Web Interface
121
+
122
+ 1. Install Gradio (in addition to the main dependencies):
123
+ ```bash
124
+ pip install gradio>=4.0.0
125
+ ```
126
+
127
+ 2. Launch the web interface:
128
+ ```bash
129
+ python gradio_app.py
130
+ ```
131
+
132
+ 3. Open your browser and navigate to `http://localhost:7860`
133
+
134
+ ### Features
135
+
136
+ The Gradio interface provides:
137
+
138
+ - **Interactive Parameter Configuration**: Adjust all training parameters through an intuitive UI
139
+ - **Image Upload**: Drag and drop any image to train on
140
+ - **Real-time Training Progress**: Stream training logs and intermediate results
141
+ - **Live Visualization**: Watch Gaussian placement and rendering progress during training
142
+ - **Result Gallery**: View final renders, gradient maps, and saliency maps
143
+ - **Easy Experimentation**: No need to remember command-line arguments
144
+
145
+ ### Interface Sections
146
+
147
+ 1. **Configuration Panel**:
148
+ - Basic parameters (number of Gaussians, training steps)
149
+ - Quantization settings for memory efficiency
150
+ - Initialization modes (gradient, saliency, random)
151
+ - Advanced optimization parameters (learning rates, loss weights)
152
+
153
+ 2. **Training Progress**:
154
+ - Real-time streaming logs
155
+ - Current render and Gaussian visualization updates
156
+ - Training status and control buttons
157
+
158
+ 3. **Results Display**:
159
+ - Final optimized image
160
+ - Gradient and saliency maps used for initialization
161
+ - Download capabilities for all results
162
+
163
+ ### Usage Tips
164
+
165
+ - Start with default parameters for your first run
166
+ - Use **saliency initialization** for better results on complex images
167
+ - Enable **Gaussian visualization** to see how the representation evolves
168
+ - Adjust **save image steps** to control visualization frequency (lower = more updates, but slower)
169
+ - For quick tests, reduce **max steps** to 500-1000
170
+
171
+ ### Command Line Arguments
172
+ Please refer to `cfgs/default.yaml` for the full list of arguments and their default values.
173
+
174
+ **Post-optimization rendering**
175
+ - `--eval` render the optimized Image-GS representation.
176
+ - `--render_height` image height for rendering (aspect ratio is maintained).
177
+
178
+ **Bit precision control**: 32 bits (float32) per dimension by default
179
+ - `--quantize` enable bit precision control of Gaussian parameters.
180
+ - `--pos_bits` bit precision of individual coordinate dimension.
181
+ - `--scale_bits` bit precision of individual scale dimension.
182
+ - `--rot_bits` bit precision of Gaussian orientation angle.
183
+ - `--feat_bits` bit precision of individual feature dimension.
184
+
185
+ **Logging**
186
+ - `--exp_name` path to the logging directory.
187
+ - `--vis_gaussians`: visualize Gaussians during optimization.
188
+ - `--save_image_steps` frequency of rendering intermediate results during optimization.
189
+ - `--save_ckpt_steps` frequency of checkpointing during optimization.
190
+
191
+ **Input image**
192
+ - `--input_path` path to an image file or a directory containing a texture stack.
193
+ - `--downsample` load a downsampled version of the input image or texture stack as the optimization target to evaluate image upsampling performance.
194
+ - `--downsample_ratio` downsampling ratio.
195
+ - `--gamma` optimize in a gamma-corrected space, modify with caution.
196
+
197
+ **Gaussian**
198
+ - `--num_gaussians` number of Gaussians (for compression rate control).
199
+ - `--init_scale` initial Gaussian scale in number of pixels.
200
+ - `--disable_topk_norm` disable top-K normalization.
201
+ - `--disable_inverse_scale` disable inverse Gaussian scale optimization.
202
+ - `--init_mode` Gaussian position initialization mode, valid values include "gradient", "saliency", and "random".
203
+ - `--init_random_ratio` ratio of Gaussians with randomly initialized position.
204
+
205
+ **Optimization**
206
+ - `--disable_tiles` disable tile-based rendering (warning: optimization and rendering without tiles will be way slower).
207
+ - `--max_steps` maximum number of optimization steps.
208
+ - `--pos_lr` Gaussian position learning rate.
209
+ - `--scale_lr` Gaussian scale learning rate.
210
+ - `--rot_lr` Gaussian orientation angle learning rate.
211
+ - `--feat_lr` Gaussian feature learning rate.
212
+ - `--disable_lr_schedule` disable learning rate decay and early stopping schedule.
213
+ - `--disable_prog_optim` disable error-guided progressive optimization.
214
+
215
+ ## Acknowledgements
216
+ We would like to thank the [gsplat](https://github.com/nerfstudio-project/gsplat) team, and the authors of [3DGS](https://github.com/graphdeco-inria/gaussian-splatting), [fused-ssim](https://github.com/rahul-goel/fused-ssim), and [EML-Net](https://github.com/SenJia/EML-NET-Saliency) for their great work, based on which Image-GS was developed.
217
+
218
+ ## License
219
+ This project is licensed under the terms of the MIT license.
220
+
221
+ ## Citation
222
+ If you find this project helpful to your research, please consider citing [BibTeX](assets/docs/image-gs.bib):
223
+ ```bibtex
224
+ @inproceedings{zhang2025image,
225
+ title={Image-gs: Content-adaptive image representation via 2d gaussians},
226
+ author={Zhang, Yunxiang and Li, Bingxuan and Kuznetsov, Alexandr and Jindal, Akshay and Diolatzis, Stavros and Chen, Kenneth and Sochenov, Anton and Kaplanyan, Anton and Sun, Qi},
227
+ booktitle={Proceedings of the Special Interest Group on Computer Graphics and Interactive Techniques Conference Conference Papers},
228
+ pages={1--11},
229
+ year={2025}
230
+ }
231
+ ```
cfgs/default.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 123
2
+ device: "cuda:0"
3
+ # Evaluation
4
+ eval: False # Render the optimized Image-GS representation
5
+ render_height: 2048 # Image height for rendering (aspect ratio is maintained)
6
+ # Bit precision
7
+ quantize: False # Enable bit precision control of Gaussian parameters
8
+ pos_bits: 16 # Bit precision of individual coordinate dimension
9
+ scale_bits: 16 # Bit precision of individual scale dimension
10
+ rot_bits: 16 # Bit precision of Gaussian orientation angle
11
+ feat_bits: 16 # Bit precision of individual feature dimension
12
+ # Logging
13
+ log_root: "results"
14
+ exp_name: "test/anime-1_2k" # Path to the logging directory
15
+ log_level: "INFO"
16
+ vis_gaussians: False # Visualize Gaussians during optimization
17
+ save_image_steps: 100000 # Frequency of rendering intermediate results during optimization
18
+ save_ckpt_steps: 100000 # Frequency of checkpointing during optimization
19
+ eval_steps: 100
20
+ # Target images
21
+ gamma: 1.0 # Optimize in a gamma-corrected space, modify with caution
22
+ data_root: "media"
23
+ input_path: "images/anime-1_2k.png" # Path to an image file or a directory containing a texture stack
24
+ downsample: False # Load a downsampled version of the input image or texture stack as the optimization target to evaluate image upsampling performance
25
+ downsample_ratio: 2.0
26
+ # Gaussians
27
+ num_gaussians: 10000 # Number of Gaussians (for compression rate control)
28
+ init_scale: 5.0 # Initial Gaussian scale in number of pixels
29
+ topk: 10 # Warning: Must match hardcoded value in CUDA kernel, modify with caution
30
+ disable_topk_norm: False # Disable top-K normalization
31
+ disable_inverse_scale: False # Disable inverse Gaussian scale optimization
32
+ ckpt_file: ""
33
+ disable_color_init: False
34
+ init_mode: "gradient" # Gaussian position initialization mode, valid values include "gradient", "saliency", and "random"
35
+ init_random_ratio: 0.3 # Ratio of Gaussians with randomly initialized position
36
+ smap_filter_size: 20 # Gaussian filter size for smoothing saliency maps
37
+ # Loss functions
38
+ l1_loss_ratio: 1.0
39
+ l2_loss_ratio: 0.0
40
+ ssim_loss_ratio: 0.1
41
+ # Optimization
42
+ disable_tiles: False # Disable tile-based rendering (warning: optimization and rendering without tiles will be way slower)
43
+ max_steps: 10000 # Maximum number of optimization steps
44
+ pos_lr: 5.0e-4
45
+ scale_lr: 2.0e-3
46
+ rot_lr: 2.0e-3
47
+ feat_lr: 5.0e-3
48
+ disable_lr_schedule: False # Disable learning rate schedule and early stopping
49
+ decay_ratio: 10.0
50
+ check_decay_steps: 1000
51
+ max_decay_times: 1
52
+ decay_threshold: 1.0e-3
53
+ disable_prog_optim: False # Disable error-guided progressive optimization
54
+ initial_ratio: 0.5
55
+ add_steps: 500
56
+ add_times: 4
57
+ post_min_steps: 3000
gradio_app.py ADDED
@@ -0,0 +1,809 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import threading
5
+ import argparse
6
+ import tempfile
7
+ import shutil
8
+ from typing import Generator, Optional, Tuple
9
+ import logging
10
+
11
+ try:
12
+ import gradio as gr
13
+ except ImportError:
14
+ print("❌ Gradio not found. Please install it with: pip install gradio>=4.0.0")
15
+ sys.exit(1)
16
+
17
+ try:
18
+ from huggingface_hub import hf_hub_download, snapshot_download
19
+ except ImportError:
20
+ print(
21
+ "❌ huggingface_hub not found. Please install it with: pip install huggingface_hub"
22
+ )
23
+ sys.exit(1)
24
+
25
+ import torch
26
+ from PIL import Image
27
+
28
+ # Add the project root to the path so we can import the modules
29
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
30
+
31
+ from gradio_models import GradioGaussianSplatting2D, StreamingResults
32
+ from utils.misc_utils import load_cfg
33
+ from main import get_log_dir
34
+
35
+
36
+ class TrainingState:
37
+ """Manages the state of training sessions"""
38
+
39
+ def __init__(self):
40
+ self.is_training = False
41
+ self.training_thread = None
42
+ self.model = None
43
+ self.temp_dir = None
44
+ self.results = StreamingResults()
45
+
46
+ def reset(self):
47
+ self.is_training = False
48
+ if self.temp_dir and os.path.exists(self.temp_dir):
49
+ shutil.rmtree(self.temp_dir)
50
+ self.temp_dir = None
51
+ self.results = StreamingResults()
52
+
53
+
54
+ # Global training state
55
+ training_state = TrainingState()
56
+
57
+
58
+ def ensure_models_available():
59
+ """Download models from HuggingFace if they're not available locally"""
60
+ models_dir = "models"
61
+
62
+ # Check if models directory exists and has the required files
63
+ required_files = [
64
+ "models/emlnet/res_decoder.pth",
65
+ "models/emlnet/res_imagenet.pth",
66
+ "models/emlnet/res_places.pth",
67
+ "models/torch/checkpoints/alexnet-owt-7be5be79.pth",
68
+ ]
69
+
70
+ # Check if all required files exist
71
+ all_files_exist = all(os.path.exists(file_path) for file_path in required_files)
72
+
73
+ if not all_files_exist:
74
+ print("📥 Downloading model files from HuggingFace...")
75
+ try:
76
+ # Create models directory if it doesn't exist
77
+ os.makedirs("models", exist_ok=True)
78
+
79
+ # Download individual model files to ensure they end up in the right place
80
+ model_files_remote = [
81
+ "emlnet/res_decoder.pth",
82
+ "emlnet/res_imagenet.pth",
83
+ "emlnet/res_places.pth",
84
+ "torch/checkpoints/alexnet-owt-7be5be79.pth",
85
+ ]
86
+
87
+ model_files_local = [
88
+ "models/emlnet/res_decoder.pth",
89
+ "models/emlnet/res_imagenet.pth",
90
+ "models/emlnet/res_places.pth",
91
+ "models/torch/checkpoints/alexnet-owt-7be5be79.pth",
92
+ ]
93
+
94
+ for remote_file, local_file in zip(model_files_remote, model_files_local):
95
+ if not os.path.exists(local_file):
96
+ # Create directory structure
97
+ os.makedirs(os.path.dirname(local_file), exist_ok=True)
98
+
99
+ # Download the specific file
100
+ print(f"📥 Downloading {remote_file} -> {local_file}...")
101
+ downloaded_path = hf_hub_download(
102
+ repo_id="blanchon/image-gs-models-utils",
103
+ filename=remote_file,
104
+ repo_type="model",
105
+ )
106
+
107
+ # Copy to the expected local path
108
+ shutil.copy2(downloaded_path, local_file)
109
+
110
+ print("✅ Model files downloaded successfully!")
111
+ except Exception as e:
112
+ print(f"❌ Failed to download model files: {e}")
113
+ print("⚠️ The app may not work properly without these model files.")
114
+ else:
115
+ print("✅ Model files are already available locally.")
116
+
117
+
118
+ def create_args_from_config(
119
+ image_path: str,
120
+ exp_name: str,
121
+ num_gaussians: int,
122
+ quantize: bool,
123
+ pos_bits: int,
124
+ scale_bits: int,
125
+ rot_bits: int,
126
+ feat_bits: int,
127
+ init_mode: str,
128
+ init_random_ratio: float,
129
+ max_steps: int,
130
+ vis_gaussians: bool,
131
+ save_image_steps: int,
132
+ l1_loss_ratio: float,
133
+ l2_loss_ratio: float,
134
+ ssim_loss_ratio: float,
135
+ pos_lr: float,
136
+ scale_lr: float,
137
+ rot_lr: float,
138
+ feat_lr: float,
139
+ disable_lr_schedule: bool,
140
+ disable_prog_optim: bool,
141
+ ) -> argparse.Namespace:
142
+ """Create arguments object from Gradio inputs"""
143
+
144
+ # Load default config
145
+ parser = argparse.ArgumentParser()
146
+ parser = load_cfg(cfg_path="cfgs/default.yaml", parser=parser)
147
+ args = parser.parse_args([]) # Parse empty args to get defaults
148
+
149
+ # Override with user inputs
150
+ args.input_path = image_path
151
+ args.exp_name = exp_name
152
+ args.num_gaussians = num_gaussians
153
+ args.quantize = quantize
154
+ args.pos_bits = pos_bits
155
+ args.scale_bits = scale_bits
156
+ args.rot_bits = rot_bits
157
+ args.feat_bits = feat_bits
158
+ args.init_mode = init_mode
159
+ args.init_random_ratio = init_random_ratio
160
+ args.max_steps = max_steps
161
+ args.vis_gaussians = vis_gaussians
162
+ args.save_image_steps = save_image_steps
163
+ args.l1_loss_ratio = l1_loss_ratio
164
+ args.l2_loss_ratio = l2_loss_ratio
165
+ args.ssim_loss_ratio = ssim_loss_ratio
166
+ args.pos_lr = pos_lr
167
+ args.scale_lr = scale_lr
168
+ args.rot_lr = rot_lr
169
+ args.feat_lr = feat_lr
170
+ args.disable_lr_schedule = disable_lr_schedule
171
+ args.disable_prog_optim = disable_prog_optim
172
+ args.eval = False
173
+
174
+ # Set up logging directory
175
+ args.log_dir = get_log_dir(args)
176
+
177
+ return args
178
+
179
+
180
+ def train_model(args: argparse.Namespace) -> None:
181
+ """Training function that runs in a separate thread"""
182
+ try:
183
+ # Create and train model with streaming results
184
+ training_state.model = GradioGaussianSplatting2D(args, training_state.results)
185
+
186
+ # Start training
187
+ training_state.model.optimize()
188
+
189
+ except Exception as e:
190
+ training_state.results.training_logs.append(f"ERROR: {str(e)}")
191
+ logging.error(f"Training failed: {str(e)}")
192
+ finally:
193
+ training_state.is_training = False
194
+
195
+
196
+ def start_training_and_stream(
197
+ image_file,
198
+ exp_name: str,
199
+ num_gaussians: int,
200
+ quantize: bool,
201
+ pos_bits: int,
202
+ scale_bits: int,
203
+ rot_bits: int,
204
+ feat_bits: int,
205
+ init_mode: str,
206
+ init_random_ratio: float,
207
+ max_steps: int,
208
+ vis_gaussians: bool,
209
+ save_image_steps: int,
210
+ l1_loss_ratio: float,
211
+ l2_loss_ratio: float,
212
+ ssim_loss_ratio: float,
213
+ pos_lr: float,
214
+ scale_lr: float,
215
+ rot_lr: float,
216
+ feat_lr: float,
217
+ disable_lr_schedule: bool,
218
+ disable_prog_optim: bool,
219
+ ) -> Generator[
220
+ Tuple[
221
+ str,
222
+ str,
223
+ Optional[Image.Image], # initialization_map
224
+ Optional[Image.Image], # current_render
225
+ Optional[Image.Image], # current_gaussian_id
226
+ bool, # start_btn_interactive
227
+ bool, # stop_btn_interactive
228
+ ],
229
+ None,
230
+ None,
231
+ ]:
232
+ """Start training and stream progress with images"""
233
+
234
+ if training_state.is_training:
235
+ yield (
236
+ "Training is already in progress!",
237
+ "",
238
+ None,
239
+ None,
240
+ None,
241
+ False, # start_btn disabled
242
+ True, # stop_btn enabled
243
+ )
244
+ return
245
+
246
+ if image_file is None:
247
+ yield (
248
+ "Please upload an image first!",
249
+ "",
250
+ None,
251
+ None,
252
+ None,
253
+ True, # start_btn enabled
254
+ False, # stop_btn disabled
255
+ )
256
+ return
257
+
258
+ try:
259
+ # Reset training state
260
+ training_state.reset()
261
+
262
+ # Create temporary directory for the uploaded image
263
+ training_state.temp_dir = tempfile.mkdtemp()
264
+
265
+ # Save uploaded image
266
+ image_path = os.path.join(training_state.temp_dir, "input_image.png")
267
+ image_file.save(image_path)
268
+
269
+ # Create args
270
+ args = create_args_from_config(
271
+ image_path=image_path,
272
+ exp_name=exp_name,
273
+ num_gaussians=num_gaussians,
274
+ quantize=quantize,
275
+ pos_bits=pos_bits,
276
+ scale_bits=scale_bits,
277
+ rot_bits=rot_bits,
278
+ feat_bits=feat_bits,
279
+ init_mode=init_mode,
280
+ init_random_ratio=init_random_ratio,
281
+ max_steps=max_steps,
282
+ vis_gaussians=vis_gaussians,
283
+ save_image_steps=save_image_steps,
284
+ l1_loss_ratio=l1_loss_ratio,
285
+ l2_loss_ratio=l2_loss_ratio,
286
+ ssim_loss_ratio=ssim_loss_ratio,
287
+ pos_lr=pos_lr,
288
+ scale_lr=scale_lr,
289
+ rot_lr=rot_lr,
290
+ feat_lr=feat_lr,
291
+ disable_lr_schedule=disable_lr_schedule,
292
+ disable_prog_optim=disable_prog_optim,
293
+ )
294
+
295
+ # Update data_root to use temp directory
296
+ args.data_root = training_state.temp_dir
297
+ args.input_path = "input_image.png"
298
+
299
+ # Start training in separate thread
300
+ training_state.is_training = True
301
+ training_state.training_thread = threading.Thread(
302
+ target=train_model, args=(args,)
303
+ )
304
+ training_state.training_thread.start()
305
+
306
+ # Initial yield
307
+ yield (
308
+ "Training started! Check the progress below.",
309
+ "Initializing training...",
310
+ None, # initialization_map
311
+ None, # current_render
312
+ None, # current_gaussian_id
313
+ False, # start_btn disabled
314
+ True, # stop_btn enabled
315
+ )
316
+
317
+ # Stream training progress
318
+ while training_state.is_training or not training_state.results.is_complete:
319
+ # Check if stop was requested
320
+ if (
321
+ not training_state.is_training
322
+ and training_state.training_thread
323
+ and training_state.training_thread.is_alive()
324
+ ):
325
+ # Force stop the training thread if needed
326
+ training_state.results.training_logs.append(
327
+ "🛑 Training stopped by user request"
328
+ )
329
+ break
330
+
331
+ # Get training logs
332
+ if training_state.results.training_logs:
333
+ logs_text = "\n".join(training_state.results.training_logs)
334
+
335
+ # Add current metrics if available
336
+ if training_state.results.step > 0:
337
+ # Break if step is greater than total steps
338
+ if training_state.results.step > training_state.results.total_steps:
339
+ break
340
+
341
+ metrics = training_state.results.metrics
342
+ status_line = (
343
+ f"\nCurrent: Step {training_state.results.step}/{training_state.results.total_steps} | "
344
+ f"PSNR: {metrics['psnr']:.2f} | SSIM: {metrics['ssim']:.4f} | "
345
+ f"Loss: {metrics['loss']:.4f}"
346
+ )
347
+ logs_text += status_line
348
+
349
+ # Add image status info for debugging
350
+ if training_state.results.current_render is not None:
351
+ logs_text += f"\n📸 Current render: {training_state.results.current_render.size}"
352
+ else:
353
+ logs_text += "\n📸 Current render: None"
354
+
355
+ if training_state.results.current_gaussian_id is not None:
356
+ logs_text += f"\n🆔 Gaussian ID: {training_state.results.current_gaussian_id.size}"
357
+ else:
358
+ logs_text += "\n🆔 Gaussian ID: None"
359
+
360
+ logs_text += (
361
+ f"\n💾 Stored steps: {len(training_state.results.step_renders)}"
362
+ )
363
+ else:
364
+ logs_text = "Waiting for training to start..."
365
+
366
+ # Get current images
367
+ initialization_map = training_state.results.initialization_map
368
+ current_render = training_state.results.current_render
369
+ current_gaussian_id = training_state.results.current_gaussian_id
370
+
371
+ # Simple status based on training state
372
+ current_step = training_state.results.step
373
+ if training_state.results.is_complete:
374
+ status = "✅ Training completed successfully!"
375
+ start_btn_interactive = True
376
+ stop_btn_interactive = False
377
+ elif not training_state.is_training:
378
+ status = "⏹️ Training stopped."
379
+ start_btn_interactive = True
380
+ stop_btn_interactive = False
381
+ else:
382
+ status = f"🔄 Training in progress... Step {current_step}/{training_state.results.total_steps}"
383
+ start_btn_interactive = False
384
+ stop_btn_interactive = True
385
+
386
+ # Always yield, even if images haven't changed
387
+ yield (
388
+ status,
389
+ logs_text,
390
+ initialization_map,
391
+ current_render,
392
+ current_gaussian_id,
393
+ start_btn_interactive,
394
+ stop_btn_interactive,
395
+ )
396
+
397
+ # Stop if training is complete
398
+ if training_state.results.is_complete or not training_state.is_training:
399
+ break
400
+ if current_step > training_state.results.total_steps:
401
+ break
402
+
403
+ time.sleep(0.5) # Update more frequently for better responsiveness
404
+
405
+ except Exception as e:
406
+ training_state.reset()
407
+ yield (
408
+ f"Failed to start training: {str(e)}",
409
+ "",
410
+ None,
411
+ None,
412
+ None,
413
+ True, # start_btn enabled
414
+ False, # stop_btn disabled
415
+ )
416
+
417
+
418
+ def stop_training() -> str:
419
+ """Stop the current training"""
420
+ if not training_state.is_training:
421
+ return "No training in progress."
422
+
423
+ training_state.is_training = False
424
+ training_state.results.training_logs.append(
425
+ "🛑 STOP: Training stop requested by user..."
426
+ )
427
+
428
+ # Set a flag in the model to stop training
429
+ if training_state.model:
430
+ training_state.model.stop_requested = True
431
+
432
+ return "Training stop requested. Will complete current step and stop."
433
+
434
+
435
+ def get_final_results() -> Tuple[Optional[Image.Image], Optional[str]]:
436
+ """Get final training results"""
437
+ final_render = training_state.results.final_render
438
+ checkpoint_path = training_state.results.final_checkpoint_path
439
+ return final_render, checkpoint_path
440
+
441
+
442
+ def browse_step_results(
443
+ step: int,
444
+ ) -> Tuple[Optional[Image.Image], Optional[Image.Image]]:
445
+ """Browse results from a specific training step"""
446
+ if not training_state.results.is_complete:
447
+ return None, None
448
+
449
+ # Find the closest available step
450
+ available_steps = list(training_state.results.step_renders.keys())
451
+ if not available_steps:
452
+ return None, None
453
+
454
+ closest_step = min(available_steps, key=lambda x: abs(x - step))
455
+
456
+ render_img = training_state.results.step_renders.get(closest_step)
457
+ gaussian_id_img = training_state.results.step_gaussian_ids.get(closest_step)
458
+
459
+ return render_img, gaussian_id_img
460
+
461
+
462
+ def update_step_slider_after_training() -> gr.Slider:
463
+ """Update step slider range and enable it after training completes"""
464
+ if not training_state.results.is_complete:
465
+ return gr.Slider(
466
+ minimum=0,
467
+ maximum=10000,
468
+ value=0,
469
+ step=100,
470
+ label="Browse Training Steps",
471
+ info="Training not complete yet",
472
+ interactive=False,
473
+ )
474
+
475
+ available_steps = list(training_state.results.step_renders.keys())
476
+ if not available_steps:
477
+ return gr.Slider(
478
+ minimum=0,
479
+ maximum=10000,
480
+ value=0,
481
+ step=100,
482
+ label="Browse Training Steps",
483
+ info="No training steps available",
484
+ interactive=False,
485
+ )
486
+
487
+ max_step = max(available_steps)
488
+ min_step = min(available_steps)
489
+ # Use the step size from save_image_steps if available, otherwise use difference between steps
490
+ if len(available_steps) > 1:
491
+ step_size = available_steps[1] - available_steps[0]
492
+ else:
493
+ step_size = 100
494
+
495
+ return gr.Slider(
496
+ minimum=min_step,
497
+ maximum=max_step,
498
+ value=max_step,
499
+ step=step_size,
500
+ label="Browse Training Steps",
501
+ info=f"Browse results from steps {min_step}-{max_step} (interactive)",
502
+ interactive=True,
503
+ )
504
+
505
+
506
+ def create_interface():
507
+ """Create the Gradio interface"""
508
+
509
+ with gr.Blocks(
510
+ title="Image-GS: 2D Gaussian Splatting", theme=gr.themes.Soft()
511
+ ) as demo:
512
+ gr.Markdown("""
513
+ # Image-GS: Content-Adaptive Image Representation via 2D Gaussians
514
+
515
+ Upload an image and configure parameters to train a 2D Gaussian Splatting representation.
516
+ """)
517
+
518
+ with gr.Row():
519
+ with gr.Column(scale=1):
520
+ gr.Markdown("## Configuration")
521
+
522
+ # Image upload
523
+ image_input = gr.Image(
524
+ label="Input Image",
525
+ type="pil",
526
+ height=300,
527
+ sources=["upload"],
528
+ show_label=True,
529
+ )
530
+
531
+ # Basic parameters
532
+ with gr.Group():
533
+ gr.Markdown("### Basic Parameters")
534
+ exp_name = gr.Textbox(
535
+ label="Experiment Name",
536
+ value="gradio_experiment",
537
+ info="Name for this training run",
538
+ )
539
+ num_gaussians = gr.Slider(
540
+ minimum=1000,
541
+ maximum=50000,
542
+ value=10000,
543
+ step=1000,
544
+ label="Number of Gaussians",
545
+ info="Number of Gaussians (for compression rate control). More = higher quality but slower training",
546
+ )
547
+ max_steps = gr.Slider(
548
+ minimum=100,
549
+ maximum=20000,
550
+ value=10000,
551
+ step=100,
552
+ label="Maximum Training Steps",
553
+ info="Maximum number of optimization steps. Default: 10000",
554
+ )
555
+
556
+ # Quantization parameters
557
+ with gr.Group():
558
+ gr.Markdown("### Quantization")
559
+ quantize = gr.Checkbox(
560
+ label="Enable Quantization",
561
+ value=False,
562
+ info="Enable bit precision control of Gaussian parameters. Reduces memory usage.",
563
+ )
564
+ with gr.Row():
565
+ pos_bits = gr.Slider(
566
+ 4,
567
+ 32,
568
+ 16,
569
+ step=1,
570
+ label="Position Bits",
571
+ info="Bit precision of individual coordinate dimension",
572
+ )
573
+ scale_bits = gr.Slider(
574
+ 4,
575
+ 32,
576
+ 16,
577
+ step=1,
578
+ label="Scale Bits",
579
+ info="Bit precision of individual scale dimension",
580
+ )
581
+ with gr.Row():
582
+ rot_bits = gr.Slider(
583
+ 4,
584
+ 32,
585
+ 16,
586
+ step=1,
587
+ label="Rotation Bits",
588
+ info="Bit precision of Gaussian orientation angle",
589
+ )
590
+ feat_bits = gr.Slider(
591
+ 4,
592
+ 32,
593
+ 16,
594
+ step=1,
595
+ label="Feature Bits",
596
+ info="Bit precision of individual feature dimension",
597
+ )
598
+
599
+ # Initialization parameters
600
+ with gr.Group():
601
+ gr.Markdown("### Initialization")
602
+ init_mode = gr.Radio(
603
+ choices=["gradient", "saliency", "random"],
604
+ value="saliency",
605
+ label="Initialization Mode",
606
+ info="Gaussian position initialization mode. Gradient uses image gradients, saliency uses attention maps.",
607
+ )
608
+ init_random_ratio = gr.Slider(
609
+ minimum=0.0,
610
+ maximum=1.0,
611
+ value=0.3,
612
+ step=0.1,
613
+ label="Random Ratio",
614
+ info="Ratio of Gaussians with randomly initialized position (default: 0.3)",
615
+ )
616
+
617
+ # Advanced parameters (collapsible)
618
+ with gr.Accordion("Advanced Parameters", open=False):
619
+ # Loss parameters
620
+ gr.Markdown("#### Loss Weights")
621
+ with gr.Row():
622
+ l1_loss_ratio = gr.Slider(
623
+ 0.0, 2.0, 1.0, step=0.1, label="L1 Loss"
624
+ )
625
+ l2_loss_ratio = gr.Slider(
626
+ 0.0, 2.0, 0.0, step=0.1, label="L2 Loss"
627
+ )
628
+ ssim_loss_ratio = gr.Slider(
629
+ 0.0, 1.0, 0.1, step=0.01, label="SSIM Loss"
630
+ )
631
+
632
+ # Learning rates
633
+ gr.Markdown("#### Learning Rates")
634
+ with gr.Row():
635
+ pos_lr = gr.Number(value=5e-4, label="Position LR", precision=6)
636
+ scale_lr = gr.Number(value=2e-3, label="Scale LR", precision=6)
637
+ with gr.Row():
638
+ rot_lr = gr.Number(value=2e-3, label="Rotation LR", precision=6)
639
+ feat_lr = gr.Number(value=5e-3, label="Feature LR", precision=6)
640
+
641
+ # Optimization options
642
+ gr.Markdown("#### Optimization")
643
+ disable_lr_schedule = gr.Checkbox(
644
+ label="Disable LR Schedule",
645
+ value=False,
646
+ info="Keep learning rate constant",
647
+ )
648
+ disable_prog_optim = gr.Checkbox(
649
+ label="Disable Progressive Optimization",
650
+ value=False,
651
+ info="Don't add Gaussians during training",
652
+ )
653
+
654
+ # Visualization parameters
655
+ with gr.Group():
656
+ gr.Markdown("### Visualization")
657
+ vis_gaussians = gr.Checkbox(
658
+ label="Visualize Gaussians",
659
+ value=True,
660
+ info="Visualize Gaussians during optimization (default: True)",
661
+ )
662
+ save_image_steps = gr.Slider(
663
+ minimum=200,
664
+ maximum=10000,
665
+ value=200,
666
+ step=100,
667
+ label="Save Image Every N Steps",
668
+ info="Frequency of rendering intermediate results during optimization (default: 100)",
669
+ )
670
+
671
+ # Control buttons
672
+ with gr.Row():
673
+ start_btn = gr.Button(
674
+ "Start Training", variant="primary", size="lg"
675
+ )
676
+ stop_btn = gr.Button("Stop Training", variant="stop", size="lg")
677
+
678
+ status_text = gr.Textbox(label="Status", interactive=False, lines=2)
679
+
680
+ with gr.Column(scale=2):
681
+ gr.Markdown("## Training Progress")
682
+
683
+ # Progress logs (streaming)
684
+ progress_logs = gr.Textbox(
685
+ label="Training Logs",
686
+ lines=10,
687
+ max_lines=15,
688
+ interactive=False,
689
+ autoscroll=True,
690
+ )
691
+
692
+ # Initial map (computed at start based on initialization mode)
693
+ gr.Markdown("### Initialization Map")
694
+ initialization_map = gr.Image(
695
+ label="Initialization Map",
696
+ type="pil",
697
+ height=200,
698
+ )
699
+
700
+ # Training images (streaming)
701
+ gr.Markdown("### Current Training Results")
702
+ with gr.Row():
703
+ current_render = gr.Image(
704
+ label="Current Render",
705
+ type="pil",
706
+ height=300,
707
+ show_label=True,
708
+ show_download_button=True,
709
+ )
710
+ current_gaussian_id = gr.Image(
711
+ label="Gaussian ID",
712
+ type="pil",
713
+ height=300,
714
+ show_label=True,
715
+ show_download_button=True,
716
+ )
717
+
718
+ # Step slider for interactive browsing (will be updated dynamically)
719
+ step_slider = gr.Slider(
720
+ minimum=0,
721
+ maximum=10000,
722
+ value=0,
723
+ step=100,
724
+ label="Browse Training Steps",
725
+ info="Slide to view results from different training steps (disabled during training)",
726
+ interactive=False,
727
+ )
728
+
729
+ gr.Markdown("## Final Results")
730
+ with gr.Row():
731
+ final_render = gr.Image(
732
+ label="Final Render", type="pil", height=300
733
+ )
734
+ final_checkpoint = gr.File(label="Download Final Checkpoint (.pt)")
735
+
736
+ # Results buttons
737
+ with gr.Row():
738
+ results_btn = gr.Button("Load Final Results", size="lg")
739
+ enable_slider_btn = gr.Button(
740
+ "Enable Step Browsing", size="lg", variant="secondary"
741
+ )
742
+
743
+ # Event handlers
744
+ start_btn.click(
745
+ fn=start_training_and_stream,
746
+ inputs=[
747
+ image_input,
748
+ exp_name,
749
+ num_gaussians,
750
+ quantize,
751
+ pos_bits,
752
+ scale_bits,
753
+ rot_bits,
754
+ feat_bits,
755
+ init_mode,
756
+ init_random_ratio,
757
+ max_steps,
758
+ vis_gaussians,
759
+ save_image_steps,
760
+ l1_loss_ratio,
761
+ l2_loss_ratio,
762
+ ssim_loss_ratio,
763
+ pos_lr,
764
+ scale_lr,
765
+ rot_lr,
766
+ feat_lr,
767
+ disable_lr_schedule,
768
+ disable_prog_optim,
769
+ ],
770
+ outputs=[
771
+ status_text,
772
+ progress_logs,
773
+ initialization_map,
774
+ current_render,
775
+ current_gaussian_id,
776
+ start_btn,
777
+ stop_btn,
778
+ ],
779
+ )
780
+
781
+ stop_btn.click(fn=stop_training, outputs=status_text)
782
+
783
+ results_btn.click(
784
+ fn=get_final_results, outputs=[final_render, final_checkpoint]
785
+ )
786
+
787
+ enable_slider_btn.click(
788
+ fn=update_step_slider_after_training, outputs=[step_slider]
789
+ )
790
+
791
+ step_slider.change(
792
+ fn=browse_step_results,
793
+ inputs=[step_slider],
794
+ outputs=[current_render, current_gaussian_id],
795
+ )
796
+
797
+ return demo
798
+
799
+
800
+ if __name__ == "__main__":
801
+ # Ensure model files are available (download from HF if needed)
802
+ ensure_models_available()
803
+
804
+ # Set torch hub directory
805
+ torch.hub.set_dir("models/torch")
806
+
807
+ # Create and launch the interface
808
+ demo = create_interface()
809
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)
gradio_models.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import os
4
+ import threading
5
+ from time import perf_counter
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from fused_ssim import fused_ssim
12
+ from torchvision.transforms.functional import gaussian_blur
13
+ from PIL import Image
14
+
15
+ from gsplat import (
16
+ project_gaussians_2d_scale_rot,
17
+ rasterize_gaussians_no_tiles,
18
+ rasterize_gaussians_sum,
19
+ )
20
+ from utils.image_utils import (
21
+ compute_image_gradients,
22
+ get_grid,
23
+ get_psnr,
24
+ load_images,
25
+ to_output_format,
26
+ )
27
+ from utils.misc_utils import set_random_seed
28
+ from utils.quantization_utils import ste_quantize
29
+ from utils.saliency_utils import get_smap
30
+
31
+
32
+ class StreamingResults:
33
+ """Container for streaming training results"""
34
+
35
+ def __init__(self):
36
+ self.step = 0
37
+ self.total_steps = 0
38
+ self.current_render = None
39
+ self.current_gaussian_id = None
40
+ self.initialization_map = None # Single map for current initialization mode
41
+ self.final_render = None
42
+ self.final_checkpoint_path = None
43
+ self.training_logs = []
44
+ self.metrics = {
45
+ "psnr": 0.0,
46
+ "ssim": 0.0,
47
+ "loss": 0.0,
48
+ "render_time": 0.0,
49
+ "total_time": 0.0,
50
+ }
51
+ self.is_complete = False
52
+ # Store all step results for interactive browsing
53
+ self.step_renders = {} # {step: PIL_Image}
54
+ self.step_gaussian_ids = {} # {step: PIL_Image}
55
+ # For async visualization generation
56
+ self.vis_lock = threading.Lock()
57
+
58
+
59
+ class GradioStreamingHandler(logging.Handler):
60
+ """Custom logging handler that captures logs for Gradio streaming"""
61
+
62
+ def __init__(self, results_container: StreamingResults):
63
+ super().__init__()
64
+ self.results = results_container
65
+
66
+ def emit(self, record):
67
+ log_entry = self.format(record)
68
+ self.results.training_logs.append(log_entry)
69
+ # Keep only last 100 log entries to avoid memory issues
70
+ if len(self.results.training_logs) > 100:
71
+ self.results.training_logs = self.results.training_logs[-100:]
72
+
73
+
74
+ class GradioGaussianSplatting2D(nn.Module):
75
+ """Gradio-optimized version of GaussianSplatting2D with streaming capabilities"""
76
+
77
+ def __init__(self, args, results_container: StreamingResults):
78
+ super(GradioGaussianSplatting2D, self).__init__()
79
+ self.results = results_container
80
+ self.evaluate = args.eval
81
+ set_random_seed(seed=args.seed)
82
+
83
+ # Device setup
84
+ if torch.cuda.is_available():
85
+ torch.cuda.set_device(0)
86
+ self.device = torch.device("cuda:0")
87
+ else:
88
+ self.device = torch.device("cpu")
89
+ self.dtype = torch.float32
90
+
91
+ # Initialize components
92
+ self._init_logging(args)
93
+ self._init_target(args)
94
+ self._init_bit_precision(args)
95
+ self._init_gaussians(args)
96
+ self._init_loss(args)
97
+ self._init_optimization(args)
98
+
99
+ # Initialization
100
+ if self.evaluate:
101
+ self.ckpt_file = args.ckpt_file
102
+ self._load_model()
103
+ else:
104
+ self._init_pos_scale_feat(args)
105
+
106
+ def _init_logging(self, args):
107
+ self.log_dir = getattr(args, "log_dir", "temp_gradio_logs")
108
+ self.vis_gaussians = args.vis_gaussians
109
+ self.save_image_steps = args.save_image_steps
110
+ self.eval_steps = args.eval_steps
111
+
112
+ # Set up streaming logger
113
+ self.worklog = logging.getLogger("GradioImageGS")
114
+ self.worklog.handlers.clear() # Remove existing handlers
115
+
116
+ # Add our streaming handler
117
+ stream_handler = GradioStreamingHandler(self.results)
118
+ stream_handler.setFormatter(
119
+ logging.Formatter(fmt="[{asctime}] {message}", style="{")
120
+ )
121
+ self.worklog.addHandler(stream_handler)
122
+ self.worklog.setLevel(logging.INFO)
123
+
124
+ self.worklog.info(
125
+ f"Start optimizing {args.num_gaussians:d} Gaussians for '{args.input_path}'"
126
+ )
127
+
128
+ def _init_target(self, args):
129
+ self.gamma = args.gamma
130
+ self.downsample = args.downsample
131
+ if self.downsample:
132
+ self.downsample_ratio = float(args.downsample_ratio)
133
+
134
+ self.block_h, self.block_w = 16, 16
135
+ self._load_target_images(path=os.path.join(args.data_root, args.input_path))
136
+
137
+ if self.downsample:
138
+ self.gt_images_upsampled = self.gt_images
139
+ self.img_h_upsampled, self.img_w_upsampled = self.img_h, self.img_w
140
+ self.tile_bounds_upsampled = self.tile_bounds
141
+ self._load_target_images(
142
+ path=os.path.join(args.data_root, args.input_path),
143
+ downsample_ratio=self.downsample_ratio,
144
+ )
145
+
146
+ self.num_pixels = self.img_h * self.img_w
147
+
148
+ def _load_target_images(self, path, downsample_ratio=None):
149
+ self.gt_images, self.input_channels, self.image_fnames = load_images(
150
+ load_path=path, downsample_ratio=downsample_ratio, gamma=self.gamma
151
+ )
152
+ self.gt_images = torch.from_numpy(self.gt_images).to(
153
+ dtype=self.dtype, device=self.device
154
+ )
155
+ self.img_h, self.img_w = self.gt_images.shape[1:]
156
+ self.tile_bounds = (
157
+ (self.img_w + self.block_w - 1) // self.block_w,
158
+ (self.img_h + self.block_h - 1) // self.block_h,
159
+ 1,
160
+ )
161
+
162
+ def _init_bit_precision(self, args):
163
+ self.quantize = args.quantize
164
+ self.pos_bits = args.pos_bits
165
+ self.scale_bits = args.scale_bits
166
+ self.rot_bits = args.rot_bits
167
+ self.feat_bits = args.feat_bits
168
+
169
+ def _init_gaussians(self, args):
170
+ self.num_gaussians = args.num_gaussians
171
+ self.total_num_gaussians = args.num_gaussians
172
+ self.disable_prog_optim = args.disable_prog_optim
173
+
174
+ if not self.disable_prog_optim and not self.evaluate:
175
+ self.initial_ratio = args.initial_ratio
176
+ self.add_times = args.add_times
177
+ self.add_steps = args.add_steps
178
+ self.num_gaussians = math.ceil(
179
+ self.initial_ratio * self.total_num_gaussians
180
+ )
181
+ self.max_add_num = math.ceil(
182
+ float(self.total_num_gaussians - self.num_gaussians) / self.add_times
183
+ )
184
+ min_steps = self.add_steps * self.add_times + args.post_min_steps
185
+ if args.max_steps < min_steps:
186
+ self.worklog.info(
187
+ f"Max steps ({args.max_steps:d}) is too small for progressive optimization. Resetting to {min_steps:d}"
188
+ )
189
+ args.max_steps = min_steps
190
+
191
+ self.topk = args.topk
192
+ self.eps = 1e-7 if args.disable_tiles else 1e-4
193
+ self.init_scale = args.init_scale
194
+ self.disable_topk_norm = args.disable_topk_norm
195
+ self.disable_inverse_scale = args.disable_inverse_scale
196
+ self.disable_color_init = args.disable_color_init
197
+
198
+ # Initialize parameters
199
+ self.xy = nn.Parameter(
200
+ torch.rand(self.num_gaussians, 2, dtype=self.dtype, device=self.device),
201
+ requires_grad=True,
202
+ )
203
+ self.scale = nn.Parameter(
204
+ torch.ones(self.num_gaussians, 2, dtype=self.dtype, device=self.device),
205
+ requires_grad=True,
206
+ )
207
+ self.rot = nn.Parameter(
208
+ torch.zeros(self.num_gaussians, 1, dtype=self.dtype, device=self.device),
209
+ requires_grad=True,
210
+ )
211
+ self.feat_dim = sum(self.input_channels)
212
+ self.feat = nn.Parameter(
213
+ torch.rand(
214
+ self.num_gaussians, self.feat_dim, dtype=self.dtype, device=self.device
215
+ ),
216
+ requires_grad=True,
217
+ )
218
+ self.vis_feat = nn.Parameter(torch.rand_like(self.feat), requires_grad=False)
219
+
220
+ self._log_compression_rate()
221
+
222
+ def _log_compression_rate(self):
223
+ bytes_uncompressed = float(self.gt_images.numel())
224
+ bpp_uncompressed = float(8 * self.feat_dim)
225
+ self.worklog.info(
226
+ f"Uncompressed: {bytes_uncompressed / 1e3:.2f} KB | {bpp_uncompressed:.3f} bpp | 8.0 bppc"
227
+ )
228
+ bits_compressed = (
229
+ 2 * self.pos_bits
230
+ + 2 * self.scale_bits
231
+ + self.rot_bits
232
+ + self.feat_dim * self.feat_bits
233
+ ) * self.total_num_gaussians
234
+ bytes_compressed = bits_compressed / 8.0
235
+ bpp_compressed = float(bits_compressed) / self.num_pixels
236
+ bppc_compressed = bpp_compressed / self.feat_dim
237
+ self.num_bytes = bytes_compressed
238
+ self.worklog.info(
239
+ f"Compressed: {bytes_compressed / 1e3:.2f} KB | {bpp_compressed:.3f} bpp | {bppc_compressed:.3f} bppc"
240
+ )
241
+ self.worklog.info(
242
+ f"Compression rate: {bpp_uncompressed / bpp_compressed:.2f}x | {100.0 * bpp_compressed / bpp_uncompressed:.2f}%"
243
+ )
244
+
245
+ def _init_loss(self, args):
246
+ self.l1_loss_ratio = args.l1_loss_ratio
247
+ self.l2_loss_ratio = args.l2_loss_ratio
248
+ self.ssim_loss_ratio = args.ssim_loss_ratio
249
+
250
+ def _init_optimization(self, args):
251
+ self.disable_tiles = args.disable_tiles
252
+ self.start_step = 1
253
+ self.max_steps = args.max_steps
254
+ self.results.total_steps = (
255
+ args.max_steps
256
+ ) # Set total steps for streaming progress
257
+ self.pos_lr = args.pos_lr
258
+ self.scale_lr = args.scale_lr
259
+ self.rot_lr = args.rot_lr
260
+ self.feat_lr = args.feat_lr
261
+
262
+ self.optimizer = torch.optim.Adam(
263
+ [
264
+ {"params": self.xy, "lr": self.pos_lr},
265
+ {"params": self.scale, "lr": self.scale_lr},
266
+ {"params": self.rot, "lr": self.rot_lr},
267
+ {"params": self.feat, "lr": self.feat_lr},
268
+ ]
269
+ )
270
+
271
+ self.disable_lr_schedule = args.disable_lr_schedule
272
+ if not self.disable_lr_schedule:
273
+ self.decay_ratio = args.decay_ratio
274
+ self.check_decay_steps = args.check_decay_steps
275
+ self.max_decay_times = args.max_decay_times
276
+ self.decay_threshold = args.decay_threshold
277
+
278
+ def _init_pos_scale_feat(self, args):
279
+ self.init_mode = args.init_mode
280
+ self.init_random_ratio = args.init_random_ratio
281
+ self.pixel_xy = (
282
+ get_grid(h=self.img_h, w=self.img_w)
283
+ .to(dtype=self.dtype, device=self.device)
284
+ .reshape(-1, 2)
285
+ )
286
+
287
+ with torch.no_grad():
288
+ # Position initialization
289
+ if self.init_mode == "gradient":
290
+ self._compute_gmap()
291
+ self.xy.copy_(self._sample_pos(prob=self.image_gradients))
292
+ elif self.init_mode == "saliency":
293
+ self.smap_filter_size = args.smap_filter_size
294
+ self._compute_smap()
295
+ self.xy.copy_(self._sample_pos(prob=self.saliency))
296
+ else: # random mode
297
+ selected = np.random.choice(
298
+ self.num_pixels, self.num_gaussians, replace=False, p=None
299
+ )
300
+ self.xy.copy_(self.pixel_xy.detach().clone()[selected])
301
+ # For random mode, create a simple random noise pattern
302
+ if self.init_mode == "random":
303
+ random_pattern = np.random.rand(self.img_h, self.img_w)
304
+ self.results.initialization_map = Image.fromarray(
305
+ (random_pattern * 255).astype(np.uint8)
306
+ )
307
+
308
+ # Scale initialization
309
+ self.scale.fill_(
310
+ self.init_scale if self.disable_inverse_scale else 1.0 / self.init_scale
311
+ )
312
+
313
+ # Feature initialization
314
+ if not self.disable_color_init:
315
+ self.feat.copy_(
316
+ self._get_target_features(positions=self.xy).detach().clone()
317
+ )
318
+
319
+ def _sample_pos(self, prob):
320
+ num_random = round(self.init_random_ratio * self.num_gaussians)
321
+ selected_random = np.random.choice(
322
+ self.num_pixels, num_random, replace=False, p=None
323
+ )
324
+ selected_other = np.random.choice(
325
+ self.num_pixels, self.num_gaussians - num_random, replace=False, p=prob
326
+ )
327
+ return torch.cat(
328
+ [
329
+ self.pixel_xy.detach().clone()[selected_random],
330
+ self.pixel_xy.detach().clone()[selected_other],
331
+ ],
332
+ dim=0,
333
+ )
334
+
335
+ def _compute_gmap(self):
336
+ gy, gx = compute_image_gradients(
337
+ np.power(self.gt_images.detach().cpu().clone().numpy(), 1.0 / self.gamma)
338
+ )
339
+ g_norm = np.hypot(gy, gx).astype(np.float32)
340
+ g_norm = g_norm / g_norm.max()
341
+
342
+ # Store gradient map for streaming (only if this is the selected initialization mode)
343
+ if self.init_mode == "gradient":
344
+ self.results.initialization_map = Image.fromarray(
345
+ (g_norm * 255).astype(np.uint8)
346
+ )
347
+
348
+ g_norm = np.power(g_norm.reshape(-1), 2.0)
349
+ self.image_gradients = g_norm / g_norm.sum()
350
+ self.worklog.info("Image gradient map computed")
351
+
352
+ def _compute_smap(self):
353
+ smap = get_smap(
354
+ torch.pow(self.gt_images.detach().clone(), 1.0 / self.gamma),
355
+ "models",
356
+ self.smap_filter_size,
357
+ )
358
+
359
+ # Store saliency map for streaming (only if this is the selected initialization mode)
360
+ if self.init_mode == "saliency":
361
+ self.results.initialization_map = Image.fromarray(
362
+ (smap * 255).astype(np.uint8)
363
+ )
364
+
365
+ self.saliency = (smap / smap.sum()).reshape(-1)
366
+ self.worklog.info("Saliency map computed")
367
+
368
+ def _get_target_features(self, positions):
369
+ with torch.no_grad():
370
+ target_features = F.grid_sample(
371
+ self.gt_images.unsqueeze(0),
372
+ positions[None, None, ...] * 2.0 - 1.0,
373
+ align_corners=False,
374
+ )
375
+ target_features = target_features[0, :, 0, :].permute(1, 0)
376
+ return target_features
377
+
378
+ def forward(self, img_h, img_w, tile_bounds, upsample_ratio=None, benchmark=False):
379
+ scale = self._get_scale(upsample_ratio=upsample_ratio)
380
+ xy, rot, feat = self.xy, self.rot, self.feat
381
+
382
+ if self.quantize:
383
+ xy, scale, rot, feat = (
384
+ ste_quantize(xy, self.pos_bits),
385
+ ste_quantize(scale, self.scale_bits),
386
+ ste_quantize(rot, self.rot_bits),
387
+ ste_quantize(feat, self.feat_bits),
388
+ )
389
+
390
+ begin = perf_counter()
391
+ tmp = project_gaussians_2d_scale_rot(xy, scale, rot, img_h, img_w, tile_bounds)
392
+ xy, radii, conics, num_tiles_hit = tmp
393
+
394
+ if not self.disable_tiles:
395
+ enable_topk_norm = not self.disable_topk_norm
396
+ tmp = (
397
+ xy,
398
+ radii,
399
+ conics,
400
+ num_tiles_hit,
401
+ feat,
402
+ img_h,
403
+ img_w,
404
+ self.block_h,
405
+ self.block_w,
406
+ enable_topk_norm,
407
+ )
408
+ out_image = rasterize_gaussians_sum(*tmp)
409
+ else:
410
+ tmp = xy, conics, feat, img_h, img_w
411
+ out_image = rasterize_gaussians_no_tiles(*tmp)
412
+
413
+ render_time = perf_counter() - begin
414
+
415
+ if benchmark:
416
+ return render_time
417
+
418
+ out_image = (
419
+ out_image.view(-1, img_h, img_w, self.feat_dim)
420
+ .permute(0, 3, 1, 2)
421
+ .contiguous()
422
+ )
423
+ return out_image.squeeze(dim=0), render_time
424
+
425
+ def _get_scale(self, upsample_ratio=None):
426
+ scale = self.scale
427
+ if not self.disable_inverse_scale:
428
+ scale = 1.0 / scale
429
+ if upsample_ratio is not None:
430
+ scale = upsample_ratio * scale
431
+ return scale
432
+
433
+ def _tensor_to_pil_image(self, tensor_image):
434
+ """Convert tensor image to PIL Image for streaming"""
435
+ if tensor_image is None:
436
+ return None
437
+
438
+ # Convert to numpy and apply gamma correction
439
+ image_np = (
440
+ torch.pow(torch.clamp(tensor_image, 0.0, 1.0), 1.0 / self.gamma)
441
+ .detach()
442
+ .cpu()
443
+ .numpy()
444
+ )
445
+
446
+ # Convert to uint8 format
447
+ image_formatted = to_output_format(image_np, gamma=None)
448
+ return Image.fromarray(image_formatted)
449
+
450
+ def _create_gaussian_id_visualization(self):
451
+ """Create Gaussian ID visualization as PIL Image using rasterization with vis_feat"""
452
+ if not self.vis_gaussians:
453
+ return None
454
+
455
+ try:
456
+ # Use vis_feat for ID visualization (this creates unique colors per Gaussian)
457
+ feat = self.vis_feat * self.feat.norm(dim=-1, keepdim=True)
458
+
459
+ # Render with ID features
460
+ scale = self._get_scale()
461
+ xy, rot = self.xy, self.rot
462
+
463
+ if self.quantize:
464
+ xy, scale, rot, feat = (
465
+ ste_quantize(xy, self.pos_bits),
466
+ ste_quantize(scale, self.scale_bits),
467
+ ste_quantize(rot, self.rot_bits),
468
+ ste_quantize(feat, self.feat_bits),
469
+ )
470
+
471
+ tmp = project_gaussians_2d_scale_rot(
472
+ xy, scale, rot, self.img_h, self.img_w, self.tile_bounds
473
+ )
474
+ xy, radii, conics, num_tiles_hit = tmp
475
+
476
+ if not self.disable_tiles:
477
+ enable_topk_norm = not self.disable_topk_norm
478
+ tmp = (
479
+ xy,
480
+ radii,
481
+ conics,
482
+ num_tiles_hit,
483
+ feat,
484
+ self.img_h,
485
+ self.img_w,
486
+ self.block_h,
487
+ self.block_w,
488
+ enable_topk_norm,
489
+ )
490
+ out_image = rasterize_gaussians_sum(*tmp)
491
+ else:
492
+ tmp = xy, conics, feat, self.img_h, self.img_w
493
+ out_image = rasterize_gaussians_no_tiles(*tmp)
494
+
495
+ out_image = (
496
+ out_image.view(-1, self.img_h, self.img_w, self.feat_dim)
497
+ .permute(0, 3, 1, 2)
498
+ .contiguous()
499
+ ).squeeze(dim=0)
500
+
501
+ return self._tensor_to_pil_image(out_image)
502
+
503
+ except Exception as e:
504
+ self.worklog.error(f"Error creating Gaussian ID visualization: {e}")
505
+ return None
506
+
507
+ def optimize(self):
508
+ """Main optimization loop with streaming updates"""
509
+ self.psnr_curr, self.ssim_curr = 0.0, 0.0
510
+ self.best_psnr, self.best_ssim = 0.0, 0.0
511
+ self.decay_times, self.no_improvement_steps = 0, 0
512
+ self.render_time_accum, self.total_time_accum = 0.0, 0.0
513
+
514
+ # Initialize attributes needed for evaluation
515
+ self.l1_loss = None
516
+ self.l2_loss = None
517
+ self.ssim_loss = None
518
+ self.stop_requested = False
519
+
520
+ # Initial render and update
521
+ with torch.no_grad():
522
+ images, _ = self.forward(self.img_h, self.img_w, self.tile_bounds)
523
+ self.results.current_render = self._tensor_to_pil_image(images)
524
+ if self.vis_gaussians:
525
+ try:
526
+ self.results.current_gaussian_id = (
527
+ self._create_gaussian_id_visualization()
528
+ )
529
+ self.worklog.info(
530
+ f"Initial visualizations created - Render: {'✓' if self.results.current_render else '✗'}, ID: {'✓' if self.results.current_gaussian_id else '✗'}"
531
+ )
532
+ except Exception as e:
533
+ self.worklog.error(f"Error creating initial visualizations: {e}")
534
+ self.results.current_gaussian_id = None
535
+
536
+ # Store initial results (step 0)
537
+ self.results.step_renders[0] = self.results.current_render
538
+ if self.vis_gaussians:
539
+ self.results.step_gaussian_ids[0] = self.results.current_gaussian_id
540
+
541
+ for step in range(self.start_step, self.max_steps + 1):
542
+ self.step = step
543
+ self.results.step = step
544
+
545
+ self.optimizer.zero_grad()
546
+
547
+ # Forward pass
548
+ images, render_time = self.forward(self.img_h, self.img_w, self.tile_bounds)
549
+ self.render_time_accum += render_time
550
+
551
+ # Compute loss
552
+ begin = perf_counter()
553
+ self._get_total_loss(images)
554
+ self.total_loss.backward()
555
+ self.optimizer.step()
556
+ self.total_time_accum += perf_counter() - begin + render_time
557
+
558
+ # Update streaming results
559
+ with torch.no_grad():
560
+ if step % self.eval_steps == 0:
561
+ self._evaluate_and_update_stream(images)
562
+
563
+ # Update render image more frequently, but visualizations less frequently
564
+ render_update_freq = max(
565
+ 50, self.save_image_steps // 2
566
+ ) # Render updates every 50 steps
567
+ vis_update_freq = max(
568
+ 200, self.save_image_steps
569
+ ) # Visualizations every 200 steps
570
+
571
+ if step % render_update_freq == 0:
572
+ render_img = self._tensor_to_pil_image(images)
573
+ self.results.current_render = render_img
574
+
575
+ # Only update Gaussian ID visualization less frequently
576
+ if step % vis_update_freq == 0 and self.vis_gaussians:
577
+ # Generate Gaussian ID visualization asynchronously
578
+ def generate_gaussian_id_async():
579
+ try:
580
+ with self.results.vis_lock:
581
+ gaussian_id_vis = (
582
+ self._create_gaussian_id_visualization()
583
+ )
584
+ self.results.current_gaussian_id = gaussian_id_vis
585
+
586
+ except Exception as e:
587
+ self.worklog.error(
588
+ f"Error creating Gaussian ID visualization at step {step}: {e}"
589
+ )
590
+ with self.results.vis_lock:
591
+ self.results.current_gaussian_id = None
592
+
593
+ # Start async visualization generation
594
+ vis_thread = threading.Thread(target=generate_gaussian_id_async)
595
+ vis_thread.daemon = True
596
+ vis_thread.start()
597
+
598
+ # Store results for interactive browsing only at save_image_steps intervals
599
+ if step % self.save_image_steps == 0:
600
+ # Store the current render for browsing
601
+ if self.results.current_render:
602
+ self.results.step_renders[step] = self.results.current_render
603
+
604
+ # Store Gaussian ID visualization for browsing
605
+ if self.vis_gaussians and self.results.current_gaussian_id:
606
+ self.results.step_gaussian_ids[step] = (
607
+ self.results.current_gaussian_id
608
+ )
609
+
610
+ # Progressive optimization
611
+ if (
612
+ not self.disable_prog_optim
613
+ and step % self.add_steps == 0
614
+ and self.num_gaussians < self.total_num_gaussians
615
+ ):
616
+ self._add_gaussians(self.max_add_num)
617
+
618
+ # Learning rate schedule
619
+ terminate = False
620
+ if (
621
+ not self.disable_lr_schedule
622
+ and self.num_gaussians == self.total_num_gaussians
623
+ and step % self.eval_steps == 0
624
+ ):
625
+ terminate = self._lr_schedule()
626
+
627
+ if terminate or self.stop_requested:
628
+ if self.stop_requested:
629
+ self.worklog.info("Training stopped by user request")
630
+ break
631
+
632
+ # Final updates
633
+ with torch.no_grad():
634
+ images, _ = self.forward(self.img_h, self.img_w, self.tile_bounds)
635
+ self.results.final_render = self._tensor_to_pil_image(images)
636
+
637
+ # Save final checkpoint and store path
638
+ self._save_final_checkpoint()
639
+
640
+ self.results.is_complete = True
641
+ self.worklog.info("Optimization completed")
642
+
643
+ def _get_total_loss(self, images):
644
+ self.total_loss = 0
645
+
646
+ if self.l1_loss_ratio > 1e-7:
647
+ self.l1_loss = self.l1_loss_ratio * F.l1_loss(images, self.gt_images)
648
+ self.total_loss += self.l1_loss
649
+ else:
650
+ self.l1_loss = None
651
+
652
+ if self.l2_loss_ratio > 1e-7:
653
+ self.l2_loss = self.l2_loss_ratio * F.mse_loss(images, self.gt_images)
654
+ self.total_loss += self.l2_loss
655
+ else:
656
+ self.l2_loss = None
657
+
658
+ if self.ssim_loss_ratio > 1e-7:
659
+ self.ssim_loss = self.ssim_loss_ratio * (
660
+ 1 - fused_ssim(images.unsqueeze(0), self.gt_images.unsqueeze(0))
661
+ )
662
+ self.total_loss += self.ssim_loss
663
+ else:
664
+ self.ssim_loss = None
665
+
666
+ def _evaluate_and_update_stream(self, images):
667
+ """Evaluate current state and update streaming results"""
668
+ gamma_corrected_images = torch.pow(
669
+ torch.clamp(images, 0.0, 1.0), 1.0 / self.gamma
670
+ )
671
+ gamma_corrected_gt = torch.pow(self.gt_images, 1.0 / self.gamma)
672
+
673
+ psnr = get_psnr(gamma_corrected_images, gamma_corrected_gt).item()
674
+ ssim = fused_ssim(
675
+ gamma_corrected_images.unsqueeze(0), gamma_corrected_gt.unsqueeze(0)
676
+ ).item()
677
+
678
+ self.psnr_curr, self.ssim_curr = psnr, ssim
679
+
680
+ # Update metrics
681
+ self.results.metrics.update(
682
+ {
683
+ "psnr": psnr,
684
+ "ssim": ssim,
685
+ "loss": self.total_loss.item(),
686
+ "render_time": self.render_time_accum,
687
+ "total_time": self.total_time_accum,
688
+ }
689
+ )
690
+
691
+ # Log progress
692
+ loss_results = f"Loss: {self.total_loss.item():.4f}"
693
+ if self.l1_loss is not None:
694
+ loss_results += f", L1: {self.l1_loss.item():.4f}"
695
+ if self.l2_loss is not None:
696
+ loss_results += f", L2: {self.l2_loss.item():.4f}"
697
+ if self.ssim_loss is not None:
698
+ loss_results += f", SSIM: {self.ssim_loss.item():.4f}"
699
+
700
+ time_results = f"Total: {self.total_time_accum:.2f} s | Render: {self.render_time_accum:.2f} s"
701
+
702
+ self.worklog.info(
703
+ f"Step: {self.step:d} | {time_results} | {loss_results} | PSNR: {psnr:.2f} | SSIM: {ssim:.4f}"
704
+ )
705
+
706
+ def _save_final_checkpoint(self):
707
+ """Save final checkpoint and store the path"""
708
+ if self.quantize:
709
+ with torch.no_grad():
710
+ self.xy.copy_(ste_quantize(self.xy, self.pos_bits))
711
+ self.scale.copy_(ste_quantize(self.scale, self.scale_bits))
712
+ self.rot.copy_(ste_quantize(self.rot, self.rot_bits))
713
+ self.feat.copy_(ste_quantize(self.feat, self.feat_bits))
714
+
715
+ # Create checkpoint directory
716
+ ckpt_dir = os.path.join(self.log_dir, "checkpoints")
717
+ os.makedirs(ckpt_dir, exist_ok=True)
718
+
719
+ psnr = self.results.metrics.get("psnr", 0.0)
720
+ ssim = self.results.metrics.get("ssim", 0.0)
721
+
722
+ ckpt_data = {
723
+ "step": self.step,
724
+ "psnr": psnr,
725
+ "ssim": ssim,
726
+ "bytes": getattr(self, "num_bytes", 0),
727
+ "time": self.total_time_accum,
728
+ "state_dict": self.state_dict(),
729
+ "optim_state_dict": self.optimizer.state_dict(),
730
+ }
731
+
732
+ ckpt_path = os.path.join(ckpt_dir, f"ckpt_step-{self.step:d}.pt")
733
+ torch.save(ckpt_data, ckpt_path)
734
+ self.results.final_checkpoint_path = ckpt_path
735
+
736
+ self.worklog.info(f"Final checkpoint saved: {ckpt_path}")
737
+
738
+ def _lr_schedule(self):
739
+ """Learning rate scheduling logic"""
740
+ if (
741
+ self.psnr_curr <= self.best_psnr + 100 * self.decay_threshold
742
+ or self.ssim_curr <= self.best_ssim + self.decay_threshold
743
+ ):
744
+ self.no_improvement_steps += self.eval_steps
745
+ if self.no_improvement_steps >= self.check_decay_steps:
746
+ self.no_improvement_steps = 0
747
+ self.decay_times += 1
748
+ if self.decay_times > self.max_decay_times:
749
+ return True
750
+ for param_group in self.optimizer.param_groups:
751
+ param_group["lr"] /= self.decay_ratio
752
+ self.worklog.info(f"Learning rate decayed by {self.decay_ratio:.1f}")
753
+ return False
754
+ else:
755
+ self.best_psnr = self.psnr_curr
756
+ self.best_ssim = self.ssim_curr
757
+ self.no_improvement_steps = 0
758
+ return False
759
+
760
+ def _add_gaussians(self, add_num):
761
+ """Add Gaussians during progressive optimization"""
762
+ add_num = min(
763
+ add_num, self.max_add_num, self.total_num_gaussians - self.num_gaussians
764
+ )
765
+ if add_num <= 0:
766
+ return
767
+
768
+ # Compute error map for new Gaussian placement
769
+ raw_images, _ = self.forward(self.img_h, self.img_w, self.tile_bounds)
770
+ images = torch.pow(torch.clamp(raw_images, 0.0, 1.0), 1.0 / self.gamma)
771
+ gt_images = torch.pow(self.gt_images, 1.0 / self.gamma)
772
+
773
+ kernel_size = round(np.sqrt(self.img_h * self.img_w) // 400)
774
+ if kernel_size >= 1:
775
+ kernel_size = max(3, kernel_size)
776
+ kernel_size = kernel_size + 1 if kernel_size % 2 == 0 else kernel_size
777
+ gt_images = gaussian_blur(img=gt_images, kernel_size=kernel_size)
778
+
779
+ diff_map = (gt_images - images).detach().clone()
780
+ error_map = torch.pow(torch.abs(diff_map).mean(dim=0).reshape(-1), 2.0)
781
+ sample_prob = (error_map / error_map.sum()).cpu().numpy()
782
+ selected = np.random.choice(
783
+ self.num_pixels, add_num, replace=False, p=sample_prob
784
+ )
785
+
786
+ # Create new Gaussians
787
+ new_xy = self.pixel_xy.detach().clone()[selected]
788
+ new_scale = torch.ones(add_num, 2, dtype=self.dtype, device=self.device)
789
+ init_scale = self.init_scale
790
+ new_scale.fill_(init_scale if self.disable_inverse_scale else 1.0 / init_scale)
791
+ new_rot = torch.zeros(add_num, 1, dtype=self.dtype, device=self.device)
792
+ new_feat = diff_map.permute(1, 2, 0).reshape(-1, self.feat_dim)[selected]
793
+ new_vis_feat = torch.rand_like(new_feat)
794
+
795
+ # Update parameters
796
+ old_xy = self.xy.detach().clone()
797
+ old_scale = self.scale.detach().clone()
798
+ old_rot = self.rot.detach().clone()
799
+ old_feat = self.feat.detach().clone()
800
+ old_vis_feat = self.vis_feat.detach().clone()
801
+
802
+ self.num_gaussians += add_num
803
+ all_xy = torch.cat([old_xy, new_xy], dim=0)
804
+ all_scale = torch.cat([old_scale, new_scale], dim=0)
805
+ all_rot = torch.cat([old_rot, new_rot], dim=0)
806
+ all_feat = torch.cat([old_feat, new_feat], dim=0)
807
+ all_vis_feat = torch.cat([old_vis_feat, new_vis_feat], dim=0)
808
+
809
+ self.xy = nn.Parameter(all_xy, requires_grad=True)
810
+ self.scale = nn.Parameter(all_scale, requires_grad=True)
811
+ self.rot = nn.Parameter(all_rot, requires_grad=True)
812
+ self.feat = nn.Parameter(all_feat, requires_grad=True)
813
+ self.vis_feat = nn.Parameter(all_vis_feat, requires_grad=False)
814
+
815
+ # Update optimizer
816
+ self.optimizer = torch.optim.Adam(
817
+ [
818
+ {"params": self.xy, "lr": self.pos_lr},
819
+ {"params": self.scale, "lr": self.scale_lr},
820
+ {"params": self.rot, "lr": self.rot_lr},
821
+ {"params": self.feat, "lr": self.feat_lr},
822
+ ]
823
+ )
824
+
825
+ self.worklog.info(
826
+ f"Step: {self.step:d} | Adding {add_num:d} Gaussians ({self.num_gaussians - add_num:d} -> {self.num_gaussians:d})"
827
+ )
main.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+
5
+ from model import GaussianSplatting2D
6
+ from utils.misc_utils import load_cfg
7
+
8
+
9
+ def get_gaussian_cfg(args):
10
+ gaussian_cfg = f"num-{args.num_gaussians:d}"
11
+ if args.disable_inverse_scale:
12
+ gaussian_cfg += f"_scale-{args.init_scale:.1f}"
13
+ else:
14
+ gaussian_cfg += f"_inv-scale-{args.init_scale:.1f}"
15
+ if not args.quantize:
16
+ args.pos_bits, args.scale_bits, args.rot_bits, args.feat_bits = 32, 32, 32, 32
17
+ min_bits = min(args.pos_bits, args.scale_bits, args.rot_bits, args.feat_bits)
18
+ max_bits = max(args.pos_bits, args.scale_bits, args.rot_bits, args.feat_bits)
19
+ if min_bits < 4 or max_bits > 32:
20
+ raise ValueError(
21
+ f"Bit precision must be between 4 and 32 but got: {args.pos_bits:d}, {args.scale_bits:d}, {args.rot_bits:d}, {args.feat_bits:d}"
22
+ )
23
+ gaussian_cfg += f"_bits-{args.pos_bits:d}-{args.scale_bits:d}-{args.rot_bits:d}-{args.feat_bits:d}"
24
+ if not args.disable_topk_norm:
25
+ gaussian_cfg += f"_top-{args.topk:d}"
26
+ gaussian_cfg += f"_{args.init_mode[0]}-{args.init_random_ratio:.1f}"
27
+ return gaussian_cfg
28
+
29
+
30
+ def get_log_dir(args):
31
+ gaussian_cfg = get_gaussian_cfg(args)
32
+ loss_cfg = f"l1-{args.l1_loss_ratio:.1f}_l2-{args.l2_loss_ratio:.1f}_ssim-{args.ssim_loss_ratio:.1f}"
33
+ folder = f"{gaussian_cfg}_{loss_cfg}"
34
+ if args.downsample:
35
+ folder += f"_ds-{args.downsample_ratio:.1f}"
36
+ if not args.disable_lr_schedule:
37
+ folder += f"_decay-{args.max_decay_times:d}-{args.decay_ratio:.1f}"
38
+ if not args.disable_prog_optim:
39
+ folder += "_prog"
40
+ return f"{args.log_root}/{args.exp_name}/{folder}"
41
+
42
+
43
+ def main(args):
44
+ args.log_dir = get_log_dir(args)
45
+ ImageGS = GaussianSplatting2D(args)
46
+ if args.eval:
47
+ ImageGS.render(render_height=args.render_height)
48
+ else:
49
+ ImageGS.optimize()
50
+
51
+
52
+ if __name__ == "__main__":
53
+ torch.hub.set_dir("models/torch")
54
+ parser = argparse.ArgumentParser()
55
+ parser = load_cfg(cfg_path="cfgs/default.yaml", parser=parser)
56
+ arguments = parser.parse_args()
57
+ main(arguments)
model.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import os
4
+ import sys
5
+ from time import perf_counter
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from fused_ssim import fused_ssim
12
+ from lpips import LPIPS
13
+ from pytorch_msssim import MS_SSIM
14
+ from torchvision.transforms.functional import gaussian_blur
15
+
16
+ from gsplat import (
17
+ project_gaussians_2d_scale_rot,
18
+ rasterize_gaussians_no_tiles,
19
+ rasterize_gaussians_sum,
20
+ )
21
+ from utils.flip import LDRFLIPLoss
22
+ from utils.image_utils import (
23
+ compute_image_gradients,
24
+ get_grid,
25
+ get_psnr,
26
+ load_images,
27
+ save_image,
28
+ separate_image_channels,
29
+ visualize_added_gaussians,
30
+ visualize_gaussians,
31
+ )
32
+ from utils.misc_utils import clean_dir, get_latest_ckpt_step, save_cfg, set_random_seed
33
+ from utils.quantization_utils import ste_quantize
34
+ from utils.saliency_utils import get_smap
35
+
36
+
37
+ class GaussianSplatting2D(nn.Module):
38
+ def __init__(self, args):
39
+ super(GaussianSplatting2D, self).__init__()
40
+ self.evaluate = args.eval
41
+ set_random_seed(seed=args.seed)
42
+ # Ensure we're using the correct CUDA device
43
+ if torch.cuda.is_available():
44
+ torch.cuda.set_device(0) # Force device 0
45
+ self.device = torch.device("cuda:0")
46
+ else:
47
+ self.device = torch.device("cpu")
48
+ self.dtype = torch.float32
49
+ self._init_logging(args)
50
+ self._init_target(args)
51
+ self._init_bit_precision(args)
52
+ self._init_gaussians(args)
53
+ self._init_loss(args)
54
+ self._init_optimization(args)
55
+ # Initialization
56
+ if self.evaluate:
57
+ self.ckpt_file = args.ckpt_file
58
+ self._load_model()
59
+ else:
60
+ self._init_pos_scale_feat(args)
61
+
62
+ def _init_logging(self, args):
63
+ self.log_dir = args.log_dir
64
+ self.log_level = args.log_level
65
+ self.ckpt_dir = os.path.join(self.log_dir, "checkpoints")
66
+ self.train_dir = os.path.join(self.log_dir, "train")
67
+ self.eval_dir = os.path.join(self.log_dir, "eval")
68
+ self.vis_gaussians = args.vis_gaussians
69
+ self.save_image_steps = args.save_image_steps
70
+ self.save_ckpt_steps = args.save_ckpt_steps
71
+ self.eval_steps = args.eval_steps
72
+ if not self.evaluate:
73
+ clean_dir(path=self.log_dir)
74
+ os.makedirs(self.log_dir, exist_ok=False)
75
+ os.makedirs(self.ckpt_dir, exist_ok=False)
76
+ os.makedirs(self.train_dir, exist_ok=False)
77
+ else:
78
+ os.makedirs(self.eval_dir, exist_ok=True)
79
+ self._gen_logger(args)
80
+ if not self.evaluate:
81
+ save_cfg(path=f"{self.log_dir}/cfg_train.yaml", args=args)
82
+
83
+ def _gen_logger(self, args):
84
+ log_fname = "log_train"
85
+ if self.evaluate:
86
+ log_fname = "log_eval"
87
+ log_level = getattr(logging, self.log_level, logging.INFO)
88
+ logging.basicConfig(level=log_level)
89
+ self.worklog = logging.getLogger("Image-GS Logger")
90
+ self.worklog.propagate = False
91
+ datefmt = "%Y/%m/%d %H:%M:%S"
92
+ fileHandler = logging.FileHandler(
93
+ f"{self.log_dir}/{log_fname}.txt", mode="a", encoding="utf8"
94
+ )
95
+ fileHandler.setFormatter(
96
+ logging.Formatter(fmt="[{asctime}] {message}", datefmt=datefmt, style="{")
97
+ )
98
+ consoleHandler = logging.StreamHandler(sys.stdout)
99
+ consoleHandler.setFormatter(
100
+ logging.Formatter(
101
+ fmt="\x1b[32m[{asctime}] \x1b[0m{message}", datefmt=datefmt, style="{"
102
+ )
103
+ )
104
+ self.worklog.handlers = [fileHandler, consoleHandler]
105
+ action = "rendering" if self.evaluate else "optimizing"
106
+ self.worklog.info(
107
+ f"Start {action} {args.num_gaussians:d} Gaussians for '{args.input_path}'"
108
+ )
109
+ self.worklog.info("***********************************************")
110
+
111
+ def _init_target(self, args):
112
+ self.gamma = args.gamma
113
+ self.downsample = args.downsample
114
+ if self.downsample:
115
+ self.downsample_ratio = float(args.downsample_ratio)
116
+ self.block_h, self.block_w = (
117
+ 16,
118
+ 16,
119
+ ) # Warning: Must match hardcoded value in CUDA kernel, modify with caution
120
+ self._load_target_images(path=os.path.join(args.data_root, args.input_path))
121
+ if self.downsample:
122
+ self.gt_images_upsampled = self.gt_images
123
+ self.img_h_upsampled, self.img_w_upsampled = self.img_h, self.img_w
124
+ self.tile_bounds_upsampled = self.tile_bounds
125
+ self._load_target_images(
126
+ path=os.path.join(args.data_root, args.input_path),
127
+ downsample_ratio=self.downsample_ratio,
128
+ )
129
+ if not self.evaluate:
130
+ path = f"{self.log_dir}/gt_upsample-{self.downsample_ratio:.1f}_res-{self.img_h_upsampled:d}x{self.img_w_upsampled:d}"
131
+ self._separate_and_save_images(
132
+ images=self.gt_images_upsampled,
133
+ channels=self.input_channels,
134
+ path=path,
135
+ )
136
+ self.num_pixels = self.img_h * self.img_w
137
+ if not self.evaluate:
138
+ path = f"{self.log_dir}/gt_res-{self.img_h:d}x{self.img_w:d}"
139
+ self._separate_and_save_images(
140
+ images=self.gt_images, channels=self.input_channels, path=path
141
+ )
142
+
143
+ def _load_target_images(self, path, downsample_ratio=None):
144
+ self.gt_images, self.input_channels, self.image_fnames = load_images(
145
+ load_path=path, downsample_ratio=downsample_ratio, gamma=self.gamma
146
+ )
147
+ self.gt_images = torch.from_numpy(self.gt_images).to(
148
+ dtype=self.dtype, device=self.device
149
+ )
150
+ self.img_h, self.img_w = self.gt_images.shape[1:]
151
+ self.tile_bounds = (
152
+ (self.img_w + self.block_w - 1) // self.block_w,
153
+ (self.img_h + self.block_h - 1) // self.block_h,
154
+ 1,
155
+ )
156
+
157
+ def _separate_and_save_images(self, images, channels, path):
158
+ images_sep = separate_image_channels(images=images, input_channels=channels)
159
+ for idx, image in enumerate(images_sep, 1):
160
+ suffix = "" if len(images_sep) == 1 else f"_{idx:d}"
161
+ save_image(image, f"{path}{suffix}.png", gamma=self.gamma)
162
+
163
+ def _init_bit_precision(self, args):
164
+ self.quantize = args.quantize
165
+ self.pos_bits = args.pos_bits
166
+ self.scale_bits = args.scale_bits
167
+ self.rot_bits = args.rot_bits
168
+ self.feat_bits = args.feat_bits
169
+
170
+ def _init_gaussians(self, args):
171
+ self.num_gaussians = args.num_gaussians
172
+ self.total_num_gaussians = args.num_gaussians
173
+ self.disable_prog_optim = args.disable_prog_optim
174
+ if not self.disable_prog_optim and not self.evaluate:
175
+ self.initial_ratio = args.initial_ratio
176
+ self.add_times = args.add_times
177
+ self.add_steps = args.add_steps
178
+ self.num_gaussians = math.ceil(
179
+ self.initial_ratio * self.total_num_gaussians
180
+ )
181
+ self.max_add_num = math.ceil(
182
+ float(self.total_num_gaussians - self.num_gaussians) / self.add_times
183
+ )
184
+ min_steps = self.add_steps * self.add_times + args.post_min_steps
185
+ if args.max_steps < min_steps:
186
+ self.worklog.info(
187
+ f"Max steps ({args.max_steps:d}) is too small for progressive optimization. Resetting to {min_steps:d}"
188
+ )
189
+ args.max_steps = min_steps
190
+ self.topk = (
191
+ args.topk
192
+ ) # Warning: Must match hardcoded value in CUDA kernel, modify with caution
193
+ self.eps = (
194
+ 1e-7 if args.disable_tiles else 1e-4
195
+ ) # Warning: Must match hardcoded value in CUDA kernel, modify with caution
196
+ self.init_scale = args.init_scale
197
+ self.disable_topk_norm = args.disable_topk_norm
198
+ self.disable_inverse_scale = args.disable_inverse_scale
199
+ self.disable_color_init = args.disable_color_init
200
+ self.xy = nn.Parameter(
201
+ torch.rand(self.num_gaussians, 2, dtype=self.dtype, device=self.device),
202
+ requires_grad=True,
203
+ )
204
+ self.scale = nn.Parameter(
205
+ torch.ones(self.num_gaussians, 2, dtype=self.dtype, device=self.device),
206
+ requires_grad=True,
207
+ )
208
+ self.rot = nn.Parameter(
209
+ torch.zeros(self.num_gaussians, 1, dtype=self.dtype, device=self.device),
210
+ requires_grad=True,
211
+ )
212
+ self.feat_dim = sum(self.input_channels)
213
+ self.feat = nn.Parameter(
214
+ torch.rand(
215
+ self.num_gaussians, self.feat_dim, dtype=self.dtype, device=self.device
216
+ ),
217
+ requires_grad=True,
218
+ )
219
+ self.vis_feat = nn.Parameter(
220
+ torch.rand_like(self.feat), requires_grad=False
221
+ ) # Only used for Gaussian ID visualization
222
+ self._log_compression_rate()
223
+
224
+ def _log_compression_rate(self):
225
+ bytes_uncompressed = float(self.gt_images.numel())
226
+ bpp_uncompressed = float(8 * self.feat_dim)
227
+ self.worklog.info(
228
+ f"Uncompressed: {bytes_uncompressed / 1e3:.2f} KB | {bpp_uncompressed:.3f} bpp | 8.0 bppc"
229
+ )
230
+ bits_compressed = (
231
+ 2 * self.pos_bits
232
+ + 2 * self.scale_bits
233
+ + self.rot_bits
234
+ + self.feat_dim * self.feat_bits
235
+ ) * self.total_num_gaussians
236
+ bytes_compressed = bits_compressed / 8.0
237
+ bpp_compressed = float(bits_compressed) / self.num_pixels
238
+ bppc_compressed = bpp_compressed / self.feat_dim
239
+ self.num_bytes = bytes_compressed
240
+ self.worklog.info(
241
+ f"Compressed: {bytes_compressed / 1e3:.2f} KB | {bpp_compressed:.3f} bpp | {bppc_compressed:.3f} bppc"
242
+ )
243
+ self.worklog.info(
244
+ f"Compression rate: {bpp_uncompressed / bpp_compressed:.2f}x | {100.0 * bpp_compressed / bpp_uncompressed:.2f}%"
245
+ )
246
+ self.worklog.info("***********************************************")
247
+
248
+ def _init_loss(self, args):
249
+ self.l1_loss = None
250
+ self.l2_loss = None
251
+ self.ssim_loss = None
252
+ self.l1_loss_ratio = args.l1_loss_ratio
253
+ self.l2_loss_ratio = args.l2_loss_ratio
254
+ self.ssim_loss_ratio = args.ssim_loss_ratio
255
+
256
+ def _init_optimization(self, args):
257
+ self.disable_tiles = args.disable_tiles
258
+ self.start_step = 1
259
+ self.max_steps = args.max_steps
260
+ self.pos_lr = args.pos_lr
261
+ self.scale_lr = args.scale_lr
262
+ self.rot_lr = args.rot_lr
263
+ self.feat_lr = args.feat_lr
264
+ self.optimizer = torch.optim.Adam(
265
+ [
266
+ {"params": self.xy, "lr": self.pos_lr},
267
+ {"params": self.scale, "lr": self.scale_lr},
268
+ {"params": self.rot, "lr": self.rot_lr},
269
+ {"params": self.feat, "lr": self.feat_lr},
270
+ ]
271
+ )
272
+ self.disable_lr_schedule = args.disable_lr_schedule
273
+ if not self.disable_lr_schedule:
274
+ self.decay_ratio = args.decay_ratio
275
+ self.check_decay_steps = args.check_decay_steps
276
+ self.max_decay_times = args.max_decay_times
277
+ self.decay_threshold = args.decay_threshold
278
+
279
+ def _init_pos_scale_feat(self, args):
280
+ self.init_mode = args.init_mode
281
+ self.init_random_ratio = args.init_random_ratio
282
+ self.pixel_xy = (
283
+ get_grid(h=self.img_h, w=self.img_w)
284
+ .to(dtype=self.dtype, device=self.device)
285
+ .reshape(-1, 2)
286
+ )
287
+ with torch.no_grad():
288
+ # Position
289
+ if self.init_mode == "gradient":
290
+ self._compute_gmap()
291
+ self.xy.copy_(self._sample_pos(prob=self.image_gradients))
292
+ elif self.init_mode == "saliency":
293
+ self.smap_filter_size = args.smap_filter_size
294
+ self._compute_smap(path="models")
295
+ self.xy.copy_(self._sample_pos(prob=self.saliency))
296
+ else:
297
+ selected = np.random.choice(
298
+ self.num_pixels, self.num_gaussians, replace=False, p=None
299
+ )
300
+ self.xy.copy_(self.pixel_xy.detach().clone()[selected])
301
+ # Scale
302
+ self.scale.fill_(
303
+ self.init_scale if self.disable_inverse_scale else 1.0 / self.init_scale
304
+ )
305
+ # Feature
306
+ if not self.disable_color_init:
307
+ self.feat.copy_(
308
+ self._get_target_features(positions=self.xy).detach().clone()
309
+ )
310
+
311
+ def _sample_pos(self, prob):
312
+ num_random = round(self.init_random_ratio * self.num_gaussians)
313
+ selected_random = np.random.choice(
314
+ self.num_pixels, num_random, replace=False, p=None
315
+ )
316
+ selected_other = np.random.choice(
317
+ self.num_pixels, self.num_gaussians - num_random, replace=False, p=prob
318
+ )
319
+ return torch.cat(
320
+ [
321
+ self.pixel_xy.detach().clone()[selected_random],
322
+ self.pixel_xy.detach().clone()[selected_other],
323
+ ],
324
+ dim=0,
325
+ )
326
+
327
+ def _compute_gmap(self):
328
+ gy, gx = compute_image_gradients(
329
+ np.power(self.gt_images.detach().cpu().clone().numpy(), 1.0 / self.gamma)
330
+ )
331
+ g_norm = np.hypot(gy, gx).astype(np.float32)
332
+ g_norm = g_norm / g_norm.max()
333
+ save_image(g_norm, f"{self.log_dir}/gmap_res-{self.img_h:d}x{self.img_w:d}.png")
334
+ g_norm = np.power(g_norm.reshape(-1), 2.0)
335
+ self.image_gradients = g_norm / g_norm.sum()
336
+ self.worklog.info("Image gradient map successfully saved")
337
+ self.worklog.info("***********************************************")
338
+
339
+ def _compute_smap(self, path):
340
+ smap = get_smap(
341
+ torch.pow(self.gt_images.detach().clone(), 1.0 / self.gamma),
342
+ path,
343
+ self.smap_filter_size,
344
+ )
345
+ save_image(smap, f"{self.log_dir}/smap_res-{self.img_h:d}x{self.img_w:d}.png")
346
+ self.saliency = (smap / smap.sum()).reshape(-1)
347
+ self.worklog.info("Saliency map successfully saved")
348
+ self.worklog.info("***********************************************")
349
+
350
+ def _get_target_features(self, positions):
351
+ with torch.no_grad():
352
+ # gt_images [1, C, H, W]; positions [1, 1, P, 2]; top-left [-1, -1]; bottom-right [1, 1]
353
+ target_features = F.grid_sample(
354
+ self.gt_images.unsqueeze(0),
355
+ positions[None, None, ...] * 2.0 - 1.0,
356
+ align_corners=False,
357
+ )
358
+ target_features = target_features[0, :, 0, :].permute(1, 0) # [P, C]
359
+ return target_features
360
+
361
+ def _load_model(self):
362
+ if self.ckpt_file != "":
363
+ ckpt_path = os.path.join(self.ckpt_dir, self.ckpt_file)
364
+ else:
365
+ latest_step = get_latest_ckpt_step(self.ckpt_dir)
366
+ if latest_step == -1:
367
+ raise FileNotFoundError(f"No checkpoint found in '{self.ckpt_dir}'")
368
+ ckpt_path = os.path.join(self.ckpt_dir, f"ckpt_step-{latest_step:d}.pt")
369
+ checkpoint = torch.load(ckpt_path, weights_only=False)
370
+ self.load_state_dict(checkpoint["state_dict"])
371
+ self.optimizer.load_state_dict(checkpoint["optim_state_dict"])
372
+ self.start_step = checkpoint["step"] + 1
373
+ self.worklog.info(f"Checkpoint '{ckpt_path}' successfully loaded")
374
+ self.worklog.info("***********************************************")
375
+
376
+ def _save_model(self):
377
+ if self.quantize:
378
+ self._quantize()
379
+ psnr, ssim = self._evaluate(log=False, upsample=False)
380
+ self._evaluate_extra()
381
+ ckpt_data = {
382
+ "step": self.step,
383
+ "psnr": psnr,
384
+ "ssim": ssim,
385
+ "lpips": self.lpips_final,
386
+ "flip": self.flip_final,
387
+ "msssim": self.msssim_final,
388
+ "bytes": self.num_bytes,
389
+ "time": self.total_time_accum,
390
+ "state_dict": self.state_dict(),
391
+ "optim_state_dict": self.optimizer.state_dict(),
392
+ }
393
+ save_path = f"{self.ckpt_dir}/ckpt_step-{self.step:d}.pt"
394
+ torch.save(ckpt_data, save_path)
395
+ self.worklog.info(f"Checkpoint 'ckpt_step-{self.step:d}.pt' successfully saved")
396
+ self.worklog.info(
397
+ f"PSNR: {psnr:.2f} | SSIM: {ssim:.4f} | LPIPS: {self.lpips_final:.4f} | FLIP: {self.flip_final:.4f} | MS-SSIM: {self.msssim_final:.4f}"
398
+ )
399
+ self.worklog.info("***********************************************")
400
+
401
+ def _quantize(self):
402
+ with torch.no_grad():
403
+ self.xy.copy_(ste_quantize(self.xy, self.pos_bits))
404
+ self.scale.copy_(ste_quantize(self.scale, self.scale_bits))
405
+ self.rot.copy_(ste_quantize(self.rot, self.rot_bits))
406
+ self.feat.copy_(ste_quantize(self.feat, self.feat_bits))
407
+
408
+ def render(self, render_height=None):
409
+ img_h, img_w = self.img_h, self.img_w
410
+ if render_height is not None:
411
+ img_h, img_w = render_height, round((float(render_height) / img_h) * img_w)
412
+ tile_bounds = (
413
+ (img_w + self.block_w - 1) // self.block_w,
414
+ (img_h + self.block_h - 1) // self.block_h,
415
+ 1,
416
+ )
417
+ upsample_ratio = float(img_h) / self.img_h
418
+ with torch.no_grad():
419
+ num_prep_runs = 2
420
+ for _ in range(num_prep_runs):
421
+ self.forward(img_h, img_w, tile_bounds, upsample_ratio, benchmark=True)
422
+ images, render_time = self.forward(
423
+ img_h, img_w, tile_bounds, upsample_ratio
424
+ )
425
+ path = f"{self.eval_dir}/render_upsample-{upsample_ratio:.1f}_res-{img_h:d}x{img_w:d}"
426
+ self._separate_and_save_images(
427
+ images=images, channels=self.input_channels, path=path
428
+ )
429
+ self.worklog.info(f"Step: {self.start_step - 1:d} | Time: {render_time:.6f} s")
430
+ self.worklog.info(f"Rendering at resolution ({img_h:d}, {img_w:d}) completed")
431
+ self.worklog.info("***********************************************")
432
+
433
+ def benchmark_render_time(self, num_reps, render_height=None):
434
+ img_h, img_w = self.img_h, self.img_w
435
+ if render_height is not None:
436
+ img_h, img_w = render_height, round((float(render_height) / img_h) * img_w)
437
+ tile_bounds = (
438
+ (img_w + self.block_w - 1) // self.block_w,
439
+ (img_h + self.block_h - 1) // self.block_h,
440
+ 1,
441
+ )
442
+ upsample_ratio = float(img_h) / self.img_h
443
+ with torch.no_grad():
444
+ render_time_all = np.zeros(num_reps, dtype=np.float32)
445
+ num_prep_runs = 2
446
+ for _ in range(num_prep_runs):
447
+ self.forward(img_h, img_w, tile_bounds, upsample_ratio, benchmark=True)
448
+ for rid in range(num_reps):
449
+ render_time = self.forward(
450
+ img_h, img_w, tile_bounds, upsample_ratio, benchmark=True
451
+ )
452
+ render_time_all[rid] = render_time
453
+ return render_time_all
454
+
455
+ def forward(self, img_h, img_w, tile_bounds, upsample_ratio=None, benchmark=False):
456
+ scale = self._get_scale(upsample_ratio=upsample_ratio)
457
+ xy, rot, feat = self.xy, self.rot, self.feat
458
+ if self.quantize:
459
+ xy, scale, rot, feat = (
460
+ ste_quantize(xy, self.pos_bits),
461
+ ste_quantize(scale, self.scale_bits),
462
+ ste_quantize(rot, self.rot_bits),
463
+ ste_quantize(feat, self.feat_bits),
464
+ )
465
+ begin = perf_counter()
466
+ tmp = project_gaussians_2d_scale_rot(xy, scale, rot, img_h, img_w, tile_bounds)
467
+ xy, radii, conics, num_tiles_hit = tmp
468
+ if not self.disable_tiles:
469
+ enable_topk_norm = not self.disable_topk_norm
470
+ tmp = (
471
+ xy,
472
+ radii,
473
+ conics,
474
+ num_tiles_hit,
475
+ feat,
476
+ img_h,
477
+ img_w,
478
+ self.block_h,
479
+ self.block_w,
480
+ enable_topk_norm,
481
+ )
482
+ out_image = rasterize_gaussians_sum(*tmp)
483
+ else:
484
+ tmp = xy, conics, feat, img_h, img_w
485
+ out_image = rasterize_gaussians_no_tiles(*tmp)
486
+ render_time = perf_counter() - begin
487
+ if benchmark:
488
+ return render_time
489
+ out_image = (
490
+ out_image.view(-1, img_h, img_w, self.feat_dim)
491
+ .permute(0, 3, 1, 2)
492
+ .contiguous()
493
+ )
494
+ return out_image.squeeze(dim=0), render_time
495
+
496
+ def _get_scale(self, upsample_ratio=None):
497
+ scale = self.scale
498
+ if not self.disable_inverse_scale:
499
+ scale = 1.0 / scale
500
+ if upsample_ratio is not None:
501
+ scale = upsample_ratio * scale
502
+ return scale
503
+
504
+ def _visualize_gaussian_id(self, img_h, img_w, tile_bounds, upsample_ratio=None):
505
+ scale = self._get_scale(upsample_ratio=upsample_ratio)
506
+ xy, rot, feat = self.xy, self.rot, self.feat
507
+ if self.quantize:
508
+ xy, scale, rot, feat = (
509
+ ste_quantize(xy, self.pos_bits),
510
+ ste_quantize(scale, self.scale_bits),
511
+ ste_quantize(rot, self.rot_bits),
512
+ ste_quantize(feat, self.feat_bits),
513
+ )
514
+ feat = self.vis_feat * feat.norm(dim=-1, keepdim=True)
515
+ tmp = project_gaussians_2d_scale_rot(xy, scale, rot, img_h, img_w, tile_bounds)
516
+ xy, radii, conics, num_tiles_hit = tmp
517
+ if not self.disable_tiles:
518
+ enable_topk_norm = not self.disable_topk_norm
519
+ tmp = (
520
+ xy,
521
+ radii,
522
+ conics,
523
+ num_tiles_hit,
524
+ feat,
525
+ img_h,
526
+ img_w,
527
+ self.block_h,
528
+ self.block_w,
529
+ enable_topk_norm,
530
+ )
531
+ out_image = rasterize_gaussians_sum(*tmp)
532
+ else:
533
+ tmp = xy, conics, feat, img_h, img_w
534
+ out_image = rasterize_gaussians_no_tiles(*tmp)
535
+ out_image = (
536
+ out_image.view(-1, img_h, img_w, self.feat_dim)
537
+ .permute(0, 3, 1, 2)
538
+ .contiguous()
539
+ )
540
+ return out_image.squeeze(dim=0)
541
+
542
+ def optimize(self):
543
+ self.psnr_curr, self.ssim_curr = 0.0, 0.0
544
+ self.best_psnr, self.best_ssim = 0.0, 0.0
545
+ self.decay_times, self.no_improvement_steps = 0, 0
546
+ self.render_time_accum, self.total_time_accum = 0.0, 0.0
547
+ self.lpips_final, self.flip_final, self.msssim_final = 1.0, 1.0, 0.0
548
+
549
+ self.step = 0
550
+ with torch.no_grad():
551
+ self._log_images(log_final=False, plot_gaussians=self.vis_gaussians)
552
+ for step in range(self.start_step, self.max_steps + 1):
553
+ self.step = step
554
+ self.optimizer.zero_grad()
555
+ # Rendering
556
+ images, render_time = self.forward(self.img_h, self.img_w, self.tile_bounds)
557
+ self.render_time_accum += render_time
558
+ # Optimization
559
+ begin = perf_counter()
560
+ self._get_total_loss(images)
561
+ self.total_loss.backward()
562
+ self.optimizer.step()
563
+ self.total_time_accum += perf_counter() - begin + render_time
564
+ # Logging
565
+ terminate = False
566
+ with torch.no_grad():
567
+ if self.step % self.eval_steps == 0:
568
+ self._evaluate(log=True, upsample=False)
569
+ if (
570
+ not self.disable_lr_schedule
571
+ and self.num_gaussians == self.total_num_gaussians
572
+ ):
573
+ terminate = self._lr_schedule()
574
+ if self.step % self.save_image_steps == 0:
575
+ self._log_images(log_final=False, plot_gaussians=self.vis_gaussians)
576
+ if (
577
+ self.step % self.save_ckpt_steps == 0
578
+ and self.num_gaussians == self.total_num_gaussians
579
+ ):
580
+ self._save_model()
581
+ if (
582
+ not self.disable_prog_optim
583
+ and self.step % self.add_steps == 0
584
+ and self.num_gaussians < self.total_num_gaussians
585
+ ):
586
+ self._add_gaussians(
587
+ self.max_add_num, plot_gaussians=self.vis_gaussians
588
+ )
589
+ if terminate:
590
+ break
591
+ with torch.no_grad():
592
+ self._log_images(log_final=True, plot_gaussians=self.vis_gaussians)
593
+ self._save_model()
594
+ self.worklog.info("Optimization completed")
595
+ self.worklog.info("***********************************************")
596
+ self.worklog.info(
597
+ f"Mean scale: {self._get_scale().mean().item():.4f} (pixel) | {self.scale.mean().item():.4f} (raw)"
598
+ )
599
+ self.worklog.info("***********************************************")
600
+ return self.psnr_curr, self.ssim_curr
601
+
602
+ def _get_total_loss(self, images):
603
+ self.total_loss = 0
604
+ if self.l1_loss_ratio > 1e-7:
605
+ self.l1_loss = self.l1_loss_ratio * F.l1_loss(images, self.gt_images)
606
+ self.total_loss += self.l1_loss
607
+ else:
608
+ self.l1_loss = None
609
+ if self.l2_loss_ratio > 1e-7:
610
+ self.l2_loss = self.l2_loss_ratio * F.mse_loss(images, self.gt_images)
611
+ self.total_loss += self.l2_loss
612
+ else:
613
+ self.l2_loss = None
614
+ if self.ssim_loss_ratio > 1e-7:
615
+ self.ssim_loss = self.ssim_loss_ratio * (
616
+ 1 - fused_ssim(images.unsqueeze(0), self.gt_images.unsqueeze(0))
617
+ )
618
+ self.total_loss += self.ssim_loss
619
+ else:
620
+ self.ssim_loss = None
621
+
622
+ def _evaluate(self, log=True, upsample=False):
623
+ if upsample: # Do not log performance metrics for upsampled images
624
+ log = False
625
+ images = torch.pow(
626
+ torch.clamp(self._render_images(upsample=upsample), 0.0, 1.0),
627
+ 1.0 / self.gamma,
628
+ )
629
+ gt_images = torch.pow(
630
+ self.gt_images_upsampled if upsample else self.gt_images, 1.0 / self.gamma
631
+ )
632
+ psnr = get_psnr(images, gt_images).item()
633
+ ssim = fused_ssim(images.unsqueeze(0), gt_images.unsqueeze(0)).item()
634
+ if log:
635
+ self.psnr_curr, self.ssim_curr = psnr, ssim
636
+ loss_results = f"Loss: {self.total_loss.item():.4f}"
637
+ loss_results += (
638
+ f", L1: {self.l1_loss.item():.4f}" if self.l1_loss is not None else ""
639
+ )
640
+ loss_results += (
641
+ f", L2: {self.l2_loss.item():.4f}" if self.l2_loss is not None else ""
642
+ )
643
+ loss_results += (
644
+ f", SSIM: {self.ssim_loss.item():.4f}"
645
+ if self.ssim_loss is not None
646
+ else ""
647
+ )
648
+ time_results = f"Total: {self.total_time_accum:.2f} s | Render: {self.render_time_accum:.2f} s"
649
+ self.worklog.info(
650
+ f"Step: {self.step:d} | {time_results} | {loss_results} | PSNR: {self.psnr_curr:.2f} | SSIM: {self.ssim_curr:.4f}"
651
+ )
652
+ return psnr, ssim
653
+
654
+ def _evaluate_extra(self):
655
+ images = torch.pow(
656
+ torch.clamp(self._render_images(upsample=False), 0.0, 1.0), 1.0 / self.gamma
657
+ )[None, ...]
658
+ gt_images = torch.pow(self.gt_images, 1.0 / self.gamma)[None, ...]
659
+ msssim_metric = (
660
+ MS_SSIM(data_range=1.0, size_average=True, channel=self.feat_dim)
661
+ .to(device=self.device)
662
+ .eval()
663
+ )
664
+ self.msssim_final = msssim_metric(images, gt_images).item()
665
+ lpips_metric = LPIPS(net="alex").to(device=self.device).eval()
666
+ flip_metric = LDRFLIPLoss().to(device=self.device).eval()
667
+ num_channels = 1 if self.feat_dim < 3 else 3
668
+ self.lpips_final = lpips_metric(
669
+ images[:, :num_channels], gt_images[:, :num_channels]
670
+ ).item()
671
+ if self.feat_dim >= 3:
672
+ self.flip_final = flip_metric(images[:, :3], gt_images[:, :3]).item()
673
+
674
+ def _log_images(self, log_final=False, plot_gaussians=False):
675
+ images = self._render_images(upsample=False)
676
+ if log_final:
677
+ path = f"{self.log_dir}/render_res-{self.img_h:d}x{self.img_w:d}"
678
+ self._separate_and_save_images(
679
+ images=images, channels=self.input_channels, path=path
680
+ )
681
+ psnr, ssim = self._evaluate(log=False, upsample=False)
682
+ path = f"{self.train_dir}/render_step-{self.step:d}_psnr-{psnr:.2f}_ssim-{ssim:.4f}_res-{self.img_h:d}x{self.img_w:d}"
683
+ self._separate_and_save_images(
684
+ images=images, channels=self.input_channels, path=path
685
+ )
686
+ if plot_gaussians:
687
+ path = f"{self.train_dir}/gaussian_step-{self.step:d}_psnr-{psnr:.2f}_ssim-{ssim:.4f}_res-{self.img_h:d}x{self.img_w:d}"
688
+ visualize_gaussians(
689
+ path,
690
+ self.xy,
691
+ self._get_scale(),
692
+ self.rot,
693
+ self.feat,
694
+ self.img_h,
695
+ self.img_w,
696
+ self.input_channels,
697
+ alpha=0.8,
698
+ gamma=self.gamma,
699
+ )
700
+ images = self._visualize_gaussian_id(
701
+ self.img_h, self.img_w, self.tile_bounds
702
+ )
703
+ path = f"{self.train_dir}/gaussian-id_step-{self.step:d}_psnr-{psnr:.2f}_ssim-{ssim:.4f}_res-{self.img_h:d}x{self.img_w:d}"
704
+ self._separate_and_save_images(
705
+ images=images, channels=self.input_channels, path=path
706
+ )
707
+ if self.downsample:
708
+ images = self._render_images(upsample=True)
709
+ psnr, ssim = self._evaluate(log=False, upsample=True)
710
+ img_h, img_w = self.img_h_upsampled, self.img_w_upsampled
711
+ path = f"{self.train_dir}/render_upsample-{self.downsample_ratio:.1f}_step-{self.step:d}_psnr-{psnr:.2f}_ssim-{ssim:.4f}_res-{img_h:d}x{img_w:d}"
712
+ self._separate_and_save_images(
713
+ images=images, channels=self.input_channels, path=path
714
+ )
715
+
716
+ def _render_images(self, upsample=False):
717
+ if upsample:
718
+ images, _ = self.forward(
719
+ self.img_h_upsampled,
720
+ self.img_w_upsampled,
721
+ self.tile_bounds_upsampled,
722
+ upsample_ratio=self.downsample_ratio,
723
+ )
724
+ else:
725
+ images, _ = self.forward(self.img_h, self.img_w, self.tile_bounds)
726
+ return images
727
+
728
+ def _lr_schedule(self):
729
+ if (
730
+ self.psnr_curr <= self.best_psnr + 100 * self.decay_threshold
731
+ or self.ssim_curr <= self.best_ssim + self.decay_threshold
732
+ ):
733
+ self.no_improvement_steps += self.eval_steps
734
+ if self.no_improvement_steps >= self.check_decay_steps:
735
+ self.no_improvement_steps = 0
736
+ self.decay_times += 1
737
+ if self.decay_times > self.max_decay_times:
738
+ return True
739
+ for param_group in self.optimizer.param_groups:
740
+ param_group["lr"] /= self.decay_ratio
741
+ self.worklog.info(f"Learning rate decayed by {self.decay_ratio:.1f}")
742
+ self.worklog.info("***********************************************")
743
+ return False
744
+ else:
745
+ self.best_psnr = self.psnr_curr
746
+ self.best_ssim = self.ssim_curr
747
+ self.no_improvement_steps = 0
748
+ return False
749
+
750
+ def _add_gaussians(self, add_num, plot_gaussians=False):
751
+ add_num = min(
752
+ add_num, self.max_add_num, self.total_num_gaussians - self.num_gaussians
753
+ )
754
+ if add_num <= 0:
755
+ return
756
+ raw_images = self._render_images(upsample=False)
757
+ images = torch.pow(torch.clamp(raw_images, 0.0, 1.0), 1.0 / self.gamma)
758
+ gt_images = torch.pow(self.gt_images, 1.0 / self.gamma)
759
+ kernel_size = round(np.sqrt(self.img_h * self.img_w) // 400)
760
+ if kernel_size >= 1:
761
+ kernel_size = max(3, kernel_size)
762
+ kernel_size = kernel_size + 1 if kernel_size % 2 == 0 else kernel_size
763
+ gt_images = gaussian_blur(img=gt_images, kernel_size=kernel_size)
764
+ diff_map = (gt_images - images).detach().clone()
765
+ error_map = torch.pow(torch.abs(diff_map).mean(dim=0).reshape(-1), 2.0)
766
+ sample_prob = (error_map / error_map.sum()).cpu().numpy()
767
+ selected = np.random.choice(
768
+ self.num_pixels, add_num, replace=False, p=sample_prob
769
+ )
770
+ # New Gaussians
771
+ new_xy = self.pixel_xy.detach().clone()[selected]
772
+ new_scale = torch.ones(add_num, 2, dtype=self.dtype, device=self.device)
773
+ init_scale = self.init_scale
774
+ new_scale.fill_(init_scale if self.disable_inverse_scale else 1.0 / init_scale)
775
+ new_rot = torch.zeros(add_num, 1, dtype=self.dtype, device=self.device)
776
+ new_feat = diff_map.permute(1, 2, 0).reshape(-1, self.feat_dim)[selected]
777
+ new_vis_feat = torch.rand_like(new_feat)
778
+ # Old Gaussians
779
+ old_xy = self.xy.detach().clone()
780
+ old_scale = self.scale.detach().clone()
781
+ old_rot = self.rot.detach().clone()
782
+ old_feat = self.feat.detach().clone()
783
+ old_vis_feat = self.vis_feat.detach().clone()
784
+ # Update trainable parameters
785
+ self.num_gaussians += add_num
786
+ all_xy = torch.cat([old_xy, new_xy], dim=0)
787
+ all_scale = torch.cat([old_scale, new_scale], dim=0)
788
+ all_rot = torch.cat([old_rot, new_rot], dim=0)
789
+ all_feat = torch.cat([old_feat, new_feat], dim=0)
790
+ all_vis_feat = torch.cat([old_vis_feat, new_vis_feat], dim=0)
791
+ self.xy = nn.Parameter(all_xy, requires_grad=True)
792
+ self.scale = nn.Parameter(all_scale, requires_grad=True)
793
+ self.rot = nn.Parameter(all_rot, requires_grad=True)
794
+ self.feat = nn.Parameter(all_feat, requires_grad=True)
795
+ self.vis_feat = nn.Parameter(all_vis_feat, requires_grad=False)
796
+ # Plot Gaussians
797
+ if plot_gaussians:
798
+ path = f"{self.train_dir}/add-gaussian_step-{self.step:d}_num-{self.num_gaussians:d}_res-{self.img_h:d}x{self.img_w:d}"
799
+ every_n = max(1, self.total_num_gaussians // 2000)
800
+ size = (self.img_h * self.img_w) / 1e4
801
+ visualize_added_gaussians(
802
+ path,
803
+ raw_images,
804
+ old_xy,
805
+ new_xy,
806
+ self.input_channels,
807
+ size=size,
808
+ every_n=every_n,
809
+ alpha=0.8,
810
+ gamma=self.gamma,
811
+ )
812
+ # Update optimizer
813
+ self.optimizer = torch.optim.Adam(
814
+ [
815
+ {"params": self.xy, "lr": self.pos_lr},
816
+ {"params": self.scale, "lr": self.scale_lr},
817
+ {"params": self.rot, "lr": self.rot_lr},
818
+ {"params": self.feat, "lr": self.feat_lr},
819
+ ]
820
+ )
821
+ self.worklog.info(
822
+ f"Step: {self.step:d} | Adding {add_num:d} Gaussians ({self.num_gaussians - add_num:d} -> {self.num_gaussians:d})"
823
+ )
824
+ self.worklog.info("***********************************************")
pyproject.toml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "image-gs"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "lpips>=0.1.4",
9
+ "matplotlib>=3.10.6",
10
+ "numpy>=2.2.6",
11
+ "pytorch-msssim>=1.0.0",
12
+ "scikit-image>=0.25.2",
13
+ "scipy>=1.15.3",
14
+ "torch>=2.6.0",
15
+ "torchmetrics>=1.8.2",
16
+ "torchvision>=0.21.0",
17
+ "fused_ssim",
18
+ "pyyaml>=6.0.2",
19
+ "gsplat",
20
+ "gradio>=4.0.0",
21
+ "huggingface_hub>=0.24.0",
22
+ ]
23
+
24
+ # We use python 3.10 and cu124
25
+ [tool.uv.sources]
26
+ fused_ssim = { git = "https://github.com/rahul-goel/fused-ssim/" }
27
+ torch = [
28
+ { index = "pytorch-cu124", marker = "sys_platform == 'linux'" },
29
+ ]
30
+ torchvision = [
31
+ { index = "pytorch-cu124", marker = "sys_platform == 'linux'" },
32
+ ]
33
+
34
+
35
+ [tool.uv.extra-build-dependencies]
36
+ fused-ssim = ["torch", "numpy"]
37
+
38
+ [[tool.uv.index]]
39
+ name = "pytorch-cu124"
40
+ url = "https://download.pytorch.org/whl/cu124"
41
+ explicit = true
42
+
43
+ [dependency-groups]
44
+ dev = [
45
+ "huggingface-hub[cli]>=0.34.4",
46
+ ]
utils/__init__.py ADDED
File without changes
utils/flip.py ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FLIP metric functions"""
2
+ #################################################################################
3
+ # Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Redistribution and use in source and binary forms, with or without
6
+ # modification, are permitted provided that the following conditions are met:
7
+ #
8
+ # 1. Redistributions of source code must retain the above copyright notice, this
9
+ # list of conditions and the following disclaimer.
10
+ #
11
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ # this list of conditions and the following disclaimer in the documentation
13
+ # and/or other materials provided with the distribution.
14
+ #
15
+ # 3. Neither the name of the copyright holder nor the names of its
16
+ # contributors may be used to endorse or promote products derived from
17
+ # this software without specific prior written permission.
18
+ #
19
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ #
30
+ # SPDX-FileCopyrightText: Copyright (c) 2020-2024 NVIDIA CORPORATION & AFFILIATES
31
+ # SPDX-License-Identifier: BSD-3-Clause
32
+ #################################################################################
33
+
34
+ # Visualizing and Communicating Errors in Rendered Images
35
+ # Ray Tracing Gems II, 2021,
36
+ # by Pontus Andersson, Jim Nilsson, and Tomas Akenine-Moller.
37
+ # Pointer to the chapter: https://research.nvidia.com/publication/2021-08_Visualizing-and-Communicating.
38
+
39
+ # Visualizing Errors in Rendered High Dynamic Range Images
40
+ # Eurographics 2021,
41
+ # by Pontus Andersson, Jim Nilsson, Peter Shirley, and Tomas Akenine-Moller.
42
+ # Pointer to the paper: https://research.nvidia.com/publication/2021-05_HDR-FLIP.
43
+
44
+ # FLIP: A Difference Evaluator for Alternating Images
45
+ # High Performance Graphics 2020,
46
+ # by Pontus Andersson, Jim Nilsson, Tomas Akenine-Moller,
47
+ # Magnus Oskarsson, Kalle Astrom, and Mark D. Fairchild.
48
+ # Pointer to the paper: https://research.nvidia.com/publication/2020-07_FLIP.
49
+
50
+ # Code by Pontus Ebelin (formerly Andersson), Jim Nilsson, and Tomas Akenine-Moller.
51
+
52
+ import sys
53
+ import numpy as np
54
+ import torch
55
+ import torch.nn as nn
56
+
57
+
58
+ class HDRFLIPLoss(nn.Module):
59
+ """Class for computing HDR-FLIP"""
60
+
61
+ def __init__(self):
62
+ """Init"""
63
+ super().__init__()
64
+ self.qc = 0.7
65
+ self.qf = 0.5
66
+ self.pc = 0.4
67
+ self.pt = 0.95
68
+ self.tmax = 0.85
69
+ self.tmin = 0.85
70
+ self.eps = 1e-15
71
+
72
+ def forward(
73
+ self,
74
+ test,
75
+ reference,
76
+ pixels_per_degree=(0.7 * 3840 / 0.7) * np.pi / 180,
77
+ tone_mapper="aces",
78
+ start_exposure=None,
79
+ stop_exposure=None,
80
+ ):
81
+ """
82
+ Computes the HDR-FLIP error map between two HDR images,
83
+ assuming the images are observed at a certain number of
84
+ pixels per degree of visual angle
85
+
86
+ :param test: test tensor (with NxCxHxW layout with nonnegative values)
87
+ :param reference: reference tensor (with NxCxHxW layout with nonnegative values)
88
+ :param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer,
89
+ default corresponds to viewing the images on a 0.7 meters wide 4K monitor at 0.7 meters from the display
90
+ :param tone_mapper: (optional) string describing what tone mapper HDR-FLIP should assume
91
+ :param start_exposure: (optional tensor (with Nx1x1x1 layout) with start exposures corresponding to each HDR reference/test pair
92
+ :param stop_exposure: (optional) tensor (with Nx1x1x1 layout) with stop exposures corresponding to each HDR reference/test pair
93
+ :return: float containing the mean FLIP error (in the range [0,1]) between the HDR reference and test images in the batch
94
+ """
95
+ # HDR-FLIP expects nonnegative and non-NaN values in the input
96
+ reference = torch.clamp(reference, 0, 65536.0)
97
+ test = torch.clamp(test, 0, 65536.0)
98
+
99
+ # Compute start and stop exposures, if they are not given
100
+ if start_exposure is None or stop_exposure is None:
101
+ c_start, c_stop = compute_start_stop_exposures(
102
+ reference, tone_mapper, self.tmax, self.tmin
103
+ )
104
+ if start_exposure is None:
105
+ start_exposure = c_start
106
+ if stop_exposure is None:
107
+ stop_exposure = c_stop
108
+
109
+ # Compute number of exposures
110
+ num_exposures = torch.max(
111
+ torch.tensor([2.0], requires_grad=False).cuda(),
112
+ torch.ceil(stop_exposure - start_exposure),
113
+ )
114
+ most_exposures = int(torch.amax(num_exposures, dim=0).item())
115
+
116
+ # Compute exposure step size
117
+ step_size = (stop_exposure - start_exposure) / torch.max(
118
+ num_exposures - 1, torch.tensor([1.0], requires_grad=False).cuda()
119
+ )
120
+
121
+ # Set the depth of the error tensor to the number of exposures given by the largest exposure range any reference image yielded.
122
+ # This allows us to do one loop for each image in our batch, while not affecting the HDR-FLIP error, as we fill up the error tensor with 0s.
123
+ # Note that the step size still depends on num_exposures and is therefore independent of most_exposures
124
+ dim = reference.size()
125
+ all_errors = torch.zeros(size=(dim[0], most_exposures, dim[2], dim[3])).cuda()
126
+
127
+ # Loop over exposures and compute LDR-FLIP for each pair of LDR reference and test
128
+ for i in range(0, most_exposures):
129
+ exposure = start_exposure + i * step_size
130
+
131
+ reference_tone_mapped = tone_map(reference, tone_mapper, exposure)
132
+ test_tone_mapped = tone_map(test, tone_mapper, exposure)
133
+
134
+ reference_opponent = color_space_transform(
135
+ reference_tone_mapped, "linrgb2ycxcz"
136
+ )
137
+ test_opponent = color_space_transform(test_tone_mapped, "linrgb2ycxcz")
138
+
139
+ all_errors[:, i, :, :] = compute_ldrflip(
140
+ test_opponent,
141
+ reference_opponent,
142
+ pixels_per_degree,
143
+ self.qc,
144
+ self.qf,
145
+ self.pc,
146
+ self.pt,
147
+ self.eps,
148
+ ).squeeze(1)
149
+
150
+ # Take per-pixel maximum over all LDR-FLIP errors to get HDR-FLIP
151
+ hdrflip_error = torch.amax(all_errors, dim=1, keepdim=True)
152
+ return torch.mean(hdrflip_error)
153
+
154
+
155
+ class LDRFLIPLoss(nn.Module):
156
+ """Class for computing LDR FLIP loss"""
157
+
158
+ def __init__(self):
159
+ """Init"""
160
+ super().__init__()
161
+ self.qc = 0.7
162
+ self.qf = 0.5
163
+ self.pc = 0.4
164
+ self.pt = 0.95
165
+ self.eps = 1e-15
166
+
167
+ def forward(
168
+ self, test, reference, pixels_per_degree=(0.7 * 3840 / 0.7) * np.pi / 180
169
+ ):
170
+ """
171
+ Computes the LDR-FLIP error map between two LDR images,
172
+ assuming the images are observed at a certain number of
173
+ pixels per degree of visual angle
174
+
175
+ :param test: test tensor (with NxCxHxW layout with values in the range [0, 1] in the sRGB color space)
176
+ :param reference: reference tensor (with NxCxHxW layout with values in the range [0, 1] in the sRGB color space)
177
+ :param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer,
178
+ default corresponds to viewing the images on a 0.7 meters wide 4K monitor at 0.7 meters from the display
179
+ :return: float containing the mean FLIP error (in the range [0,1]) between the LDR reference and test images in the batch
180
+ """
181
+ # LDR-FLIP expects non-NaN values in [0,1] as input
182
+ reference = torch.clamp(reference, 0, 1)
183
+ test = torch.clamp(test, 0, 1)
184
+
185
+ # Transform reference and test to opponent color space
186
+ reference_opponent = color_space_transform(reference, "srgb2ycxcz")
187
+ test_opponent = color_space_transform(test, "srgb2ycxcz")
188
+
189
+ deltaE = compute_ldrflip(
190
+ test_opponent,
191
+ reference_opponent,
192
+ pixels_per_degree,
193
+ self.qc,
194
+ self.qf,
195
+ self.pc,
196
+ self.pt,
197
+ self.eps,
198
+ )
199
+
200
+ return torch.mean(deltaE)
201
+
202
+
203
+ def compute_ldrflip(test, reference, pixels_per_degree, qc, qf, pc, pt, eps):
204
+ """
205
+ Computes the LDR-FLIP error map between two LDR images,
206
+ assuming the images are observed at a certain number of
207
+ pixels per degree of visual angle
208
+
209
+ :param reference: reference tensor (with NxCxHxW layout with values in the YCxCz color space)
210
+ :param test: test tensor (with NxCxHxW layout with values in the YCxCz color space)
211
+ :param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer,
212
+ default corresponds to viewing the images on a 0.7 meters wide 4K monitor at 0.7 meters from the display
213
+ :param qc: float describing the q_c exponent in the LDR-FLIP color pipeline (see FLIP paper for details)
214
+ :param qf: float describing the q_f exponent in the LDR-FLIP feature pipeline (see FLIP paper for details)
215
+ :param pc: float describing the p_c exponent in the LDR-FLIP color pipeline (see FLIP paper for details)
216
+ :param pt: float describing the p_t exponent in the LDR-FLIP color pipeline (see FLIP paper for details)
217
+ :param eps: float containing a small value used to improve training stability
218
+ :return: tensor containing the per-pixel FLIP errors (with Nx1xHxW layout and values in the range [0, 1]) between LDR reference and test images
219
+ """
220
+ # --- Color pipeline ---
221
+ # Spatial filtering
222
+ s_a, radius_a = generate_spatial_filter(pixels_per_degree, "A")
223
+ s_rg, radius_rg = generate_spatial_filter(pixels_per_degree, "RG")
224
+ s_by, radius_by = generate_spatial_filter(pixels_per_degree, "BY")
225
+ radius = max(radius_a, radius_rg, radius_by)
226
+ filtered_reference = spatial_filter(reference, s_a, s_rg, s_by, radius)
227
+ filtered_test = spatial_filter(test, s_a, s_rg, s_by, radius)
228
+
229
+ # Perceptually Uniform Color Space
230
+ preprocessed_reference = hunt_adjustment(
231
+ color_space_transform(filtered_reference, "linrgb2lab")
232
+ )
233
+ preprocessed_test = hunt_adjustment(
234
+ color_space_transform(filtered_test, "linrgb2lab")
235
+ )
236
+
237
+ # Color metric
238
+ deltaE_hyab = hyab(preprocessed_reference, preprocessed_test, eps)
239
+ power_deltaE_hyab = torch.pow(deltaE_hyab, qc)
240
+ hunt_adjusted_green = hunt_adjustment(
241
+ color_space_transform(
242
+ torch.tensor([[[0.0]], [[1.0]], [[0.0]]]).unsqueeze(0), "linrgb2lab"
243
+ )
244
+ )
245
+ hunt_adjusted_blue = hunt_adjustment(
246
+ color_space_transform(
247
+ torch.tensor([[[0.0]], [[0.0]], [[1.0]]]).unsqueeze(0), "linrgb2lab"
248
+ )
249
+ )
250
+ cmax = torch.pow(hyab(hunt_adjusted_green, hunt_adjusted_blue, eps), qc).item()
251
+ deltaE_c = redistribute_errors(power_deltaE_hyab, cmax, pc, pt)
252
+
253
+ # --- Feature pipeline ---
254
+ # Extract and normalize Yy component
255
+ ref_y = (reference[:, 0:1, :, :] + 16) / 116
256
+ test_y = (test[:, 0:1, :, :] + 16) / 116
257
+
258
+ # Edge and point detection
259
+ edges_reference = feature_detection(ref_y, pixels_per_degree, "edge")
260
+ points_reference = feature_detection(ref_y, pixels_per_degree, "point")
261
+ edges_test = feature_detection(test_y, pixels_per_degree, "edge")
262
+ points_test = feature_detection(test_y, pixels_per_degree, "point")
263
+
264
+ # Feature metric
265
+ deltaE_f = torch.max(
266
+ torch.abs(
267
+ torch.norm(edges_reference, dim=1, keepdim=True)
268
+ - torch.norm(edges_test, dim=1, keepdim=True)
269
+ ),
270
+ torch.abs(
271
+ torch.norm(points_test, dim=1, keepdim=True)
272
+ - torch.norm(points_reference, dim=1, keepdim=True)
273
+ ),
274
+ )
275
+ deltaE_f = torch.clamp(deltaE_f, min=eps) # clamp to stabilize training
276
+ deltaE_f = torch.pow(((1 / np.sqrt(2)) * deltaE_f), qf)
277
+
278
+ # --- Final error ---
279
+ return torch.pow(deltaE_c, 1 - deltaE_f)
280
+
281
+
282
+ def tone_map(img, tone_mapper, exposure):
283
+ """
284
+ Applies exposure compensation and tone mapping.
285
+ Refer to the Visualizing Errors in Rendered High Dynamic Range Images
286
+ paper for details about the formulas.
287
+
288
+ :param img: float tensor (with NxCxHxW layout) containing nonnegative values
289
+ :param tone_mapper: string describing the tone mapper to apply
290
+ :param exposure: float tensor (with Nx1x1x1 layout) describing the exposure compensation factor
291
+ """
292
+ # Exposure compensation
293
+ x = (2**exposure) * img
294
+
295
+ # Set tone mapping coefficients depending on tone_mapper
296
+ if tone_mapper == "reinhard":
297
+ lum_coeff_r = 0.2126
298
+ lum_coeff_g = 0.7152
299
+ lum_coeff_b = 0.0722
300
+
301
+ Y = (
302
+ x[:, 0:1, :, :] * lum_coeff_r
303
+ + x[:, 1:2, :, :] * lum_coeff_g
304
+ + x[:, 2:3, :, :] * lum_coeff_b
305
+ )
306
+ return torch.clamp(torch.div(x, 1 + Y), 0.0, 1.0)
307
+
308
+ if tone_mapper == "hable":
309
+ # Source: https://64.github.io/tonemapping/
310
+ A = 0.15
311
+ B = 0.50
312
+ C = 0.10
313
+ D = 0.20
314
+ E = 0.02
315
+ F = 0.30
316
+ k0 = A * F - A * E
317
+ k1 = C * B * F - B * E
318
+ k2 = 0
319
+ k3 = A * F
320
+ k4 = B * F
321
+ k5 = D * F * F
322
+
323
+ W = 11.2
324
+ nom = k0 * torch.pow(W, torch.tensor([2.0]).cuda()) + k1 * W + k2
325
+ denom = k3 * torch.pow(W, torch.tensor([2.0]).cuda()) + k4 * W + k5
326
+ white_scale = torch.div(denom, nom) # = 1 / (nom / denom)
327
+
328
+ # Include white scale and exposure bias in rational polynomial coefficients
329
+ k0 = 4 * k0 * white_scale
330
+ k1 = 2 * k1 * white_scale
331
+ k2 = k2 * white_scale
332
+ k3 = 4 * k3
333
+ k4 = 2 * k4
334
+ # k5 = k5 # k5 is not changed
335
+ else:
336
+ # Source: ACES approximation: https://knarkowicz.wordpress.com/2016/01/06/aces-filmic-tone-mapping-curve/
337
+ # Include pre-exposure cancelation in constants
338
+ k0 = 0.6 * 0.6 * 2.51
339
+ k1 = 0.6 * 0.03
340
+ k2 = 0
341
+ k3 = 0.6 * 0.6 * 2.43
342
+ k4 = 0.6 * 0.59
343
+ k5 = 0.14
344
+
345
+ x2 = torch.pow(x, 2)
346
+ nom = k0 * x2 + k1 * x + k2
347
+ denom = k3 * x2 + k4 * x + k5
348
+ denom = torch.where(
349
+ torch.isinf(denom), torch.Tensor([1.0]).cuda(), denom
350
+ ) # if denom is inf, then so is nom => nan. Pixel is very bright. It becomes inf here, but 1 after clamp below
351
+ y = torch.div(nom, denom)
352
+ return torch.clamp(y, 0.0, 1.0)
353
+
354
+
355
+ def compute_start_stop_exposures(reference, tone_mapper, tmax, tmin):
356
+ """
357
+ Computes start and stop exposure for HDR-FLIP based on given tone mapper and reference image.
358
+ Refer to the Visualizing Errors in Rendered High Dynamic Range Images
359
+ paper for details about the formulas
360
+
361
+ :param reference: float tensor (with NxCxHxW layout) containing reference images (nonnegative values)
362
+ :param tone_mapper: string describing which tone mapper should be assumed
363
+ :param tmax: float describing the t value used to find the start exposure
364
+ :param tmin: float describing the t value used to find the stop exposure
365
+ :return: two float tensors (with Nx1x1x1 layout) containing start and stop exposures, respectively, to use for HDR-FLIP
366
+ """
367
+ if tone_mapper == "reinhard":
368
+ k0 = 0
369
+ k1 = 1
370
+ k2 = 0
371
+ k3 = 0
372
+ k4 = 1
373
+ k5 = 1
374
+
375
+ x_max = tmax * k5 / (k1 - tmax * k4)
376
+ x_min = tmin * k5 / (k1 - tmin * k4)
377
+ elif tone_mapper == "hable":
378
+ # Source: https://64.github.io/tonemapping/
379
+ A = 0.15
380
+ B = 0.50
381
+ C = 0.10
382
+ D = 0.20
383
+ E = 0.02
384
+ F = 0.30
385
+ k0 = A * F - A * E
386
+ k1 = C * B * F - B * E
387
+ k2 = 0
388
+ k3 = A * F
389
+ k4 = B * F
390
+ k5 = D * F * F
391
+
392
+ W = 11.2
393
+ nom = k0 * torch.pow(W, torch.tensor([2.0]).cuda()) + k1 * W + k2
394
+ denom = k3 * torch.pow(W, torch.tensor([2.0]).cuda()) + k4 * W + k5
395
+ white_scale = torch.div(denom, nom) # = 1 / (nom / denom)
396
+
397
+ # Include white scale and exposure bias in rational polynomial coefficients
398
+ k0 = 4 * k0 * white_scale
399
+ k1 = 2 * k1 * white_scale
400
+ k2 = k2 * white_scale
401
+ k3 = 4 * k3
402
+ k4 = 2 * k4
403
+ # k5 = k5 # k5 is not changed
404
+
405
+ c0 = (k1 - k4 * tmax) / (k0 - k3 * tmax)
406
+ c1 = (k2 - k5 * tmax) / (k0 - k3 * tmax)
407
+ x_max = -0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1)
408
+
409
+ c0 = (k1 - k4 * tmin) / (k0 - k3 * tmin)
410
+ c1 = (k2 - k5 * tmin) / (k0 - k3 * tmin)
411
+ x_min = -0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1)
412
+ else:
413
+ # Source: ACES approximation: https://knarkowicz.wordpress.com/2016/01/06/aces-filmic-tone-mapping-curve/
414
+ # Include pre-exposure cancelation in constants
415
+ k0 = 0.6 * 0.6 * 2.51
416
+ k1 = 0.6 * 0.03
417
+ k2 = 0
418
+ k3 = 0.6 * 0.6 * 2.43
419
+ k4 = 0.6 * 0.59
420
+ k5 = 0.14
421
+
422
+ c0 = (k1 - k4 * tmax) / (k0 - k3 * tmax)
423
+ c1 = (k2 - k5 * tmax) / (k0 - k3 * tmax)
424
+ x_max = -0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1)
425
+
426
+ c0 = (k1 - k4 * tmin) / (k0 - k3 * tmin)
427
+ c1 = (k2 - k5 * tmin) / (k0 - k3 * tmin)
428
+ x_min = -0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1)
429
+
430
+ # Convert reference to luminance
431
+ lum_coeff_r = 0.2126
432
+ lum_coeff_g = 0.7152
433
+ lum_coeff_b = 0.0722
434
+ Y_reference = (
435
+ reference[:, 0:1, :, :] * lum_coeff_r
436
+ + reference[:, 1:2, :, :] * lum_coeff_g
437
+ + reference[:, 2:3, :, :] * lum_coeff_b
438
+ )
439
+
440
+ # Compute start exposure
441
+ Y_hi = torch.amax(Y_reference, dim=(2, 3), keepdim=True)
442
+ start_exposure = torch.log2(x_max / Y_hi)
443
+
444
+ # Compute stop exposure
445
+ dim = Y_reference.size()
446
+ Y_ref = Y_reference.view(dim[0], dim[1], dim[2] * dim[3])
447
+ Y_lo = torch.median(Y_ref, dim=2).values.unsqueeze(2).unsqueeze(3)
448
+ stop_exposure = torch.log2(x_min / Y_lo)
449
+
450
+ return start_exposure, stop_exposure
451
+
452
+
453
+ def generate_spatial_filter(pixels_per_degree, channel):
454
+ """
455
+ Generates spatial contrast sensitivity filters with width depending on
456
+ the number of pixels per degree of visual angle of the observer
457
+
458
+ :param pixels_per_degree: float indicating number of pixels per degree of visual angle
459
+ :param channel: string describing what filter should be generated
460
+ :yield: Filter kernel corresponding to the spatial contrast sensitivity function of the given channel and kernel's radius
461
+ """
462
+ a1_A = 1
463
+ b1_A = 0.0047
464
+ a2_A = 0
465
+ b2_A = 1e-5 # avoid division by 0
466
+ a1_rg = 1
467
+ b1_rg = 0.0053
468
+ a2_rg = 0
469
+ b2_rg = 1e-5 # avoid division by 0
470
+ a1_by = 34.1
471
+ b1_by = 0.04
472
+ a2_by = 13.5
473
+ b2_by = 0.025
474
+ if channel == "A": # Achromatic CSF
475
+ a1 = a1_A
476
+ b1 = b1_A
477
+ a2 = a2_A
478
+ b2 = b2_A
479
+ elif channel == "RG": # Red-Green CSF
480
+ a1 = a1_rg
481
+ b1 = b1_rg
482
+ a2 = a2_rg
483
+ b2 = b2_rg
484
+ elif channel == "BY": # Blue-Yellow CSF
485
+ a1 = a1_by
486
+ b1 = b1_by
487
+ a2 = a2_by
488
+ b2 = b2_by
489
+
490
+ # Determine evaluation domain
491
+ max_scale_parameter = max([b1_A, b2_A, b1_rg, b2_rg, b1_by, b2_by])
492
+ r = np.ceil(3 * np.sqrt(max_scale_parameter / (2 * np.pi**2)) * pixels_per_degree)
493
+ r = int(r)
494
+ deltaX = 1.0 / pixels_per_degree
495
+ x, y = np.meshgrid(range(-r, r + 1), range(-r, r + 1))
496
+ z = (x * deltaX) ** 2 + (y * deltaX) ** 2
497
+
498
+ # Generate weights
499
+ g = a1 * np.sqrt(np.pi / b1) * np.exp(-(np.pi**2) * z / b1) + a2 * np.sqrt(
500
+ np.pi / b2
501
+ ) * np.exp(-(np.pi**2) * z / b2)
502
+ g = g / np.sum(g)
503
+ g = torch.Tensor(g).unsqueeze(0).unsqueeze(0).cuda()
504
+
505
+ return g, r
506
+
507
+
508
+ def spatial_filter(img, s_a, s_rg, s_by, radius):
509
+ """
510
+ Filters an image with channel specific spatial contrast sensitivity functions
511
+ and clips result to the unit cube in linear RGB
512
+
513
+ :param img: image tensor to filter (with NxCxHxW layout in the YCxCz color space)
514
+ :param s_a: spatial filter matrix for the achromatic channel
515
+ :param s_rg: spatial filter matrix for the red-green channel
516
+ :param s_by: spatial filter matrix for the blue-yellow channel
517
+ :return: input image (with NxCxHxW layout) transformed to linear RGB after filtering with spatial contrast sensitivity functions
518
+ """
519
+ dim = img.size()
520
+ # Prepare image for convolution
521
+ img_pad = torch.zeros(
522
+ (dim[0], dim[1], dim[2] + 2 * radius, dim[3] + 2 * radius), device="cuda"
523
+ )
524
+ img_pad[:, 0:1, :, :] = nn.functional.pad(
525
+ img[:, 0:1, :, :], (radius, radius, radius, radius), mode="replicate"
526
+ )
527
+ img_pad[:, 1:2, :, :] = nn.functional.pad(
528
+ img[:, 1:2, :, :], (radius, radius, radius, radius), mode="replicate"
529
+ )
530
+ img_pad[:, 2:3, :, :] = nn.functional.pad(
531
+ img[:, 2:3, :, :], (radius, radius, radius, radius), mode="replicate"
532
+ )
533
+
534
+ # Apply Gaussian filters
535
+ img_tilde_opponent = torch.zeros((dim[0], dim[1], dim[2], dim[3]), device="cuda")
536
+ img_tilde_opponent[:, 0:1, :, :] = nn.functional.conv2d(
537
+ img_pad[:, 0:1, :, :], s_a.cuda(), padding=0
538
+ )
539
+ img_tilde_opponent[:, 1:2, :, :] = nn.functional.conv2d(
540
+ img_pad[:, 1:2, :, :], s_rg.cuda(), padding=0
541
+ )
542
+ img_tilde_opponent[:, 2:3, :, :] = nn.functional.conv2d(
543
+ img_pad[:, 2:3, :, :], s_by.cuda(), padding=0
544
+ )
545
+
546
+ # Transform to linear RGB for clamp
547
+ img_tilde_linear_rgb = color_space_transform(img_tilde_opponent, "ycxcz2linrgb")
548
+
549
+ # Clamp to RGB box
550
+ return torch.clamp(img_tilde_linear_rgb, 0.0, 1.0)
551
+
552
+
553
+ def hunt_adjustment(img):
554
+ """
555
+ Applies Hunt-adjustment to an image
556
+
557
+ :param img: image tensor to adjust (with NxCxHxW layout in the L*a*b* color space)
558
+ :return: Hunt-adjusted image tensor (with NxCxHxW layout in the Hunt-adjusted L*A*B* color space)
559
+ """
560
+ # Extract luminance component
561
+ L = img[:, 0:1, :, :]
562
+
563
+ # Apply Hunt adjustment
564
+ img_h = torch.zeros(img.size(), device="cuda")
565
+ img_h[:, 0:1, :, :] = L
566
+ img_h[:, 1:2, :, :] = torch.mul((0.01 * L), img[:, 1:2, :, :])
567
+ img_h[:, 2:3, :, :] = torch.mul((0.01 * L), img[:, 2:3, :, :])
568
+
569
+ return img_h
570
+
571
+
572
+ def hyab(reference, test, eps):
573
+ """
574
+ Computes the HyAB distance between reference and test images
575
+
576
+ :param reference: reference image tensor (with NxCxHxW layout in the standard or Hunt-adjusted L*A*B* color space)
577
+ :param test: test image tensor (with NxCxHxW layout in the standard or Hunt-adjusted L*a*b* color space)
578
+ :param eps: float containing a small value used to improve training stability
579
+ :return: image tensor (with Nx1xHxW layout) containing the per-pixel HyAB distances between reference and test images
580
+ """
581
+ delta = reference - test
582
+ root = torch.sqrt(torch.clamp(torch.pow(delta[:, 0:1, :, :], 2), min=eps))
583
+ delta_norm = torch.norm(delta[:, 1:3, :, :], dim=1, keepdim=True)
584
+ return root + delta_norm # alternative abs to stabilize training
585
+
586
+
587
+ def redistribute_errors(power_deltaE_hyab, cmax, pc, pt):
588
+ """
589
+ Redistributes exponentiated HyAB errors to the [0,1] range
590
+
591
+ :param power_deltaE_hyab: float tensor (with Nx1xHxW layout) containing the exponentiated HyAb distance
592
+ :param cmax: float containing the exponentiated, maximum HyAB difference between two colors in Hunt-adjusted L*A*B* space
593
+ :param pc: float containing the cmax multiplier p_c (see FLIP paper)
594
+ :param pt: float containing the target value, p_t, for p_c * cmax (see FLIP paper)
595
+ :return: image tensor (with Nx1xHxW layout) containing redistributed per-pixel HyAB distances (in range [0,1])
596
+ """
597
+ # Re-map error to 0-1 range. Values between 0 and
598
+ # pccmax are mapped to the range [0, pt],
599
+ # while the rest are mapped to the range (pt, 1]
600
+ deltaE_c = torch.zeros(power_deltaE_hyab.size(), device="cuda")
601
+ pccmax = pc * cmax
602
+ deltaE_c = torch.where(
603
+ power_deltaE_hyab < pccmax,
604
+ (pt / pccmax) * power_deltaE_hyab,
605
+ pt + ((power_deltaE_hyab - pccmax) / (cmax - pccmax)) * (1.0 - pt),
606
+ )
607
+
608
+ return deltaE_c
609
+
610
+
611
+ def feature_detection(img_y, pixels_per_degree, feature_type):
612
+ """
613
+ Detects edges and points (features) in the achromatic image
614
+
615
+ :param imgy: achromatic image tensor (with Nx1xHxW layout, containing normalized Y-values from YCxCz)
616
+ :param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer
617
+ :param feature_type: string indicating the type of feature to detect
618
+ :return: image tensor (with Nx2xHxW layout, with values in range [0,1]) containing large values where features were detected
619
+ """
620
+ # Set peak to trough value (2x standard deviations) of human edge
621
+ # detection filter
622
+ w = 0.082
623
+
624
+ # Compute filter radius
625
+ sd = 0.5 * w * pixels_per_degree
626
+ radius = int(np.ceil(3 * sd))
627
+
628
+ # Compute 2D Gaussian
629
+ [x, y] = np.meshgrid(range(-radius, radius + 1), range(-radius, radius + 1))
630
+ g = np.exp(-(x**2 + y**2) / (2 * sd * sd))
631
+
632
+ if feature_type == "edge": # Edge detector
633
+ # Compute partial derivative in x-direction
634
+ Gx = np.multiply(-x, g)
635
+ else: # Point detector
636
+ # Compute second partial derivative in x-direction
637
+ Gx = np.multiply(x**2 / (sd * sd) - 1, g)
638
+
639
+ # Normalize positive weights to sum to 1 and negative weights to sum to -1
640
+ negative_weights_sum = -np.sum(Gx[Gx < 0])
641
+ positive_weights_sum = np.sum(Gx[Gx > 0])
642
+ Gx = torch.Tensor(Gx)
643
+ Gx = torch.where(Gx < 0, Gx / negative_weights_sum, Gx / positive_weights_sum)
644
+ Gx = Gx.unsqueeze(0).unsqueeze(0).cuda()
645
+
646
+ # Detect features
647
+ featuresX = nn.functional.conv2d(
648
+ nn.functional.pad(img_y, (radius, radius, radius, radius), mode="replicate"),
649
+ Gx,
650
+ padding=0,
651
+ )
652
+ featuresY = nn.functional.conv2d(
653
+ nn.functional.pad(img_y, (radius, radius, radius, radius), mode="replicate"),
654
+ torch.transpose(Gx, 2, 3),
655
+ padding=0,
656
+ )
657
+ return torch.cat((featuresX, featuresY), dim=1)
658
+
659
+
660
+ def color_space_transform(input_color, fromSpace2toSpace):
661
+ """
662
+ Transforms inputs between different color spaces
663
+
664
+ :param input_color: tensor of colors to transform (with NxCxHxW layout)
665
+ :param fromSpace2toSpace: string describing transform
666
+ :return: transformed tensor (with NxCxHxW layout)
667
+ """
668
+ dim = input_color.size()
669
+
670
+ # Assume D65 standard illuminant
671
+ reference_illuminant = torch.tensor(
672
+ [[[0.950428545]], [[1.000000000]], [[1.088900371]]]
673
+ ).cuda()
674
+ inv_reference_illuminant = torch.tensor(
675
+ [[[1.052156925]], [[1.000000000]], [[0.918357670]]]
676
+ ).cuda()
677
+
678
+ if fromSpace2toSpace == "srgb2linrgb":
679
+ limit = 0.04045
680
+ transformed_color = torch.where(
681
+ input_color > limit,
682
+ torch.pow((torch.clamp(input_color, min=limit) + 0.055) / 1.055, 2.4),
683
+ input_color / 12.92,
684
+ ) # clamp to stabilize training
685
+
686
+ elif fromSpace2toSpace == "linrgb2srgb":
687
+ limit = 0.0031308
688
+ transformed_color = torch.where(
689
+ input_color > limit,
690
+ 1.055 * torch.pow(torch.clamp(input_color, min=limit), (1.0 / 2.4)) - 0.055,
691
+ 12.92 * input_color,
692
+ )
693
+
694
+ elif fromSpace2toSpace in ["linrgb2xyz", "xyz2linrgb"]:
695
+ # Source: https://www.image-engineering.de/library/technotes/958-how-to-convert-between-srgb-and-ciexyz
696
+ # Assumes D65 standard illuminant
697
+ if fromSpace2toSpace == "linrgb2xyz":
698
+ a11 = 10135552 / 24577794
699
+ a12 = 8788810 / 24577794
700
+ a13 = 4435075 / 24577794
701
+ a21 = 2613072 / 12288897
702
+ a22 = 8788810 / 12288897
703
+ a23 = 887015 / 12288897
704
+ a31 = 1425312 / 73733382
705
+ a32 = 8788810 / 73733382
706
+ a33 = 70074185 / 73733382
707
+ else:
708
+ # Constants found by taking the inverse of the matrix
709
+ # defined by the constants for linrgb2xyz
710
+ a11 = 3.241003275
711
+ a12 = -1.537398934
712
+ a13 = -0.498615861
713
+ a21 = -0.969224334
714
+ a22 = 1.875930071
715
+ a23 = 0.041554224
716
+ a31 = 0.055639423
717
+ a32 = -0.204011202
718
+ a33 = 1.057148933
719
+ A = torch.Tensor([[a11, a12, a13], [a21, a22, a23], [a31, a32, a33]])
720
+
721
+ input_color = input_color.view(dim[0], dim[1], dim[2] * dim[3]).cuda() # NC(HW)
722
+
723
+ transformed_color = torch.matmul(A.cuda(), input_color)
724
+ transformed_color = transformed_color.view(dim[0], dim[1], dim[2], dim[3])
725
+
726
+ elif fromSpace2toSpace == "xyz2ycxcz":
727
+ input_color = torch.mul(input_color, inv_reference_illuminant)
728
+ y = 116 * input_color[:, 1:2, :, :] - 16
729
+ cx = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :])
730
+ cz = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :])
731
+ transformed_color = torch.cat((y, cx, cz), 1)
732
+
733
+ elif fromSpace2toSpace == "ycxcz2xyz":
734
+ y = (input_color[:, 0:1, :, :] + 16) / 116
735
+ cx = input_color[:, 1:2, :, :] / 500
736
+ cz = input_color[:, 2:3, :, :] / 200
737
+
738
+ x = y + cx
739
+ z = y - cz
740
+ transformed_color = torch.cat((x, y, z), 1)
741
+
742
+ transformed_color = torch.mul(transformed_color, reference_illuminant)
743
+
744
+ elif fromSpace2toSpace == "xyz2lab":
745
+ input_color = torch.mul(input_color, inv_reference_illuminant)
746
+ delta = 6 / 29
747
+ delta_square = delta * delta
748
+ delta_cube = delta * delta_square
749
+ factor = 1 / (3 * delta_square)
750
+
751
+ clamped_term = torch.pow(
752
+ torch.clamp(input_color, min=delta_cube), 1.0 / 3.0
753
+ ).to(dtype=input_color.dtype)
754
+ div = (factor * input_color + (4 / 29)).to(dtype=input_color.dtype)
755
+ input_color = torch.where(
756
+ input_color > delta_cube, clamped_term, div
757
+ ) # clamp to stabilize training
758
+
759
+ L = 116 * input_color[:, 1:2, :, :] - 16
760
+ a = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :])
761
+ b = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :])
762
+
763
+ transformed_color = torch.cat((L, a, b), 1)
764
+
765
+ elif fromSpace2toSpace == "lab2xyz":
766
+ y = (input_color[:, 0:1, :, :] + 16) / 116
767
+ a = input_color[:, 1:2, :, :] / 500
768
+ b = input_color[:, 2:3, :, :] / 200
769
+
770
+ x = y + a
771
+ z = y - b
772
+
773
+ xyz = torch.cat((x, y, z), 1)
774
+ delta = 6 / 29
775
+ delta_square = delta * delta
776
+ factor = 3 * delta_square
777
+ xyz = torch.where(xyz > delta, torch.pow(xyz, 3), factor * (xyz - 4 / 29))
778
+
779
+ transformed_color = torch.mul(xyz, reference_illuminant)
780
+
781
+ elif fromSpace2toSpace == "srgb2xyz":
782
+ transformed_color = color_space_transform(input_color, "srgb2linrgb")
783
+ transformed_color = color_space_transform(transformed_color, "linrgb2xyz")
784
+ elif fromSpace2toSpace == "srgb2ycxcz":
785
+ transformed_color = color_space_transform(input_color, "srgb2linrgb")
786
+ transformed_color = color_space_transform(transformed_color, "linrgb2xyz")
787
+ transformed_color = color_space_transform(transformed_color, "xyz2ycxcz")
788
+ elif fromSpace2toSpace == "linrgb2ycxcz":
789
+ transformed_color = color_space_transform(input_color, "linrgb2xyz")
790
+ transformed_color = color_space_transform(transformed_color, "xyz2ycxcz")
791
+ elif fromSpace2toSpace == "srgb2lab":
792
+ transformed_color = color_space_transform(input_color, "srgb2linrgb")
793
+ transformed_color = color_space_transform(transformed_color, "linrgb2xyz")
794
+ transformed_color = color_space_transform(transformed_color, "xyz2lab")
795
+ elif fromSpace2toSpace == "linrgb2lab":
796
+ transformed_color = color_space_transform(input_color, "linrgb2xyz")
797
+ transformed_color = color_space_transform(transformed_color, "xyz2lab")
798
+ elif fromSpace2toSpace == "ycxcz2linrgb":
799
+ transformed_color = color_space_transform(input_color, "ycxcz2xyz")
800
+ transformed_color = color_space_transform(transformed_color, "xyz2linrgb")
801
+ elif fromSpace2toSpace == "lab2srgb":
802
+ transformed_color = color_space_transform(input_color, "lab2xyz")
803
+ transformed_color = color_space_transform(transformed_color, "xyz2linrgb")
804
+ transformed_color = color_space_transform(transformed_color, "linrgb2srgb")
805
+ elif fromSpace2toSpace == "ycxcz2lab":
806
+ transformed_color = color_space_transform(input_color, "ycxcz2xyz")
807
+ transformed_color = color_space_transform(transformed_color, "xyz2lab")
808
+ else:
809
+ sys.exit("Error: The color transform %s is not defined!" % fromSpace2toSpace)
810
+
811
+ return transformed_color
utils/image_utils.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import matplotlib
4
+ import matplotlib.font_manager as font_manager
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import torch
8
+ from matplotlib.patches import Ellipse
9
+ from numpy.linalg import norm
10
+ from PIL import Image
11
+ from scipy.ndimage import sobel
12
+
13
+ FONT_PATH = "assets/fonts/linux_libertine/LinLibertine_R.ttf"
14
+
15
+ # Make font loading optional for deployment environments
16
+ try:
17
+ font_manager.fontManager.addfont(FONT_PATH)
18
+ FONT_PROP = font_manager.FontProperties(fname=FONT_PATH).get_name()
19
+ plt.rcParams["font.family"] = FONT_PROP
20
+ plt.rcParams["text.usetex"] = True
21
+ except (FileNotFoundError, OSError):
22
+ # Use default font if custom font is not available
23
+ FONT_PROP = "DejaVu Sans"
24
+ plt.rcParams["font.family"] = FONT_PROP
25
+ plt.rcParams["text.usetex"] = False # Disable LaTeX if custom font unavailable
26
+ matplotlib.rcParams["font.size"] = 16
27
+ matplotlib.rcParams["axes.titlesize"] = 16
28
+ matplotlib.rcParams["figure.titlesize"] = 16
29
+ matplotlib.rcParams["legend.fontsize"] = 16
30
+ matplotlib.rcParams["legend.title_fontsize"] = 16
31
+ matplotlib.rcParams["xtick.labelsize"] = 14
32
+ matplotlib.rcParams["ytick.labelsize"] = 14
33
+
34
+ ALLOWED_IMAGE_FILE_FORMATS = [".jpeg", ".jpg", ".png"]
35
+ ALLOWED_IMAGE_TYPES = {"RGB": 3, "RGBA": 3, "L": 1}
36
+
37
+ PLOT_DPI = 72.0
38
+ GAUSSIAN_ZOOM = 5
39
+ GAUSSIAN_COLOR = "#80ed99"
40
+
41
+
42
+ def get_psnr(image1, image2, max_value=1.0):
43
+ mse = torch.mean((image1 - image2) ** 2)
44
+ if mse.item() <= 1e-7:
45
+ return float("inf")
46
+ psnr = 20 * torch.log10(max_value / torch.sqrt(mse))
47
+ return psnr
48
+
49
+
50
+ def get_grid(h, w, x_lim=np.asarray([0, 1]), y_lim=np.asarray([0, 1])):
51
+ x = torch.linspace(x_lim[0], x_lim[1], steps=w + 1)[:-1] + 0.5 / w
52
+ y = torch.linspace(y_lim[0], y_lim[1], steps=h + 1)[:-1] + 0.5 / h
53
+ grid_x, grid_y = torch.meshgrid(x, y, indexing="xy")
54
+ grid = torch.stack([grid_x, grid_y], dim=-1)
55
+ return grid
56
+
57
+
58
+ def compute_image_gradients(image):
59
+ gy, gx = [], []
60
+ for image_channel in image:
61
+ gy.append(sobel(image_channel, 0))
62
+ gx.append(sobel(image_channel, 1))
63
+ gy = norm(np.stack(gy, axis=0), ord=2, axis=0).astype(np.float32)
64
+ gx = norm(np.stack(gx, axis=0), ord=2, axis=0).astype(np.float32)
65
+ return gy, gx
66
+
67
+
68
+ def load_images(load_path, downsample_ratio=None, gamma=None):
69
+ """
70
+ Load target images or textures from a directory or a single file.
71
+ """
72
+ image_list = []
73
+ image_path_list = []
74
+ image_fname_list = []
75
+ num_channels_list = []
76
+ if (
77
+ os.path.isfile(load_path)
78
+ and os.path.splitext(load_path)[1].lower() in ALLOWED_IMAGE_FILE_FORMATS
79
+ ):
80
+ image_path_list.append(load_path)
81
+ elif os.path.isdir(load_path):
82
+ for file in sorted(os.listdir(load_path), key=str.lower):
83
+ if os.path.splitext(file)[1].lower() in ALLOWED_IMAGE_FILE_FORMATS:
84
+ image_path_list.append(os.path.join(load_path, file))
85
+ if len(image_path_list) == 0:
86
+ raise FileNotFoundError(f"No supported image file found at '{load_path}'")
87
+ for image_path in image_path_list:
88
+ image_fname_list.append(os.path.splitext(os.path.basename(image_path))[0])
89
+ image = Image.open(image_path)
90
+ # Warning: Only support images of type L, RGB, or RGBA in JPEG or PNG format
91
+ if image.mode not in ALLOWED_IMAGE_TYPES:
92
+ raise TypeError(
93
+ f"Only support images of type {list(ALLOWED_IMAGE_TYPES.keys())} in JPEG or PNG format"
94
+ )
95
+ num_channels = ALLOWED_IMAGE_TYPES[image.mode]
96
+ num_channels_list.append(num_channels)
97
+ if downsample_ratio is not None:
98
+ image = image.resize(
99
+ (
100
+ round(image.width / downsample_ratio),
101
+ round(image.height / downsample_ratio),
102
+ ),
103
+ resample=Image.Resampling.BILINEAR,
104
+ )
105
+ # Warning: Assume 8 bit color depth
106
+ image = np.asarray(image, dtype=np.float32) / 255.0
107
+ if gamma is not None:
108
+ image = np.power(image, gamma)
109
+ if len(image.shape) == 2:
110
+ image = np.expand_dims(image, axis=2)
111
+ image = image.transpose(2, 0, 1)
112
+ image = image[:num_channels]
113
+ image_list.append(image)
114
+ return np.concatenate(image_list, axis=0), num_channels_list, image_fname_list
115
+
116
+
117
+ def to_output_format(image, gamma):
118
+ if len(image.shape) not in [2, 3]:
119
+ raise ValueError(f"Wrong image format: shape = {image.shape}")
120
+ if isinstance(image, torch.Tensor):
121
+ image = image.detach().cpu().clone().numpy()
122
+ if len(image.shape) == 3 and image.shape[2] not in [1, 3]:
123
+ image = image.transpose(1, 2, 0)
124
+ if image.shape[2] not in [1, 3]:
125
+ raise ValueError(f"Wrong image format: shape = {image.shape}")
126
+ if len(image.shape) == 3 and image.shape[2] == 1:
127
+ image = image.squeeze(axis=2)
128
+ image = np.clip(image, 0.0, 1.0)
129
+ if gamma is not None:
130
+ image = np.power(image, 1.0 / gamma)
131
+ image = (255.0 * image).astype(np.uint8)
132
+ return image
133
+
134
+
135
+ def save_image(image, save_path, gamma=None, zoom=None):
136
+ image = to_output_format(image, gamma)
137
+ image = Image.fromarray(image)
138
+ if zoom is not None and zoom > 0.0:
139
+ width, height = image.size
140
+ image = image.resize(
141
+ (round(width * zoom), round(height * zoom)), resample=Image.Resampling.BOX
142
+ )
143
+ image.save(save_path)
144
+
145
+
146
+ def separate_image_channels(images, input_channels):
147
+ if len(images) != sum(input_channels):
148
+ raise ValueError(
149
+ f"Incompatible number of channels: {len(images):d} vs {sum(input_channels):d}"
150
+ )
151
+ image_list = []
152
+ curr_channel = 0
153
+ for num_channels in input_channels:
154
+ image_list.append(images[curr_channel : curr_channel + num_channels])
155
+ curr_channel += num_channels
156
+ return image_list
157
+
158
+
159
+ def visualize_gaussians(
160
+ filepath, xy, scale, rot, feat, img_h, img_w, input_channels, alpha=0.8, gamma=None
161
+ ):
162
+ """
163
+ Visualize Gaussians as colored elliptical disks.
164
+ """
165
+ if feat.shape[1] != sum(input_channels):
166
+ raise ValueError(
167
+ f"Incompatible number of channels: {feat.shape[1]:d} vs {sum(input_channels):d}"
168
+ )
169
+ xy = xy.detach().cpu().clone().numpy()
170
+ y, x = xy[:, 1] * img_h, xy[:, 0] * img_w
171
+ scale = GAUSSIAN_ZOOM * scale.detach().cpu().clone().numpy()
172
+ rot = rot.detach().cpu().clone().numpy()
173
+ if gamma is not None:
174
+ feat = torch.pow(feat, 1.0 / gamma)
175
+ feat = np.clip(feat.detach().cpu().clone().numpy(), 0.0, 1.0)
176
+
177
+ curr_channel = 0
178
+ for image_id, num_channels in enumerate(input_channels, 1):
179
+ curr_feat = feat[:, curr_channel : curr_channel + num_channels]
180
+ fig = plt.figure()
181
+ fig.set_dpi(PLOT_DPI)
182
+ fig.set_size_inches(w=img_w / PLOT_DPI, h=img_h / PLOT_DPI, forward=False)
183
+ ax = plt.gca()
184
+ for gid in range(len(xy)):
185
+ ellipse = Ellipse(
186
+ xy=(x[gid], y[gid]),
187
+ width=scale[gid, 0],
188
+ height=scale[gid, 1],
189
+ angle=rot[gid, 0] * 180 / np.pi,
190
+ alpha=alpha,
191
+ ec=None,
192
+ fc=curr_feat[gid],
193
+ lw=None,
194
+ )
195
+ ax.add_patch(ellipse)
196
+ plt.xlim(0, img_w)
197
+ plt.ylim(img_h, 0)
198
+ plt.axis("off")
199
+ plt.tight_layout()
200
+ suffix = "" if len(input_channels) == 1 else f"_{image_id:d}"
201
+ plt.savefig(
202
+ f"{filepath}{suffix}.png", bbox_inches="tight", pad_inches=0, dpi=PLOT_DPI
203
+ )
204
+ plt.close()
205
+ curr_channel += num_channels
206
+
207
+
208
+ def visualize_added_gaussians(
209
+ filepath,
210
+ images,
211
+ old_xy,
212
+ new_xy,
213
+ input_channels,
214
+ size=500,
215
+ every_n=5,
216
+ alpha=0.8,
217
+ gamma=None,
218
+ ):
219
+ """
220
+ Visualize the positions of added Gaussians during error-guided progressive optimization.
221
+ """
222
+ if len(images) != sum(input_channels):
223
+ raise ValueError(
224
+ f"Incompatible number of channels: {len(images):d} vs {sum(input_channels):d}"
225
+ )
226
+ image_height, image_width = images.shape[1:]
227
+ old_xy = old_xy.detach().cpu().clone().numpy()[::every_n]
228
+ new_xy = new_xy.detach().cpu().clone().numpy()[::every_n]
229
+ old_x, old_y = old_xy[:, 0] * image_width, old_xy[:, 1] * image_height
230
+ new_x, new_y = new_xy[:, 0] * image_width, new_xy[:, 1] * image_height
231
+
232
+ curr_channel = 0
233
+ for image_id, num_channels in enumerate(input_channels, 1):
234
+ image = images[curr_channel : curr_channel + num_channels]
235
+ image = to_output_format(image, gamma)
236
+ fig = plt.figure()
237
+ fig.set_dpi(PLOT_DPI)
238
+ fig.set_size_inches(
239
+ w=image_width / PLOT_DPI, h=image_height / PLOT_DPI, forward=False
240
+ )
241
+ plt.imshow(Image.fromarray(image), cmap="gray", vmin=0, vmax=255)
242
+ plt.scatter(old_x, old_y, s=size, c="#ef476f", marker="o", alpha=alpha) # red
243
+ plt.scatter(new_x, new_y, s=size, c="#06d6a0", marker="o", alpha=alpha) # green
244
+ plt.xlim(0, image_width)
245
+ plt.ylim(image_height, 0)
246
+ plt.axis("off")
247
+ plt.tight_layout()
248
+ suffix = "" if len(input_channels) == 1 else f"_{image_id:d}"
249
+ plt.savefig(
250
+ f"{filepath}{suffix}.png", bbox_inches="tight", pad_inches=0, dpi=PLOT_DPI
251
+ )
252
+ plt.close()
253
+ curr_channel += num_channels
utils/misc_utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import shutil
4
+ from argparse import ArgumentParser
5
+
6
+ import numpy as np
7
+ import torch
8
+ import yaml
9
+
10
+
11
+ def clean_dir(path):
12
+ if os.path.exists(path):
13
+ shutil.rmtree(path)
14
+
15
+
16
+ def get_latest_ckpt_step(load_path):
17
+ saved_steps = [
18
+ int(os.path.splitext(path)[0].split("-")[-1])
19
+ for path in os.listdir(load_path)
20
+ if path.endswith(".pt")
21
+ ]
22
+ latest_step = -1 if len(saved_steps) == 0 else max(saved_steps)
23
+ return latest_step
24
+
25
+
26
+ def set_random_seed(seed):
27
+ random.seed(seed)
28
+ np.random.seed(seed)
29
+ torch.manual_seed(seed)
30
+ torch.cuda.manual_seed(seed)
31
+ torch.cuda.manual_seed_all(seed)
32
+ torch.backends.cudnn.deterministic = True
33
+ torch.backends.cudnn.benchmark = False
34
+
35
+
36
+ def load_cfg(cfg_path: str, parser: ArgumentParser) -> ArgumentParser:
37
+ with open(cfg_path, "r", encoding="utf-8") as file:
38
+ cfg: dict = yaml.safe_load(file)
39
+ for key, value in cfg.items():
40
+ if value is None:
41
+ raise ValueError("'None' is not a supported value in the config file")
42
+ if isinstance(value, bool):
43
+ parser.add_argument(f"--{key}", action="store_true", default=value)
44
+ else:
45
+ parser.add_argument(f"--{key}", type=type(value), default=value)
46
+ return parser
47
+
48
+
49
+ def save_cfg(path: str, args, mode="w"):
50
+ with open(path, mode=mode, encoding="utf-8") as file:
51
+ print("#################### Training Config ####################", file=file)
52
+ yaml.dump(vars(args), file, default_flow_style=False)
utils/quantization_utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def ste_quantize(x: torch.Tensor, num_bits: int = 16) -> torch.Tensor:
5
+ """
6
+ Bit precision control of Gaussian parameters using a straight-through estimator.
7
+ Reference: https://arxiv.org/abs/1308.3432
8
+ """
9
+ qmin, qmax = 0, 2**num_bits - 1
10
+ min_val, max_val = x.min().item(), x.max().item()
11
+ scale = max((max_val - min_val) / (qmax - qmin), 1e-8)
12
+ # Quantize in forward pass (non-differentiable)
13
+ q_x = torch.round((x - min_val) / scale).clamp(qmin, qmax)
14
+ dq_x = q_x * scale + min_val
15
+ # Restore gradients in backward pass
16
+ dq_x = x + (dq_x - x).detach()
17
+ return dq_x
utils/saliency/decoder.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class Decoder(nn.Module):
9
+ def __init__(self, shape, num_img_feat, num_pla_feat):
10
+ super(Decoder, self).__init__()
11
+ self.shape = shape
12
+ self.img_model = self._make_layer(num_img_feat)
13
+ self.pla_model = self._make_layer(num_pla_feat)
14
+
15
+ self.combined = self._make_output(num_img_feat + num_pla_feat)
16
+
17
+ for m in self.modules():
18
+ if isinstance(m, nn.Conv2d):
19
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
20
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
21
+ elif isinstance(m, nn.BatchNorm2d):
22
+ m.weight.data.fill_(1)
23
+ m.bias.data.zero_()
24
+
25
+ def _make_layer(self, num_feat):
26
+ ans = nn.ModuleList()
27
+ for _ in range(num_feat):
28
+ m = nn.Sequential(
29
+ nn.Conv2d(1, 1, 3, padding=1), nn.BatchNorm2d(1), nn.ReLU(inplace=True)
30
+ )
31
+ ans.append(m)
32
+ return ans
33
+
34
+ def _make_output(self, planes, readout=1):
35
+ return nn.Sequential(
36
+ nn.Conv2d(planes, readout, 3, stride=1, padding=1),
37
+ nn.BatchNorm2d(readout),
38
+ nn.Sigmoid(),
39
+ )
40
+
41
+ def forward(self, x):
42
+ img_feat, pla_feat = x
43
+ feat = []
44
+
45
+ for a, b in zip(img_feat, self.img_model):
46
+ f = F.interpolate(b(a), self.shape)
47
+ feat.append(f)
48
+
49
+ for a, b in zip(pla_feat, self.pla_model):
50
+ f = F.interpolate(b(a), self.shape)
51
+ feat.append(f)
52
+
53
+ feat = torch.cat(feat, dim=1)
54
+ feat = self.combined(feat)
55
+ return feat
56
+
57
+
58
+ def build_decoder(model_path, *args):
59
+ decoder = Decoder(*args)
60
+ loaded = torch.load(model_path, weights_only=True)["state_dict"]
61
+ decoder.load_state_dict(loaded)
62
+ return decoder
utils/saliency/resnet.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def conv3x3(in_planes, out_planes, stride=1):
9
+ conv = nn.Conv2d(
10
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
11
+ )
12
+ return conv
13
+
14
+
15
+ class BasicBlock(nn.Module):
16
+ expansion = 1
17
+
18
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
19
+ super(BasicBlock, self).__init__()
20
+ self.conv1 = conv3x3(inplanes, planes, stride)
21
+ self.bn1 = nn.BatchNorm2d(planes)
22
+ self.relu = nn.ReLU(inplace=True)
23
+ self.conv2 = conv3x3(planes, planes)
24
+ self.bn2 = nn.BatchNorm2d(planes)
25
+ self.downsample = downsample
26
+ self.stride = stride
27
+
28
+ def forward(self, x):
29
+ residual = x
30
+ out = self.conv1(x)
31
+ out = self.bn1(out)
32
+ out = self.relu(out)
33
+ out = self.conv2(out)
34
+ out = self.bn2(out)
35
+ if self.downsample is not None:
36
+ residual = self.downsample(x)
37
+ out += residual
38
+ out = self.relu(out)
39
+ return out
40
+
41
+
42
+ class Bottleneck(nn.Module):
43
+ expansion = 4
44
+
45
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
46
+ super(Bottleneck, self).__init__()
47
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
48
+ self.bn1 = nn.BatchNorm2d(planes)
49
+ self.conv2 = nn.Conv2d(
50
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
51
+ )
52
+ self.bn2 = nn.BatchNorm2d(planes)
53
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
54
+ self.bn3 = nn.BatchNorm2d(planes * 4)
55
+ self.relu = nn.ReLU(inplace=True)
56
+ self.downsample = downsample
57
+ self.stride = stride
58
+
59
+ def forward(self, x):
60
+ residual = x
61
+ out = self.conv1(x)
62
+ out = self.bn1(out)
63
+ out = self.relu(out)
64
+ out = self.conv2(out)
65
+ out = self.bn2(out)
66
+ out = self.relu(out)
67
+ out = self.conv3(out)
68
+ out = self.bn3(out)
69
+ if self.downsample is not None:
70
+ residual = self.downsample(x)
71
+ out += residual
72
+ out = self.relu(out)
73
+ return out
74
+
75
+
76
+ class ResNet(nn.Module):
77
+ def __init__(self, block, layers):
78
+ self.inplanes = 64
79
+ super(ResNet, self).__init__()
80
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
81
+ self.bn1 = nn.BatchNorm2d(64)
82
+ self.relu = nn.ReLU(inplace=True)
83
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
84
+ self.layer1 = self._make_layer(block, 64, layers[0])
85
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
86
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
87
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
88
+ self.out_channels = 1
89
+ self.output0 = self._make_output(64, readout=self.out_channels)
90
+ self.output1 = self._make_output(256, readout=self.out_channels)
91
+ self.output2 = self._make_output(512, readout=self.out_channels)
92
+ self.output3 = self._make_output(1024, readout=self.out_channels)
93
+ self.output4 = self._make_output(2048, readout=self.out_channels)
94
+ self.combined = self._make_output(5, sigmoid=True)
95
+ for m in self.modules():
96
+ if isinstance(m, nn.Conv2d):
97
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
98
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
99
+ elif isinstance(m, nn.BatchNorm2d):
100
+ m.weight.data.fill_(1)
101
+ m.bias.data.zero_()
102
+
103
+ def _make_output(self, planes, readout=1, sigmoid=False):
104
+ layers = [
105
+ nn.Conv2d(planes, readout, kernel_size=3, padding=1),
106
+ nn.BatchNorm2d(readout),
107
+ ]
108
+ if sigmoid:
109
+ layers.append(nn.Sigmoid())
110
+ else:
111
+ layers.append(nn.ReLU(inplace=True))
112
+ return nn.Sequential(*layers)
113
+
114
+ def _make_layer(self, block, planes, blocks, stride=1):
115
+ downsample = None
116
+ if stride != 1 or self.inplanes != planes * block.expansion:
117
+ downsample = nn.Sequential(
118
+ nn.Conv2d(
119
+ self.inplanes,
120
+ planes * block.expansion,
121
+ kernel_size=1,
122
+ stride=stride,
123
+ bias=False,
124
+ ),
125
+ nn.BatchNorm2d(planes * block.expansion),
126
+ )
127
+ layers = []
128
+ layers.append(block(self.inplanes, planes, stride, downsample))
129
+ self.inplanes = planes * block.expansion
130
+ for _ in range(1, blocks):
131
+ layers.append(block(self.inplanes, planes))
132
+ return nn.Sequential(*layers)
133
+
134
+ def forward(self, x, decode=False):
135
+ h, w = x.size(2), x.size(3)
136
+ x = self.conv1(x)
137
+ x = self.bn1(x)
138
+ out0 = self.relu(x)
139
+ x = self.maxpool(out0)
140
+ out1 = self.layer1(x)
141
+ out2 = self.layer2(out1)
142
+ out3 = self.layer3(out2)
143
+ out4 = self.layer4(out3)
144
+ out0 = self.output0(out0)
145
+ r, c = out0.size(2), out0.size(3)
146
+ out1 = self.output1(out1)
147
+ out2 = self.output2(out2)
148
+ out3 = self.output3(out3)
149
+ out4 = self.output4(out4)
150
+ if decode:
151
+ return [out0, out1, out2, out3, out4]
152
+ out1 = F.interpolate(out1, (r, c))
153
+ out2 = F.interpolate(out2, (r, c))
154
+ out3 = F.interpolate(out3, (r, c))
155
+ out4 = F.interpolate(out4, (r, c))
156
+ x = torch.cat([out0, out1, out2, out3, out4], dim=1)
157
+ x = self.combined(x)
158
+ x = F.interpolate(x, (h, w))
159
+ return x
160
+
161
+
162
+ def resnet50(model_path, **kwargs):
163
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
164
+ if model_path is not None:
165
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
166
+ model_state = model.state_dict()
167
+ loaded_model = torch.load(model_path, weights_only=True)
168
+ if "state_dict" in loaded_model:
169
+ loaded_model = loaded_model["state_dict"]
170
+ pretrained = {k[7:]: v for k, v in loaded_model.items() if k[7:] in model_state}
171
+ if len(pretrained) == 0:
172
+ pretrained = {k: v for k, v in loaded_model.items() if k in model_state}
173
+ model_state.update(pretrained)
174
+ model.load_state_dict(model_state)
175
+ return model
utils/saliency_utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from skimage import filters
4
+ from torchvision.transforms.functional import resize
5
+
6
+ from utils.saliency import decoder, resnet
7
+
8
+
9
+ def get_smap(image, path, filter_size=15):
10
+ """
11
+ Compute the saliency map of the target image using EMLNet.
12
+ Reference: https://arxiv.org/abs/1805.01047
13
+ Reference: https://github.com/SenJia/EML-NET-Saliency
14
+ """
15
+ if image.shape[0] != 3:
16
+ raise ValueError("Saliency prediction only supports RGB images")
17
+ sod_res = (480, 640)
18
+ imagenet_model = resnet.resnet50(f"{path}/emlnet/res_imagenet.pth").cuda().eval()
19
+ places_model = resnet.resnet50(f"{path}/emlnet/res_places.pth").cuda().eval()
20
+ decoder_model = (
21
+ decoder.build_decoder(f"{path}/emlnet/res_decoder.pth", sod_res, 5, 5)
22
+ .cuda()
23
+ .eval()
24
+ )
25
+ image_sod = resize(image, sod_res).unsqueeze(0)
26
+ with torch.no_grad():
27
+ imagenet_feat = imagenet_model(image_sod, decode=True)
28
+ places_feat = places_model(image_sod, decode=True)
29
+ smap = decoder_model([imagenet_feat, places_feat])
30
+ smap = resize(smap.squeeze(0).detach().cpu(), image.shape[1:]).squeeze(0)
31
+
32
+ def post_process(smap):
33
+ smap = filters.gaussian(smap, filter_size)
34
+ smap -= smap.min()
35
+ smap /= smap.max()
36
+ return smap
37
+
38
+ return post_process(smap.numpy()).astype(np.float32)
uv.lock ADDED
The diff for this file is too large to render. See raw diff