Julien Blanchon
commited on
Commit
·
d62394f
0
Parent(s):
Deploy optimized Image-GS with dynamic dependencies
Browse files- Pre-built gsplat wheel stored in blanchon/image-gs-models-utils
- Models automatically downloaded from HF models repository
- Dynamic installation of gsplat wheel at runtime
- Optimized Docker build without CUDA compilation
- Clean repository without binary files
- .dockerignore +46 -0
- Dockerfile +62 -0
- README.md +231 -0
- cfgs/default.yaml +57 -0
- gradio_app.py +809 -0
- gradio_models.py +827 -0
- main.py +57 -0
- model.py +824 -0
- pyproject.toml +46 -0
- utils/__init__.py +0 -0
- utils/flip.py +811 -0
- utils/image_utils.py +253 -0
- utils/misc_utils.py +52 -0
- utils/quantization_utils.py +17 -0
- utils/saliency/decoder.py +62 -0
- utils/saliency/resnet.py +175 -0
- utils/saliency_utils.py +38 -0
- uv.lock +0 -0
.dockerignore
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Git
|
| 2 |
+
.git
|
| 3 |
+
.gitignore
|
| 4 |
+
|
| 5 |
+
# Python
|
| 6 |
+
__pycache__
|
| 7 |
+
*.pyc
|
| 8 |
+
*.pyo
|
| 9 |
+
*.pyd
|
| 10 |
+
.Python
|
| 11 |
+
*.so
|
| 12 |
+
.pytest_cache
|
| 13 |
+
.coverage
|
| 14 |
+
|
| 15 |
+
# Virtual environments
|
| 16 |
+
.venv
|
| 17 |
+
.env
|
| 18 |
+
venv/
|
| 19 |
+
env/
|
| 20 |
+
|
| 21 |
+
# IDE
|
| 22 |
+
.vscode
|
| 23 |
+
.idea
|
| 24 |
+
*.swp
|
| 25 |
+
*.swo
|
| 26 |
+
|
| 27 |
+
# OS
|
| 28 |
+
.DS_Store
|
| 29 |
+
Thumbs.db
|
| 30 |
+
|
| 31 |
+
# Project specific
|
| 32 |
+
results/
|
| 33 |
+
temp_*
|
| 34 |
+
*.log
|
| 35 |
+
|
| 36 |
+
# Documentation
|
| 37 |
+
docs/
|
| 38 |
+
*.md
|
| 39 |
+
!README.md
|
| 40 |
+
|
| 41 |
+
# Assets (if large)
|
| 42 |
+
assets/images/
|
| 43 |
+
assets/fonts/
|
| 44 |
+
|
| 45 |
+
# GSplat documentation (not needed for runtime)
|
| 46 |
+
gsplat/src/gsplat/cuda/csrc/third_party/glm/doc/
|
Dockerfile
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use NVIDIA CUDA image that matches PyTorch's CUDA 12.4 compilation
|
| 2 |
+
FROM nvidia/cuda:12.4.0-devel-ubuntu22.04
|
| 3 |
+
|
| 4 |
+
# Install Python 3.10 and dependencies with cache mounts
|
| 5 |
+
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
| 6 |
+
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
| 7 |
+
apt-get update && apt-get install -y \
|
| 8 |
+
python3.10 \
|
| 9 |
+
python3.10-venv \
|
| 10 |
+
python3.10-dev \
|
| 11 |
+
python3-pip \
|
| 12 |
+
git \
|
| 13 |
+
build-essential \
|
| 14 |
+
curl \
|
| 15 |
+
ninja-build \
|
| 16 |
+
wget
|
| 17 |
+
|
| 18 |
+
# Create symlinks for python
|
| 19 |
+
RUN ln -sf /usr/bin/python3.10 /usr/bin/python3 && \
|
| 20 |
+
ln -sf /usr/bin/python3.10 /usr/bin/python
|
| 21 |
+
|
| 22 |
+
# Set CUDA environment variables for runtime
|
| 23 |
+
ENV CUDA_HOME=/usr/local/cuda \
|
| 24 |
+
PATH=/usr/local/cuda/bin:$PATH \
|
| 25 |
+
LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
| 26 |
+
|
| 27 |
+
# Install uv globally
|
| 28 |
+
COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
|
| 29 |
+
|
| 30 |
+
# Set up user with ID 1000 (required for HF Spaces)
|
| 31 |
+
RUN useradd -m -u 1000 user
|
| 32 |
+
|
| 33 |
+
# Switch to user and set working directory
|
| 34 |
+
USER user
|
| 35 |
+
WORKDIR /home/user/app
|
| 36 |
+
|
| 37 |
+
# Set environment variables
|
| 38 |
+
ENV HOME=/home/user \
|
| 39 |
+
PATH=/home/user/.local/bin:$PATH \
|
| 40 |
+
PYTHONUNBUFFERED=1 \
|
| 41 |
+
GRADIO_SERVER_NAME=0.0.0.0 \
|
| 42 |
+
GRADIO_SERVER_PORT=7860 \
|
| 43 |
+
UV_CACHE_DIR=/home/user/.cache/uv
|
| 44 |
+
|
| 45 |
+
# Copy dependency files first for better caching
|
| 46 |
+
COPY --chown=user pyproject.toml uv.lock ./
|
| 47 |
+
|
| 48 |
+
# Copy the pre-built wheels directory
|
| 49 |
+
COPY --chown=user wheels/ ./wheels/
|
| 50 |
+
|
| 51 |
+
# Install dependencies with uv (using pre-built wheel - much faster!)
|
| 52 |
+
RUN --mount=type=cache,target=/tmp/uv-cache,sharing=locked,uid=1000,gid=1000 \
|
| 53 |
+
UV_CACHE_DIR=/tmp/uv-cache uv sync --frozen --no-dev
|
| 54 |
+
|
| 55 |
+
# Copy the rest of the application
|
| 56 |
+
COPY --chown=user . .
|
| 57 |
+
|
| 58 |
+
# Expose port 7860 (default for HF Spaces)
|
| 59 |
+
EXPOSE 7860
|
| 60 |
+
|
| 61 |
+
# Launch the Gradio app
|
| 62 |
+
CMD ["uv", "run", "python", "gradio_app.py"]
|
README.md
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Image GS
|
| 3 |
+
emoji: 💻
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
<div align="center">
|
| 12 |
+
|
| 13 |
+
<h1>Image-GS: Content-Adaptive Image Representation via 2D Gaussians</h1>
|
| 14 |
+
|
| 15 |
+
[**Yunxiang Zhang**](https://yunxiangzhang.github.io/)<sup>1\*</sup>,
|
| 16 |
+
[**Bingxuan Li**](https://bingxuan-li.github.io/)<sup>1\*</sup>,
|
| 17 |
+
[**Alexandr Kuznetsov**](https://alexku.me/)<sup>3†</sup>,
|
| 18 |
+
[**Akshay Jindal**](https://www.akshayjindal.com/)<sup>2</sup>,
|
| 19 |
+
[**Stavros Diolatzis**](https://www.sdiolatz.info/)<sup>2</sup>,
|
| 20 |
+
[**Kenneth Chen**](https://kenchen10.github.io/)<sup>1</sup>,
|
| 21 |
+
[**Anton Sochenov**](https://www.intel.com/content/www/us/en/developer/articles/community/gpu-researchers-anton-sochenov.html)<sup>2</sup>,
|
| 22 |
+
[**Anton Kaplanyan**](http://kaplanyan.com/)<sup>2</sup>,
|
| 23 |
+
[**Qi Sun**](https://qisun.me/)<sup>1</sup>
|
| 24 |
+
|
| 25 |
+
\* Equal contribution   † Work done while at Intel
|
| 26 |
+
|
| 27 |
+
<sup>1</sup>
|
| 28 |
+
<a href="https://www.immersivecomputinglab.org/research/"><img width="30%" src="assets/images/NYU-logo.png" style="vertical-align: top;" alt="NYU logo"></a>
|
| 29 |
+
 
|
| 30 |
+
<sup>2</sup>
|
| 31 |
+
<a href="https://www.intel.com/content/www/us/en/developer/topic-technology/graphics-research/overview.html"><img width="22%" src="assets/images/Intel-logo.png" style="vertical-align: top;" alt="Intel logo"></a>
|
| 32 |
+
 
|
| 33 |
+
<sup>3</sup>
|
| 34 |
+
<a href="https://www.amd.com/en.html"><img width="33%" src="assets/images/AMD-logo.png" style="vertical-align: top;" alt="AMD logo"></a>
|
| 35 |
+
|
| 36 |
+
<a href="https://arxiv.org/abs/2407.01866"><img src="https://img.shields.io/badge/arXiv-2407.01866-red" alt="arXiv"></a>
|
| 37 |
+
<a href="https://www.immersivecomputinglab.org/publication/image-gs-content-adaptive-image-representation-via-2d-gaussians/"><img src="https://img.shields.io/badge/project page-ImageGS-blue" alt="project page"></a>
|
| 38 |
+
<a href="https://github.com/NYU-ICL/image-gs"><img src="https://visitor-badge.laobi.icu/badge?page_id=NYU-ICL.image-gs&left_color=green&right_color=red" alt="visitors"></a>
|
| 39 |
+
|
| 40 |
+
</div>
|
| 41 |
+
|
| 42 |
+
<div style="width: 90%; margin: 0 auto;">
|
| 43 |
+
Neural image representations have emerged as a promising approach for encoding and rendering visual data. Combined with learning-based workflows, they demonstrate impressive trade-offs between visual fidelity and memory footprint. Existing methods in this domain, however, often rely on fixed data structures that suboptimally allocate memory or compute-intensive implicit models, hindering their practicality for real-time graphics applications.
|
| 44 |
+
|
| 45 |
+
Inspired by recent advancements in radiance field rendering, we introduce Image-GS, a content-adaptive image representation based on 2D Gaussians. Leveraging a custom differentiable renderer, Image-GS reconstructs images by adaptively allocating and progressively optimizing a group of anisotropic, colored 2D Gaussians. It achieves a favorable balance between visual fidelity and memory efficiency across a variety of stylized images frequently seen in graphics workflows, especially for those showing non-uniformly distributed features and in low-bitrate regimes. Moreover, it supports hardware-friendly rapid random access for real-time usage, requiring only 0.3K MACs to decode a pixel. Through error-guided progressive optimization, Image-GS naturally constructs a smooth level-of-detail hierarchy. We demonstrate its versatility with several applications, including texture compression, semantics-aware compression, and joint image compression and restoration.
|
| 46 |
+
|
| 47 |
+
<img src="assets/images/teaser.jpg" width="100%" />
|
| 48 |
+
<sub>
|
| 49 |
+
Figure 1: Image-GS reconstructs an image by adaptively allocating and progressively optimizing a set of colored 2D Gaussians. It achieves favorable rate-distortion trade-offs, hardware-friendly random access, and flexible quality control through a smooth level-of-detail stack. (a) visualizes the optimized spatial distribution of Gaussians (20% randomly sampled for clarity). (b) Image-GS’s explicit content-adaptive design effectively captures non-uniformly distributed image features and better preserves fine details under constrained memory budgets. In the inset error maps, brighter colors indicate larger errors.
|
| 50 |
+
</sub>
|
| 51 |
+
</div>
|
| 52 |
+
|
| 53 |
+
## Setup
|
| 54 |
+
1. Create a dedicated Python environment and install the dependencies
|
| 55 |
+
```bash
|
| 56 |
+
git clone https://github.com/NYU-ICL/image-gs.git
|
| 57 |
+
cd image-gs
|
| 58 |
+
conda env create -f environment.yml
|
| 59 |
+
conda activate image-gs
|
| 60 |
+
pip install git+https://github.com/rahul-goel/fused-ssim/ --no-build-isolation
|
| 61 |
+
cd gsplat
|
| 62 |
+
pip install -e ".[dev]"
|
| 63 |
+
cd ..
|
| 64 |
+
```
|
| 65 |
+
2. Download the image and texture datasets from [OneDrive](https://1drv.ms/u/c/3a8968df8a027819/EeshjZJlMtdCmvvmESiN2pABM71EDaoLYmEwuOvecg0tAA?e=GybqBv) and organize the folder structure as follows
|
| 66 |
+
```
|
| 67 |
+
image-gs
|
| 68 |
+
└── media
|
| 69 |
+
├── images
|
| 70 |
+
└── textures
|
| 71 |
+
```
|
| 72 |
+
3. (Optional) To run saliency-guided Gaussian position initialization, download the pre-trained [EML-Net](https://github.com/SenJia/EML-NET-Saliency) models ([res_imagenet.pth](https://drive.google.com/open?id=1-a494canr9qWKLdm-DUDMgbGwtlAJz71), [res_places.pth](https://drive.google.com/open?id=18nRz0JSRICLqnLQtAvq01azZAsH0SEzS), [res_decoder.pth](https://drive.google.com/open?id=1vwrkz3eX-AMtXQE08oivGMwS4lKB74sH)) and place them under the `models/emlnet/` folder
|
| 73 |
+
```
|
| 74 |
+
image-gs
|
| 75 |
+
└── models
|
| 76 |
+
└── emlnet
|
| 77 |
+
├── res_decoder.pth
|
| 78 |
+
├── res_imagenet.pth
|
| 79 |
+
└── res_places.pth
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
## Quick Start
|
| 83 |
+
|
| 84 |
+
#### Image Compression
|
| 85 |
+
- Optimize an Image-GS representation for an input image `anime-1_2k.png` using `10000` Gaussians with half-precision parameters
|
| 86 |
+
```bash
|
| 87 |
+
python main.py --input_path="images/anime-1_2k.png" --exp_name="test/anime-1_2k" --num_gaussians=10000 --quantize
|
| 88 |
+
```
|
| 89 |
+
- Render the corresponding optimized Image-GS representation at a new resolution with height `4000` (aspect ratio is maintained)
|
| 90 |
+
```bash
|
| 91 |
+
python main.py --input_path="images/anime-1_2k.png" --exp_name="test/anime-1_2k" --num_gaussians=10000 --quantize --eval --render_height=4000
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
#### Texture Stack Compression
|
| 95 |
+
- Optimize an Image-GS representation for an input texture stack `alarm-clock_2k` using `30000` Gaussians with half-precision parameters
|
| 96 |
+
```bash
|
| 97 |
+
python main.py --input_path="textures/alarm-clock_2k" --exp_name="test/alarm-clock_2k" --num_gaussians=30000 --quantize
|
| 98 |
+
```
|
| 99 |
+
- Render the corresponding optimized Image-GS representation at a new resolution with height `3000` (aspect ratio is maintained)
|
| 100 |
+
```bash
|
| 101 |
+
python main.py --input_path="textures/alarm-clock_2k" --exp_name="test/alarm-clock_2k" --num_gaussians=30000 --quantize --eval --render_height=3000
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
#### Control bit precision of Gaussian parameters
|
| 105 |
+
- Optimize an Image-GS representation for an input image `anime-1_2k.png` using `10000` Gaussians with 12-bit-precision parameters
|
| 106 |
+
```bash
|
| 107 |
+
python main.py --input_path="images/anime-1_2k.png" --exp_name="test/anime-1_2k" --num_gaussians=10000 --quantize --pos_bits=12 --scale_bits 12 --rot_bits 12 --feat_bits 12
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
#### Switch to saliency-guided Gaussian position initialization
|
| 111 |
+
- Optimize an Image-GS representation for an input image `anime-1_2k.png` using `10000` Gaussians with half-precision parameters and saliency-guided initialization
|
| 112 |
+
```bash
|
| 113 |
+
python main.py --input_path="images/anime-1_2k.png" --exp_name="test/anime-1_2k" --num_gaussians=10000 --quantize --init_mode="saliency"
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
## Gradio Web Interface
|
| 117 |
+
|
| 118 |
+
We provide a user-friendly web interface built with Gradio for easy experimentation and training visualization.
|
| 119 |
+
|
| 120 |
+
### Setup for Web Interface
|
| 121 |
+
|
| 122 |
+
1. Install Gradio (in addition to the main dependencies):
|
| 123 |
+
```bash
|
| 124 |
+
pip install gradio>=4.0.0
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
2. Launch the web interface:
|
| 128 |
+
```bash
|
| 129 |
+
python gradio_app.py
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
3. Open your browser and navigate to `http://localhost:7860`
|
| 133 |
+
|
| 134 |
+
### Features
|
| 135 |
+
|
| 136 |
+
The Gradio interface provides:
|
| 137 |
+
|
| 138 |
+
- **Interactive Parameter Configuration**: Adjust all training parameters through an intuitive UI
|
| 139 |
+
- **Image Upload**: Drag and drop any image to train on
|
| 140 |
+
- **Real-time Training Progress**: Stream training logs and intermediate results
|
| 141 |
+
- **Live Visualization**: Watch Gaussian placement and rendering progress during training
|
| 142 |
+
- **Result Gallery**: View final renders, gradient maps, and saliency maps
|
| 143 |
+
- **Easy Experimentation**: No need to remember command-line arguments
|
| 144 |
+
|
| 145 |
+
### Interface Sections
|
| 146 |
+
|
| 147 |
+
1. **Configuration Panel**:
|
| 148 |
+
- Basic parameters (number of Gaussians, training steps)
|
| 149 |
+
- Quantization settings for memory efficiency
|
| 150 |
+
- Initialization modes (gradient, saliency, random)
|
| 151 |
+
- Advanced optimization parameters (learning rates, loss weights)
|
| 152 |
+
|
| 153 |
+
2. **Training Progress**:
|
| 154 |
+
- Real-time streaming logs
|
| 155 |
+
- Current render and Gaussian visualization updates
|
| 156 |
+
- Training status and control buttons
|
| 157 |
+
|
| 158 |
+
3. **Results Display**:
|
| 159 |
+
- Final optimized image
|
| 160 |
+
- Gradient and saliency maps used for initialization
|
| 161 |
+
- Download capabilities for all results
|
| 162 |
+
|
| 163 |
+
### Usage Tips
|
| 164 |
+
|
| 165 |
+
- Start with default parameters for your first run
|
| 166 |
+
- Use **saliency initialization** for better results on complex images
|
| 167 |
+
- Enable **Gaussian visualization** to see how the representation evolves
|
| 168 |
+
- Adjust **save image steps** to control visualization frequency (lower = more updates, but slower)
|
| 169 |
+
- For quick tests, reduce **max steps** to 500-1000
|
| 170 |
+
|
| 171 |
+
### Command Line Arguments
|
| 172 |
+
Please refer to `cfgs/default.yaml` for the full list of arguments and their default values.
|
| 173 |
+
|
| 174 |
+
**Post-optimization rendering**
|
| 175 |
+
- `--eval` render the optimized Image-GS representation.
|
| 176 |
+
- `--render_height` image height for rendering (aspect ratio is maintained).
|
| 177 |
+
|
| 178 |
+
**Bit precision control**: 32 bits (float32) per dimension by default
|
| 179 |
+
- `--quantize` enable bit precision control of Gaussian parameters.
|
| 180 |
+
- `--pos_bits` bit precision of individual coordinate dimension.
|
| 181 |
+
- `--scale_bits` bit precision of individual scale dimension.
|
| 182 |
+
- `--rot_bits` bit precision of Gaussian orientation angle.
|
| 183 |
+
- `--feat_bits` bit precision of individual feature dimension.
|
| 184 |
+
|
| 185 |
+
**Logging**
|
| 186 |
+
- `--exp_name` path to the logging directory.
|
| 187 |
+
- `--vis_gaussians`: visualize Gaussians during optimization.
|
| 188 |
+
- `--save_image_steps` frequency of rendering intermediate results during optimization.
|
| 189 |
+
- `--save_ckpt_steps` frequency of checkpointing during optimization.
|
| 190 |
+
|
| 191 |
+
**Input image**
|
| 192 |
+
- `--input_path` path to an image file or a directory containing a texture stack.
|
| 193 |
+
- `--downsample` load a downsampled version of the input image or texture stack as the optimization target to evaluate image upsampling performance.
|
| 194 |
+
- `--downsample_ratio` downsampling ratio.
|
| 195 |
+
- `--gamma` optimize in a gamma-corrected space, modify with caution.
|
| 196 |
+
|
| 197 |
+
**Gaussian**
|
| 198 |
+
- `--num_gaussians` number of Gaussians (for compression rate control).
|
| 199 |
+
- `--init_scale` initial Gaussian scale in number of pixels.
|
| 200 |
+
- `--disable_topk_norm` disable top-K normalization.
|
| 201 |
+
- `--disable_inverse_scale` disable inverse Gaussian scale optimization.
|
| 202 |
+
- `--init_mode` Gaussian position initialization mode, valid values include "gradient", "saliency", and "random".
|
| 203 |
+
- `--init_random_ratio` ratio of Gaussians with randomly initialized position.
|
| 204 |
+
|
| 205 |
+
**Optimization**
|
| 206 |
+
- `--disable_tiles` disable tile-based rendering (warning: optimization and rendering without tiles will be way slower).
|
| 207 |
+
- `--max_steps` maximum number of optimization steps.
|
| 208 |
+
- `--pos_lr` Gaussian position learning rate.
|
| 209 |
+
- `--scale_lr` Gaussian scale learning rate.
|
| 210 |
+
- `--rot_lr` Gaussian orientation angle learning rate.
|
| 211 |
+
- `--feat_lr` Gaussian feature learning rate.
|
| 212 |
+
- `--disable_lr_schedule` disable learning rate decay and early stopping schedule.
|
| 213 |
+
- `--disable_prog_optim` disable error-guided progressive optimization.
|
| 214 |
+
|
| 215 |
+
## Acknowledgements
|
| 216 |
+
We would like to thank the [gsplat](https://github.com/nerfstudio-project/gsplat) team, and the authors of [3DGS](https://github.com/graphdeco-inria/gaussian-splatting), [fused-ssim](https://github.com/rahul-goel/fused-ssim), and [EML-Net](https://github.com/SenJia/EML-NET-Saliency) for their great work, based on which Image-GS was developed.
|
| 217 |
+
|
| 218 |
+
## License
|
| 219 |
+
This project is licensed under the terms of the MIT license.
|
| 220 |
+
|
| 221 |
+
## Citation
|
| 222 |
+
If you find this project helpful to your research, please consider citing [BibTeX](assets/docs/image-gs.bib):
|
| 223 |
+
```bibtex
|
| 224 |
+
@inproceedings{zhang2025image,
|
| 225 |
+
title={Image-gs: Content-adaptive image representation via 2d gaussians},
|
| 226 |
+
author={Zhang, Yunxiang and Li, Bingxuan and Kuznetsov, Alexandr and Jindal, Akshay and Diolatzis, Stavros and Chen, Kenneth and Sochenov, Anton and Kaplanyan, Anton and Sun, Qi},
|
| 227 |
+
booktitle={Proceedings of the Special Interest Group on Computer Graphics and Interactive Techniques Conference Conference Papers},
|
| 228 |
+
pages={1--11},
|
| 229 |
+
year={2025}
|
| 230 |
+
}
|
| 231 |
+
```
|
cfgs/default.yaml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed: 123
|
| 2 |
+
device: "cuda:0"
|
| 3 |
+
# Evaluation
|
| 4 |
+
eval: False # Render the optimized Image-GS representation
|
| 5 |
+
render_height: 2048 # Image height for rendering (aspect ratio is maintained)
|
| 6 |
+
# Bit precision
|
| 7 |
+
quantize: False # Enable bit precision control of Gaussian parameters
|
| 8 |
+
pos_bits: 16 # Bit precision of individual coordinate dimension
|
| 9 |
+
scale_bits: 16 # Bit precision of individual scale dimension
|
| 10 |
+
rot_bits: 16 # Bit precision of Gaussian orientation angle
|
| 11 |
+
feat_bits: 16 # Bit precision of individual feature dimension
|
| 12 |
+
# Logging
|
| 13 |
+
log_root: "results"
|
| 14 |
+
exp_name: "test/anime-1_2k" # Path to the logging directory
|
| 15 |
+
log_level: "INFO"
|
| 16 |
+
vis_gaussians: False # Visualize Gaussians during optimization
|
| 17 |
+
save_image_steps: 100000 # Frequency of rendering intermediate results during optimization
|
| 18 |
+
save_ckpt_steps: 100000 # Frequency of checkpointing during optimization
|
| 19 |
+
eval_steps: 100
|
| 20 |
+
# Target images
|
| 21 |
+
gamma: 1.0 # Optimize in a gamma-corrected space, modify with caution
|
| 22 |
+
data_root: "media"
|
| 23 |
+
input_path: "images/anime-1_2k.png" # Path to an image file or a directory containing a texture stack
|
| 24 |
+
downsample: False # Load a downsampled version of the input image or texture stack as the optimization target to evaluate image upsampling performance
|
| 25 |
+
downsample_ratio: 2.0
|
| 26 |
+
# Gaussians
|
| 27 |
+
num_gaussians: 10000 # Number of Gaussians (for compression rate control)
|
| 28 |
+
init_scale: 5.0 # Initial Gaussian scale in number of pixels
|
| 29 |
+
topk: 10 # Warning: Must match hardcoded value in CUDA kernel, modify with caution
|
| 30 |
+
disable_topk_norm: False # Disable top-K normalization
|
| 31 |
+
disable_inverse_scale: False # Disable inverse Gaussian scale optimization
|
| 32 |
+
ckpt_file: ""
|
| 33 |
+
disable_color_init: False
|
| 34 |
+
init_mode: "gradient" # Gaussian position initialization mode, valid values include "gradient", "saliency", and "random"
|
| 35 |
+
init_random_ratio: 0.3 # Ratio of Gaussians with randomly initialized position
|
| 36 |
+
smap_filter_size: 20 # Gaussian filter size for smoothing saliency maps
|
| 37 |
+
# Loss functions
|
| 38 |
+
l1_loss_ratio: 1.0
|
| 39 |
+
l2_loss_ratio: 0.0
|
| 40 |
+
ssim_loss_ratio: 0.1
|
| 41 |
+
# Optimization
|
| 42 |
+
disable_tiles: False # Disable tile-based rendering (warning: optimization and rendering without tiles will be way slower)
|
| 43 |
+
max_steps: 10000 # Maximum number of optimization steps
|
| 44 |
+
pos_lr: 5.0e-4
|
| 45 |
+
scale_lr: 2.0e-3
|
| 46 |
+
rot_lr: 2.0e-3
|
| 47 |
+
feat_lr: 5.0e-3
|
| 48 |
+
disable_lr_schedule: False # Disable learning rate schedule and early stopping
|
| 49 |
+
decay_ratio: 10.0
|
| 50 |
+
check_decay_steps: 1000
|
| 51 |
+
max_decay_times: 1
|
| 52 |
+
decay_threshold: 1.0e-3
|
| 53 |
+
disable_prog_optim: False # Disable error-guided progressive optimization
|
| 54 |
+
initial_ratio: 0.5
|
| 55 |
+
add_steps: 500
|
| 56 |
+
add_times: 4
|
| 57 |
+
post_min_steps: 3000
|
gradio_app.py
ADDED
|
@@ -0,0 +1,809 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import threading
|
| 5 |
+
import argparse
|
| 6 |
+
import tempfile
|
| 7 |
+
import shutil
|
| 8 |
+
from typing import Generator, Optional, Tuple
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import gradio as gr
|
| 13 |
+
except ImportError:
|
| 14 |
+
print("❌ Gradio not found. Please install it with: pip install gradio>=4.0.0")
|
| 15 |
+
sys.exit(1)
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 19 |
+
except ImportError:
|
| 20 |
+
print(
|
| 21 |
+
"❌ huggingface_hub not found. Please install it with: pip install huggingface_hub"
|
| 22 |
+
)
|
| 23 |
+
sys.exit(1)
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
from PIL import Image
|
| 27 |
+
|
| 28 |
+
# Add the project root to the path so we can import the modules
|
| 29 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 30 |
+
|
| 31 |
+
from gradio_models import GradioGaussianSplatting2D, StreamingResults
|
| 32 |
+
from utils.misc_utils import load_cfg
|
| 33 |
+
from main import get_log_dir
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class TrainingState:
|
| 37 |
+
"""Manages the state of training sessions"""
|
| 38 |
+
|
| 39 |
+
def __init__(self):
|
| 40 |
+
self.is_training = False
|
| 41 |
+
self.training_thread = None
|
| 42 |
+
self.model = None
|
| 43 |
+
self.temp_dir = None
|
| 44 |
+
self.results = StreamingResults()
|
| 45 |
+
|
| 46 |
+
def reset(self):
|
| 47 |
+
self.is_training = False
|
| 48 |
+
if self.temp_dir and os.path.exists(self.temp_dir):
|
| 49 |
+
shutil.rmtree(self.temp_dir)
|
| 50 |
+
self.temp_dir = None
|
| 51 |
+
self.results = StreamingResults()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Global training state
|
| 55 |
+
training_state = TrainingState()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def ensure_models_available():
|
| 59 |
+
"""Download models from HuggingFace if they're not available locally"""
|
| 60 |
+
models_dir = "models"
|
| 61 |
+
|
| 62 |
+
# Check if models directory exists and has the required files
|
| 63 |
+
required_files = [
|
| 64 |
+
"models/emlnet/res_decoder.pth",
|
| 65 |
+
"models/emlnet/res_imagenet.pth",
|
| 66 |
+
"models/emlnet/res_places.pth",
|
| 67 |
+
"models/torch/checkpoints/alexnet-owt-7be5be79.pth",
|
| 68 |
+
]
|
| 69 |
+
|
| 70 |
+
# Check if all required files exist
|
| 71 |
+
all_files_exist = all(os.path.exists(file_path) for file_path in required_files)
|
| 72 |
+
|
| 73 |
+
if not all_files_exist:
|
| 74 |
+
print("📥 Downloading model files from HuggingFace...")
|
| 75 |
+
try:
|
| 76 |
+
# Create models directory if it doesn't exist
|
| 77 |
+
os.makedirs("models", exist_ok=True)
|
| 78 |
+
|
| 79 |
+
# Download individual model files to ensure they end up in the right place
|
| 80 |
+
model_files_remote = [
|
| 81 |
+
"emlnet/res_decoder.pth",
|
| 82 |
+
"emlnet/res_imagenet.pth",
|
| 83 |
+
"emlnet/res_places.pth",
|
| 84 |
+
"torch/checkpoints/alexnet-owt-7be5be79.pth",
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
model_files_local = [
|
| 88 |
+
"models/emlnet/res_decoder.pth",
|
| 89 |
+
"models/emlnet/res_imagenet.pth",
|
| 90 |
+
"models/emlnet/res_places.pth",
|
| 91 |
+
"models/torch/checkpoints/alexnet-owt-7be5be79.pth",
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
for remote_file, local_file in zip(model_files_remote, model_files_local):
|
| 95 |
+
if not os.path.exists(local_file):
|
| 96 |
+
# Create directory structure
|
| 97 |
+
os.makedirs(os.path.dirname(local_file), exist_ok=True)
|
| 98 |
+
|
| 99 |
+
# Download the specific file
|
| 100 |
+
print(f"📥 Downloading {remote_file} -> {local_file}...")
|
| 101 |
+
downloaded_path = hf_hub_download(
|
| 102 |
+
repo_id="blanchon/image-gs-models-utils",
|
| 103 |
+
filename=remote_file,
|
| 104 |
+
repo_type="model",
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Copy to the expected local path
|
| 108 |
+
shutil.copy2(downloaded_path, local_file)
|
| 109 |
+
|
| 110 |
+
print("✅ Model files downloaded successfully!")
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"❌ Failed to download model files: {e}")
|
| 113 |
+
print("⚠️ The app may not work properly without these model files.")
|
| 114 |
+
else:
|
| 115 |
+
print("✅ Model files are already available locally.")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def create_args_from_config(
|
| 119 |
+
image_path: str,
|
| 120 |
+
exp_name: str,
|
| 121 |
+
num_gaussians: int,
|
| 122 |
+
quantize: bool,
|
| 123 |
+
pos_bits: int,
|
| 124 |
+
scale_bits: int,
|
| 125 |
+
rot_bits: int,
|
| 126 |
+
feat_bits: int,
|
| 127 |
+
init_mode: str,
|
| 128 |
+
init_random_ratio: float,
|
| 129 |
+
max_steps: int,
|
| 130 |
+
vis_gaussians: bool,
|
| 131 |
+
save_image_steps: int,
|
| 132 |
+
l1_loss_ratio: float,
|
| 133 |
+
l2_loss_ratio: float,
|
| 134 |
+
ssim_loss_ratio: float,
|
| 135 |
+
pos_lr: float,
|
| 136 |
+
scale_lr: float,
|
| 137 |
+
rot_lr: float,
|
| 138 |
+
feat_lr: float,
|
| 139 |
+
disable_lr_schedule: bool,
|
| 140 |
+
disable_prog_optim: bool,
|
| 141 |
+
) -> argparse.Namespace:
|
| 142 |
+
"""Create arguments object from Gradio inputs"""
|
| 143 |
+
|
| 144 |
+
# Load default config
|
| 145 |
+
parser = argparse.ArgumentParser()
|
| 146 |
+
parser = load_cfg(cfg_path="cfgs/default.yaml", parser=parser)
|
| 147 |
+
args = parser.parse_args([]) # Parse empty args to get defaults
|
| 148 |
+
|
| 149 |
+
# Override with user inputs
|
| 150 |
+
args.input_path = image_path
|
| 151 |
+
args.exp_name = exp_name
|
| 152 |
+
args.num_gaussians = num_gaussians
|
| 153 |
+
args.quantize = quantize
|
| 154 |
+
args.pos_bits = pos_bits
|
| 155 |
+
args.scale_bits = scale_bits
|
| 156 |
+
args.rot_bits = rot_bits
|
| 157 |
+
args.feat_bits = feat_bits
|
| 158 |
+
args.init_mode = init_mode
|
| 159 |
+
args.init_random_ratio = init_random_ratio
|
| 160 |
+
args.max_steps = max_steps
|
| 161 |
+
args.vis_gaussians = vis_gaussians
|
| 162 |
+
args.save_image_steps = save_image_steps
|
| 163 |
+
args.l1_loss_ratio = l1_loss_ratio
|
| 164 |
+
args.l2_loss_ratio = l2_loss_ratio
|
| 165 |
+
args.ssim_loss_ratio = ssim_loss_ratio
|
| 166 |
+
args.pos_lr = pos_lr
|
| 167 |
+
args.scale_lr = scale_lr
|
| 168 |
+
args.rot_lr = rot_lr
|
| 169 |
+
args.feat_lr = feat_lr
|
| 170 |
+
args.disable_lr_schedule = disable_lr_schedule
|
| 171 |
+
args.disable_prog_optim = disable_prog_optim
|
| 172 |
+
args.eval = False
|
| 173 |
+
|
| 174 |
+
# Set up logging directory
|
| 175 |
+
args.log_dir = get_log_dir(args)
|
| 176 |
+
|
| 177 |
+
return args
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def train_model(args: argparse.Namespace) -> None:
|
| 181 |
+
"""Training function that runs in a separate thread"""
|
| 182 |
+
try:
|
| 183 |
+
# Create and train model with streaming results
|
| 184 |
+
training_state.model = GradioGaussianSplatting2D(args, training_state.results)
|
| 185 |
+
|
| 186 |
+
# Start training
|
| 187 |
+
training_state.model.optimize()
|
| 188 |
+
|
| 189 |
+
except Exception as e:
|
| 190 |
+
training_state.results.training_logs.append(f"ERROR: {str(e)}")
|
| 191 |
+
logging.error(f"Training failed: {str(e)}")
|
| 192 |
+
finally:
|
| 193 |
+
training_state.is_training = False
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def start_training_and_stream(
|
| 197 |
+
image_file,
|
| 198 |
+
exp_name: str,
|
| 199 |
+
num_gaussians: int,
|
| 200 |
+
quantize: bool,
|
| 201 |
+
pos_bits: int,
|
| 202 |
+
scale_bits: int,
|
| 203 |
+
rot_bits: int,
|
| 204 |
+
feat_bits: int,
|
| 205 |
+
init_mode: str,
|
| 206 |
+
init_random_ratio: float,
|
| 207 |
+
max_steps: int,
|
| 208 |
+
vis_gaussians: bool,
|
| 209 |
+
save_image_steps: int,
|
| 210 |
+
l1_loss_ratio: float,
|
| 211 |
+
l2_loss_ratio: float,
|
| 212 |
+
ssim_loss_ratio: float,
|
| 213 |
+
pos_lr: float,
|
| 214 |
+
scale_lr: float,
|
| 215 |
+
rot_lr: float,
|
| 216 |
+
feat_lr: float,
|
| 217 |
+
disable_lr_schedule: bool,
|
| 218 |
+
disable_prog_optim: bool,
|
| 219 |
+
) -> Generator[
|
| 220 |
+
Tuple[
|
| 221 |
+
str,
|
| 222 |
+
str,
|
| 223 |
+
Optional[Image.Image], # initialization_map
|
| 224 |
+
Optional[Image.Image], # current_render
|
| 225 |
+
Optional[Image.Image], # current_gaussian_id
|
| 226 |
+
bool, # start_btn_interactive
|
| 227 |
+
bool, # stop_btn_interactive
|
| 228 |
+
],
|
| 229 |
+
None,
|
| 230 |
+
None,
|
| 231 |
+
]:
|
| 232 |
+
"""Start training and stream progress with images"""
|
| 233 |
+
|
| 234 |
+
if training_state.is_training:
|
| 235 |
+
yield (
|
| 236 |
+
"Training is already in progress!",
|
| 237 |
+
"",
|
| 238 |
+
None,
|
| 239 |
+
None,
|
| 240 |
+
None,
|
| 241 |
+
False, # start_btn disabled
|
| 242 |
+
True, # stop_btn enabled
|
| 243 |
+
)
|
| 244 |
+
return
|
| 245 |
+
|
| 246 |
+
if image_file is None:
|
| 247 |
+
yield (
|
| 248 |
+
"Please upload an image first!",
|
| 249 |
+
"",
|
| 250 |
+
None,
|
| 251 |
+
None,
|
| 252 |
+
None,
|
| 253 |
+
True, # start_btn enabled
|
| 254 |
+
False, # stop_btn disabled
|
| 255 |
+
)
|
| 256 |
+
return
|
| 257 |
+
|
| 258 |
+
try:
|
| 259 |
+
# Reset training state
|
| 260 |
+
training_state.reset()
|
| 261 |
+
|
| 262 |
+
# Create temporary directory for the uploaded image
|
| 263 |
+
training_state.temp_dir = tempfile.mkdtemp()
|
| 264 |
+
|
| 265 |
+
# Save uploaded image
|
| 266 |
+
image_path = os.path.join(training_state.temp_dir, "input_image.png")
|
| 267 |
+
image_file.save(image_path)
|
| 268 |
+
|
| 269 |
+
# Create args
|
| 270 |
+
args = create_args_from_config(
|
| 271 |
+
image_path=image_path,
|
| 272 |
+
exp_name=exp_name,
|
| 273 |
+
num_gaussians=num_gaussians,
|
| 274 |
+
quantize=quantize,
|
| 275 |
+
pos_bits=pos_bits,
|
| 276 |
+
scale_bits=scale_bits,
|
| 277 |
+
rot_bits=rot_bits,
|
| 278 |
+
feat_bits=feat_bits,
|
| 279 |
+
init_mode=init_mode,
|
| 280 |
+
init_random_ratio=init_random_ratio,
|
| 281 |
+
max_steps=max_steps,
|
| 282 |
+
vis_gaussians=vis_gaussians,
|
| 283 |
+
save_image_steps=save_image_steps,
|
| 284 |
+
l1_loss_ratio=l1_loss_ratio,
|
| 285 |
+
l2_loss_ratio=l2_loss_ratio,
|
| 286 |
+
ssim_loss_ratio=ssim_loss_ratio,
|
| 287 |
+
pos_lr=pos_lr,
|
| 288 |
+
scale_lr=scale_lr,
|
| 289 |
+
rot_lr=rot_lr,
|
| 290 |
+
feat_lr=feat_lr,
|
| 291 |
+
disable_lr_schedule=disable_lr_schedule,
|
| 292 |
+
disable_prog_optim=disable_prog_optim,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Update data_root to use temp directory
|
| 296 |
+
args.data_root = training_state.temp_dir
|
| 297 |
+
args.input_path = "input_image.png"
|
| 298 |
+
|
| 299 |
+
# Start training in separate thread
|
| 300 |
+
training_state.is_training = True
|
| 301 |
+
training_state.training_thread = threading.Thread(
|
| 302 |
+
target=train_model, args=(args,)
|
| 303 |
+
)
|
| 304 |
+
training_state.training_thread.start()
|
| 305 |
+
|
| 306 |
+
# Initial yield
|
| 307 |
+
yield (
|
| 308 |
+
"Training started! Check the progress below.",
|
| 309 |
+
"Initializing training...",
|
| 310 |
+
None, # initialization_map
|
| 311 |
+
None, # current_render
|
| 312 |
+
None, # current_gaussian_id
|
| 313 |
+
False, # start_btn disabled
|
| 314 |
+
True, # stop_btn enabled
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
# Stream training progress
|
| 318 |
+
while training_state.is_training or not training_state.results.is_complete:
|
| 319 |
+
# Check if stop was requested
|
| 320 |
+
if (
|
| 321 |
+
not training_state.is_training
|
| 322 |
+
and training_state.training_thread
|
| 323 |
+
and training_state.training_thread.is_alive()
|
| 324 |
+
):
|
| 325 |
+
# Force stop the training thread if needed
|
| 326 |
+
training_state.results.training_logs.append(
|
| 327 |
+
"🛑 Training stopped by user request"
|
| 328 |
+
)
|
| 329 |
+
break
|
| 330 |
+
|
| 331 |
+
# Get training logs
|
| 332 |
+
if training_state.results.training_logs:
|
| 333 |
+
logs_text = "\n".join(training_state.results.training_logs)
|
| 334 |
+
|
| 335 |
+
# Add current metrics if available
|
| 336 |
+
if training_state.results.step > 0:
|
| 337 |
+
# Break if step is greater than total steps
|
| 338 |
+
if training_state.results.step > training_state.results.total_steps:
|
| 339 |
+
break
|
| 340 |
+
|
| 341 |
+
metrics = training_state.results.metrics
|
| 342 |
+
status_line = (
|
| 343 |
+
f"\nCurrent: Step {training_state.results.step}/{training_state.results.total_steps} | "
|
| 344 |
+
f"PSNR: {metrics['psnr']:.2f} | SSIM: {metrics['ssim']:.4f} | "
|
| 345 |
+
f"Loss: {metrics['loss']:.4f}"
|
| 346 |
+
)
|
| 347 |
+
logs_text += status_line
|
| 348 |
+
|
| 349 |
+
# Add image status info for debugging
|
| 350 |
+
if training_state.results.current_render is not None:
|
| 351 |
+
logs_text += f"\n📸 Current render: {training_state.results.current_render.size}"
|
| 352 |
+
else:
|
| 353 |
+
logs_text += "\n📸 Current render: None"
|
| 354 |
+
|
| 355 |
+
if training_state.results.current_gaussian_id is not None:
|
| 356 |
+
logs_text += f"\n🆔 Gaussian ID: {training_state.results.current_gaussian_id.size}"
|
| 357 |
+
else:
|
| 358 |
+
logs_text += "\n🆔 Gaussian ID: None"
|
| 359 |
+
|
| 360 |
+
logs_text += (
|
| 361 |
+
f"\n💾 Stored steps: {len(training_state.results.step_renders)}"
|
| 362 |
+
)
|
| 363 |
+
else:
|
| 364 |
+
logs_text = "Waiting for training to start..."
|
| 365 |
+
|
| 366 |
+
# Get current images
|
| 367 |
+
initialization_map = training_state.results.initialization_map
|
| 368 |
+
current_render = training_state.results.current_render
|
| 369 |
+
current_gaussian_id = training_state.results.current_gaussian_id
|
| 370 |
+
|
| 371 |
+
# Simple status based on training state
|
| 372 |
+
current_step = training_state.results.step
|
| 373 |
+
if training_state.results.is_complete:
|
| 374 |
+
status = "✅ Training completed successfully!"
|
| 375 |
+
start_btn_interactive = True
|
| 376 |
+
stop_btn_interactive = False
|
| 377 |
+
elif not training_state.is_training:
|
| 378 |
+
status = "⏹️ Training stopped."
|
| 379 |
+
start_btn_interactive = True
|
| 380 |
+
stop_btn_interactive = False
|
| 381 |
+
else:
|
| 382 |
+
status = f"🔄 Training in progress... Step {current_step}/{training_state.results.total_steps}"
|
| 383 |
+
start_btn_interactive = False
|
| 384 |
+
stop_btn_interactive = True
|
| 385 |
+
|
| 386 |
+
# Always yield, even if images haven't changed
|
| 387 |
+
yield (
|
| 388 |
+
status,
|
| 389 |
+
logs_text,
|
| 390 |
+
initialization_map,
|
| 391 |
+
current_render,
|
| 392 |
+
current_gaussian_id,
|
| 393 |
+
start_btn_interactive,
|
| 394 |
+
stop_btn_interactive,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# Stop if training is complete
|
| 398 |
+
if training_state.results.is_complete or not training_state.is_training:
|
| 399 |
+
break
|
| 400 |
+
if current_step > training_state.results.total_steps:
|
| 401 |
+
break
|
| 402 |
+
|
| 403 |
+
time.sleep(0.5) # Update more frequently for better responsiveness
|
| 404 |
+
|
| 405 |
+
except Exception as e:
|
| 406 |
+
training_state.reset()
|
| 407 |
+
yield (
|
| 408 |
+
f"Failed to start training: {str(e)}",
|
| 409 |
+
"",
|
| 410 |
+
None,
|
| 411 |
+
None,
|
| 412 |
+
None,
|
| 413 |
+
True, # start_btn enabled
|
| 414 |
+
False, # stop_btn disabled
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def stop_training() -> str:
|
| 419 |
+
"""Stop the current training"""
|
| 420 |
+
if not training_state.is_training:
|
| 421 |
+
return "No training in progress."
|
| 422 |
+
|
| 423 |
+
training_state.is_training = False
|
| 424 |
+
training_state.results.training_logs.append(
|
| 425 |
+
"🛑 STOP: Training stop requested by user..."
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
# Set a flag in the model to stop training
|
| 429 |
+
if training_state.model:
|
| 430 |
+
training_state.model.stop_requested = True
|
| 431 |
+
|
| 432 |
+
return "Training stop requested. Will complete current step and stop."
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def get_final_results() -> Tuple[Optional[Image.Image], Optional[str]]:
|
| 436 |
+
"""Get final training results"""
|
| 437 |
+
final_render = training_state.results.final_render
|
| 438 |
+
checkpoint_path = training_state.results.final_checkpoint_path
|
| 439 |
+
return final_render, checkpoint_path
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def browse_step_results(
|
| 443 |
+
step: int,
|
| 444 |
+
) -> Tuple[Optional[Image.Image], Optional[Image.Image]]:
|
| 445 |
+
"""Browse results from a specific training step"""
|
| 446 |
+
if not training_state.results.is_complete:
|
| 447 |
+
return None, None
|
| 448 |
+
|
| 449 |
+
# Find the closest available step
|
| 450 |
+
available_steps = list(training_state.results.step_renders.keys())
|
| 451 |
+
if not available_steps:
|
| 452 |
+
return None, None
|
| 453 |
+
|
| 454 |
+
closest_step = min(available_steps, key=lambda x: abs(x - step))
|
| 455 |
+
|
| 456 |
+
render_img = training_state.results.step_renders.get(closest_step)
|
| 457 |
+
gaussian_id_img = training_state.results.step_gaussian_ids.get(closest_step)
|
| 458 |
+
|
| 459 |
+
return render_img, gaussian_id_img
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def update_step_slider_after_training() -> gr.Slider:
|
| 463 |
+
"""Update step slider range and enable it after training completes"""
|
| 464 |
+
if not training_state.results.is_complete:
|
| 465 |
+
return gr.Slider(
|
| 466 |
+
minimum=0,
|
| 467 |
+
maximum=10000,
|
| 468 |
+
value=0,
|
| 469 |
+
step=100,
|
| 470 |
+
label="Browse Training Steps",
|
| 471 |
+
info="Training not complete yet",
|
| 472 |
+
interactive=False,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
available_steps = list(training_state.results.step_renders.keys())
|
| 476 |
+
if not available_steps:
|
| 477 |
+
return gr.Slider(
|
| 478 |
+
minimum=0,
|
| 479 |
+
maximum=10000,
|
| 480 |
+
value=0,
|
| 481 |
+
step=100,
|
| 482 |
+
label="Browse Training Steps",
|
| 483 |
+
info="No training steps available",
|
| 484 |
+
interactive=False,
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
max_step = max(available_steps)
|
| 488 |
+
min_step = min(available_steps)
|
| 489 |
+
# Use the step size from save_image_steps if available, otherwise use difference between steps
|
| 490 |
+
if len(available_steps) > 1:
|
| 491 |
+
step_size = available_steps[1] - available_steps[0]
|
| 492 |
+
else:
|
| 493 |
+
step_size = 100
|
| 494 |
+
|
| 495 |
+
return gr.Slider(
|
| 496 |
+
minimum=min_step,
|
| 497 |
+
maximum=max_step,
|
| 498 |
+
value=max_step,
|
| 499 |
+
step=step_size,
|
| 500 |
+
label="Browse Training Steps",
|
| 501 |
+
info=f"Browse results from steps {min_step}-{max_step} (interactive)",
|
| 502 |
+
interactive=True,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def create_interface():
|
| 507 |
+
"""Create the Gradio interface"""
|
| 508 |
+
|
| 509 |
+
with gr.Blocks(
|
| 510 |
+
title="Image-GS: 2D Gaussian Splatting", theme=gr.themes.Soft()
|
| 511 |
+
) as demo:
|
| 512 |
+
gr.Markdown("""
|
| 513 |
+
# Image-GS: Content-Adaptive Image Representation via 2D Gaussians
|
| 514 |
+
|
| 515 |
+
Upload an image and configure parameters to train a 2D Gaussian Splatting representation.
|
| 516 |
+
""")
|
| 517 |
+
|
| 518 |
+
with gr.Row():
|
| 519 |
+
with gr.Column(scale=1):
|
| 520 |
+
gr.Markdown("## Configuration")
|
| 521 |
+
|
| 522 |
+
# Image upload
|
| 523 |
+
image_input = gr.Image(
|
| 524 |
+
label="Input Image",
|
| 525 |
+
type="pil",
|
| 526 |
+
height=300,
|
| 527 |
+
sources=["upload"],
|
| 528 |
+
show_label=True,
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
# Basic parameters
|
| 532 |
+
with gr.Group():
|
| 533 |
+
gr.Markdown("### Basic Parameters")
|
| 534 |
+
exp_name = gr.Textbox(
|
| 535 |
+
label="Experiment Name",
|
| 536 |
+
value="gradio_experiment",
|
| 537 |
+
info="Name for this training run",
|
| 538 |
+
)
|
| 539 |
+
num_gaussians = gr.Slider(
|
| 540 |
+
minimum=1000,
|
| 541 |
+
maximum=50000,
|
| 542 |
+
value=10000,
|
| 543 |
+
step=1000,
|
| 544 |
+
label="Number of Gaussians",
|
| 545 |
+
info="Number of Gaussians (for compression rate control). More = higher quality but slower training",
|
| 546 |
+
)
|
| 547 |
+
max_steps = gr.Slider(
|
| 548 |
+
minimum=100,
|
| 549 |
+
maximum=20000,
|
| 550 |
+
value=10000,
|
| 551 |
+
step=100,
|
| 552 |
+
label="Maximum Training Steps",
|
| 553 |
+
info="Maximum number of optimization steps. Default: 10000",
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
# Quantization parameters
|
| 557 |
+
with gr.Group():
|
| 558 |
+
gr.Markdown("### Quantization")
|
| 559 |
+
quantize = gr.Checkbox(
|
| 560 |
+
label="Enable Quantization",
|
| 561 |
+
value=False,
|
| 562 |
+
info="Enable bit precision control of Gaussian parameters. Reduces memory usage.",
|
| 563 |
+
)
|
| 564 |
+
with gr.Row():
|
| 565 |
+
pos_bits = gr.Slider(
|
| 566 |
+
4,
|
| 567 |
+
32,
|
| 568 |
+
16,
|
| 569 |
+
step=1,
|
| 570 |
+
label="Position Bits",
|
| 571 |
+
info="Bit precision of individual coordinate dimension",
|
| 572 |
+
)
|
| 573 |
+
scale_bits = gr.Slider(
|
| 574 |
+
4,
|
| 575 |
+
32,
|
| 576 |
+
16,
|
| 577 |
+
step=1,
|
| 578 |
+
label="Scale Bits",
|
| 579 |
+
info="Bit precision of individual scale dimension",
|
| 580 |
+
)
|
| 581 |
+
with gr.Row():
|
| 582 |
+
rot_bits = gr.Slider(
|
| 583 |
+
4,
|
| 584 |
+
32,
|
| 585 |
+
16,
|
| 586 |
+
step=1,
|
| 587 |
+
label="Rotation Bits",
|
| 588 |
+
info="Bit precision of Gaussian orientation angle",
|
| 589 |
+
)
|
| 590 |
+
feat_bits = gr.Slider(
|
| 591 |
+
4,
|
| 592 |
+
32,
|
| 593 |
+
16,
|
| 594 |
+
step=1,
|
| 595 |
+
label="Feature Bits",
|
| 596 |
+
info="Bit precision of individual feature dimension",
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
# Initialization parameters
|
| 600 |
+
with gr.Group():
|
| 601 |
+
gr.Markdown("### Initialization")
|
| 602 |
+
init_mode = gr.Radio(
|
| 603 |
+
choices=["gradient", "saliency", "random"],
|
| 604 |
+
value="saliency",
|
| 605 |
+
label="Initialization Mode",
|
| 606 |
+
info="Gaussian position initialization mode. Gradient uses image gradients, saliency uses attention maps.",
|
| 607 |
+
)
|
| 608 |
+
init_random_ratio = gr.Slider(
|
| 609 |
+
minimum=0.0,
|
| 610 |
+
maximum=1.0,
|
| 611 |
+
value=0.3,
|
| 612 |
+
step=0.1,
|
| 613 |
+
label="Random Ratio",
|
| 614 |
+
info="Ratio of Gaussians with randomly initialized position (default: 0.3)",
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
# Advanced parameters (collapsible)
|
| 618 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
| 619 |
+
# Loss parameters
|
| 620 |
+
gr.Markdown("#### Loss Weights")
|
| 621 |
+
with gr.Row():
|
| 622 |
+
l1_loss_ratio = gr.Slider(
|
| 623 |
+
0.0, 2.0, 1.0, step=0.1, label="L1 Loss"
|
| 624 |
+
)
|
| 625 |
+
l2_loss_ratio = gr.Slider(
|
| 626 |
+
0.0, 2.0, 0.0, step=0.1, label="L2 Loss"
|
| 627 |
+
)
|
| 628 |
+
ssim_loss_ratio = gr.Slider(
|
| 629 |
+
0.0, 1.0, 0.1, step=0.01, label="SSIM Loss"
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
# Learning rates
|
| 633 |
+
gr.Markdown("#### Learning Rates")
|
| 634 |
+
with gr.Row():
|
| 635 |
+
pos_lr = gr.Number(value=5e-4, label="Position LR", precision=6)
|
| 636 |
+
scale_lr = gr.Number(value=2e-3, label="Scale LR", precision=6)
|
| 637 |
+
with gr.Row():
|
| 638 |
+
rot_lr = gr.Number(value=2e-3, label="Rotation LR", precision=6)
|
| 639 |
+
feat_lr = gr.Number(value=5e-3, label="Feature LR", precision=6)
|
| 640 |
+
|
| 641 |
+
# Optimization options
|
| 642 |
+
gr.Markdown("#### Optimization")
|
| 643 |
+
disable_lr_schedule = gr.Checkbox(
|
| 644 |
+
label="Disable LR Schedule",
|
| 645 |
+
value=False,
|
| 646 |
+
info="Keep learning rate constant",
|
| 647 |
+
)
|
| 648 |
+
disable_prog_optim = gr.Checkbox(
|
| 649 |
+
label="Disable Progressive Optimization",
|
| 650 |
+
value=False,
|
| 651 |
+
info="Don't add Gaussians during training",
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
# Visualization parameters
|
| 655 |
+
with gr.Group():
|
| 656 |
+
gr.Markdown("### Visualization")
|
| 657 |
+
vis_gaussians = gr.Checkbox(
|
| 658 |
+
label="Visualize Gaussians",
|
| 659 |
+
value=True,
|
| 660 |
+
info="Visualize Gaussians during optimization (default: True)",
|
| 661 |
+
)
|
| 662 |
+
save_image_steps = gr.Slider(
|
| 663 |
+
minimum=200,
|
| 664 |
+
maximum=10000,
|
| 665 |
+
value=200,
|
| 666 |
+
step=100,
|
| 667 |
+
label="Save Image Every N Steps",
|
| 668 |
+
info="Frequency of rendering intermediate results during optimization (default: 100)",
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
# Control buttons
|
| 672 |
+
with gr.Row():
|
| 673 |
+
start_btn = gr.Button(
|
| 674 |
+
"Start Training", variant="primary", size="lg"
|
| 675 |
+
)
|
| 676 |
+
stop_btn = gr.Button("Stop Training", variant="stop", size="lg")
|
| 677 |
+
|
| 678 |
+
status_text = gr.Textbox(label="Status", interactive=False, lines=2)
|
| 679 |
+
|
| 680 |
+
with gr.Column(scale=2):
|
| 681 |
+
gr.Markdown("## Training Progress")
|
| 682 |
+
|
| 683 |
+
# Progress logs (streaming)
|
| 684 |
+
progress_logs = gr.Textbox(
|
| 685 |
+
label="Training Logs",
|
| 686 |
+
lines=10,
|
| 687 |
+
max_lines=15,
|
| 688 |
+
interactive=False,
|
| 689 |
+
autoscroll=True,
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
# Initial map (computed at start based on initialization mode)
|
| 693 |
+
gr.Markdown("### Initialization Map")
|
| 694 |
+
initialization_map = gr.Image(
|
| 695 |
+
label="Initialization Map",
|
| 696 |
+
type="pil",
|
| 697 |
+
height=200,
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
# Training images (streaming)
|
| 701 |
+
gr.Markdown("### Current Training Results")
|
| 702 |
+
with gr.Row():
|
| 703 |
+
current_render = gr.Image(
|
| 704 |
+
label="Current Render",
|
| 705 |
+
type="pil",
|
| 706 |
+
height=300,
|
| 707 |
+
show_label=True,
|
| 708 |
+
show_download_button=True,
|
| 709 |
+
)
|
| 710 |
+
current_gaussian_id = gr.Image(
|
| 711 |
+
label="Gaussian ID",
|
| 712 |
+
type="pil",
|
| 713 |
+
height=300,
|
| 714 |
+
show_label=True,
|
| 715 |
+
show_download_button=True,
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
# Step slider for interactive browsing (will be updated dynamically)
|
| 719 |
+
step_slider = gr.Slider(
|
| 720 |
+
minimum=0,
|
| 721 |
+
maximum=10000,
|
| 722 |
+
value=0,
|
| 723 |
+
step=100,
|
| 724 |
+
label="Browse Training Steps",
|
| 725 |
+
info="Slide to view results from different training steps (disabled during training)",
|
| 726 |
+
interactive=False,
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
gr.Markdown("## Final Results")
|
| 730 |
+
with gr.Row():
|
| 731 |
+
final_render = gr.Image(
|
| 732 |
+
label="Final Render", type="pil", height=300
|
| 733 |
+
)
|
| 734 |
+
final_checkpoint = gr.File(label="Download Final Checkpoint (.pt)")
|
| 735 |
+
|
| 736 |
+
# Results buttons
|
| 737 |
+
with gr.Row():
|
| 738 |
+
results_btn = gr.Button("Load Final Results", size="lg")
|
| 739 |
+
enable_slider_btn = gr.Button(
|
| 740 |
+
"Enable Step Browsing", size="lg", variant="secondary"
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
# Event handlers
|
| 744 |
+
start_btn.click(
|
| 745 |
+
fn=start_training_and_stream,
|
| 746 |
+
inputs=[
|
| 747 |
+
image_input,
|
| 748 |
+
exp_name,
|
| 749 |
+
num_gaussians,
|
| 750 |
+
quantize,
|
| 751 |
+
pos_bits,
|
| 752 |
+
scale_bits,
|
| 753 |
+
rot_bits,
|
| 754 |
+
feat_bits,
|
| 755 |
+
init_mode,
|
| 756 |
+
init_random_ratio,
|
| 757 |
+
max_steps,
|
| 758 |
+
vis_gaussians,
|
| 759 |
+
save_image_steps,
|
| 760 |
+
l1_loss_ratio,
|
| 761 |
+
l2_loss_ratio,
|
| 762 |
+
ssim_loss_ratio,
|
| 763 |
+
pos_lr,
|
| 764 |
+
scale_lr,
|
| 765 |
+
rot_lr,
|
| 766 |
+
feat_lr,
|
| 767 |
+
disable_lr_schedule,
|
| 768 |
+
disable_prog_optim,
|
| 769 |
+
],
|
| 770 |
+
outputs=[
|
| 771 |
+
status_text,
|
| 772 |
+
progress_logs,
|
| 773 |
+
initialization_map,
|
| 774 |
+
current_render,
|
| 775 |
+
current_gaussian_id,
|
| 776 |
+
start_btn,
|
| 777 |
+
stop_btn,
|
| 778 |
+
],
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
stop_btn.click(fn=stop_training, outputs=status_text)
|
| 782 |
+
|
| 783 |
+
results_btn.click(
|
| 784 |
+
fn=get_final_results, outputs=[final_render, final_checkpoint]
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
enable_slider_btn.click(
|
| 788 |
+
fn=update_step_slider_after_training, outputs=[step_slider]
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
step_slider.change(
|
| 792 |
+
fn=browse_step_results,
|
| 793 |
+
inputs=[step_slider],
|
| 794 |
+
outputs=[current_render, current_gaussian_id],
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
return demo
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
if __name__ == "__main__":
|
| 801 |
+
# Ensure model files are available (download from HF if needed)
|
| 802 |
+
ensure_models_available()
|
| 803 |
+
|
| 804 |
+
# Set torch hub directory
|
| 805 |
+
torch.hub.set_dir("models/torch")
|
| 806 |
+
|
| 807 |
+
# Create and launch the interface
|
| 808 |
+
demo = create_interface()
|
| 809 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)
|
gradio_models.py
ADDED
|
@@ -0,0 +1,827 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import threading
|
| 5 |
+
from time import perf_counter
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from fused_ssim import fused_ssim
|
| 12 |
+
from torchvision.transforms.functional import gaussian_blur
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
from gsplat import (
|
| 16 |
+
project_gaussians_2d_scale_rot,
|
| 17 |
+
rasterize_gaussians_no_tiles,
|
| 18 |
+
rasterize_gaussians_sum,
|
| 19 |
+
)
|
| 20 |
+
from utils.image_utils import (
|
| 21 |
+
compute_image_gradients,
|
| 22 |
+
get_grid,
|
| 23 |
+
get_psnr,
|
| 24 |
+
load_images,
|
| 25 |
+
to_output_format,
|
| 26 |
+
)
|
| 27 |
+
from utils.misc_utils import set_random_seed
|
| 28 |
+
from utils.quantization_utils import ste_quantize
|
| 29 |
+
from utils.saliency_utils import get_smap
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class StreamingResults:
|
| 33 |
+
"""Container for streaming training results"""
|
| 34 |
+
|
| 35 |
+
def __init__(self):
|
| 36 |
+
self.step = 0
|
| 37 |
+
self.total_steps = 0
|
| 38 |
+
self.current_render = None
|
| 39 |
+
self.current_gaussian_id = None
|
| 40 |
+
self.initialization_map = None # Single map for current initialization mode
|
| 41 |
+
self.final_render = None
|
| 42 |
+
self.final_checkpoint_path = None
|
| 43 |
+
self.training_logs = []
|
| 44 |
+
self.metrics = {
|
| 45 |
+
"psnr": 0.0,
|
| 46 |
+
"ssim": 0.0,
|
| 47 |
+
"loss": 0.0,
|
| 48 |
+
"render_time": 0.0,
|
| 49 |
+
"total_time": 0.0,
|
| 50 |
+
}
|
| 51 |
+
self.is_complete = False
|
| 52 |
+
# Store all step results for interactive browsing
|
| 53 |
+
self.step_renders = {} # {step: PIL_Image}
|
| 54 |
+
self.step_gaussian_ids = {} # {step: PIL_Image}
|
| 55 |
+
# For async visualization generation
|
| 56 |
+
self.vis_lock = threading.Lock()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class GradioStreamingHandler(logging.Handler):
|
| 60 |
+
"""Custom logging handler that captures logs for Gradio streaming"""
|
| 61 |
+
|
| 62 |
+
def __init__(self, results_container: StreamingResults):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.results = results_container
|
| 65 |
+
|
| 66 |
+
def emit(self, record):
|
| 67 |
+
log_entry = self.format(record)
|
| 68 |
+
self.results.training_logs.append(log_entry)
|
| 69 |
+
# Keep only last 100 log entries to avoid memory issues
|
| 70 |
+
if len(self.results.training_logs) > 100:
|
| 71 |
+
self.results.training_logs = self.results.training_logs[-100:]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class GradioGaussianSplatting2D(nn.Module):
|
| 75 |
+
"""Gradio-optimized version of GaussianSplatting2D with streaming capabilities"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, args, results_container: StreamingResults):
|
| 78 |
+
super(GradioGaussianSplatting2D, self).__init__()
|
| 79 |
+
self.results = results_container
|
| 80 |
+
self.evaluate = args.eval
|
| 81 |
+
set_random_seed(seed=args.seed)
|
| 82 |
+
|
| 83 |
+
# Device setup
|
| 84 |
+
if torch.cuda.is_available():
|
| 85 |
+
torch.cuda.set_device(0)
|
| 86 |
+
self.device = torch.device("cuda:0")
|
| 87 |
+
else:
|
| 88 |
+
self.device = torch.device("cpu")
|
| 89 |
+
self.dtype = torch.float32
|
| 90 |
+
|
| 91 |
+
# Initialize components
|
| 92 |
+
self._init_logging(args)
|
| 93 |
+
self._init_target(args)
|
| 94 |
+
self._init_bit_precision(args)
|
| 95 |
+
self._init_gaussians(args)
|
| 96 |
+
self._init_loss(args)
|
| 97 |
+
self._init_optimization(args)
|
| 98 |
+
|
| 99 |
+
# Initialization
|
| 100 |
+
if self.evaluate:
|
| 101 |
+
self.ckpt_file = args.ckpt_file
|
| 102 |
+
self._load_model()
|
| 103 |
+
else:
|
| 104 |
+
self._init_pos_scale_feat(args)
|
| 105 |
+
|
| 106 |
+
def _init_logging(self, args):
|
| 107 |
+
self.log_dir = getattr(args, "log_dir", "temp_gradio_logs")
|
| 108 |
+
self.vis_gaussians = args.vis_gaussians
|
| 109 |
+
self.save_image_steps = args.save_image_steps
|
| 110 |
+
self.eval_steps = args.eval_steps
|
| 111 |
+
|
| 112 |
+
# Set up streaming logger
|
| 113 |
+
self.worklog = logging.getLogger("GradioImageGS")
|
| 114 |
+
self.worklog.handlers.clear() # Remove existing handlers
|
| 115 |
+
|
| 116 |
+
# Add our streaming handler
|
| 117 |
+
stream_handler = GradioStreamingHandler(self.results)
|
| 118 |
+
stream_handler.setFormatter(
|
| 119 |
+
logging.Formatter(fmt="[{asctime}] {message}", style="{")
|
| 120 |
+
)
|
| 121 |
+
self.worklog.addHandler(stream_handler)
|
| 122 |
+
self.worklog.setLevel(logging.INFO)
|
| 123 |
+
|
| 124 |
+
self.worklog.info(
|
| 125 |
+
f"Start optimizing {args.num_gaussians:d} Gaussians for '{args.input_path}'"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def _init_target(self, args):
|
| 129 |
+
self.gamma = args.gamma
|
| 130 |
+
self.downsample = args.downsample
|
| 131 |
+
if self.downsample:
|
| 132 |
+
self.downsample_ratio = float(args.downsample_ratio)
|
| 133 |
+
|
| 134 |
+
self.block_h, self.block_w = 16, 16
|
| 135 |
+
self._load_target_images(path=os.path.join(args.data_root, args.input_path))
|
| 136 |
+
|
| 137 |
+
if self.downsample:
|
| 138 |
+
self.gt_images_upsampled = self.gt_images
|
| 139 |
+
self.img_h_upsampled, self.img_w_upsampled = self.img_h, self.img_w
|
| 140 |
+
self.tile_bounds_upsampled = self.tile_bounds
|
| 141 |
+
self._load_target_images(
|
| 142 |
+
path=os.path.join(args.data_root, args.input_path),
|
| 143 |
+
downsample_ratio=self.downsample_ratio,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
self.num_pixels = self.img_h * self.img_w
|
| 147 |
+
|
| 148 |
+
def _load_target_images(self, path, downsample_ratio=None):
|
| 149 |
+
self.gt_images, self.input_channels, self.image_fnames = load_images(
|
| 150 |
+
load_path=path, downsample_ratio=downsample_ratio, gamma=self.gamma
|
| 151 |
+
)
|
| 152 |
+
self.gt_images = torch.from_numpy(self.gt_images).to(
|
| 153 |
+
dtype=self.dtype, device=self.device
|
| 154 |
+
)
|
| 155 |
+
self.img_h, self.img_w = self.gt_images.shape[1:]
|
| 156 |
+
self.tile_bounds = (
|
| 157 |
+
(self.img_w + self.block_w - 1) // self.block_w,
|
| 158 |
+
(self.img_h + self.block_h - 1) // self.block_h,
|
| 159 |
+
1,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def _init_bit_precision(self, args):
|
| 163 |
+
self.quantize = args.quantize
|
| 164 |
+
self.pos_bits = args.pos_bits
|
| 165 |
+
self.scale_bits = args.scale_bits
|
| 166 |
+
self.rot_bits = args.rot_bits
|
| 167 |
+
self.feat_bits = args.feat_bits
|
| 168 |
+
|
| 169 |
+
def _init_gaussians(self, args):
|
| 170 |
+
self.num_gaussians = args.num_gaussians
|
| 171 |
+
self.total_num_gaussians = args.num_gaussians
|
| 172 |
+
self.disable_prog_optim = args.disable_prog_optim
|
| 173 |
+
|
| 174 |
+
if not self.disable_prog_optim and not self.evaluate:
|
| 175 |
+
self.initial_ratio = args.initial_ratio
|
| 176 |
+
self.add_times = args.add_times
|
| 177 |
+
self.add_steps = args.add_steps
|
| 178 |
+
self.num_gaussians = math.ceil(
|
| 179 |
+
self.initial_ratio * self.total_num_gaussians
|
| 180 |
+
)
|
| 181 |
+
self.max_add_num = math.ceil(
|
| 182 |
+
float(self.total_num_gaussians - self.num_gaussians) / self.add_times
|
| 183 |
+
)
|
| 184 |
+
min_steps = self.add_steps * self.add_times + args.post_min_steps
|
| 185 |
+
if args.max_steps < min_steps:
|
| 186 |
+
self.worklog.info(
|
| 187 |
+
f"Max steps ({args.max_steps:d}) is too small for progressive optimization. Resetting to {min_steps:d}"
|
| 188 |
+
)
|
| 189 |
+
args.max_steps = min_steps
|
| 190 |
+
|
| 191 |
+
self.topk = args.topk
|
| 192 |
+
self.eps = 1e-7 if args.disable_tiles else 1e-4
|
| 193 |
+
self.init_scale = args.init_scale
|
| 194 |
+
self.disable_topk_norm = args.disable_topk_norm
|
| 195 |
+
self.disable_inverse_scale = args.disable_inverse_scale
|
| 196 |
+
self.disable_color_init = args.disable_color_init
|
| 197 |
+
|
| 198 |
+
# Initialize parameters
|
| 199 |
+
self.xy = nn.Parameter(
|
| 200 |
+
torch.rand(self.num_gaussians, 2, dtype=self.dtype, device=self.device),
|
| 201 |
+
requires_grad=True,
|
| 202 |
+
)
|
| 203 |
+
self.scale = nn.Parameter(
|
| 204 |
+
torch.ones(self.num_gaussians, 2, dtype=self.dtype, device=self.device),
|
| 205 |
+
requires_grad=True,
|
| 206 |
+
)
|
| 207 |
+
self.rot = nn.Parameter(
|
| 208 |
+
torch.zeros(self.num_gaussians, 1, dtype=self.dtype, device=self.device),
|
| 209 |
+
requires_grad=True,
|
| 210 |
+
)
|
| 211 |
+
self.feat_dim = sum(self.input_channels)
|
| 212 |
+
self.feat = nn.Parameter(
|
| 213 |
+
torch.rand(
|
| 214 |
+
self.num_gaussians, self.feat_dim, dtype=self.dtype, device=self.device
|
| 215 |
+
),
|
| 216 |
+
requires_grad=True,
|
| 217 |
+
)
|
| 218 |
+
self.vis_feat = nn.Parameter(torch.rand_like(self.feat), requires_grad=False)
|
| 219 |
+
|
| 220 |
+
self._log_compression_rate()
|
| 221 |
+
|
| 222 |
+
def _log_compression_rate(self):
|
| 223 |
+
bytes_uncompressed = float(self.gt_images.numel())
|
| 224 |
+
bpp_uncompressed = float(8 * self.feat_dim)
|
| 225 |
+
self.worklog.info(
|
| 226 |
+
f"Uncompressed: {bytes_uncompressed / 1e3:.2f} KB | {bpp_uncompressed:.3f} bpp | 8.0 bppc"
|
| 227 |
+
)
|
| 228 |
+
bits_compressed = (
|
| 229 |
+
2 * self.pos_bits
|
| 230 |
+
+ 2 * self.scale_bits
|
| 231 |
+
+ self.rot_bits
|
| 232 |
+
+ self.feat_dim * self.feat_bits
|
| 233 |
+
) * self.total_num_gaussians
|
| 234 |
+
bytes_compressed = bits_compressed / 8.0
|
| 235 |
+
bpp_compressed = float(bits_compressed) / self.num_pixels
|
| 236 |
+
bppc_compressed = bpp_compressed / self.feat_dim
|
| 237 |
+
self.num_bytes = bytes_compressed
|
| 238 |
+
self.worklog.info(
|
| 239 |
+
f"Compressed: {bytes_compressed / 1e3:.2f} KB | {bpp_compressed:.3f} bpp | {bppc_compressed:.3f} bppc"
|
| 240 |
+
)
|
| 241 |
+
self.worklog.info(
|
| 242 |
+
f"Compression rate: {bpp_uncompressed / bpp_compressed:.2f}x | {100.0 * bpp_compressed / bpp_uncompressed:.2f}%"
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
def _init_loss(self, args):
|
| 246 |
+
self.l1_loss_ratio = args.l1_loss_ratio
|
| 247 |
+
self.l2_loss_ratio = args.l2_loss_ratio
|
| 248 |
+
self.ssim_loss_ratio = args.ssim_loss_ratio
|
| 249 |
+
|
| 250 |
+
def _init_optimization(self, args):
|
| 251 |
+
self.disable_tiles = args.disable_tiles
|
| 252 |
+
self.start_step = 1
|
| 253 |
+
self.max_steps = args.max_steps
|
| 254 |
+
self.results.total_steps = (
|
| 255 |
+
args.max_steps
|
| 256 |
+
) # Set total steps for streaming progress
|
| 257 |
+
self.pos_lr = args.pos_lr
|
| 258 |
+
self.scale_lr = args.scale_lr
|
| 259 |
+
self.rot_lr = args.rot_lr
|
| 260 |
+
self.feat_lr = args.feat_lr
|
| 261 |
+
|
| 262 |
+
self.optimizer = torch.optim.Adam(
|
| 263 |
+
[
|
| 264 |
+
{"params": self.xy, "lr": self.pos_lr},
|
| 265 |
+
{"params": self.scale, "lr": self.scale_lr},
|
| 266 |
+
{"params": self.rot, "lr": self.rot_lr},
|
| 267 |
+
{"params": self.feat, "lr": self.feat_lr},
|
| 268 |
+
]
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
self.disable_lr_schedule = args.disable_lr_schedule
|
| 272 |
+
if not self.disable_lr_schedule:
|
| 273 |
+
self.decay_ratio = args.decay_ratio
|
| 274 |
+
self.check_decay_steps = args.check_decay_steps
|
| 275 |
+
self.max_decay_times = args.max_decay_times
|
| 276 |
+
self.decay_threshold = args.decay_threshold
|
| 277 |
+
|
| 278 |
+
def _init_pos_scale_feat(self, args):
|
| 279 |
+
self.init_mode = args.init_mode
|
| 280 |
+
self.init_random_ratio = args.init_random_ratio
|
| 281 |
+
self.pixel_xy = (
|
| 282 |
+
get_grid(h=self.img_h, w=self.img_w)
|
| 283 |
+
.to(dtype=self.dtype, device=self.device)
|
| 284 |
+
.reshape(-1, 2)
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
# Position initialization
|
| 289 |
+
if self.init_mode == "gradient":
|
| 290 |
+
self._compute_gmap()
|
| 291 |
+
self.xy.copy_(self._sample_pos(prob=self.image_gradients))
|
| 292 |
+
elif self.init_mode == "saliency":
|
| 293 |
+
self.smap_filter_size = args.smap_filter_size
|
| 294 |
+
self._compute_smap()
|
| 295 |
+
self.xy.copy_(self._sample_pos(prob=self.saliency))
|
| 296 |
+
else: # random mode
|
| 297 |
+
selected = np.random.choice(
|
| 298 |
+
self.num_pixels, self.num_gaussians, replace=False, p=None
|
| 299 |
+
)
|
| 300 |
+
self.xy.copy_(self.pixel_xy.detach().clone()[selected])
|
| 301 |
+
# For random mode, create a simple random noise pattern
|
| 302 |
+
if self.init_mode == "random":
|
| 303 |
+
random_pattern = np.random.rand(self.img_h, self.img_w)
|
| 304 |
+
self.results.initialization_map = Image.fromarray(
|
| 305 |
+
(random_pattern * 255).astype(np.uint8)
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# Scale initialization
|
| 309 |
+
self.scale.fill_(
|
| 310 |
+
self.init_scale if self.disable_inverse_scale else 1.0 / self.init_scale
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# Feature initialization
|
| 314 |
+
if not self.disable_color_init:
|
| 315 |
+
self.feat.copy_(
|
| 316 |
+
self._get_target_features(positions=self.xy).detach().clone()
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
def _sample_pos(self, prob):
|
| 320 |
+
num_random = round(self.init_random_ratio * self.num_gaussians)
|
| 321 |
+
selected_random = np.random.choice(
|
| 322 |
+
self.num_pixels, num_random, replace=False, p=None
|
| 323 |
+
)
|
| 324 |
+
selected_other = np.random.choice(
|
| 325 |
+
self.num_pixels, self.num_gaussians - num_random, replace=False, p=prob
|
| 326 |
+
)
|
| 327 |
+
return torch.cat(
|
| 328 |
+
[
|
| 329 |
+
self.pixel_xy.detach().clone()[selected_random],
|
| 330 |
+
self.pixel_xy.detach().clone()[selected_other],
|
| 331 |
+
],
|
| 332 |
+
dim=0,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def _compute_gmap(self):
|
| 336 |
+
gy, gx = compute_image_gradients(
|
| 337 |
+
np.power(self.gt_images.detach().cpu().clone().numpy(), 1.0 / self.gamma)
|
| 338 |
+
)
|
| 339 |
+
g_norm = np.hypot(gy, gx).astype(np.float32)
|
| 340 |
+
g_norm = g_norm / g_norm.max()
|
| 341 |
+
|
| 342 |
+
# Store gradient map for streaming (only if this is the selected initialization mode)
|
| 343 |
+
if self.init_mode == "gradient":
|
| 344 |
+
self.results.initialization_map = Image.fromarray(
|
| 345 |
+
(g_norm * 255).astype(np.uint8)
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
g_norm = np.power(g_norm.reshape(-1), 2.0)
|
| 349 |
+
self.image_gradients = g_norm / g_norm.sum()
|
| 350 |
+
self.worklog.info("Image gradient map computed")
|
| 351 |
+
|
| 352 |
+
def _compute_smap(self):
|
| 353 |
+
smap = get_smap(
|
| 354 |
+
torch.pow(self.gt_images.detach().clone(), 1.0 / self.gamma),
|
| 355 |
+
"models",
|
| 356 |
+
self.smap_filter_size,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
# Store saliency map for streaming (only if this is the selected initialization mode)
|
| 360 |
+
if self.init_mode == "saliency":
|
| 361 |
+
self.results.initialization_map = Image.fromarray(
|
| 362 |
+
(smap * 255).astype(np.uint8)
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
self.saliency = (smap / smap.sum()).reshape(-1)
|
| 366 |
+
self.worklog.info("Saliency map computed")
|
| 367 |
+
|
| 368 |
+
def _get_target_features(self, positions):
|
| 369 |
+
with torch.no_grad():
|
| 370 |
+
target_features = F.grid_sample(
|
| 371 |
+
self.gt_images.unsqueeze(0),
|
| 372 |
+
positions[None, None, ...] * 2.0 - 1.0,
|
| 373 |
+
align_corners=False,
|
| 374 |
+
)
|
| 375 |
+
target_features = target_features[0, :, 0, :].permute(1, 0)
|
| 376 |
+
return target_features
|
| 377 |
+
|
| 378 |
+
def forward(self, img_h, img_w, tile_bounds, upsample_ratio=None, benchmark=False):
|
| 379 |
+
scale = self._get_scale(upsample_ratio=upsample_ratio)
|
| 380 |
+
xy, rot, feat = self.xy, self.rot, self.feat
|
| 381 |
+
|
| 382 |
+
if self.quantize:
|
| 383 |
+
xy, scale, rot, feat = (
|
| 384 |
+
ste_quantize(xy, self.pos_bits),
|
| 385 |
+
ste_quantize(scale, self.scale_bits),
|
| 386 |
+
ste_quantize(rot, self.rot_bits),
|
| 387 |
+
ste_quantize(feat, self.feat_bits),
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
begin = perf_counter()
|
| 391 |
+
tmp = project_gaussians_2d_scale_rot(xy, scale, rot, img_h, img_w, tile_bounds)
|
| 392 |
+
xy, radii, conics, num_tiles_hit = tmp
|
| 393 |
+
|
| 394 |
+
if not self.disable_tiles:
|
| 395 |
+
enable_topk_norm = not self.disable_topk_norm
|
| 396 |
+
tmp = (
|
| 397 |
+
xy,
|
| 398 |
+
radii,
|
| 399 |
+
conics,
|
| 400 |
+
num_tiles_hit,
|
| 401 |
+
feat,
|
| 402 |
+
img_h,
|
| 403 |
+
img_w,
|
| 404 |
+
self.block_h,
|
| 405 |
+
self.block_w,
|
| 406 |
+
enable_topk_norm,
|
| 407 |
+
)
|
| 408 |
+
out_image = rasterize_gaussians_sum(*tmp)
|
| 409 |
+
else:
|
| 410 |
+
tmp = xy, conics, feat, img_h, img_w
|
| 411 |
+
out_image = rasterize_gaussians_no_tiles(*tmp)
|
| 412 |
+
|
| 413 |
+
render_time = perf_counter() - begin
|
| 414 |
+
|
| 415 |
+
if benchmark:
|
| 416 |
+
return render_time
|
| 417 |
+
|
| 418 |
+
out_image = (
|
| 419 |
+
out_image.view(-1, img_h, img_w, self.feat_dim)
|
| 420 |
+
.permute(0, 3, 1, 2)
|
| 421 |
+
.contiguous()
|
| 422 |
+
)
|
| 423 |
+
return out_image.squeeze(dim=0), render_time
|
| 424 |
+
|
| 425 |
+
def _get_scale(self, upsample_ratio=None):
|
| 426 |
+
scale = self.scale
|
| 427 |
+
if not self.disable_inverse_scale:
|
| 428 |
+
scale = 1.0 / scale
|
| 429 |
+
if upsample_ratio is not None:
|
| 430 |
+
scale = upsample_ratio * scale
|
| 431 |
+
return scale
|
| 432 |
+
|
| 433 |
+
def _tensor_to_pil_image(self, tensor_image):
|
| 434 |
+
"""Convert tensor image to PIL Image for streaming"""
|
| 435 |
+
if tensor_image is None:
|
| 436 |
+
return None
|
| 437 |
+
|
| 438 |
+
# Convert to numpy and apply gamma correction
|
| 439 |
+
image_np = (
|
| 440 |
+
torch.pow(torch.clamp(tensor_image, 0.0, 1.0), 1.0 / self.gamma)
|
| 441 |
+
.detach()
|
| 442 |
+
.cpu()
|
| 443 |
+
.numpy()
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Convert to uint8 format
|
| 447 |
+
image_formatted = to_output_format(image_np, gamma=None)
|
| 448 |
+
return Image.fromarray(image_formatted)
|
| 449 |
+
|
| 450 |
+
def _create_gaussian_id_visualization(self):
|
| 451 |
+
"""Create Gaussian ID visualization as PIL Image using rasterization with vis_feat"""
|
| 452 |
+
if not self.vis_gaussians:
|
| 453 |
+
return None
|
| 454 |
+
|
| 455 |
+
try:
|
| 456 |
+
# Use vis_feat for ID visualization (this creates unique colors per Gaussian)
|
| 457 |
+
feat = self.vis_feat * self.feat.norm(dim=-1, keepdim=True)
|
| 458 |
+
|
| 459 |
+
# Render with ID features
|
| 460 |
+
scale = self._get_scale()
|
| 461 |
+
xy, rot = self.xy, self.rot
|
| 462 |
+
|
| 463 |
+
if self.quantize:
|
| 464 |
+
xy, scale, rot, feat = (
|
| 465 |
+
ste_quantize(xy, self.pos_bits),
|
| 466 |
+
ste_quantize(scale, self.scale_bits),
|
| 467 |
+
ste_quantize(rot, self.rot_bits),
|
| 468 |
+
ste_quantize(feat, self.feat_bits),
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
tmp = project_gaussians_2d_scale_rot(
|
| 472 |
+
xy, scale, rot, self.img_h, self.img_w, self.tile_bounds
|
| 473 |
+
)
|
| 474 |
+
xy, radii, conics, num_tiles_hit = tmp
|
| 475 |
+
|
| 476 |
+
if not self.disable_tiles:
|
| 477 |
+
enable_topk_norm = not self.disable_topk_norm
|
| 478 |
+
tmp = (
|
| 479 |
+
xy,
|
| 480 |
+
radii,
|
| 481 |
+
conics,
|
| 482 |
+
num_tiles_hit,
|
| 483 |
+
feat,
|
| 484 |
+
self.img_h,
|
| 485 |
+
self.img_w,
|
| 486 |
+
self.block_h,
|
| 487 |
+
self.block_w,
|
| 488 |
+
enable_topk_norm,
|
| 489 |
+
)
|
| 490 |
+
out_image = rasterize_gaussians_sum(*tmp)
|
| 491 |
+
else:
|
| 492 |
+
tmp = xy, conics, feat, self.img_h, self.img_w
|
| 493 |
+
out_image = rasterize_gaussians_no_tiles(*tmp)
|
| 494 |
+
|
| 495 |
+
out_image = (
|
| 496 |
+
out_image.view(-1, self.img_h, self.img_w, self.feat_dim)
|
| 497 |
+
.permute(0, 3, 1, 2)
|
| 498 |
+
.contiguous()
|
| 499 |
+
).squeeze(dim=0)
|
| 500 |
+
|
| 501 |
+
return self._tensor_to_pil_image(out_image)
|
| 502 |
+
|
| 503 |
+
except Exception as e:
|
| 504 |
+
self.worklog.error(f"Error creating Gaussian ID visualization: {e}")
|
| 505 |
+
return None
|
| 506 |
+
|
| 507 |
+
def optimize(self):
|
| 508 |
+
"""Main optimization loop with streaming updates"""
|
| 509 |
+
self.psnr_curr, self.ssim_curr = 0.0, 0.0
|
| 510 |
+
self.best_psnr, self.best_ssim = 0.0, 0.0
|
| 511 |
+
self.decay_times, self.no_improvement_steps = 0, 0
|
| 512 |
+
self.render_time_accum, self.total_time_accum = 0.0, 0.0
|
| 513 |
+
|
| 514 |
+
# Initialize attributes needed for evaluation
|
| 515 |
+
self.l1_loss = None
|
| 516 |
+
self.l2_loss = None
|
| 517 |
+
self.ssim_loss = None
|
| 518 |
+
self.stop_requested = False
|
| 519 |
+
|
| 520 |
+
# Initial render and update
|
| 521 |
+
with torch.no_grad():
|
| 522 |
+
images, _ = self.forward(self.img_h, self.img_w, self.tile_bounds)
|
| 523 |
+
self.results.current_render = self._tensor_to_pil_image(images)
|
| 524 |
+
if self.vis_gaussians:
|
| 525 |
+
try:
|
| 526 |
+
self.results.current_gaussian_id = (
|
| 527 |
+
self._create_gaussian_id_visualization()
|
| 528 |
+
)
|
| 529 |
+
self.worklog.info(
|
| 530 |
+
f"Initial visualizations created - Render: {'✓' if self.results.current_render else '✗'}, ID: {'✓' if self.results.current_gaussian_id else '✗'}"
|
| 531 |
+
)
|
| 532 |
+
except Exception as e:
|
| 533 |
+
self.worklog.error(f"Error creating initial visualizations: {e}")
|
| 534 |
+
self.results.current_gaussian_id = None
|
| 535 |
+
|
| 536 |
+
# Store initial results (step 0)
|
| 537 |
+
self.results.step_renders[0] = self.results.current_render
|
| 538 |
+
if self.vis_gaussians:
|
| 539 |
+
self.results.step_gaussian_ids[0] = self.results.current_gaussian_id
|
| 540 |
+
|
| 541 |
+
for step in range(self.start_step, self.max_steps + 1):
|
| 542 |
+
self.step = step
|
| 543 |
+
self.results.step = step
|
| 544 |
+
|
| 545 |
+
self.optimizer.zero_grad()
|
| 546 |
+
|
| 547 |
+
# Forward pass
|
| 548 |
+
images, render_time = self.forward(self.img_h, self.img_w, self.tile_bounds)
|
| 549 |
+
self.render_time_accum += render_time
|
| 550 |
+
|
| 551 |
+
# Compute loss
|
| 552 |
+
begin = perf_counter()
|
| 553 |
+
self._get_total_loss(images)
|
| 554 |
+
self.total_loss.backward()
|
| 555 |
+
self.optimizer.step()
|
| 556 |
+
self.total_time_accum += perf_counter() - begin + render_time
|
| 557 |
+
|
| 558 |
+
# Update streaming results
|
| 559 |
+
with torch.no_grad():
|
| 560 |
+
if step % self.eval_steps == 0:
|
| 561 |
+
self._evaluate_and_update_stream(images)
|
| 562 |
+
|
| 563 |
+
# Update render image more frequently, but visualizations less frequently
|
| 564 |
+
render_update_freq = max(
|
| 565 |
+
50, self.save_image_steps // 2
|
| 566 |
+
) # Render updates every 50 steps
|
| 567 |
+
vis_update_freq = max(
|
| 568 |
+
200, self.save_image_steps
|
| 569 |
+
) # Visualizations every 200 steps
|
| 570 |
+
|
| 571 |
+
if step % render_update_freq == 0:
|
| 572 |
+
render_img = self._tensor_to_pil_image(images)
|
| 573 |
+
self.results.current_render = render_img
|
| 574 |
+
|
| 575 |
+
# Only update Gaussian ID visualization less frequently
|
| 576 |
+
if step % vis_update_freq == 0 and self.vis_gaussians:
|
| 577 |
+
# Generate Gaussian ID visualization asynchronously
|
| 578 |
+
def generate_gaussian_id_async():
|
| 579 |
+
try:
|
| 580 |
+
with self.results.vis_lock:
|
| 581 |
+
gaussian_id_vis = (
|
| 582 |
+
self._create_gaussian_id_visualization()
|
| 583 |
+
)
|
| 584 |
+
self.results.current_gaussian_id = gaussian_id_vis
|
| 585 |
+
|
| 586 |
+
except Exception as e:
|
| 587 |
+
self.worklog.error(
|
| 588 |
+
f"Error creating Gaussian ID visualization at step {step}: {e}"
|
| 589 |
+
)
|
| 590 |
+
with self.results.vis_lock:
|
| 591 |
+
self.results.current_gaussian_id = None
|
| 592 |
+
|
| 593 |
+
# Start async visualization generation
|
| 594 |
+
vis_thread = threading.Thread(target=generate_gaussian_id_async)
|
| 595 |
+
vis_thread.daemon = True
|
| 596 |
+
vis_thread.start()
|
| 597 |
+
|
| 598 |
+
# Store results for interactive browsing only at save_image_steps intervals
|
| 599 |
+
if step % self.save_image_steps == 0:
|
| 600 |
+
# Store the current render for browsing
|
| 601 |
+
if self.results.current_render:
|
| 602 |
+
self.results.step_renders[step] = self.results.current_render
|
| 603 |
+
|
| 604 |
+
# Store Gaussian ID visualization for browsing
|
| 605 |
+
if self.vis_gaussians and self.results.current_gaussian_id:
|
| 606 |
+
self.results.step_gaussian_ids[step] = (
|
| 607 |
+
self.results.current_gaussian_id
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
# Progressive optimization
|
| 611 |
+
if (
|
| 612 |
+
not self.disable_prog_optim
|
| 613 |
+
and step % self.add_steps == 0
|
| 614 |
+
and self.num_gaussians < self.total_num_gaussians
|
| 615 |
+
):
|
| 616 |
+
self._add_gaussians(self.max_add_num)
|
| 617 |
+
|
| 618 |
+
# Learning rate schedule
|
| 619 |
+
terminate = False
|
| 620 |
+
if (
|
| 621 |
+
not self.disable_lr_schedule
|
| 622 |
+
and self.num_gaussians == self.total_num_gaussians
|
| 623 |
+
and step % self.eval_steps == 0
|
| 624 |
+
):
|
| 625 |
+
terminate = self._lr_schedule()
|
| 626 |
+
|
| 627 |
+
if terminate or self.stop_requested:
|
| 628 |
+
if self.stop_requested:
|
| 629 |
+
self.worklog.info("Training stopped by user request")
|
| 630 |
+
break
|
| 631 |
+
|
| 632 |
+
# Final updates
|
| 633 |
+
with torch.no_grad():
|
| 634 |
+
images, _ = self.forward(self.img_h, self.img_w, self.tile_bounds)
|
| 635 |
+
self.results.final_render = self._tensor_to_pil_image(images)
|
| 636 |
+
|
| 637 |
+
# Save final checkpoint and store path
|
| 638 |
+
self._save_final_checkpoint()
|
| 639 |
+
|
| 640 |
+
self.results.is_complete = True
|
| 641 |
+
self.worklog.info("Optimization completed")
|
| 642 |
+
|
| 643 |
+
def _get_total_loss(self, images):
|
| 644 |
+
self.total_loss = 0
|
| 645 |
+
|
| 646 |
+
if self.l1_loss_ratio > 1e-7:
|
| 647 |
+
self.l1_loss = self.l1_loss_ratio * F.l1_loss(images, self.gt_images)
|
| 648 |
+
self.total_loss += self.l1_loss
|
| 649 |
+
else:
|
| 650 |
+
self.l1_loss = None
|
| 651 |
+
|
| 652 |
+
if self.l2_loss_ratio > 1e-7:
|
| 653 |
+
self.l2_loss = self.l2_loss_ratio * F.mse_loss(images, self.gt_images)
|
| 654 |
+
self.total_loss += self.l2_loss
|
| 655 |
+
else:
|
| 656 |
+
self.l2_loss = None
|
| 657 |
+
|
| 658 |
+
if self.ssim_loss_ratio > 1e-7:
|
| 659 |
+
self.ssim_loss = self.ssim_loss_ratio * (
|
| 660 |
+
1 - fused_ssim(images.unsqueeze(0), self.gt_images.unsqueeze(0))
|
| 661 |
+
)
|
| 662 |
+
self.total_loss += self.ssim_loss
|
| 663 |
+
else:
|
| 664 |
+
self.ssim_loss = None
|
| 665 |
+
|
| 666 |
+
def _evaluate_and_update_stream(self, images):
|
| 667 |
+
"""Evaluate current state and update streaming results"""
|
| 668 |
+
gamma_corrected_images = torch.pow(
|
| 669 |
+
torch.clamp(images, 0.0, 1.0), 1.0 / self.gamma
|
| 670 |
+
)
|
| 671 |
+
gamma_corrected_gt = torch.pow(self.gt_images, 1.0 / self.gamma)
|
| 672 |
+
|
| 673 |
+
psnr = get_psnr(gamma_corrected_images, gamma_corrected_gt).item()
|
| 674 |
+
ssim = fused_ssim(
|
| 675 |
+
gamma_corrected_images.unsqueeze(0), gamma_corrected_gt.unsqueeze(0)
|
| 676 |
+
).item()
|
| 677 |
+
|
| 678 |
+
self.psnr_curr, self.ssim_curr = psnr, ssim
|
| 679 |
+
|
| 680 |
+
# Update metrics
|
| 681 |
+
self.results.metrics.update(
|
| 682 |
+
{
|
| 683 |
+
"psnr": psnr,
|
| 684 |
+
"ssim": ssim,
|
| 685 |
+
"loss": self.total_loss.item(),
|
| 686 |
+
"render_time": self.render_time_accum,
|
| 687 |
+
"total_time": self.total_time_accum,
|
| 688 |
+
}
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
# Log progress
|
| 692 |
+
loss_results = f"Loss: {self.total_loss.item():.4f}"
|
| 693 |
+
if self.l1_loss is not None:
|
| 694 |
+
loss_results += f", L1: {self.l1_loss.item():.4f}"
|
| 695 |
+
if self.l2_loss is not None:
|
| 696 |
+
loss_results += f", L2: {self.l2_loss.item():.4f}"
|
| 697 |
+
if self.ssim_loss is not None:
|
| 698 |
+
loss_results += f", SSIM: {self.ssim_loss.item():.4f}"
|
| 699 |
+
|
| 700 |
+
time_results = f"Total: {self.total_time_accum:.2f} s | Render: {self.render_time_accum:.2f} s"
|
| 701 |
+
|
| 702 |
+
self.worklog.info(
|
| 703 |
+
f"Step: {self.step:d} | {time_results} | {loss_results} | PSNR: {psnr:.2f} | SSIM: {ssim:.4f}"
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
def _save_final_checkpoint(self):
|
| 707 |
+
"""Save final checkpoint and store the path"""
|
| 708 |
+
if self.quantize:
|
| 709 |
+
with torch.no_grad():
|
| 710 |
+
self.xy.copy_(ste_quantize(self.xy, self.pos_bits))
|
| 711 |
+
self.scale.copy_(ste_quantize(self.scale, self.scale_bits))
|
| 712 |
+
self.rot.copy_(ste_quantize(self.rot, self.rot_bits))
|
| 713 |
+
self.feat.copy_(ste_quantize(self.feat, self.feat_bits))
|
| 714 |
+
|
| 715 |
+
# Create checkpoint directory
|
| 716 |
+
ckpt_dir = os.path.join(self.log_dir, "checkpoints")
|
| 717 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 718 |
+
|
| 719 |
+
psnr = self.results.metrics.get("psnr", 0.0)
|
| 720 |
+
ssim = self.results.metrics.get("ssim", 0.0)
|
| 721 |
+
|
| 722 |
+
ckpt_data = {
|
| 723 |
+
"step": self.step,
|
| 724 |
+
"psnr": psnr,
|
| 725 |
+
"ssim": ssim,
|
| 726 |
+
"bytes": getattr(self, "num_bytes", 0),
|
| 727 |
+
"time": self.total_time_accum,
|
| 728 |
+
"state_dict": self.state_dict(),
|
| 729 |
+
"optim_state_dict": self.optimizer.state_dict(),
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
ckpt_path = os.path.join(ckpt_dir, f"ckpt_step-{self.step:d}.pt")
|
| 733 |
+
torch.save(ckpt_data, ckpt_path)
|
| 734 |
+
self.results.final_checkpoint_path = ckpt_path
|
| 735 |
+
|
| 736 |
+
self.worklog.info(f"Final checkpoint saved: {ckpt_path}")
|
| 737 |
+
|
| 738 |
+
def _lr_schedule(self):
|
| 739 |
+
"""Learning rate scheduling logic"""
|
| 740 |
+
if (
|
| 741 |
+
self.psnr_curr <= self.best_psnr + 100 * self.decay_threshold
|
| 742 |
+
or self.ssim_curr <= self.best_ssim + self.decay_threshold
|
| 743 |
+
):
|
| 744 |
+
self.no_improvement_steps += self.eval_steps
|
| 745 |
+
if self.no_improvement_steps >= self.check_decay_steps:
|
| 746 |
+
self.no_improvement_steps = 0
|
| 747 |
+
self.decay_times += 1
|
| 748 |
+
if self.decay_times > self.max_decay_times:
|
| 749 |
+
return True
|
| 750 |
+
for param_group in self.optimizer.param_groups:
|
| 751 |
+
param_group["lr"] /= self.decay_ratio
|
| 752 |
+
self.worklog.info(f"Learning rate decayed by {self.decay_ratio:.1f}")
|
| 753 |
+
return False
|
| 754 |
+
else:
|
| 755 |
+
self.best_psnr = self.psnr_curr
|
| 756 |
+
self.best_ssim = self.ssim_curr
|
| 757 |
+
self.no_improvement_steps = 0
|
| 758 |
+
return False
|
| 759 |
+
|
| 760 |
+
def _add_gaussians(self, add_num):
|
| 761 |
+
"""Add Gaussians during progressive optimization"""
|
| 762 |
+
add_num = min(
|
| 763 |
+
add_num, self.max_add_num, self.total_num_gaussians - self.num_gaussians
|
| 764 |
+
)
|
| 765 |
+
if add_num <= 0:
|
| 766 |
+
return
|
| 767 |
+
|
| 768 |
+
# Compute error map for new Gaussian placement
|
| 769 |
+
raw_images, _ = self.forward(self.img_h, self.img_w, self.tile_bounds)
|
| 770 |
+
images = torch.pow(torch.clamp(raw_images, 0.0, 1.0), 1.0 / self.gamma)
|
| 771 |
+
gt_images = torch.pow(self.gt_images, 1.0 / self.gamma)
|
| 772 |
+
|
| 773 |
+
kernel_size = round(np.sqrt(self.img_h * self.img_w) // 400)
|
| 774 |
+
if kernel_size >= 1:
|
| 775 |
+
kernel_size = max(3, kernel_size)
|
| 776 |
+
kernel_size = kernel_size + 1 if kernel_size % 2 == 0 else kernel_size
|
| 777 |
+
gt_images = gaussian_blur(img=gt_images, kernel_size=kernel_size)
|
| 778 |
+
|
| 779 |
+
diff_map = (gt_images - images).detach().clone()
|
| 780 |
+
error_map = torch.pow(torch.abs(diff_map).mean(dim=0).reshape(-1), 2.0)
|
| 781 |
+
sample_prob = (error_map / error_map.sum()).cpu().numpy()
|
| 782 |
+
selected = np.random.choice(
|
| 783 |
+
self.num_pixels, add_num, replace=False, p=sample_prob
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
# Create new Gaussians
|
| 787 |
+
new_xy = self.pixel_xy.detach().clone()[selected]
|
| 788 |
+
new_scale = torch.ones(add_num, 2, dtype=self.dtype, device=self.device)
|
| 789 |
+
init_scale = self.init_scale
|
| 790 |
+
new_scale.fill_(init_scale if self.disable_inverse_scale else 1.0 / init_scale)
|
| 791 |
+
new_rot = torch.zeros(add_num, 1, dtype=self.dtype, device=self.device)
|
| 792 |
+
new_feat = diff_map.permute(1, 2, 0).reshape(-1, self.feat_dim)[selected]
|
| 793 |
+
new_vis_feat = torch.rand_like(new_feat)
|
| 794 |
+
|
| 795 |
+
# Update parameters
|
| 796 |
+
old_xy = self.xy.detach().clone()
|
| 797 |
+
old_scale = self.scale.detach().clone()
|
| 798 |
+
old_rot = self.rot.detach().clone()
|
| 799 |
+
old_feat = self.feat.detach().clone()
|
| 800 |
+
old_vis_feat = self.vis_feat.detach().clone()
|
| 801 |
+
|
| 802 |
+
self.num_gaussians += add_num
|
| 803 |
+
all_xy = torch.cat([old_xy, new_xy], dim=0)
|
| 804 |
+
all_scale = torch.cat([old_scale, new_scale], dim=0)
|
| 805 |
+
all_rot = torch.cat([old_rot, new_rot], dim=0)
|
| 806 |
+
all_feat = torch.cat([old_feat, new_feat], dim=0)
|
| 807 |
+
all_vis_feat = torch.cat([old_vis_feat, new_vis_feat], dim=0)
|
| 808 |
+
|
| 809 |
+
self.xy = nn.Parameter(all_xy, requires_grad=True)
|
| 810 |
+
self.scale = nn.Parameter(all_scale, requires_grad=True)
|
| 811 |
+
self.rot = nn.Parameter(all_rot, requires_grad=True)
|
| 812 |
+
self.feat = nn.Parameter(all_feat, requires_grad=True)
|
| 813 |
+
self.vis_feat = nn.Parameter(all_vis_feat, requires_grad=False)
|
| 814 |
+
|
| 815 |
+
# Update optimizer
|
| 816 |
+
self.optimizer = torch.optim.Adam(
|
| 817 |
+
[
|
| 818 |
+
{"params": self.xy, "lr": self.pos_lr},
|
| 819 |
+
{"params": self.scale, "lr": self.scale_lr},
|
| 820 |
+
{"params": self.rot, "lr": self.rot_lr},
|
| 821 |
+
{"params": self.feat, "lr": self.feat_lr},
|
| 822 |
+
]
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
self.worklog.info(
|
| 826 |
+
f"Step: {self.step:d} | Adding {add_num:d} Gaussians ({self.num_gaussians - add_num:d} -> {self.num_gaussians:d})"
|
| 827 |
+
)
|
main.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from model import GaussianSplatting2D
|
| 6 |
+
from utils.misc_utils import load_cfg
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_gaussian_cfg(args):
|
| 10 |
+
gaussian_cfg = f"num-{args.num_gaussians:d}"
|
| 11 |
+
if args.disable_inverse_scale:
|
| 12 |
+
gaussian_cfg += f"_scale-{args.init_scale:.1f}"
|
| 13 |
+
else:
|
| 14 |
+
gaussian_cfg += f"_inv-scale-{args.init_scale:.1f}"
|
| 15 |
+
if not args.quantize:
|
| 16 |
+
args.pos_bits, args.scale_bits, args.rot_bits, args.feat_bits = 32, 32, 32, 32
|
| 17 |
+
min_bits = min(args.pos_bits, args.scale_bits, args.rot_bits, args.feat_bits)
|
| 18 |
+
max_bits = max(args.pos_bits, args.scale_bits, args.rot_bits, args.feat_bits)
|
| 19 |
+
if min_bits < 4 or max_bits > 32:
|
| 20 |
+
raise ValueError(
|
| 21 |
+
f"Bit precision must be between 4 and 32 but got: {args.pos_bits:d}, {args.scale_bits:d}, {args.rot_bits:d}, {args.feat_bits:d}"
|
| 22 |
+
)
|
| 23 |
+
gaussian_cfg += f"_bits-{args.pos_bits:d}-{args.scale_bits:d}-{args.rot_bits:d}-{args.feat_bits:d}"
|
| 24 |
+
if not args.disable_topk_norm:
|
| 25 |
+
gaussian_cfg += f"_top-{args.topk:d}"
|
| 26 |
+
gaussian_cfg += f"_{args.init_mode[0]}-{args.init_random_ratio:.1f}"
|
| 27 |
+
return gaussian_cfg
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_log_dir(args):
|
| 31 |
+
gaussian_cfg = get_gaussian_cfg(args)
|
| 32 |
+
loss_cfg = f"l1-{args.l1_loss_ratio:.1f}_l2-{args.l2_loss_ratio:.1f}_ssim-{args.ssim_loss_ratio:.1f}"
|
| 33 |
+
folder = f"{gaussian_cfg}_{loss_cfg}"
|
| 34 |
+
if args.downsample:
|
| 35 |
+
folder += f"_ds-{args.downsample_ratio:.1f}"
|
| 36 |
+
if not args.disable_lr_schedule:
|
| 37 |
+
folder += f"_decay-{args.max_decay_times:d}-{args.decay_ratio:.1f}"
|
| 38 |
+
if not args.disable_prog_optim:
|
| 39 |
+
folder += "_prog"
|
| 40 |
+
return f"{args.log_root}/{args.exp_name}/{folder}"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main(args):
|
| 44 |
+
args.log_dir = get_log_dir(args)
|
| 45 |
+
ImageGS = GaussianSplatting2D(args)
|
| 46 |
+
if args.eval:
|
| 47 |
+
ImageGS.render(render_height=args.render_height)
|
| 48 |
+
else:
|
| 49 |
+
ImageGS.optimize()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if __name__ == "__main__":
|
| 53 |
+
torch.hub.set_dir("models/torch")
|
| 54 |
+
parser = argparse.ArgumentParser()
|
| 55 |
+
parser = load_cfg(cfg_path="cfgs/default.yaml", parser=parser)
|
| 56 |
+
arguments = parser.parse_args()
|
| 57 |
+
main(arguments)
|
model.py
ADDED
|
@@ -0,0 +1,824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from time import perf_counter
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from fused_ssim import fused_ssim
|
| 12 |
+
from lpips import LPIPS
|
| 13 |
+
from pytorch_msssim import MS_SSIM
|
| 14 |
+
from torchvision.transforms.functional import gaussian_blur
|
| 15 |
+
|
| 16 |
+
from gsplat import (
|
| 17 |
+
project_gaussians_2d_scale_rot,
|
| 18 |
+
rasterize_gaussians_no_tiles,
|
| 19 |
+
rasterize_gaussians_sum,
|
| 20 |
+
)
|
| 21 |
+
from utils.flip import LDRFLIPLoss
|
| 22 |
+
from utils.image_utils import (
|
| 23 |
+
compute_image_gradients,
|
| 24 |
+
get_grid,
|
| 25 |
+
get_psnr,
|
| 26 |
+
load_images,
|
| 27 |
+
save_image,
|
| 28 |
+
separate_image_channels,
|
| 29 |
+
visualize_added_gaussians,
|
| 30 |
+
visualize_gaussians,
|
| 31 |
+
)
|
| 32 |
+
from utils.misc_utils import clean_dir, get_latest_ckpt_step, save_cfg, set_random_seed
|
| 33 |
+
from utils.quantization_utils import ste_quantize
|
| 34 |
+
from utils.saliency_utils import get_smap
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class GaussianSplatting2D(nn.Module):
|
| 38 |
+
def __init__(self, args):
|
| 39 |
+
super(GaussianSplatting2D, self).__init__()
|
| 40 |
+
self.evaluate = args.eval
|
| 41 |
+
set_random_seed(seed=args.seed)
|
| 42 |
+
# Ensure we're using the correct CUDA device
|
| 43 |
+
if torch.cuda.is_available():
|
| 44 |
+
torch.cuda.set_device(0) # Force device 0
|
| 45 |
+
self.device = torch.device("cuda:0")
|
| 46 |
+
else:
|
| 47 |
+
self.device = torch.device("cpu")
|
| 48 |
+
self.dtype = torch.float32
|
| 49 |
+
self._init_logging(args)
|
| 50 |
+
self._init_target(args)
|
| 51 |
+
self._init_bit_precision(args)
|
| 52 |
+
self._init_gaussians(args)
|
| 53 |
+
self._init_loss(args)
|
| 54 |
+
self._init_optimization(args)
|
| 55 |
+
# Initialization
|
| 56 |
+
if self.evaluate:
|
| 57 |
+
self.ckpt_file = args.ckpt_file
|
| 58 |
+
self._load_model()
|
| 59 |
+
else:
|
| 60 |
+
self._init_pos_scale_feat(args)
|
| 61 |
+
|
| 62 |
+
def _init_logging(self, args):
|
| 63 |
+
self.log_dir = args.log_dir
|
| 64 |
+
self.log_level = args.log_level
|
| 65 |
+
self.ckpt_dir = os.path.join(self.log_dir, "checkpoints")
|
| 66 |
+
self.train_dir = os.path.join(self.log_dir, "train")
|
| 67 |
+
self.eval_dir = os.path.join(self.log_dir, "eval")
|
| 68 |
+
self.vis_gaussians = args.vis_gaussians
|
| 69 |
+
self.save_image_steps = args.save_image_steps
|
| 70 |
+
self.save_ckpt_steps = args.save_ckpt_steps
|
| 71 |
+
self.eval_steps = args.eval_steps
|
| 72 |
+
if not self.evaluate:
|
| 73 |
+
clean_dir(path=self.log_dir)
|
| 74 |
+
os.makedirs(self.log_dir, exist_ok=False)
|
| 75 |
+
os.makedirs(self.ckpt_dir, exist_ok=False)
|
| 76 |
+
os.makedirs(self.train_dir, exist_ok=False)
|
| 77 |
+
else:
|
| 78 |
+
os.makedirs(self.eval_dir, exist_ok=True)
|
| 79 |
+
self._gen_logger(args)
|
| 80 |
+
if not self.evaluate:
|
| 81 |
+
save_cfg(path=f"{self.log_dir}/cfg_train.yaml", args=args)
|
| 82 |
+
|
| 83 |
+
def _gen_logger(self, args):
|
| 84 |
+
log_fname = "log_train"
|
| 85 |
+
if self.evaluate:
|
| 86 |
+
log_fname = "log_eval"
|
| 87 |
+
log_level = getattr(logging, self.log_level, logging.INFO)
|
| 88 |
+
logging.basicConfig(level=log_level)
|
| 89 |
+
self.worklog = logging.getLogger("Image-GS Logger")
|
| 90 |
+
self.worklog.propagate = False
|
| 91 |
+
datefmt = "%Y/%m/%d %H:%M:%S"
|
| 92 |
+
fileHandler = logging.FileHandler(
|
| 93 |
+
f"{self.log_dir}/{log_fname}.txt", mode="a", encoding="utf8"
|
| 94 |
+
)
|
| 95 |
+
fileHandler.setFormatter(
|
| 96 |
+
logging.Formatter(fmt="[{asctime}] {message}", datefmt=datefmt, style="{")
|
| 97 |
+
)
|
| 98 |
+
consoleHandler = logging.StreamHandler(sys.stdout)
|
| 99 |
+
consoleHandler.setFormatter(
|
| 100 |
+
logging.Formatter(
|
| 101 |
+
fmt="\x1b[32m[{asctime}] \x1b[0m{message}", datefmt=datefmt, style="{"
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
self.worklog.handlers = [fileHandler, consoleHandler]
|
| 105 |
+
action = "rendering" if self.evaluate else "optimizing"
|
| 106 |
+
self.worklog.info(
|
| 107 |
+
f"Start {action} {args.num_gaussians:d} Gaussians for '{args.input_path}'"
|
| 108 |
+
)
|
| 109 |
+
self.worklog.info("***********************************************")
|
| 110 |
+
|
| 111 |
+
def _init_target(self, args):
|
| 112 |
+
self.gamma = args.gamma
|
| 113 |
+
self.downsample = args.downsample
|
| 114 |
+
if self.downsample:
|
| 115 |
+
self.downsample_ratio = float(args.downsample_ratio)
|
| 116 |
+
self.block_h, self.block_w = (
|
| 117 |
+
16,
|
| 118 |
+
16,
|
| 119 |
+
) # Warning: Must match hardcoded value in CUDA kernel, modify with caution
|
| 120 |
+
self._load_target_images(path=os.path.join(args.data_root, args.input_path))
|
| 121 |
+
if self.downsample:
|
| 122 |
+
self.gt_images_upsampled = self.gt_images
|
| 123 |
+
self.img_h_upsampled, self.img_w_upsampled = self.img_h, self.img_w
|
| 124 |
+
self.tile_bounds_upsampled = self.tile_bounds
|
| 125 |
+
self._load_target_images(
|
| 126 |
+
path=os.path.join(args.data_root, args.input_path),
|
| 127 |
+
downsample_ratio=self.downsample_ratio,
|
| 128 |
+
)
|
| 129 |
+
if not self.evaluate:
|
| 130 |
+
path = f"{self.log_dir}/gt_upsample-{self.downsample_ratio:.1f}_res-{self.img_h_upsampled:d}x{self.img_w_upsampled:d}"
|
| 131 |
+
self._separate_and_save_images(
|
| 132 |
+
images=self.gt_images_upsampled,
|
| 133 |
+
channels=self.input_channels,
|
| 134 |
+
path=path,
|
| 135 |
+
)
|
| 136 |
+
self.num_pixels = self.img_h * self.img_w
|
| 137 |
+
if not self.evaluate:
|
| 138 |
+
path = f"{self.log_dir}/gt_res-{self.img_h:d}x{self.img_w:d}"
|
| 139 |
+
self._separate_and_save_images(
|
| 140 |
+
images=self.gt_images, channels=self.input_channels, path=path
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def _load_target_images(self, path, downsample_ratio=None):
|
| 144 |
+
self.gt_images, self.input_channels, self.image_fnames = load_images(
|
| 145 |
+
load_path=path, downsample_ratio=downsample_ratio, gamma=self.gamma
|
| 146 |
+
)
|
| 147 |
+
self.gt_images = torch.from_numpy(self.gt_images).to(
|
| 148 |
+
dtype=self.dtype, device=self.device
|
| 149 |
+
)
|
| 150 |
+
self.img_h, self.img_w = self.gt_images.shape[1:]
|
| 151 |
+
self.tile_bounds = (
|
| 152 |
+
(self.img_w + self.block_w - 1) // self.block_w,
|
| 153 |
+
(self.img_h + self.block_h - 1) // self.block_h,
|
| 154 |
+
1,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
def _separate_and_save_images(self, images, channels, path):
|
| 158 |
+
images_sep = separate_image_channels(images=images, input_channels=channels)
|
| 159 |
+
for idx, image in enumerate(images_sep, 1):
|
| 160 |
+
suffix = "" if len(images_sep) == 1 else f"_{idx:d}"
|
| 161 |
+
save_image(image, f"{path}{suffix}.png", gamma=self.gamma)
|
| 162 |
+
|
| 163 |
+
def _init_bit_precision(self, args):
|
| 164 |
+
self.quantize = args.quantize
|
| 165 |
+
self.pos_bits = args.pos_bits
|
| 166 |
+
self.scale_bits = args.scale_bits
|
| 167 |
+
self.rot_bits = args.rot_bits
|
| 168 |
+
self.feat_bits = args.feat_bits
|
| 169 |
+
|
| 170 |
+
def _init_gaussians(self, args):
|
| 171 |
+
self.num_gaussians = args.num_gaussians
|
| 172 |
+
self.total_num_gaussians = args.num_gaussians
|
| 173 |
+
self.disable_prog_optim = args.disable_prog_optim
|
| 174 |
+
if not self.disable_prog_optim and not self.evaluate:
|
| 175 |
+
self.initial_ratio = args.initial_ratio
|
| 176 |
+
self.add_times = args.add_times
|
| 177 |
+
self.add_steps = args.add_steps
|
| 178 |
+
self.num_gaussians = math.ceil(
|
| 179 |
+
self.initial_ratio * self.total_num_gaussians
|
| 180 |
+
)
|
| 181 |
+
self.max_add_num = math.ceil(
|
| 182 |
+
float(self.total_num_gaussians - self.num_gaussians) / self.add_times
|
| 183 |
+
)
|
| 184 |
+
min_steps = self.add_steps * self.add_times + args.post_min_steps
|
| 185 |
+
if args.max_steps < min_steps:
|
| 186 |
+
self.worklog.info(
|
| 187 |
+
f"Max steps ({args.max_steps:d}) is too small for progressive optimization. Resetting to {min_steps:d}"
|
| 188 |
+
)
|
| 189 |
+
args.max_steps = min_steps
|
| 190 |
+
self.topk = (
|
| 191 |
+
args.topk
|
| 192 |
+
) # Warning: Must match hardcoded value in CUDA kernel, modify with caution
|
| 193 |
+
self.eps = (
|
| 194 |
+
1e-7 if args.disable_tiles else 1e-4
|
| 195 |
+
) # Warning: Must match hardcoded value in CUDA kernel, modify with caution
|
| 196 |
+
self.init_scale = args.init_scale
|
| 197 |
+
self.disable_topk_norm = args.disable_topk_norm
|
| 198 |
+
self.disable_inverse_scale = args.disable_inverse_scale
|
| 199 |
+
self.disable_color_init = args.disable_color_init
|
| 200 |
+
self.xy = nn.Parameter(
|
| 201 |
+
torch.rand(self.num_gaussians, 2, dtype=self.dtype, device=self.device),
|
| 202 |
+
requires_grad=True,
|
| 203 |
+
)
|
| 204 |
+
self.scale = nn.Parameter(
|
| 205 |
+
torch.ones(self.num_gaussians, 2, dtype=self.dtype, device=self.device),
|
| 206 |
+
requires_grad=True,
|
| 207 |
+
)
|
| 208 |
+
self.rot = nn.Parameter(
|
| 209 |
+
torch.zeros(self.num_gaussians, 1, dtype=self.dtype, device=self.device),
|
| 210 |
+
requires_grad=True,
|
| 211 |
+
)
|
| 212 |
+
self.feat_dim = sum(self.input_channels)
|
| 213 |
+
self.feat = nn.Parameter(
|
| 214 |
+
torch.rand(
|
| 215 |
+
self.num_gaussians, self.feat_dim, dtype=self.dtype, device=self.device
|
| 216 |
+
),
|
| 217 |
+
requires_grad=True,
|
| 218 |
+
)
|
| 219 |
+
self.vis_feat = nn.Parameter(
|
| 220 |
+
torch.rand_like(self.feat), requires_grad=False
|
| 221 |
+
) # Only used for Gaussian ID visualization
|
| 222 |
+
self._log_compression_rate()
|
| 223 |
+
|
| 224 |
+
def _log_compression_rate(self):
|
| 225 |
+
bytes_uncompressed = float(self.gt_images.numel())
|
| 226 |
+
bpp_uncompressed = float(8 * self.feat_dim)
|
| 227 |
+
self.worklog.info(
|
| 228 |
+
f"Uncompressed: {bytes_uncompressed / 1e3:.2f} KB | {bpp_uncompressed:.3f} bpp | 8.0 bppc"
|
| 229 |
+
)
|
| 230 |
+
bits_compressed = (
|
| 231 |
+
2 * self.pos_bits
|
| 232 |
+
+ 2 * self.scale_bits
|
| 233 |
+
+ self.rot_bits
|
| 234 |
+
+ self.feat_dim * self.feat_bits
|
| 235 |
+
) * self.total_num_gaussians
|
| 236 |
+
bytes_compressed = bits_compressed / 8.0
|
| 237 |
+
bpp_compressed = float(bits_compressed) / self.num_pixels
|
| 238 |
+
bppc_compressed = bpp_compressed / self.feat_dim
|
| 239 |
+
self.num_bytes = bytes_compressed
|
| 240 |
+
self.worklog.info(
|
| 241 |
+
f"Compressed: {bytes_compressed / 1e3:.2f} KB | {bpp_compressed:.3f} bpp | {bppc_compressed:.3f} bppc"
|
| 242 |
+
)
|
| 243 |
+
self.worklog.info(
|
| 244 |
+
f"Compression rate: {bpp_uncompressed / bpp_compressed:.2f}x | {100.0 * bpp_compressed / bpp_uncompressed:.2f}%"
|
| 245 |
+
)
|
| 246 |
+
self.worklog.info("***********************************************")
|
| 247 |
+
|
| 248 |
+
def _init_loss(self, args):
|
| 249 |
+
self.l1_loss = None
|
| 250 |
+
self.l2_loss = None
|
| 251 |
+
self.ssim_loss = None
|
| 252 |
+
self.l1_loss_ratio = args.l1_loss_ratio
|
| 253 |
+
self.l2_loss_ratio = args.l2_loss_ratio
|
| 254 |
+
self.ssim_loss_ratio = args.ssim_loss_ratio
|
| 255 |
+
|
| 256 |
+
def _init_optimization(self, args):
|
| 257 |
+
self.disable_tiles = args.disable_tiles
|
| 258 |
+
self.start_step = 1
|
| 259 |
+
self.max_steps = args.max_steps
|
| 260 |
+
self.pos_lr = args.pos_lr
|
| 261 |
+
self.scale_lr = args.scale_lr
|
| 262 |
+
self.rot_lr = args.rot_lr
|
| 263 |
+
self.feat_lr = args.feat_lr
|
| 264 |
+
self.optimizer = torch.optim.Adam(
|
| 265 |
+
[
|
| 266 |
+
{"params": self.xy, "lr": self.pos_lr},
|
| 267 |
+
{"params": self.scale, "lr": self.scale_lr},
|
| 268 |
+
{"params": self.rot, "lr": self.rot_lr},
|
| 269 |
+
{"params": self.feat, "lr": self.feat_lr},
|
| 270 |
+
]
|
| 271 |
+
)
|
| 272 |
+
self.disable_lr_schedule = args.disable_lr_schedule
|
| 273 |
+
if not self.disable_lr_schedule:
|
| 274 |
+
self.decay_ratio = args.decay_ratio
|
| 275 |
+
self.check_decay_steps = args.check_decay_steps
|
| 276 |
+
self.max_decay_times = args.max_decay_times
|
| 277 |
+
self.decay_threshold = args.decay_threshold
|
| 278 |
+
|
| 279 |
+
def _init_pos_scale_feat(self, args):
|
| 280 |
+
self.init_mode = args.init_mode
|
| 281 |
+
self.init_random_ratio = args.init_random_ratio
|
| 282 |
+
self.pixel_xy = (
|
| 283 |
+
get_grid(h=self.img_h, w=self.img_w)
|
| 284 |
+
.to(dtype=self.dtype, device=self.device)
|
| 285 |
+
.reshape(-1, 2)
|
| 286 |
+
)
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
# Position
|
| 289 |
+
if self.init_mode == "gradient":
|
| 290 |
+
self._compute_gmap()
|
| 291 |
+
self.xy.copy_(self._sample_pos(prob=self.image_gradients))
|
| 292 |
+
elif self.init_mode == "saliency":
|
| 293 |
+
self.smap_filter_size = args.smap_filter_size
|
| 294 |
+
self._compute_smap(path="models")
|
| 295 |
+
self.xy.copy_(self._sample_pos(prob=self.saliency))
|
| 296 |
+
else:
|
| 297 |
+
selected = np.random.choice(
|
| 298 |
+
self.num_pixels, self.num_gaussians, replace=False, p=None
|
| 299 |
+
)
|
| 300 |
+
self.xy.copy_(self.pixel_xy.detach().clone()[selected])
|
| 301 |
+
# Scale
|
| 302 |
+
self.scale.fill_(
|
| 303 |
+
self.init_scale if self.disable_inverse_scale else 1.0 / self.init_scale
|
| 304 |
+
)
|
| 305 |
+
# Feature
|
| 306 |
+
if not self.disable_color_init:
|
| 307 |
+
self.feat.copy_(
|
| 308 |
+
self._get_target_features(positions=self.xy).detach().clone()
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
def _sample_pos(self, prob):
|
| 312 |
+
num_random = round(self.init_random_ratio * self.num_gaussians)
|
| 313 |
+
selected_random = np.random.choice(
|
| 314 |
+
self.num_pixels, num_random, replace=False, p=None
|
| 315 |
+
)
|
| 316 |
+
selected_other = np.random.choice(
|
| 317 |
+
self.num_pixels, self.num_gaussians - num_random, replace=False, p=prob
|
| 318 |
+
)
|
| 319 |
+
return torch.cat(
|
| 320 |
+
[
|
| 321 |
+
self.pixel_xy.detach().clone()[selected_random],
|
| 322 |
+
self.pixel_xy.detach().clone()[selected_other],
|
| 323 |
+
],
|
| 324 |
+
dim=0,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
def _compute_gmap(self):
|
| 328 |
+
gy, gx = compute_image_gradients(
|
| 329 |
+
np.power(self.gt_images.detach().cpu().clone().numpy(), 1.0 / self.gamma)
|
| 330 |
+
)
|
| 331 |
+
g_norm = np.hypot(gy, gx).astype(np.float32)
|
| 332 |
+
g_norm = g_norm / g_norm.max()
|
| 333 |
+
save_image(g_norm, f"{self.log_dir}/gmap_res-{self.img_h:d}x{self.img_w:d}.png")
|
| 334 |
+
g_norm = np.power(g_norm.reshape(-1), 2.0)
|
| 335 |
+
self.image_gradients = g_norm / g_norm.sum()
|
| 336 |
+
self.worklog.info("Image gradient map successfully saved")
|
| 337 |
+
self.worklog.info("***********************************************")
|
| 338 |
+
|
| 339 |
+
def _compute_smap(self, path):
|
| 340 |
+
smap = get_smap(
|
| 341 |
+
torch.pow(self.gt_images.detach().clone(), 1.0 / self.gamma),
|
| 342 |
+
path,
|
| 343 |
+
self.smap_filter_size,
|
| 344 |
+
)
|
| 345 |
+
save_image(smap, f"{self.log_dir}/smap_res-{self.img_h:d}x{self.img_w:d}.png")
|
| 346 |
+
self.saliency = (smap / smap.sum()).reshape(-1)
|
| 347 |
+
self.worklog.info("Saliency map successfully saved")
|
| 348 |
+
self.worklog.info("***********************************************")
|
| 349 |
+
|
| 350 |
+
def _get_target_features(self, positions):
|
| 351 |
+
with torch.no_grad():
|
| 352 |
+
# gt_images [1, C, H, W]; positions [1, 1, P, 2]; top-left [-1, -1]; bottom-right [1, 1]
|
| 353 |
+
target_features = F.grid_sample(
|
| 354 |
+
self.gt_images.unsqueeze(0),
|
| 355 |
+
positions[None, None, ...] * 2.0 - 1.0,
|
| 356 |
+
align_corners=False,
|
| 357 |
+
)
|
| 358 |
+
target_features = target_features[0, :, 0, :].permute(1, 0) # [P, C]
|
| 359 |
+
return target_features
|
| 360 |
+
|
| 361 |
+
def _load_model(self):
|
| 362 |
+
if self.ckpt_file != "":
|
| 363 |
+
ckpt_path = os.path.join(self.ckpt_dir, self.ckpt_file)
|
| 364 |
+
else:
|
| 365 |
+
latest_step = get_latest_ckpt_step(self.ckpt_dir)
|
| 366 |
+
if latest_step == -1:
|
| 367 |
+
raise FileNotFoundError(f"No checkpoint found in '{self.ckpt_dir}'")
|
| 368 |
+
ckpt_path = os.path.join(self.ckpt_dir, f"ckpt_step-{latest_step:d}.pt")
|
| 369 |
+
checkpoint = torch.load(ckpt_path, weights_only=False)
|
| 370 |
+
self.load_state_dict(checkpoint["state_dict"])
|
| 371 |
+
self.optimizer.load_state_dict(checkpoint["optim_state_dict"])
|
| 372 |
+
self.start_step = checkpoint["step"] + 1
|
| 373 |
+
self.worklog.info(f"Checkpoint '{ckpt_path}' successfully loaded")
|
| 374 |
+
self.worklog.info("***********************************************")
|
| 375 |
+
|
| 376 |
+
def _save_model(self):
|
| 377 |
+
if self.quantize:
|
| 378 |
+
self._quantize()
|
| 379 |
+
psnr, ssim = self._evaluate(log=False, upsample=False)
|
| 380 |
+
self._evaluate_extra()
|
| 381 |
+
ckpt_data = {
|
| 382 |
+
"step": self.step,
|
| 383 |
+
"psnr": psnr,
|
| 384 |
+
"ssim": ssim,
|
| 385 |
+
"lpips": self.lpips_final,
|
| 386 |
+
"flip": self.flip_final,
|
| 387 |
+
"msssim": self.msssim_final,
|
| 388 |
+
"bytes": self.num_bytes,
|
| 389 |
+
"time": self.total_time_accum,
|
| 390 |
+
"state_dict": self.state_dict(),
|
| 391 |
+
"optim_state_dict": self.optimizer.state_dict(),
|
| 392 |
+
}
|
| 393 |
+
save_path = f"{self.ckpt_dir}/ckpt_step-{self.step:d}.pt"
|
| 394 |
+
torch.save(ckpt_data, save_path)
|
| 395 |
+
self.worklog.info(f"Checkpoint 'ckpt_step-{self.step:d}.pt' successfully saved")
|
| 396 |
+
self.worklog.info(
|
| 397 |
+
f"PSNR: {psnr:.2f} | SSIM: {ssim:.4f} | LPIPS: {self.lpips_final:.4f} | FLIP: {self.flip_final:.4f} | MS-SSIM: {self.msssim_final:.4f}"
|
| 398 |
+
)
|
| 399 |
+
self.worklog.info("***********************************************")
|
| 400 |
+
|
| 401 |
+
def _quantize(self):
|
| 402 |
+
with torch.no_grad():
|
| 403 |
+
self.xy.copy_(ste_quantize(self.xy, self.pos_bits))
|
| 404 |
+
self.scale.copy_(ste_quantize(self.scale, self.scale_bits))
|
| 405 |
+
self.rot.copy_(ste_quantize(self.rot, self.rot_bits))
|
| 406 |
+
self.feat.copy_(ste_quantize(self.feat, self.feat_bits))
|
| 407 |
+
|
| 408 |
+
def render(self, render_height=None):
|
| 409 |
+
img_h, img_w = self.img_h, self.img_w
|
| 410 |
+
if render_height is not None:
|
| 411 |
+
img_h, img_w = render_height, round((float(render_height) / img_h) * img_w)
|
| 412 |
+
tile_bounds = (
|
| 413 |
+
(img_w + self.block_w - 1) // self.block_w,
|
| 414 |
+
(img_h + self.block_h - 1) // self.block_h,
|
| 415 |
+
1,
|
| 416 |
+
)
|
| 417 |
+
upsample_ratio = float(img_h) / self.img_h
|
| 418 |
+
with torch.no_grad():
|
| 419 |
+
num_prep_runs = 2
|
| 420 |
+
for _ in range(num_prep_runs):
|
| 421 |
+
self.forward(img_h, img_w, tile_bounds, upsample_ratio, benchmark=True)
|
| 422 |
+
images, render_time = self.forward(
|
| 423 |
+
img_h, img_w, tile_bounds, upsample_ratio
|
| 424 |
+
)
|
| 425 |
+
path = f"{self.eval_dir}/render_upsample-{upsample_ratio:.1f}_res-{img_h:d}x{img_w:d}"
|
| 426 |
+
self._separate_and_save_images(
|
| 427 |
+
images=images, channels=self.input_channels, path=path
|
| 428 |
+
)
|
| 429 |
+
self.worklog.info(f"Step: {self.start_step - 1:d} | Time: {render_time:.6f} s")
|
| 430 |
+
self.worklog.info(f"Rendering at resolution ({img_h:d}, {img_w:d}) completed")
|
| 431 |
+
self.worklog.info("***********************************************")
|
| 432 |
+
|
| 433 |
+
def benchmark_render_time(self, num_reps, render_height=None):
|
| 434 |
+
img_h, img_w = self.img_h, self.img_w
|
| 435 |
+
if render_height is not None:
|
| 436 |
+
img_h, img_w = render_height, round((float(render_height) / img_h) * img_w)
|
| 437 |
+
tile_bounds = (
|
| 438 |
+
(img_w + self.block_w - 1) // self.block_w,
|
| 439 |
+
(img_h + self.block_h - 1) // self.block_h,
|
| 440 |
+
1,
|
| 441 |
+
)
|
| 442 |
+
upsample_ratio = float(img_h) / self.img_h
|
| 443 |
+
with torch.no_grad():
|
| 444 |
+
render_time_all = np.zeros(num_reps, dtype=np.float32)
|
| 445 |
+
num_prep_runs = 2
|
| 446 |
+
for _ in range(num_prep_runs):
|
| 447 |
+
self.forward(img_h, img_w, tile_bounds, upsample_ratio, benchmark=True)
|
| 448 |
+
for rid in range(num_reps):
|
| 449 |
+
render_time = self.forward(
|
| 450 |
+
img_h, img_w, tile_bounds, upsample_ratio, benchmark=True
|
| 451 |
+
)
|
| 452 |
+
render_time_all[rid] = render_time
|
| 453 |
+
return render_time_all
|
| 454 |
+
|
| 455 |
+
def forward(self, img_h, img_w, tile_bounds, upsample_ratio=None, benchmark=False):
|
| 456 |
+
scale = self._get_scale(upsample_ratio=upsample_ratio)
|
| 457 |
+
xy, rot, feat = self.xy, self.rot, self.feat
|
| 458 |
+
if self.quantize:
|
| 459 |
+
xy, scale, rot, feat = (
|
| 460 |
+
ste_quantize(xy, self.pos_bits),
|
| 461 |
+
ste_quantize(scale, self.scale_bits),
|
| 462 |
+
ste_quantize(rot, self.rot_bits),
|
| 463 |
+
ste_quantize(feat, self.feat_bits),
|
| 464 |
+
)
|
| 465 |
+
begin = perf_counter()
|
| 466 |
+
tmp = project_gaussians_2d_scale_rot(xy, scale, rot, img_h, img_w, tile_bounds)
|
| 467 |
+
xy, radii, conics, num_tiles_hit = tmp
|
| 468 |
+
if not self.disable_tiles:
|
| 469 |
+
enable_topk_norm = not self.disable_topk_norm
|
| 470 |
+
tmp = (
|
| 471 |
+
xy,
|
| 472 |
+
radii,
|
| 473 |
+
conics,
|
| 474 |
+
num_tiles_hit,
|
| 475 |
+
feat,
|
| 476 |
+
img_h,
|
| 477 |
+
img_w,
|
| 478 |
+
self.block_h,
|
| 479 |
+
self.block_w,
|
| 480 |
+
enable_topk_norm,
|
| 481 |
+
)
|
| 482 |
+
out_image = rasterize_gaussians_sum(*tmp)
|
| 483 |
+
else:
|
| 484 |
+
tmp = xy, conics, feat, img_h, img_w
|
| 485 |
+
out_image = rasterize_gaussians_no_tiles(*tmp)
|
| 486 |
+
render_time = perf_counter() - begin
|
| 487 |
+
if benchmark:
|
| 488 |
+
return render_time
|
| 489 |
+
out_image = (
|
| 490 |
+
out_image.view(-1, img_h, img_w, self.feat_dim)
|
| 491 |
+
.permute(0, 3, 1, 2)
|
| 492 |
+
.contiguous()
|
| 493 |
+
)
|
| 494 |
+
return out_image.squeeze(dim=0), render_time
|
| 495 |
+
|
| 496 |
+
def _get_scale(self, upsample_ratio=None):
|
| 497 |
+
scale = self.scale
|
| 498 |
+
if not self.disable_inverse_scale:
|
| 499 |
+
scale = 1.0 / scale
|
| 500 |
+
if upsample_ratio is not None:
|
| 501 |
+
scale = upsample_ratio * scale
|
| 502 |
+
return scale
|
| 503 |
+
|
| 504 |
+
def _visualize_gaussian_id(self, img_h, img_w, tile_bounds, upsample_ratio=None):
|
| 505 |
+
scale = self._get_scale(upsample_ratio=upsample_ratio)
|
| 506 |
+
xy, rot, feat = self.xy, self.rot, self.feat
|
| 507 |
+
if self.quantize:
|
| 508 |
+
xy, scale, rot, feat = (
|
| 509 |
+
ste_quantize(xy, self.pos_bits),
|
| 510 |
+
ste_quantize(scale, self.scale_bits),
|
| 511 |
+
ste_quantize(rot, self.rot_bits),
|
| 512 |
+
ste_quantize(feat, self.feat_bits),
|
| 513 |
+
)
|
| 514 |
+
feat = self.vis_feat * feat.norm(dim=-1, keepdim=True)
|
| 515 |
+
tmp = project_gaussians_2d_scale_rot(xy, scale, rot, img_h, img_w, tile_bounds)
|
| 516 |
+
xy, radii, conics, num_tiles_hit = tmp
|
| 517 |
+
if not self.disable_tiles:
|
| 518 |
+
enable_topk_norm = not self.disable_topk_norm
|
| 519 |
+
tmp = (
|
| 520 |
+
xy,
|
| 521 |
+
radii,
|
| 522 |
+
conics,
|
| 523 |
+
num_tiles_hit,
|
| 524 |
+
feat,
|
| 525 |
+
img_h,
|
| 526 |
+
img_w,
|
| 527 |
+
self.block_h,
|
| 528 |
+
self.block_w,
|
| 529 |
+
enable_topk_norm,
|
| 530 |
+
)
|
| 531 |
+
out_image = rasterize_gaussians_sum(*tmp)
|
| 532 |
+
else:
|
| 533 |
+
tmp = xy, conics, feat, img_h, img_w
|
| 534 |
+
out_image = rasterize_gaussians_no_tiles(*tmp)
|
| 535 |
+
out_image = (
|
| 536 |
+
out_image.view(-1, img_h, img_w, self.feat_dim)
|
| 537 |
+
.permute(0, 3, 1, 2)
|
| 538 |
+
.contiguous()
|
| 539 |
+
)
|
| 540 |
+
return out_image.squeeze(dim=0)
|
| 541 |
+
|
| 542 |
+
def optimize(self):
|
| 543 |
+
self.psnr_curr, self.ssim_curr = 0.0, 0.0
|
| 544 |
+
self.best_psnr, self.best_ssim = 0.0, 0.0
|
| 545 |
+
self.decay_times, self.no_improvement_steps = 0, 0
|
| 546 |
+
self.render_time_accum, self.total_time_accum = 0.0, 0.0
|
| 547 |
+
self.lpips_final, self.flip_final, self.msssim_final = 1.0, 1.0, 0.0
|
| 548 |
+
|
| 549 |
+
self.step = 0
|
| 550 |
+
with torch.no_grad():
|
| 551 |
+
self._log_images(log_final=False, plot_gaussians=self.vis_gaussians)
|
| 552 |
+
for step in range(self.start_step, self.max_steps + 1):
|
| 553 |
+
self.step = step
|
| 554 |
+
self.optimizer.zero_grad()
|
| 555 |
+
# Rendering
|
| 556 |
+
images, render_time = self.forward(self.img_h, self.img_w, self.tile_bounds)
|
| 557 |
+
self.render_time_accum += render_time
|
| 558 |
+
# Optimization
|
| 559 |
+
begin = perf_counter()
|
| 560 |
+
self._get_total_loss(images)
|
| 561 |
+
self.total_loss.backward()
|
| 562 |
+
self.optimizer.step()
|
| 563 |
+
self.total_time_accum += perf_counter() - begin + render_time
|
| 564 |
+
# Logging
|
| 565 |
+
terminate = False
|
| 566 |
+
with torch.no_grad():
|
| 567 |
+
if self.step % self.eval_steps == 0:
|
| 568 |
+
self._evaluate(log=True, upsample=False)
|
| 569 |
+
if (
|
| 570 |
+
not self.disable_lr_schedule
|
| 571 |
+
and self.num_gaussians == self.total_num_gaussians
|
| 572 |
+
):
|
| 573 |
+
terminate = self._lr_schedule()
|
| 574 |
+
if self.step % self.save_image_steps == 0:
|
| 575 |
+
self._log_images(log_final=False, plot_gaussians=self.vis_gaussians)
|
| 576 |
+
if (
|
| 577 |
+
self.step % self.save_ckpt_steps == 0
|
| 578 |
+
and self.num_gaussians == self.total_num_gaussians
|
| 579 |
+
):
|
| 580 |
+
self._save_model()
|
| 581 |
+
if (
|
| 582 |
+
not self.disable_prog_optim
|
| 583 |
+
and self.step % self.add_steps == 0
|
| 584 |
+
and self.num_gaussians < self.total_num_gaussians
|
| 585 |
+
):
|
| 586 |
+
self._add_gaussians(
|
| 587 |
+
self.max_add_num, plot_gaussians=self.vis_gaussians
|
| 588 |
+
)
|
| 589 |
+
if terminate:
|
| 590 |
+
break
|
| 591 |
+
with torch.no_grad():
|
| 592 |
+
self._log_images(log_final=True, plot_gaussians=self.vis_gaussians)
|
| 593 |
+
self._save_model()
|
| 594 |
+
self.worklog.info("Optimization completed")
|
| 595 |
+
self.worklog.info("***********************************************")
|
| 596 |
+
self.worklog.info(
|
| 597 |
+
f"Mean scale: {self._get_scale().mean().item():.4f} (pixel) | {self.scale.mean().item():.4f} (raw)"
|
| 598 |
+
)
|
| 599 |
+
self.worklog.info("***********************************************")
|
| 600 |
+
return self.psnr_curr, self.ssim_curr
|
| 601 |
+
|
| 602 |
+
def _get_total_loss(self, images):
|
| 603 |
+
self.total_loss = 0
|
| 604 |
+
if self.l1_loss_ratio > 1e-7:
|
| 605 |
+
self.l1_loss = self.l1_loss_ratio * F.l1_loss(images, self.gt_images)
|
| 606 |
+
self.total_loss += self.l1_loss
|
| 607 |
+
else:
|
| 608 |
+
self.l1_loss = None
|
| 609 |
+
if self.l2_loss_ratio > 1e-7:
|
| 610 |
+
self.l2_loss = self.l2_loss_ratio * F.mse_loss(images, self.gt_images)
|
| 611 |
+
self.total_loss += self.l2_loss
|
| 612 |
+
else:
|
| 613 |
+
self.l2_loss = None
|
| 614 |
+
if self.ssim_loss_ratio > 1e-7:
|
| 615 |
+
self.ssim_loss = self.ssim_loss_ratio * (
|
| 616 |
+
1 - fused_ssim(images.unsqueeze(0), self.gt_images.unsqueeze(0))
|
| 617 |
+
)
|
| 618 |
+
self.total_loss += self.ssim_loss
|
| 619 |
+
else:
|
| 620 |
+
self.ssim_loss = None
|
| 621 |
+
|
| 622 |
+
def _evaluate(self, log=True, upsample=False):
|
| 623 |
+
if upsample: # Do not log performance metrics for upsampled images
|
| 624 |
+
log = False
|
| 625 |
+
images = torch.pow(
|
| 626 |
+
torch.clamp(self._render_images(upsample=upsample), 0.0, 1.0),
|
| 627 |
+
1.0 / self.gamma,
|
| 628 |
+
)
|
| 629 |
+
gt_images = torch.pow(
|
| 630 |
+
self.gt_images_upsampled if upsample else self.gt_images, 1.0 / self.gamma
|
| 631 |
+
)
|
| 632 |
+
psnr = get_psnr(images, gt_images).item()
|
| 633 |
+
ssim = fused_ssim(images.unsqueeze(0), gt_images.unsqueeze(0)).item()
|
| 634 |
+
if log:
|
| 635 |
+
self.psnr_curr, self.ssim_curr = psnr, ssim
|
| 636 |
+
loss_results = f"Loss: {self.total_loss.item():.4f}"
|
| 637 |
+
loss_results += (
|
| 638 |
+
f", L1: {self.l1_loss.item():.4f}" if self.l1_loss is not None else ""
|
| 639 |
+
)
|
| 640 |
+
loss_results += (
|
| 641 |
+
f", L2: {self.l2_loss.item():.4f}" if self.l2_loss is not None else ""
|
| 642 |
+
)
|
| 643 |
+
loss_results += (
|
| 644 |
+
f", SSIM: {self.ssim_loss.item():.4f}"
|
| 645 |
+
if self.ssim_loss is not None
|
| 646 |
+
else ""
|
| 647 |
+
)
|
| 648 |
+
time_results = f"Total: {self.total_time_accum:.2f} s | Render: {self.render_time_accum:.2f} s"
|
| 649 |
+
self.worklog.info(
|
| 650 |
+
f"Step: {self.step:d} | {time_results} | {loss_results} | PSNR: {self.psnr_curr:.2f} | SSIM: {self.ssim_curr:.4f}"
|
| 651 |
+
)
|
| 652 |
+
return psnr, ssim
|
| 653 |
+
|
| 654 |
+
def _evaluate_extra(self):
|
| 655 |
+
images = torch.pow(
|
| 656 |
+
torch.clamp(self._render_images(upsample=False), 0.0, 1.0), 1.0 / self.gamma
|
| 657 |
+
)[None, ...]
|
| 658 |
+
gt_images = torch.pow(self.gt_images, 1.0 / self.gamma)[None, ...]
|
| 659 |
+
msssim_metric = (
|
| 660 |
+
MS_SSIM(data_range=1.0, size_average=True, channel=self.feat_dim)
|
| 661 |
+
.to(device=self.device)
|
| 662 |
+
.eval()
|
| 663 |
+
)
|
| 664 |
+
self.msssim_final = msssim_metric(images, gt_images).item()
|
| 665 |
+
lpips_metric = LPIPS(net="alex").to(device=self.device).eval()
|
| 666 |
+
flip_metric = LDRFLIPLoss().to(device=self.device).eval()
|
| 667 |
+
num_channels = 1 if self.feat_dim < 3 else 3
|
| 668 |
+
self.lpips_final = lpips_metric(
|
| 669 |
+
images[:, :num_channels], gt_images[:, :num_channels]
|
| 670 |
+
).item()
|
| 671 |
+
if self.feat_dim >= 3:
|
| 672 |
+
self.flip_final = flip_metric(images[:, :3], gt_images[:, :3]).item()
|
| 673 |
+
|
| 674 |
+
def _log_images(self, log_final=False, plot_gaussians=False):
|
| 675 |
+
images = self._render_images(upsample=False)
|
| 676 |
+
if log_final:
|
| 677 |
+
path = f"{self.log_dir}/render_res-{self.img_h:d}x{self.img_w:d}"
|
| 678 |
+
self._separate_and_save_images(
|
| 679 |
+
images=images, channels=self.input_channels, path=path
|
| 680 |
+
)
|
| 681 |
+
psnr, ssim = self._evaluate(log=False, upsample=False)
|
| 682 |
+
path = f"{self.train_dir}/render_step-{self.step:d}_psnr-{psnr:.2f}_ssim-{ssim:.4f}_res-{self.img_h:d}x{self.img_w:d}"
|
| 683 |
+
self._separate_and_save_images(
|
| 684 |
+
images=images, channels=self.input_channels, path=path
|
| 685 |
+
)
|
| 686 |
+
if plot_gaussians:
|
| 687 |
+
path = f"{self.train_dir}/gaussian_step-{self.step:d}_psnr-{psnr:.2f}_ssim-{ssim:.4f}_res-{self.img_h:d}x{self.img_w:d}"
|
| 688 |
+
visualize_gaussians(
|
| 689 |
+
path,
|
| 690 |
+
self.xy,
|
| 691 |
+
self._get_scale(),
|
| 692 |
+
self.rot,
|
| 693 |
+
self.feat,
|
| 694 |
+
self.img_h,
|
| 695 |
+
self.img_w,
|
| 696 |
+
self.input_channels,
|
| 697 |
+
alpha=0.8,
|
| 698 |
+
gamma=self.gamma,
|
| 699 |
+
)
|
| 700 |
+
images = self._visualize_gaussian_id(
|
| 701 |
+
self.img_h, self.img_w, self.tile_bounds
|
| 702 |
+
)
|
| 703 |
+
path = f"{self.train_dir}/gaussian-id_step-{self.step:d}_psnr-{psnr:.2f}_ssim-{ssim:.4f}_res-{self.img_h:d}x{self.img_w:d}"
|
| 704 |
+
self._separate_and_save_images(
|
| 705 |
+
images=images, channels=self.input_channels, path=path
|
| 706 |
+
)
|
| 707 |
+
if self.downsample:
|
| 708 |
+
images = self._render_images(upsample=True)
|
| 709 |
+
psnr, ssim = self._evaluate(log=False, upsample=True)
|
| 710 |
+
img_h, img_w = self.img_h_upsampled, self.img_w_upsampled
|
| 711 |
+
path = f"{self.train_dir}/render_upsample-{self.downsample_ratio:.1f}_step-{self.step:d}_psnr-{psnr:.2f}_ssim-{ssim:.4f}_res-{img_h:d}x{img_w:d}"
|
| 712 |
+
self._separate_and_save_images(
|
| 713 |
+
images=images, channels=self.input_channels, path=path
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
def _render_images(self, upsample=False):
|
| 717 |
+
if upsample:
|
| 718 |
+
images, _ = self.forward(
|
| 719 |
+
self.img_h_upsampled,
|
| 720 |
+
self.img_w_upsampled,
|
| 721 |
+
self.tile_bounds_upsampled,
|
| 722 |
+
upsample_ratio=self.downsample_ratio,
|
| 723 |
+
)
|
| 724 |
+
else:
|
| 725 |
+
images, _ = self.forward(self.img_h, self.img_w, self.tile_bounds)
|
| 726 |
+
return images
|
| 727 |
+
|
| 728 |
+
def _lr_schedule(self):
|
| 729 |
+
if (
|
| 730 |
+
self.psnr_curr <= self.best_psnr + 100 * self.decay_threshold
|
| 731 |
+
or self.ssim_curr <= self.best_ssim + self.decay_threshold
|
| 732 |
+
):
|
| 733 |
+
self.no_improvement_steps += self.eval_steps
|
| 734 |
+
if self.no_improvement_steps >= self.check_decay_steps:
|
| 735 |
+
self.no_improvement_steps = 0
|
| 736 |
+
self.decay_times += 1
|
| 737 |
+
if self.decay_times > self.max_decay_times:
|
| 738 |
+
return True
|
| 739 |
+
for param_group in self.optimizer.param_groups:
|
| 740 |
+
param_group["lr"] /= self.decay_ratio
|
| 741 |
+
self.worklog.info(f"Learning rate decayed by {self.decay_ratio:.1f}")
|
| 742 |
+
self.worklog.info("***********************************************")
|
| 743 |
+
return False
|
| 744 |
+
else:
|
| 745 |
+
self.best_psnr = self.psnr_curr
|
| 746 |
+
self.best_ssim = self.ssim_curr
|
| 747 |
+
self.no_improvement_steps = 0
|
| 748 |
+
return False
|
| 749 |
+
|
| 750 |
+
def _add_gaussians(self, add_num, plot_gaussians=False):
|
| 751 |
+
add_num = min(
|
| 752 |
+
add_num, self.max_add_num, self.total_num_gaussians - self.num_gaussians
|
| 753 |
+
)
|
| 754 |
+
if add_num <= 0:
|
| 755 |
+
return
|
| 756 |
+
raw_images = self._render_images(upsample=False)
|
| 757 |
+
images = torch.pow(torch.clamp(raw_images, 0.0, 1.0), 1.0 / self.gamma)
|
| 758 |
+
gt_images = torch.pow(self.gt_images, 1.0 / self.gamma)
|
| 759 |
+
kernel_size = round(np.sqrt(self.img_h * self.img_w) // 400)
|
| 760 |
+
if kernel_size >= 1:
|
| 761 |
+
kernel_size = max(3, kernel_size)
|
| 762 |
+
kernel_size = kernel_size + 1 if kernel_size % 2 == 0 else kernel_size
|
| 763 |
+
gt_images = gaussian_blur(img=gt_images, kernel_size=kernel_size)
|
| 764 |
+
diff_map = (gt_images - images).detach().clone()
|
| 765 |
+
error_map = torch.pow(torch.abs(diff_map).mean(dim=0).reshape(-1), 2.0)
|
| 766 |
+
sample_prob = (error_map / error_map.sum()).cpu().numpy()
|
| 767 |
+
selected = np.random.choice(
|
| 768 |
+
self.num_pixels, add_num, replace=False, p=sample_prob
|
| 769 |
+
)
|
| 770 |
+
# New Gaussians
|
| 771 |
+
new_xy = self.pixel_xy.detach().clone()[selected]
|
| 772 |
+
new_scale = torch.ones(add_num, 2, dtype=self.dtype, device=self.device)
|
| 773 |
+
init_scale = self.init_scale
|
| 774 |
+
new_scale.fill_(init_scale if self.disable_inverse_scale else 1.0 / init_scale)
|
| 775 |
+
new_rot = torch.zeros(add_num, 1, dtype=self.dtype, device=self.device)
|
| 776 |
+
new_feat = diff_map.permute(1, 2, 0).reshape(-1, self.feat_dim)[selected]
|
| 777 |
+
new_vis_feat = torch.rand_like(new_feat)
|
| 778 |
+
# Old Gaussians
|
| 779 |
+
old_xy = self.xy.detach().clone()
|
| 780 |
+
old_scale = self.scale.detach().clone()
|
| 781 |
+
old_rot = self.rot.detach().clone()
|
| 782 |
+
old_feat = self.feat.detach().clone()
|
| 783 |
+
old_vis_feat = self.vis_feat.detach().clone()
|
| 784 |
+
# Update trainable parameters
|
| 785 |
+
self.num_gaussians += add_num
|
| 786 |
+
all_xy = torch.cat([old_xy, new_xy], dim=0)
|
| 787 |
+
all_scale = torch.cat([old_scale, new_scale], dim=0)
|
| 788 |
+
all_rot = torch.cat([old_rot, new_rot], dim=0)
|
| 789 |
+
all_feat = torch.cat([old_feat, new_feat], dim=0)
|
| 790 |
+
all_vis_feat = torch.cat([old_vis_feat, new_vis_feat], dim=0)
|
| 791 |
+
self.xy = nn.Parameter(all_xy, requires_grad=True)
|
| 792 |
+
self.scale = nn.Parameter(all_scale, requires_grad=True)
|
| 793 |
+
self.rot = nn.Parameter(all_rot, requires_grad=True)
|
| 794 |
+
self.feat = nn.Parameter(all_feat, requires_grad=True)
|
| 795 |
+
self.vis_feat = nn.Parameter(all_vis_feat, requires_grad=False)
|
| 796 |
+
# Plot Gaussians
|
| 797 |
+
if plot_gaussians:
|
| 798 |
+
path = f"{self.train_dir}/add-gaussian_step-{self.step:d}_num-{self.num_gaussians:d}_res-{self.img_h:d}x{self.img_w:d}"
|
| 799 |
+
every_n = max(1, self.total_num_gaussians // 2000)
|
| 800 |
+
size = (self.img_h * self.img_w) / 1e4
|
| 801 |
+
visualize_added_gaussians(
|
| 802 |
+
path,
|
| 803 |
+
raw_images,
|
| 804 |
+
old_xy,
|
| 805 |
+
new_xy,
|
| 806 |
+
self.input_channels,
|
| 807 |
+
size=size,
|
| 808 |
+
every_n=every_n,
|
| 809 |
+
alpha=0.8,
|
| 810 |
+
gamma=self.gamma,
|
| 811 |
+
)
|
| 812 |
+
# Update optimizer
|
| 813 |
+
self.optimizer = torch.optim.Adam(
|
| 814 |
+
[
|
| 815 |
+
{"params": self.xy, "lr": self.pos_lr},
|
| 816 |
+
{"params": self.scale, "lr": self.scale_lr},
|
| 817 |
+
{"params": self.rot, "lr": self.rot_lr},
|
| 818 |
+
{"params": self.feat, "lr": self.feat_lr},
|
| 819 |
+
]
|
| 820 |
+
)
|
| 821 |
+
self.worklog.info(
|
| 822 |
+
f"Step: {self.step:d} | Adding {add_num:d} Gaussians ({self.num_gaussians - add_num:d} -> {self.num_gaussians:d})"
|
| 823 |
+
)
|
| 824 |
+
self.worklog.info("***********************************************")
|
pyproject.toml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "image-gs"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"lpips>=0.1.4",
|
| 9 |
+
"matplotlib>=3.10.6",
|
| 10 |
+
"numpy>=2.2.6",
|
| 11 |
+
"pytorch-msssim>=1.0.0",
|
| 12 |
+
"scikit-image>=0.25.2",
|
| 13 |
+
"scipy>=1.15.3",
|
| 14 |
+
"torch>=2.6.0",
|
| 15 |
+
"torchmetrics>=1.8.2",
|
| 16 |
+
"torchvision>=0.21.0",
|
| 17 |
+
"fused_ssim",
|
| 18 |
+
"pyyaml>=6.0.2",
|
| 19 |
+
"gsplat",
|
| 20 |
+
"gradio>=4.0.0",
|
| 21 |
+
"huggingface_hub>=0.24.0",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
# We use python 3.10 and cu124
|
| 25 |
+
[tool.uv.sources]
|
| 26 |
+
fused_ssim = { git = "https://github.com/rahul-goel/fused-ssim/" }
|
| 27 |
+
torch = [
|
| 28 |
+
{ index = "pytorch-cu124", marker = "sys_platform == 'linux'" },
|
| 29 |
+
]
|
| 30 |
+
torchvision = [
|
| 31 |
+
{ index = "pytorch-cu124", marker = "sys_platform == 'linux'" },
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
[tool.uv.extra-build-dependencies]
|
| 36 |
+
fused-ssim = ["torch", "numpy"]
|
| 37 |
+
|
| 38 |
+
[[tool.uv.index]]
|
| 39 |
+
name = "pytorch-cu124"
|
| 40 |
+
url = "https://download.pytorch.org/whl/cu124"
|
| 41 |
+
explicit = true
|
| 42 |
+
|
| 43 |
+
[dependency-groups]
|
| 44 |
+
dev = [
|
| 45 |
+
"huggingface-hub[cli]>=0.34.4",
|
| 46 |
+
]
|
utils/__init__.py
ADDED
|
File without changes
|
utils/flip.py
ADDED
|
@@ -0,0 +1,811 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FLIP metric functions"""
|
| 2 |
+
#################################################################################
|
| 3 |
+
# Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Redistribution and use in source and binary forms, with or without
|
| 6 |
+
# modification, are permitted provided that the following conditions are met:
|
| 7 |
+
#
|
| 8 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
# list of conditions and the following disclaimer.
|
| 10 |
+
#
|
| 11 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
# and/or other materials provided with the distribution.
|
| 14 |
+
#
|
| 15 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
# contributors may be used to endorse or promote products derived from
|
| 17 |
+
# this software without specific prior written permission.
|
| 18 |
+
#
|
| 19 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
#
|
| 30 |
+
# SPDX-FileCopyrightText: Copyright (c) 2020-2024 NVIDIA CORPORATION & AFFILIATES
|
| 31 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 32 |
+
#################################################################################
|
| 33 |
+
|
| 34 |
+
# Visualizing and Communicating Errors in Rendered Images
|
| 35 |
+
# Ray Tracing Gems II, 2021,
|
| 36 |
+
# by Pontus Andersson, Jim Nilsson, and Tomas Akenine-Moller.
|
| 37 |
+
# Pointer to the chapter: https://research.nvidia.com/publication/2021-08_Visualizing-and-Communicating.
|
| 38 |
+
|
| 39 |
+
# Visualizing Errors in Rendered High Dynamic Range Images
|
| 40 |
+
# Eurographics 2021,
|
| 41 |
+
# by Pontus Andersson, Jim Nilsson, Peter Shirley, and Tomas Akenine-Moller.
|
| 42 |
+
# Pointer to the paper: https://research.nvidia.com/publication/2021-05_HDR-FLIP.
|
| 43 |
+
|
| 44 |
+
# FLIP: A Difference Evaluator for Alternating Images
|
| 45 |
+
# High Performance Graphics 2020,
|
| 46 |
+
# by Pontus Andersson, Jim Nilsson, Tomas Akenine-Moller,
|
| 47 |
+
# Magnus Oskarsson, Kalle Astrom, and Mark D. Fairchild.
|
| 48 |
+
# Pointer to the paper: https://research.nvidia.com/publication/2020-07_FLIP.
|
| 49 |
+
|
| 50 |
+
# Code by Pontus Ebelin (formerly Andersson), Jim Nilsson, and Tomas Akenine-Moller.
|
| 51 |
+
|
| 52 |
+
import sys
|
| 53 |
+
import numpy as np
|
| 54 |
+
import torch
|
| 55 |
+
import torch.nn as nn
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class HDRFLIPLoss(nn.Module):
|
| 59 |
+
"""Class for computing HDR-FLIP"""
|
| 60 |
+
|
| 61 |
+
def __init__(self):
|
| 62 |
+
"""Init"""
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.qc = 0.7
|
| 65 |
+
self.qf = 0.5
|
| 66 |
+
self.pc = 0.4
|
| 67 |
+
self.pt = 0.95
|
| 68 |
+
self.tmax = 0.85
|
| 69 |
+
self.tmin = 0.85
|
| 70 |
+
self.eps = 1e-15
|
| 71 |
+
|
| 72 |
+
def forward(
|
| 73 |
+
self,
|
| 74 |
+
test,
|
| 75 |
+
reference,
|
| 76 |
+
pixels_per_degree=(0.7 * 3840 / 0.7) * np.pi / 180,
|
| 77 |
+
tone_mapper="aces",
|
| 78 |
+
start_exposure=None,
|
| 79 |
+
stop_exposure=None,
|
| 80 |
+
):
|
| 81 |
+
"""
|
| 82 |
+
Computes the HDR-FLIP error map between two HDR images,
|
| 83 |
+
assuming the images are observed at a certain number of
|
| 84 |
+
pixels per degree of visual angle
|
| 85 |
+
|
| 86 |
+
:param test: test tensor (with NxCxHxW layout with nonnegative values)
|
| 87 |
+
:param reference: reference tensor (with NxCxHxW layout with nonnegative values)
|
| 88 |
+
:param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer,
|
| 89 |
+
default corresponds to viewing the images on a 0.7 meters wide 4K monitor at 0.7 meters from the display
|
| 90 |
+
:param tone_mapper: (optional) string describing what tone mapper HDR-FLIP should assume
|
| 91 |
+
:param start_exposure: (optional tensor (with Nx1x1x1 layout) with start exposures corresponding to each HDR reference/test pair
|
| 92 |
+
:param stop_exposure: (optional) tensor (with Nx1x1x1 layout) with stop exposures corresponding to each HDR reference/test pair
|
| 93 |
+
:return: float containing the mean FLIP error (in the range [0,1]) between the HDR reference and test images in the batch
|
| 94 |
+
"""
|
| 95 |
+
# HDR-FLIP expects nonnegative and non-NaN values in the input
|
| 96 |
+
reference = torch.clamp(reference, 0, 65536.0)
|
| 97 |
+
test = torch.clamp(test, 0, 65536.0)
|
| 98 |
+
|
| 99 |
+
# Compute start and stop exposures, if they are not given
|
| 100 |
+
if start_exposure is None or stop_exposure is None:
|
| 101 |
+
c_start, c_stop = compute_start_stop_exposures(
|
| 102 |
+
reference, tone_mapper, self.tmax, self.tmin
|
| 103 |
+
)
|
| 104 |
+
if start_exposure is None:
|
| 105 |
+
start_exposure = c_start
|
| 106 |
+
if stop_exposure is None:
|
| 107 |
+
stop_exposure = c_stop
|
| 108 |
+
|
| 109 |
+
# Compute number of exposures
|
| 110 |
+
num_exposures = torch.max(
|
| 111 |
+
torch.tensor([2.0], requires_grad=False).cuda(),
|
| 112 |
+
torch.ceil(stop_exposure - start_exposure),
|
| 113 |
+
)
|
| 114 |
+
most_exposures = int(torch.amax(num_exposures, dim=0).item())
|
| 115 |
+
|
| 116 |
+
# Compute exposure step size
|
| 117 |
+
step_size = (stop_exposure - start_exposure) / torch.max(
|
| 118 |
+
num_exposures - 1, torch.tensor([1.0], requires_grad=False).cuda()
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Set the depth of the error tensor to the number of exposures given by the largest exposure range any reference image yielded.
|
| 122 |
+
# This allows us to do one loop for each image in our batch, while not affecting the HDR-FLIP error, as we fill up the error tensor with 0s.
|
| 123 |
+
# Note that the step size still depends on num_exposures and is therefore independent of most_exposures
|
| 124 |
+
dim = reference.size()
|
| 125 |
+
all_errors = torch.zeros(size=(dim[0], most_exposures, dim[2], dim[3])).cuda()
|
| 126 |
+
|
| 127 |
+
# Loop over exposures and compute LDR-FLIP for each pair of LDR reference and test
|
| 128 |
+
for i in range(0, most_exposures):
|
| 129 |
+
exposure = start_exposure + i * step_size
|
| 130 |
+
|
| 131 |
+
reference_tone_mapped = tone_map(reference, tone_mapper, exposure)
|
| 132 |
+
test_tone_mapped = tone_map(test, tone_mapper, exposure)
|
| 133 |
+
|
| 134 |
+
reference_opponent = color_space_transform(
|
| 135 |
+
reference_tone_mapped, "linrgb2ycxcz"
|
| 136 |
+
)
|
| 137 |
+
test_opponent = color_space_transform(test_tone_mapped, "linrgb2ycxcz")
|
| 138 |
+
|
| 139 |
+
all_errors[:, i, :, :] = compute_ldrflip(
|
| 140 |
+
test_opponent,
|
| 141 |
+
reference_opponent,
|
| 142 |
+
pixels_per_degree,
|
| 143 |
+
self.qc,
|
| 144 |
+
self.qf,
|
| 145 |
+
self.pc,
|
| 146 |
+
self.pt,
|
| 147 |
+
self.eps,
|
| 148 |
+
).squeeze(1)
|
| 149 |
+
|
| 150 |
+
# Take per-pixel maximum over all LDR-FLIP errors to get HDR-FLIP
|
| 151 |
+
hdrflip_error = torch.amax(all_errors, dim=1, keepdim=True)
|
| 152 |
+
return torch.mean(hdrflip_error)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class LDRFLIPLoss(nn.Module):
|
| 156 |
+
"""Class for computing LDR FLIP loss"""
|
| 157 |
+
|
| 158 |
+
def __init__(self):
|
| 159 |
+
"""Init"""
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.qc = 0.7
|
| 162 |
+
self.qf = 0.5
|
| 163 |
+
self.pc = 0.4
|
| 164 |
+
self.pt = 0.95
|
| 165 |
+
self.eps = 1e-15
|
| 166 |
+
|
| 167 |
+
def forward(
|
| 168 |
+
self, test, reference, pixels_per_degree=(0.7 * 3840 / 0.7) * np.pi / 180
|
| 169 |
+
):
|
| 170 |
+
"""
|
| 171 |
+
Computes the LDR-FLIP error map between two LDR images,
|
| 172 |
+
assuming the images are observed at a certain number of
|
| 173 |
+
pixels per degree of visual angle
|
| 174 |
+
|
| 175 |
+
:param test: test tensor (with NxCxHxW layout with values in the range [0, 1] in the sRGB color space)
|
| 176 |
+
:param reference: reference tensor (with NxCxHxW layout with values in the range [0, 1] in the sRGB color space)
|
| 177 |
+
:param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer,
|
| 178 |
+
default corresponds to viewing the images on a 0.7 meters wide 4K monitor at 0.7 meters from the display
|
| 179 |
+
:return: float containing the mean FLIP error (in the range [0,1]) between the LDR reference and test images in the batch
|
| 180 |
+
"""
|
| 181 |
+
# LDR-FLIP expects non-NaN values in [0,1] as input
|
| 182 |
+
reference = torch.clamp(reference, 0, 1)
|
| 183 |
+
test = torch.clamp(test, 0, 1)
|
| 184 |
+
|
| 185 |
+
# Transform reference and test to opponent color space
|
| 186 |
+
reference_opponent = color_space_transform(reference, "srgb2ycxcz")
|
| 187 |
+
test_opponent = color_space_transform(test, "srgb2ycxcz")
|
| 188 |
+
|
| 189 |
+
deltaE = compute_ldrflip(
|
| 190 |
+
test_opponent,
|
| 191 |
+
reference_opponent,
|
| 192 |
+
pixels_per_degree,
|
| 193 |
+
self.qc,
|
| 194 |
+
self.qf,
|
| 195 |
+
self.pc,
|
| 196 |
+
self.pt,
|
| 197 |
+
self.eps,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
return torch.mean(deltaE)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def compute_ldrflip(test, reference, pixels_per_degree, qc, qf, pc, pt, eps):
|
| 204 |
+
"""
|
| 205 |
+
Computes the LDR-FLIP error map between two LDR images,
|
| 206 |
+
assuming the images are observed at a certain number of
|
| 207 |
+
pixels per degree of visual angle
|
| 208 |
+
|
| 209 |
+
:param reference: reference tensor (with NxCxHxW layout with values in the YCxCz color space)
|
| 210 |
+
:param test: test tensor (with NxCxHxW layout with values in the YCxCz color space)
|
| 211 |
+
:param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer,
|
| 212 |
+
default corresponds to viewing the images on a 0.7 meters wide 4K monitor at 0.7 meters from the display
|
| 213 |
+
:param qc: float describing the q_c exponent in the LDR-FLIP color pipeline (see FLIP paper for details)
|
| 214 |
+
:param qf: float describing the q_f exponent in the LDR-FLIP feature pipeline (see FLIP paper for details)
|
| 215 |
+
:param pc: float describing the p_c exponent in the LDR-FLIP color pipeline (see FLIP paper for details)
|
| 216 |
+
:param pt: float describing the p_t exponent in the LDR-FLIP color pipeline (see FLIP paper for details)
|
| 217 |
+
:param eps: float containing a small value used to improve training stability
|
| 218 |
+
:return: tensor containing the per-pixel FLIP errors (with Nx1xHxW layout and values in the range [0, 1]) between LDR reference and test images
|
| 219 |
+
"""
|
| 220 |
+
# --- Color pipeline ---
|
| 221 |
+
# Spatial filtering
|
| 222 |
+
s_a, radius_a = generate_spatial_filter(pixels_per_degree, "A")
|
| 223 |
+
s_rg, radius_rg = generate_spatial_filter(pixels_per_degree, "RG")
|
| 224 |
+
s_by, radius_by = generate_spatial_filter(pixels_per_degree, "BY")
|
| 225 |
+
radius = max(radius_a, radius_rg, radius_by)
|
| 226 |
+
filtered_reference = spatial_filter(reference, s_a, s_rg, s_by, radius)
|
| 227 |
+
filtered_test = spatial_filter(test, s_a, s_rg, s_by, radius)
|
| 228 |
+
|
| 229 |
+
# Perceptually Uniform Color Space
|
| 230 |
+
preprocessed_reference = hunt_adjustment(
|
| 231 |
+
color_space_transform(filtered_reference, "linrgb2lab")
|
| 232 |
+
)
|
| 233 |
+
preprocessed_test = hunt_adjustment(
|
| 234 |
+
color_space_transform(filtered_test, "linrgb2lab")
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Color metric
|
| 238 |
+
deltaE_hyab = hyab(preprocessed_reference, preprocessed_test, eps)
|
| 239 |
+
power_deltaE_hyab = torch.pow(deltaE_hyab, qc)
|
| 240 |
+
hunt_adjusted_green = hunt_adjustment(
|
| 241 |
+
color_space_transform(
|
| 242 |
+
torch.tensor([[[0.0]], [[1.0]], [[0.0]]]).unsqueeze(0), "linrgb2lab"
|
| 243 |
+
)
|
| 244 |
+
)
|
| 245 |
+
hunt_adjusted_blue = hunt_adjustment(
|
| 246 |
+
color_space_transform(
|
| 247 |
+
torch.tensor([[[0.0]], [[0.0]], [[1.0]]]).unsqueeze(0), "linrgb2lab"
|
| 248 |
+
)
|
| 249 |
+
)
|
| 250 |
+
cmax = torch.pow(hyab(hunt_adjusted_green, hunt_adjusted_blue, eps), qc).item()
|
| 251 |
+
deltaE_c = redistribute_errors(power_deltaE_hyab, cmax, pc, pt)
|
| 252 |
+
|
| 253 |
+
# --- Feature pipeline ---
|
| 254 |
+
# Extract and normalize Yy component
|
| 255 |
+
ref_y = (reference[:, 0:1, :, :] + 16) / 116
|
| 256 |
+
test_y = (test[:, 0:1, :, :] + 16) / 116
|
| 257 |
+
|
| 258 |
+
# Edge and point detection
|
| 259 |
+
edges_reference = feature_detection(ref_y, pixels_per_degree, "edge")
|
| 260 |
+
points_reference = feature_detection(ref_y, pixels_per_degree, "point")
|
| 261 |
+
edges_test = feature_detection(test_y, pixels_per_degree, "edge")
|
| 262 |
+
points_test = feature_detection(test_y, pixels_per_degree, "point")
|
| 263 |
+
|
| 264 |
+
# Feature metric
|
| 265 |
+
deltaE_f = torch.max(
|
| 266 |
+
torch.abs(
|
| 267 |
+
torch.norm(edges_reference, dim=1, keepdim=True)
|
| 268 |
+
- torch.norm(edges_test, dim=1, keepdim=True)
|
| 269 |
+
),
|
| 270 |
+
torch.abs(
|
| 271 |
+
torch.norm(points_test, dim=1, keepdim=True)
|
| 272 |
+
- torch.norm(points_reference, dim=1, keepdim=True)
|
| 273 |
+
),
|
| 274 |
+
)
|
| 275 |
+
deltaE_f = torch.clamp(deltaE_f, min=eps) # clamp to stabilize training
|
| 276 |
+
deltaE_f = torch.pow(((1 / np.sqrt(2)) * deltaE_f), qf)
|
| 277 |
+
|
| 278 |
+
# --- Final error ---
|
| 279 |
+
return torch.pow(deltaE_c, 1 - deltaE_f)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def tone_map(img, tone_mapper, exposure):
|
| 283 |
+
"""
|
| 284 |
+
Applies exposure compensation and tone mapping.
|
| 285 |
+
Refer to the Visualizing Errors in Rendered High Dynamic Range Images
|
| 286 |
+
paper for details about the formulas.
|
| 287 |
+
|
| 288 |
+
:param img: float tensor (with NxCxHxW layout) containing nonnegative values
|
| 289 |
+
:param tone_mapper: string describing the tone mapper to apply
|
| 290 |
+
:param exposure: float tensor (with Nx1x1x1 layout) describing the exposure compensation factor
|
| 291 |
+
"""
|
| 292 |
+
# Exposure compensation
|
| 293 |
+
x = (2**exposure) * img
|
| 294 |
+
|
| 295 |
+
# Set tone mapping coefficients depending on tone_mapper
|
| 296 |
+
if tone_mapper == "reinhard":
|
| 297 |
+
lum_coeff_r = 0.2126
|
| 298 |
+
lum_coeff_g = 0.7152
|
| 299 |
+
lum_coeff_b = 0.0722
|
| 300 |
+
|
| 301 |
+
Y = (
|
| 302 |
+
x[:, 0:1, :, :] * lum_coeff_r
|
| 303 |
+
+ x[:, 1:2, :, :] * lum_coeff_g
|
| 304 |
+
+ x[:, 2:3, :, :] * lum_coeff_b
|
| 305 |
+
)
|
| 306 |
+
return torch.clamp(torch.div(x, 1 + Y), 0.0, 1.0)
|
| 307 |
+
|
| 308 |
+
if tone_mapper == "hable":
|
| 309 |
+
# Source: https://64.github.io/tonemapping/
|
| 310 |
+
A = 0.15
|
| 311 |
+
B = 0.50
|
| 312 |
+
C = 0.10
|
| 313 |
+
D = 0.20
|
| 314 |
+
E = 0.02
|
| 315 |
+
F = 0.30
|
| 316 |
+
k0 = A * F - A * E
|
| 317 |
+
k1 = C * B * F - B * E
|
| 318 |
+
k2 = 0
|
| 319 |
+
k3 = A * F
|
| 320 |
+
k4 = B * F
|
| 321 |
+
k5 = D * F * F
|
| 322 |
+
|
| 323 |
+
W = 11.2
|
| 324 |
+
nom = k0 * torch.pow(W, torch.tensor([2.0]).cuda()) + k1 * W + k2
|
| 325 |
+
denom = k3 * torch.pow(W, torch.tensor([2.0]).cuda()) + k4 * W + k5
|
| 326 |
+
white_scale = torch.div(denom, nom) # = 1 / (nom / denom)
|
| 327 |
+
|
| 328 |
+
# Include white scale and exposure bias in rational polynomial coefficients
|
| 329 |
+
k0 = 4 * k0 * white_scale
|
| 330 |
+
k1 = 2 * k1 * white_scale
|
| 331 |
+
k2 = k2 * white_scale
|
| 332 |
+
k3 = 4 * k3
|
| 333 |
+
k4 = 2 * k4
|
| 334 |
+
# k5 = k5 # k5 is not changed
|
| 335 |
+
else:
|
| 336 |
+
# Source: ACES approximation: https://knarkowicz.wordpress.com/2016/01/06/aces-filmic-tone-mapping-curve/
|
| 337 |
+
# Include pre-exposure cancelation in constants
|
| 338 |
+
k0 = 0.6 * 0.6 * 2.51
|
| 339 |
+
k1 = 0.6 * 0.03
|
| 340 |
+
k2 = 0
|
| 341 |
+
k3 = 0.6 * 0.6 * 2.43
|
| 342 |
+
k4 = 0.6 * 0.59
|
| 343 |
+
k5 = 0.14
|
| 344 |
+
|
| 345 |
+
x2 = torch.pow(x, 2)
|
| 346 |
+
nom = k0 * x2 + k1 * x + k2
|
| 347 |
+
denom = k3 * x2 + k4 * x + k5
|
| 348 |
+
denom = torch.where(
|
| 349 |
+
torch.isinf(denom), torch.Tensor([1.0]).cuda(), denom
|
| 350 |
+
) # if denom is inf, then so is nom => nan. Pixel is very bright. It becomes inf here, but 1 after clamp below
|
| 351 |
+
y = torch.div(nom, denom)
|
| 352 |
+
return torch.clamp(y, 0.0, 1.0)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def compute_start_stop_exposures(reference, tone_mapper, tmax, tmin):
|
| 356 |
+
"""
|
| 357 |
+
Computes start and stop exposure for HDR-FLIP based on given tone mapper and reference image.
|
| 358 |
+
Refer to the Visualizing Errors in Rendered High Dynamic Range Images
|
| 359 |
+
paper for details about the formulas
|
| 360 |
+
|
| 361 |
+
:param reference: float tensor (with NxCxHxW layout) containing reference images (nonnegative values)
|
| 362 |
+
:param tone_mapper: string describing which tone mapper should be assumed
|
| 363 |
+
:param tmax: float describing the t value used to find the start exposure
|
| 364 |
+
:param tmin: float describing the t value used to find the stop exposure
|
| 365 |
+
:return: two float tensors (with Nx1x1x1 layout) containing start and stop exposures, respectively, to use for HDR-FLIP
|
| 366 |
+
"""
|
| 367 |
+
if tone_mapper == "reinhard":
|
| 368 |
+
k0 = 0
|
| 369 |
+
k1 = 1
|
| 370 |
+
k2 = 0
|
| 371 |
+
k3 = 0
|
| 372 |
+
k4 = 1
|
| 373 |
+
k5 = 1
|
| 374 |
+
|
| 375 |
+
x_max = tmax * k5 / (k1 - tmax * k4)
|
| 376 |
+
x_min = tmin * k5 / (k1 - tmin * k4)
|
| 377 |
+
elif tone_mapper == "hable":
|
| 378 |
+
# Source: https://64.github.io/tonemapping/
|
| 379 |
+
A = 0.15
|
| 380 |
+
B = 0.50
|
| 381 |
+
C = 0.10
|
| 382 |
+
D = 0.20
|
| 383 |
+
E = 0.02
|
| 384 |
+
F = 0.30
|
| 385 |
+
k0 = A * F - A * E
|
| 386 |
+
k1 = C * B * F - B * E
|
| 387 |
+
k2 = 0
|
| 388 |
+
k3 = A * F
|
| 389 |
+
k4 = B * F
|
| 390 |
+
k5 = D * F * F
|
| 391 |
+
|
| 392 |
+
W = 11.2
|
| 393 |
+
nom = k0 * torch.pow(W, torch.tensor([2.0]).cuda()) + k1 * W + k2
|
| 394 |
+
denom = k3 * torch.pow(W, torch.tensor([2.0]).cuda()) + k4 * W + k5
|
| 395 |
+
white_scale = torch.div(denom, nom) # = 1 / (nom / denom)
|
| 396 |
+
|
| 397 |
+
# Include white scale and exposure bias in rational polynomial coefficients
|
| 398 |
+
k0 = 4 * k0 * white_scale
|
| 399 |
+
k1 = 2 * k1 * white_scale
|
| 400 |
+
k2 = k2 * white_scale
|
| 401 |
+
k3 = 4 * k3
|
| 402 |
+
k4 = 2 * k4
|
| 403 |
+
# k5 = k5 # k5 is not changed
|
| 404 |
+
|
| 405 |
+
c0 = (k1 - k4 * tmax) / (k0 - k3 * tmax)
|
| 406 |
+
c1 = (k2 - k5 * tmax) / (k0 - k3 * tmax)
|
| 407 |
+
x_max = -0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1)
|
| 408 |
+
|
| 409 |
+
c0 = (k1 - k4 * tmin) / (k0 - k3 * tmin)
|
| 410 |
+
c1 = (k2 - k5 * tmin) / (k0 - k3 * tmin)
|
| 411 |
+
x_min = -0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1)
|
| 412 |
+
else:
|
| 413 |
+
# Source: ACES approximation: https://knarkowicz.wordpress.com/2016/01/06/aces-filmic-tone-mapping-curve/
|
| 414 |
+
# Include pre-exposure cancelation in constants
|
| 415 |
+
k0 = 0.6 * 0.6 * 2.51
|
| 416 |
+
k1 = 0.6 * 0.03
|
| 417 |
+
k2 = 0
|
| 418 |
+
k3 = 0.6 * 0.6 * 2.43
|
| 419 |
+
k4 = 0.6 * 0.59
|
| 420 |
+
k5 = 0.14
|
| 421 |
+
|
| 422 |
+
c0 = (k1 - k4 * tmax) / (k0 - k3 * tmax)
|
| 423 |
+
c1 = (k2 - k5 * tmax) / (k0 - k3 * tmax)
|
| 424 |
+
x_max = -0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1)
|
| 425 |
+
|
| 426 |
+
c0 = (k1 - k4 * tmin) / (k0 - k3 * tmin)
|
| 427 |
+
c1 = (k2 - k5 * tmin) / (k0 - k3 * tmin)
|
| 428 |
+
x_min = -0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1)
|
| 429 |
+
|
| 430 |
+
# Convert reference to luminance
|
| 431 |
+
lum_coeff_r = 0.2126
|
| 432 |
+
lum_coeff_g = 0.7152
|
| 433 |
+
lum_coeff_b = 0.0722
|
| 434 |
+
Y_reference = (
|
| 435 |
+
reference[:, 0:1, :, :] * lum_coeff_r
|
| 436 |
+
+ reference[:, 1:2, :, :] * lum_coeff_g
|
| 437 |
+
+ reference[:, 2:3, :, :] * lum_coeff_b
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Compute start exposure
|
| 441 |
+
Y_hi = torch.amax(Y_reference, dim=(2, 3), keepdim=True)
|
| 442 |
+
start_exposure = torch.log2(x_max / Y_hi)
|
| 443 |
+
|
| 444 |
+
# Compute stop exposure
|
| 445 |
+
dim = Y_reference.size()
|
| 446 |
+
Y_ref = Y_reference.view(dim[0], dim[1], dim[2] * dim[3])
|
| 447 |
+
Y_lo = torch.median(Y_ref, dim=2).values.unsqueeze(2).unsqueeze(3)
|
| 448 |
+
stop_exposure = torch.log2(x_min / Y_lo)
|
| 449 |
+
|
| 450 |
+
return start_exposure, stop_exposure
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def generate_spatial_filter(pixels_per_degree, channel):
|
| 454 |
+
"""
|
| 455 |
+
Generates spatial contrast sensitivity filters with width depending on
|
| 456 |
+
the number of pixels per degree of visual angle of the observer
|
| 457 |
+
|
| 458 |
+
:param pixels_per_degree: float indicating number of pixels per degree of visual angle
|
| 459 |
+
:param channel: string describing what filter should be generated
|
| 460 |
+
:yield: Filter kernel corresponding to the spatial contrast sensitivity function of the given channel and kernel's radius
|
| 461 |
+
"""
|
| 462 |
+
a1_A = 1
|
| 463 |
+
b1_A = 0.0047
|
| 464 |
+
a2_A = 0
|
| 465 |
+
b2_A = 1e-5 # avoid division by 0
|
| 466 |
+
a1_rg = 1
|
| 467 |
+
b1_rg = 0.0053
|
| 468 |
+
a2_rg = 0
|
| 469 |
+
b2_rg = 1e-5 # avoid division by 0
|
| 470 |
+
a1_by = 34.1
|
| 471 |
+
b1_by = 0.04
|
| 472 |
+
a2_by = 13.5
|
| 473 |
+
b2_by = 0.025
|
| 474 |
+
if channel == "A": # Achromatic CSF
|
| 475 |
+
a1 = a1_A
|
| 476 |
+
b1 = b1_A
|
| 477 |
+
a2 = a2_A
|
| 478 |
+
b2 = b2_A
|
| 479 |
+
elif channel == "RG": # Red-Green CSF
|
| 480 |
+
a1 = a1_rg
|
| 481 |
+
b1 = b1_rg
|
| 482 |
+
a2 = a2_rg
|
| 483 |
+
b2 = b2_rg
|
| 484 |
+
elif channel == "BY": # Blue-Yellow CSF
|
| 485 |
+
a1 = a1_by
|
| 486 |
+
b1 = b1_by
|
| 487 |
+
a2 = a2_by
|
| 488 |
+
b2 = b2_by
|
| 489 |
+
|
| 490 |
+
# Determine evaluation domain
|
| 491 |
+
max_scale_parameter = max([b1_A, b2_A, b1_rg, b2_rg, b1_by, b2_by])
|
| 492 |
+
r = np.ceil(3 * np.sqrt(max_scale_parameter / (2 * np.pi**2)) * pixels_per_degree)
|
| 493 |
+
r = int(r)
|
| 494 |
+
deltaX = 1.0 / pixels_per_degree
|
| 495 |
+
x, y = np.meshgrid(range(-r, r + 1), range(-r, r + 1))
|
| 496 |
+
z = (x * deltaX) ** 2 + (y * deltaX) ** 2
|
| 497 |
+
|
| 498 |
+
# Generate weights
|
| 499 |
+
g = a1 * np.sqrt(np.pi / b1) * np.exp(-(np.pi**2) * z / b1) + a2 * np.sqrt(
|
| 500 |
+
np.pi / b2
|
| 501 |
+
) * np.exp(-(np.pi**2) * z / b2)
|
| 502 |
+
g = g / np.sum(g)
|
| 503 |
+
g = torch.Tensor(g).unsqueeze(0).unsqueeze(0).cuda()
|
| 504 |
+
|
| 505 |
+
return g, r
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def spatial_filter(img, s_a, s_rg, s_by, radius):
|
| 509 |
+
"""
|
| 510 |
+
Filters an image with channel specific spatial contrast sensitivity functions
|
| 511 |
+
and clips result to the unit cube in linear RGB
|
| 512 |
+
|
| 513 |
+
:param img: image tensor to filter (with NxCxHxW layout in the YCxCz color space)
|
| 514 |
+
:param s_a: spatial filter matrix for the achromatic channel
|
| 515 |
+
:param s_rg: spatial filter matrix for the red-green channel
|
| 516 |
+
:param s_by: spatial filter matrix for the blue-yellow channel
|
| 517 |
+
:return: input image (with NxCxHxW layout) transformed to linear RGB after filtering with spatial contrast sensitivity functions
|
| 518 |
+
"""
|
| 519 |
+
dim = img.size()
|
| 520 |
+
# Prepare image for convolution
|
| 521 |
+
img_pad = torch.zeros(
|
| 522 |
+
(dim[0], dim[1], dim[2] + 2 * radius, dim[3] + 2 * radius), device="cuda"
|
| 523 |
+
)
|
| 524 |
+
img_pad[:, 0:1, :, :] = nn.functional.pad(
|
| 525 |
+
img[:, 0:1, :, :], (radius, radius, radius, radius), mode="replicate"
|
| 526 |
+
)
|
| 527 |
+
img_pad[:, 1:2, :, :] = nn.functional.pad(
|
| 528 |
+
img[:, 1:2, :, :], (radius, radius, radius, radius), mode="replicate"
|
| 529 |
+
)
|
| 530 |
+
img_pad[:, 2:3, :, :] = nn.functional.pad(
|
| 531 |
+
img[:, 2:3, :, :], (radius, radius, radius, radius), mode="replicate"
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# Apply Gaussian filters
|
| 535 |
+
img_tilde_opponent = torch.zeros((dim[0], dim[1], dim[2], dim[3]), device="cuda")
|
| 536 |
+
img_tilde_opponent[:, 0:1, :, :] = nn.functional.conv2d(
|
| 537 |
+
img_pad[:, 0:1, :, :], s_a.cuda(), padding=0
|
| 538 |
+
)
|
| 539 |
+
img_tilde_opponent[:, 1:2, :, :] = nn.functional.conv2d(
|
| 540 |
+
img_pad[:, 1:2, :, :], s_rg.cuda(), padding=0
|
| 541 |
+
)
|
| 542 |
+
img_tilde_opponent[:, 2:3, :, :] = nn.functional.conv2d(
|
| 543 |
+
img_pad[:, 2:3, :, :], s_by.cuda(), padding=0
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
# Transform to linear RGB for clamp
|
| 547 |
+
img_tilde_linear_rgb = color_space_transform(img_tilde_opponent, "ycxcz2linrgb")
|
| 548 |
+
|
| 549 |
+
# Clamp to RGB box
|
| 550 |
+
return torch.clamp(img_tilde_linear_rgb, 0.0, 1.0)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def hunt_adjustment(img):
|
| 554 |
+
"""
|
| 555 |
+
Applies Hunt-adjustment to an image
|
| 556 |
+
|
| 557 |
+
:param img: image tensor to adjust (with NxCxHxW layout in the L*a*b* color space)
|
| 558 |
+
:return: Hunt-adjusted image tensor (with NxCxHxW layout in the Hunt-adjusted L*A*B* color space)
|
| 559 |
+
"""
|
| 560 |
+
# Extract luminance component
|
| 561 |
+
L = img[:, 0:1, :, :]
|
| 562 |
+
|
| 563 |
+
# Apply Hunt adjustment
|
| 564 |
+
img_h = torch.zeros(img.size(), device="cuda")
|
| 565 |
+
img_h[:, 0:1, :, :] = L
|
| 566 |
+
img_h[:, 1:2, :, :] = torch.mul((0.01 * L), img[:, 1:2, :, :])
|
| 567 |
+
img_h[:, 2:3, :, :] = torch.mul((0.01 * L), img[:, 2:3, :, :])
|
| 568 |
+
|
| 569 |
+
return img_h
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
def hyab(reference, test, eps):
|
| 573 |
+
"""
|
| 574 |
+
Computes the HyAB distance between reference and test images
|
| 575 |
+
|
| 576 |
+
:param reference: reference image tensor (with NxCxHxW layout in the standard or Hunt-adjusted L*A*B* color space)
|
| 577 |
+
:param test: test image tensor (with NxCxHxW layout in the standard or Hunt-adjusted L*a*b* color space)
|
| 578 |
+
:param eps: float containing a small value used to improve training stability
|
| 579 |
+
:return: image tensor (with Nx1xHxW layout) containing the per-pixel HyAB distances between reference and test images
|
| 580 |
+
"""
|
| 581 |
+
delta = reference - test
|
| 582 |
+
root = torch.sqrt(torch.clamp(torch.pow(delta[:, 0:1, :, :], 2), min=eps))
|
| 583 |
+
delta_norm = torch.norm(delta[:, 1:3, :, :], dim=1, keepdim=True)
|
| 584 |
+
return root + delta_norm # alternative abs to stabilize training
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def redistribute_errors(power_deltaE_hyab, cmax, pc, pt):
|
| 588 |
+
"""
|
| 589 |
+
Redistributes exponentiated HyAB errors to the [0,1] range
|
| 590 |
+
|
| 591 |
+
:param power_deltaE_hyab: float tensor (with Nx1xHxW layout) containing the exponentiated HyAb distance
|
| 592 |
+
:param cmax: float containing the exponentiated, maximum HyAB difference between two colors in Hunt-adjusted L*A*B* space
|
| 593 |
+
:param pc: float containing the cmax multiplier p_c (see FLIP paper)
|
| 594 |
+
:param pt: float containing the target value, p_t, for p_c * cmax (see FLIP paper)
|
| 595 |
+
:return: image tensor (with Nx1xHxW layout) containing redistributed per-pixel HyAB distances (in range [0,1])
|
| 596 |
+
"""
|
| 597 |
+
# Re-map error to 0-1 range. Values between 0 and
|
| 598 |
+
# pccmax are mapped to the range [0, pt],
|
| 599 |
+
# while the rest are mapped to the range (pt, 1]
|
| 600 |
+
deltaE_c = torch.zeros(power_deltaE_hyab.size(), device="cuda")
|
| 601 |
+
pccmax = pc * cmax
|
| 602 |
+
deltaE_c = torch.where(
|
| 603 |
+
power_deltaE_hyab < pccmax,
|
| 604 |
+
(pt / pccmax) * power_deltaE_hyab,
|
| 605 |
+
pt + ((power_deltaE_hyab - pccmax) / (cmax - pccmax)) * (1.0 - pt),
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
return deltaE_c
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def feature_detection(img_y, pixels_per_degree, feature_type):
|
| 612 |
+
"""
|
| 613 |
+
Detects edges and points (features) in the achromatic image
|
| 614 |
+
|
| 615 |
+
:param imgy: achromatic image tensor (with Nx1xHxW layout, containing normalized Y-values from YCxCz)
|
| 616 |
+
:param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer
|
| 617 |
+
:param feature_type: string indicating the type of feature to detect
|
| 618 |
+
:return: image tensor (with Nx2xHxW layout, with values in range [0,1]) containing large values where features were detected
|
| 619 |
+
"""
|
| 620 |
+
# Set peak to trough value (2x standard deviations) of human edge
|
| 621 |
+
# detection filter
|
| 622 |
+
w = 0.082
|
| 623 |
+
|
| 624 |
+
# Compute filter radius
|
| 625 |
+
sd = 0.5 * w * pixels_per_degree
|
| 626 |
+
radius = int(np.ceil(3 * sd))
|
| 627 |
+
|
| 628 |
+
# Compute 2D Gaussian
|
| 629 |
+
[x, y] = np.meshgrid(range(-radius, radius + 1), range(-radius, radius + 1))
|
| 630 |
+
g = np.exp(-(x**2 + y**2) / (2 * sd * sd))
|
| 631 |
+
|
| 632 |
+
if feature_type == "edge": # Edge detector
|
| 633 |
+
# Compute partial derivative in x-direction
|
| 634 |
+
Gx = np.multiply(-x, g)
|
| 635 |
+
else: # Point detector
|
| 636 |
+
# Compute second partial derivative in x-direction
|
| 637 |
+
Gx = np.multiply(x**2 / (sd * sd) - 1, g)
|
| 638 |
+
|
| 639 |
+
# Normalize positive weights to sum to 1 and negative weights to sum to -1
|
| 640 |
+
negative_weights_sum = -np.sum(Gx[Gx < 0])
|
| 641 |
+
positive_weights_sum = np.sum(Gx[Gx > 0])
|
| 642 |
+
Gx = torch.Tensor(Gx)
|
| 643 |
+
Gx = torch.where(Gx < 0, Gx / negative_weights_sum, Gx / positive_weights_sum)
|
| 644 |
+
Gx = Gx.unsqueeze(0).unsqueeze(0).cuda()
|
| 645 |
+
|
| 646 |
+
# Detect features
|
| 647 |
+
featuresX = nn.functional.conv2d(
|
| 648 |
+
nn.functional.pad(img_y, (radius, radius, radius, radius), mode="replicate"),
|
| 649 |
+
Gx,
|
| 650 |
+
padding=0,
|
| 651 |
+
)
|
| 652 |
+
featuresY = nn.functional.conv2d(
|
| 653 |
+
nn.functional.pad(img_y, (radius, radius, radius, radius), mode="replicate"),
|
| 654 |
+
torch.transpose(Gx, 2, 3),
|
| 655 |
+
padding=0,
|
| 656 |
+
)
|
| 657 |
+
return torch.cat((featuresX, featuresY), dim=1)
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
def color_space_transform(input_color, fromSpace2toSpace):
|
| 661 |
+
"""
|
| 662 |
+
Transforms inputs between different color spaces
|
| 663 |
+
|
| 664 |
+
:param input_color: tensor of colors to transform (with NxCxHxW layout)
|
| 665 |
+
:param fromSpace2toSpace: string describing transform
|
| 666 |
+
:return: transformed tensor (with NxCxHxW layout)
|
| 667 |
+
"""
|
| 668 |
+
dim = input_color.size()
|
| 669 |
+
|
| 670 |
+
# Assume D65 standard illuminant
|
| 671 |
+
reference_illuminant = torch.tensor(
|
| 672 |
+
[[[0.950428545]], [[1.000000000]], [[1.088900371]]]
|
| 673 |
+
).cuda()
|
| 674 |
+
inv_reference_illuminant = torch.tensor(
|
| 675 |
+
[[[1.052156925]], [[1.000000000]], [[0.918357670]]]
|
| 676 |
+
).cuda()
|
| 677 |
+
|
| 678 |
+
if fromSpace2toSpace == "srgb2linrgb":
|
| 679 |
+
limit = 0.04045
|
| 680 |
+
transformed_color = torch.where(
|
| 681 |
+
input_color > limit,
|
| 682 |
+
torch.pow((torch.clamp(input_color, min=limit) + 0.055) / 1.055, 2.4),
|
| 683 |
+
input_color / 12.92,
|
| 684 |
+
) # clamp to stabilize training
|
| 685 |
+
|
| 686 |
+
elif fromSpace2toSpace == "linrgb2srgb":
|
| 687 |
+
limit = 0.0031308
|
| 688 |
+
transformed_color = torch.where(
|
| 689 |
+
input_color > limit,
|
| 690 |
+
1.055 * torch.pow(torch.clamp(input_color, min=limit), (1.0 / 2.4)) - 0.055,
|
| 691 |
+
12.92 * input_color,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
elif fromSpace2toSpace in ["linrgb2xyz", "xyz2linrgb"]:
|
| 695 |
+
# Source: https://www.image-engineering.de/library/technotes/958-how-to-convert-between-srgb-and-ciexyz
|
| 696 |
+
# Assumes D65 standard illuminant
|
| 697 |
+
if fromSpace2toSpace == "linrgb2xyz":
|
| 698 |
+
a11 = 10135552 / 24577794
|
| 699 |
+
a12 = 8788810 / 24577794
|
| 700 |
+
a13 = 4435075 / 24577794
|
| 701 |
+
a21 = 2613072 / 12288897
|
| 702 |
+
a22 = 8788810 / 12288897
|
| 703 |
+
a23 = 887015 / 12288897
|
| 704 |
+
a31 = 1425312 / 73733382
|
| 705 |
+
a32 = 8788810 / 73733382
|
| 706 |
+
a33 = 70074185 / 73733382
|
| 707 |
+
else:
|
| 708 |
+
# Constants found by taking the inverse of the matrix
|
| 709 |
+
# defined by the constants for linrgb2xyz
|
| 710 |
+
a11 = 3.241003275
|
| 711 |
+
a12 = -1.537398934
|
| 712 |
+
a13 = -0.498615861
|
| 713 |
+
a21 = -0.969224334
|
| 714 |
+
a22 = 1.875930071
|
| 715 |
+
a23 = 0.041554224
|
| 716 |
+
a31 = 0.055639423
|
| 717 |
+
a32 = -0.204011202
|
| 718 |
+
a33 = 1.057148933
|
| 719 |
+
A = torch.Tensor([[a11, a12, a13], [a21, a22, a23], [a31, a32, a33]])
|
| 720 |
+
|
| 721 |
+
input_color = input_color.view(dim[0], dim[1], dim[2] * dim[3]).cuda() # NC(HW)
|
| 722 |
+
|
| 723 |
+
transformed_color = torch.matmul(A.cuda(), input_color)
|
| 724 |
+
transformed_color = transformed_color.view(dim[0], dim[1], dim[2], dim[3])
|
| 725 |
+
|
| 726 |
+
elif fromSpace2toSpace == "xyz2ycxcz":
|
| 727 |
+
input_color = torch.mul(input_color, inv_reference_illuminant)
|
| 728 |
+
y = 116 * input_color[:, 1:2, :, :] - 16
|
| 729 |
+
cx = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :])
|
| 730 |
+
cz = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :])
|
| 731 |
+
transformed_color = torch.cat((y, cx, cz), 1)
|
| 732 |
+
|
| 733 |
+
elif fromSpace2toSpace == "ycxcz2xyz":
|
| 734 |
+
y = (input_color[:, 0:1, :, :] + 16) / 116
|
| 735 |
+
cx = input_color[:, 1:2, :, :] / 500
|
| 736 |
+
cz = input_color[:, 2:3, :, :] / 200
|
| 737 |
+
|
| 738 |
+
x = y + cx
|
| 739 |
+
z = y - cz
|
| 740 |
+
transformed_color = torch.cat((x, y, z), 1)
|
| 741 |
+
|
| 742 |
+
transformed_color = torch.mul(transformed_color, reference_illuminant)
|
| 743 |
+
|
| 744 |
+
elif fromSpace2toSpace == "xyz2lab":
|
| 745 |
+
input_color = torch.mul(input_color, inv_reference_illuminant)
|
| 746 |
+
delta = 6 / 29
|
| 747 |
+
delta_square = delta * delta
|
| 748 |
+
delta_cube = delta * delta_square
|
| 749 |
+
factor = 1 / (3 * delta_square)
|
| 750 |
+
|
| 751 |
+
clamped_term = torch.pow(
|
| 752 |
+
torch.clamp(input_color, min=delta_cube), 1.0 / 3.0
|
| 753 |
+
).to(dtype=input_color.dtype)
|
| 754 |
+
div = (factor * input_color + (4 / 29)).to(dtype=input_color.dtype)
|
| 755 |
+
input_color = torch.where(
|
| 756 |
+
input_color > delta_cube, clamped_term, div
|
| 757 |
+
) # clamp to stabilize training
|
| 758 |
+
|
| 759 |
+
L = 116 * input_color[:, 1:2, :, :] - 16
|
| 760 |
+
a = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :])
|
| 761 |
+
b = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :])
|
| 762 |
+
|
| 763 |
+
transformed_color = torch.cat((L, a, b), 1)
|
| 764 |
+
|
| 765 |
+
elif fromSpace2toSpace == "lab2xyz":
|
| 766 |
+
y = (input_color[:, 0:1, :, :] + 16) / 116
|
| 767 |
+
a = input_color[:, 1:2, :, :] / 500
|
| 768 |
+
b = input_color[:, 2:3, :, :] / 200
|
| 769 |
+
|
| 770 |
+
x = y + a
|
| 771 |
+
z = y - b
|
| 772 |
+
|
| 773 |
+
xyz = torch.cat((x, y, z), 1)
|
| 774 |
+
delta = 6 / 29
|
| 775 |
+
delta_square = delta * delta
|
| 776 |
+
factor = 3 * delta_square
|
| 777 |
+
xyz = torch.where(xyz > delta, torch.pow(xyz, 3), factor * (xyz - 4 / 29))
|
| 778 |
+
|
| 779 |
+
transformed_color = torch.mul(xyz, reference_illuminant)
|
| 780 |
+
|
| 781 |
+
elif fromSpace2toSpace == "srgb2xyz":
|
| 782 |
+
transformed_color = color_space_transform(input_color, "srgb2linrgb")
|
| 783 |
+
transformed_color = color_space_transform(transformed_color, "linrgb2xyz")
|
| 784 |
+
elif fromSpace2toSpace == "srgb2ycxcz":
|
| 785 |
+
transformed_color = color_space_transform(input_color, "srgb2linrgb")
|
| 786 |
+
transformed_color = color_space_transform(transformed_color, "linrgb2xyz")
|
| 787 |
+
transformed_color = color_space_transform(transformed_color, "xyz2ycxcz")
|
| 788 |
+
elif fromSpace2toSpace == "linrgb2ycxcz":
|
| 789 |
+
transformed_color = color_space_transform(input_color, "linrgb2xyz")
|
| 790 |
+
transformed_color = color_space_transform(transformed_color, "xyz2ycxcz")
|
| 791 |
+
elif fromSpace2toSpace == "srgb2lab":
|
| 792 |
+
transformed_color = color_space_transform(input_color, "srgb2linrgb")
|
| 793 |
+
transformed_color = color_space_transform(transformed_color, "linrgb2xyz")
|
| 794 |
+
transformed_color = color_space_transform(transformed_color, "xyz2lab")
|
| 795 |
+
elif fromSpace2toSpace == "linrgb2lab":
|
| 796 |
+
transformed_color = color_space_transform(input_color, "linrgb2xyz")
|
| 797 |
+
transformed_color = color_space_transform(transformed_color, "xyz2lab")
|
| 798 |
+
elif fromSpace2toSpace == "ycxcz2linrgb":
|
| 799 |
+
transformed_color = color_space_transform(input_color, "ycxcz2xyz")
|
| 800 |
+
transformed_color = color_space_transform(transformed_color, "xyz2linrgb")
|
| 801 |
+
elif fromSpace2toSpace == "lab2srgb":
|
| 802 |
+
transformed_color = color_space_transform(input_color, "lab2xyz")
|
| 803 |
+
transformed_color = color_space_transform(transformed_color, "xyz2linrgb")
|
| 804 |
+
transformed_color = color_space_transform(transformed_color, "linrgb2srgb")
|
| 805 |
+
elif fromSpace2toSpace == "ycxcz2lab":
|
| 806 |
+
transformed_color = color_space_transform(input_color, "ycxcz2xyz")
|
| 807 |
+
transformed_color = color_space_transform(transformed_color, "xyz2lab")
|
| 808 |
+
else:
|
| 809 |
+
sys.exit("Error: The color transform %s is not defined!" % fromSpace2toSpace)
|
| 810 |
+
|
| 811 |
+
return transformed_color
|
utils/image_utils.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import matplotlib
|
| 4 |
+
import matplotlib.font_manager as font_manager
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from matplotlib.patches import Ellipse
|
| 9 |
+
from numpy.linalg import norm
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from scipy.ndimage import sobel
|
| 12 |
+
|
| 13 |
+
FONT_PATH = "assets/fonts/linux_libertine/LinLibertine_R.ttf"
|
| 14 |
+
|
| 15 |
+
# Make font loading optional for deployment environments
|
| 16 |
+
try:
|
| 17 |
+
font_manager.fontManager.addfont(FONT_PATH)
|
| 18 |
+
FONT_PROP = font_manager.FontProperties(fname=FONT_PATH).get_name()
|
| 19 |
+
plt.rcParams["font.family"] = FONT_PROP
|
| 20 |
+
plt.rcParams["text.usetex"] = True
|
| 21 |
+
except (FileNotFoundError, OSError):
|
| 22 |
+
# Use default font if custom font is not available
|
| 23 |
+
FONT_PROP = "DejaVu Sans"
|
| 24 |
+
plt.rcParams["font.family"] = FONT_PROP
|
| 25 |
+
plt.rcParams["text.usetex"] = False # Disable LaTeX if custom font unavailable
|
| 26 |
+
matplotlib.rcParams["font.size"] = 16
|
| 27 |
+
matplotlib.rcParams["axes.titlesize"] = 16
|
| 28 |
+
matplotlib.rcParams["figure.titlesize"] = 16
|
| 29 |
+
matplotlib.rcParams["legend.fontsize"] = 16
|
| 30 |
+
matplotlib.rcParams["legend.title_fontsize"] = 16
|
| 31 |
+
matplotlib.rcParams["xtick.labelsize"] = 14
|
| 32 |
+
matplotlib.rcParams["ytick.labelsize"] = 14
|
| 33 |
+
|
| 34 |
+
ALLOWED_IMAGE_FILE_FORMATS = [".jpeg", ".jpg", ".png"]
|
| 35 |
+
ALLOWED_IMAGE_TYPES = {"RGB": 3, "RGBA": 3, "L": 1}
|
| 36 |
+
|
| 37 |
+
PLOT_DPI = 72.0
|
| 38 |
+
GAUSSIAN_ZOOM = 5
|
| 39 |
+
GAUSSIAN_COLOR = "#80ed99"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_psnr(image1, image2, max_value=1.0):
|
| 43 |
+
mse = torch.mean((image1 - image2) ** 2)
|
| 44 |
+
if mse.item() <= 1e-7:
|
| 45 |
+
return float("inf")
|
| 46 |
+
psnr = 20 * torch.log10(max_value / torch.sqrt(mse))
|
| 47 |
+
return psnr
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_grid(h, w, x_lim=np.asarray([0, 1]), y_lim=np.asarray([0, 1])):
|
| 51 |
+
x = torch.linspace(x_lim[0], x_lim[1], steps=w + 1)[:-1] + 0.5 / w
|
| 52 |
+
y = torch.linspace(y_lim[0], y_lim[1], steps=h + 1)[:-1] + 0.5 / h
|
| 53 |
+
grid_x, grid_y = torch.meshgrid(x, y, indexing="xy")
|
| 54 |
+
grid = torch.stack([grid_x, grid_y], dim=-1)
|
| 55 |
+
return grid
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def compute_image_gradients(image):
|
| 59 |
+
gy, gx = [], []
|
| 60 |
+
for image_channel in image:
|
| 61 |
+
gy.append(sobel(image_channel, 0))
|
| 62 |
+
gx.append(sobel(image_channel, 1))
|
| 63 |
+
gy = norm(np.stack(gy, axis=0), ord=2, axis=0).astype(np.float32)
|
| 64 |
+
gx = norm(np.stack(gx, axis=0), ord=2, axis=0).astype(np.float32)
|
| 65 |
+
return gy, gx
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def load_images(load_path, downsample_ratio=None, gamma=None):
|
| 69 |
+
"""
|
| 70 |
+
Load target images or textures from a directory or a single file.
|
| 71 |
+
"""
|
| 72 |
+
image_list = []
|
| 73 |
+
image_path_list = []
|
| 74 |
+
image_fname_list = []
|
| 75 |
+
num_channels_list = []
|
| 76 |
+
if (
|
| 77 |
+
os.path.isfile(load_path)
|
| 78 |
+
and os.path.splitext(load_path)[1].lower() in ALLOWED_IMAGE_FILE_FORMATS
|
| 79 |
+
):
|
| 80 |
+
image_path_list.append(load_path)
|
| 81 |
+
elif os.path.isdir(load_path):
|
| 82 |
+
for file in sorted(os.listdir(load_path), key=str.lower):
|
| 83 |
+
if os.path.splitext(file)[1].lower() in ALLOWED_IMAGE_FILE_FORMATS:
|
| 84 |
+
image_path_list.append(os.path.join(load_path, file))
|
| 85 |
+
if len(image_path_list) == 0:
|
| 86 |
+
raise FileNotFoundError(f"No supported image file found at '{load_path}'")
|
| 87 |
+
for image_path in image_path_list:
|
| 88 |
+
image_fname_list.append(os.path.splitext(os.path.basename(image_path))[0])
|
| 89 |
+
image = Image.open(image_path)
|
| 90 |
+
# Warning: Only support images of type L, RGB, or RGBA in JPEG or PNG format
|
| 91 |
+
if image.mode not in ALLOWED_IMAGE_TYPES:
|
| 92 |
+
raise TypeError(
|
| 93 |
+
f"Only support images of type {list(ALLOWED_IMAGE_TYPES.keys())} in JPEG or PNG format"
|
| 94 |
+
)
|
| 95 |
+
num_channels = ALLOWED_IMAGE_TYPES[image.mode]
|
| 96 |
+
num_channels_list.append(num_channels)
|
| 97 |
+
if downsample_ratio is not None:
|
| 98 |
+
image = image.resize(
|
| 99 |
+
(
|
| 100 |
+
round(image.width / downsample_ratio),
|
| 101 |
+
round(image.height / downsample_ratio),
|
| 102 |
+
),
|
| 103 |
+
resample=Image.Resampling.BILINEAR,
|
| 104 |
+
)
|
| 105 |
+
# Warning: Assume 8 bit color depth
|
| 106 |
+
image = np.asarray(image, dtype=np.float32) / 255.0
|
| 107 |
+
if gamma is not None:
|
| 108 |
+
image = np.power(image, gamma)
|
| 109 |
+
if len(image.shape) == 2:
|
| 110 |
+
image = np.expand_dims(image, axis=2)
|
| 111 |
+
image = image.transpose(2, 0, 1)
|
| 112 |
+
image = image[:num_channels]
|
| 113 |
+
image_list.append(image)
|
| 114 |
+
return np.concatenate(image_list, axis=0), num_channels_list, image_fname_list
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def to_output_format(image, gamma):
|
| 118 |
+
if len(image.shape) not in [2, 3]:
|
| 119 |
+
raise ValueError(f"Wrong image format: shape = {image.shape}")
|
| 120 |
+
if isinstance(image, torch.Tensor):
|
| 121 |
+
image = image.detach().cpu().clone().numpy()
|
| 122 |
+
if len(image.shape) == 3 and image.shape[2] not in [1, 3]:
|
| 123 |
+
image = image.transpose(1, 2, 0)
|
| 124 |
+
if image.shape[2] not in [1, 3]:
|
| 125 |
+
raise ValueError(f"Wrong image format: shape = {image.shape}")
|
| 126 |
+
if len(image.shape) == 3 and image.shape[2] == 1:
|
| 127 |
+
image = image.squeeze(axis=2)
|
| 128 |
+
image = np.clip(image, 0.0, 1.0)
|
| 129 |
+
if gamma is not None:
|
| 130 |
+
image = np.power(image, 1.0 / gamma)
|
| 131 |
+
image = (255.0 * image).astype(np.uint8)
|
| 132 |
+
return image
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def save_image(image, save_path, gamma=None, zoom=None):
|
| 136 |
+
image = to_output_format(image, gamma)
|
| 137 |
+
image = Image.fromarray(image)
|
| 138 |
+
if zoom is not None and zoom > 0.0:
|
| 139 |
+
width, height = image.size
|
| 140 |
+
image = image.resize(
|
| 141 |
+
(round(width * zoom), round(height * zoom)), resample=Image.Resampling.BOX
|
| 142 |
+
)
|
| 143 |
+
image.save(save_path)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def separate_image_channels(images, input_channels):
|
| 147 |
+
if len(images) != sum(input_channels):
|
| 148 |
+
raise ValueError(
|
| 149 |
+
f"Incompatible number of channels: {len(images):d} vs {sum(input_channels):d}"
|
| 150 |
+
)
|
| 151 |
+
image_list = []
|
| 152 |
+
curr_channel = 0
|
| 153 |
+
for num_channels in input_channels:
|
| 154 |
+
image_list.append(images[curr_channel : curr_channel + num_channels])
|
| 155 |
+
curr_channel += num_channels
|
| 156 |
+
return image_list
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def visualize_gaussians(
|
| 160 |
+
filepath, xy, scale, rot, feat, img_h, img_w, input_channels, alpha=0.8, gamma=None
|
| 161 |
+
):
|
| 162 |
+
"""
|
| 163 |
+
Visualize Gaussians as colored elliptical disks.
|
| 164 |
+
"""
|
| 165 |
+
if feat.shape[1] != sum(input_channels):
|
| 166 |
+
raise ValueError(
|
| 167 |
+
f"Incompatible number of channels: {feat.shape[1]:d} vs {sum(input_channels):d}"
|
| 168 |
+
)
|
| 169 |
+
xy = xy.detach().cpu().clone().numpy()
|
| 170 |
+
y, x = xy[:, 1] * img_h, xy[:, 0] * img_w
|
| 171 |
+
scale = GAUSSIAN_ZOOM * scale.detach().cpu().clone().numpy()
|
| 172 |
+
rot = rot.detach().cpu().clone().numpy()
|
| 173 |
+
if gamma is not None:
|
| 174 |
+
feat = torch.pow(feat, 1.0 / gamma)
|
| 175 |
+
feat = np.clip(feat.detach().cpu().clone().numpy(), 0.0, 1.0)
|
| 176 |
+
|
| 177 |
+
curr_channel = 0
|
| 178 |
+
for image_id, num_channels in enumerate(input_channels, 1):
|
| 179 |
+
curr_feat = feat[:, curr_channel : curr_channel + num_channels]
|
| 180 |
+
fig = plt.figure()
|
| 181 |
+
fig.set_dpi(PLOT_DPI)
|
| 182 |
+
fig.set_size_inches(w=img_w / PLOT_DPI, h=img_h / PLOT_DPI, forward=False)
|
| 183 |
+
ax = plt.gca()
|
| 184 |
+
for gid in range(len(xy)):
|
| 185 |
+
ellipse = Ellipse(
|
| 186 |
+
xy=(x[gid], y[gid]),
|
| 187 |
+
width=scale[gid, 0],
|
| 188 |
+
height=scale[gid, 1],
|
| 189 |
+
angle=rot[gid, 0] * 180 / np.pi,
|
| 190 |
+
alpha=alpha,
|
| 191 |
+
ec=None,
|
| 192 |
+
fc=curr_feat[gid],
|
| 193 |
+
lw=None,
|
| 194 |
+
)
|
| 195 |
+
ax.add_patch(ellipse)
|
| 196 |
+
plt.xlim(0, img_w)
|
| 197 |
+
plt.ylim(img_h, 0)
|
| 198 |
+
plt.axis("off")
|
| 199 |
+
plt.tight_layout()
|
| 200 |
+
suffix = "" if len(input_channels) == 1 else f"_{image_id:d}"
|
| 201 |
+
plt.savefig(
|
| 202 |
+
f"{filepath}{suffix}.png", bbox_inches="tight", pad_inches=0, dpi=PLOT_DPI
|
| 203 |
+
)
|
| 204 |
+
plt.close()
|
| 205 |
+
curr_channel += num_channels
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def visualize_added_gaussians(
|
| 209 |
+
filepath,
|
| 210 |
+
images,
|
| 211 |
+
old_xy,
|
| 212 |
+
new_xy,
|
| 213 |
+
input_channels,
|
| 214 |
+
size=500,
|
| 215 |
+
every_n=5,
|
| 216 |
+
alpha=0.8,
|
| 217 |
+
gamma=None,
|
| 218 |
+
):
|
| 219 |
+
"""
|
| 220 |
+
Visualize the positions of added Gaussians during error-guided progressive optimization.
|
| 221 |
+
"""
|
| 222 |
+
if len(images) != sum(input_channels):
|
| 223 |
+
raise ValueError(
|
| 224 |
+
f"Incompatible number of channels: {len(images):d} vs {sum(input_channels):d}"
|
| 225 |
+
)
|
| 226 |
+
image_height, image_width = images.shape[1:]
|
| 227 |
+
old_xy = old_xy.detach().cpu().clone().numpy()[::every_n]
|
| 228 |
+
new_xy = new_xy.detach().cpu().clone().numpy()[::every_n]
|
| 229 |
+
old_x, old_y = old_xy[:, 0] * image_width, old_xy[:, 1] * image_height
|
| 230 |
+
new_x, new_y = new_xy[:, 0] * image_width, new_xy[:, 1] * image_height
|
| 231 |
+
|
| 232 |
+
curr_channel = 0
|
| 233 |
+
for image_id, num_channels in enumerate(input_channels, 1):
|
| 234 |
+
image = images[curr_channel : curr_channel + num_channels]
|
| 235 |
+
image = to_output_format(image, gamma)
|
| 236 |
+
fig = plt.figure()
|
| 237 |
+
fig.set_dpi(PLOT_DPI)
|
| 238 |
+
fig.set_size_inches(
|
| 239 |
+
w=image_width / PLOT_DPI, h=image_height / PLOT_DPI, forward=False
|
| 240 |
+
)
|
| 241 |
+
plt.imshow(Image.fromarray(image), cmap="gray", vmin=0, vmax=255)
|
| 242 |
+
plt.scatter(old_x, old_y, s=size, c="#ef476f", marker="o", alpha=alpha) # red
|
| 243 |
+
plt.scatter(new_x, new_y, s=size, c="#06d6a0", marker="o", alpha=alpha) # green
|
| 244 |
+
plt.xlim(0, image_width)
|
| 245 |
+
plt.ylim(image_height, 0)
|
| 246 |
+
plt.axis("off")
|
| 247 |
+
plt.tight_layout()
|
| 248 |
+
suffix = "" if len(input_channels) == 1 else f"_{image_id:d}"
|
| 249 |
+
plt.savefig(
|
| 250 |
+
f"{filepath}{suffix}.png", bbox_inches="tight", pad_inches=0, dpi=PLOT_DPI
|
| 251 |
+
)
|
| 252 |
+
plt.close()
|
| 253 |
+
curr_channel += num_channels
|
utils/misc_utils.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import shutil
|
| 4 |
+
from argparse import ArgumentParser
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import yaml
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def clean_dir(path):
|
| 12 |
+
if os.path.exists(path):
|
| 13 |
+
shutil.rmtree(path)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_latest_ckpt_step(load_path):
|
| 17 |
+
saved_steps = [
|
| 18 |
+
int(os.path.splitext(path)[0].split("-")[-1])
|
| 19 |
+
for path in os.listdir(load_path)
|
| 20 |
+
if path.endswith(".pt")
|
| 21 |
+
]
|
| 22 |
+
latest_step = -1 if len(saved_steps) == 0 else max(saved_steps)
|
| 23 |
+
return latest_step
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def set_random_seed(seed):
|
| 27 |
+
random.seed(seed)
|
| 28 |
+
np.random.seed(seed)
|
| 29 |
+
torch.manual_seed(seed)
|
| 30 |
+
torch.cuda.manual_seed(seed)
|
| 31 |
+
torch.cuda.manual_seed_all(seed)
|
| 32 |
+
torch.backends.cudnn.deterministic = True
|
| 33 |
+
torch.backends.cudnn.benchmark = False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_cfg(cfg_path: str, parser: ArgumentParser) -> ArgumentParser:
|
| 37 |
+
with open(cfg_path, "r", encoding="utf-8") as file:
|
| 38 |
+
cfg: dict = yaml.safe_load(file)
|
| 39 |
+
for key, value in cfg.items():
|
| 40 |
+
if value is None:
|
| 41 |
+
raise ValueError("'None' is not a supported value in the config file")
|
| 42 |
+
if isinstance(value, bool):
|
| 43 |
+
parser.add_argument(f"--{key}", action="store_true", default=value)
|
| 44 |
+
else:
|
| 45 |
+
parser.add_argument(f"--{key}", type=type(value), default=value)
|
| 46 |
+
return parser
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def save_cfg(path: str, args, mode="w"):
|
| 50 |
+
with open(path, mode=mode, encoding="utf-8") as file:
|
| 51 |
+
print("#################### Training Config ####################", file=file)
|
| 52 |
+
yaml.dump(vars(args), file, default_flow_style=False)
|
utils/quantization_utils.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def ste_quantize(x: torch.Tensor, num_bits: int = 16) -> torch.Tensor:
|
| 5 |
+
"""
|
| 6 |
+
Bit precision control of Gaussian parameters using a straight-through estimator.
|
| 7 |
+
Reference: https://arxiv.org/abs/1308.3432
|
| 8 |
+
"""
|
| 9 |
+
qmin, qmax = 0, 2**num_bits - 1
|
| 10 |
+
min_val, max_val = x.min().item(), x.max().item()
|
| 11 |
+
scale = max((max_val - min_val) / (qmax - qmin), 1e-8)
|
| 12 |
+
# Quantize in forward pass (non-differentiable)
|
| 13 |
+
q_x = torch.round((x - min_val) / scale).clamp(qmin, qmax)
|
| 14 |
+
dq_x = q_x * scale + min_val
|
| 15 |
+
# Restore gradients in backward pass
|
| 16 |
+
dq_x = x + (dq_x - x).detach()
|
| 17 |
+
return dq_x
|
utils/saliency/decoder.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Decoder(nn.Module):
|
| 9 |
+
def __init__(self, shape, num_img_feat, num_pla_feat):
|
| 10 |
+
super(Decoder, self).__init__()
|
| 11 |
+
self.shape = shape
|
| 12 |
+
self.img_model = self._make_layer(num_img_feat)
|
| 13 |
+
self.pla_model = self._make_layer(num_pla_feat)
|
| 14 |
+
|
| 15 |
+
self.combined = self._make_output(num_img_feat + num_pla_feat)
|
| 16 |
+
|
| 17 |
+
for m in self.modules():
|
| 18 |
+
if isinstance(m, nn.Conv2d):
|
| 19 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 20 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / n))
|
| 21 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 22 |
+
m.weight.data.fill_(1)
|
| 23 |
+
m.bias.data.zero_()
|
| 24 |
+
|
| 25 |
+
def _make_layer(self, num_feat):
|
| 26 |
+
ans = nn.ModuleList()
|
| 27 |
+
for _ in range(num_feat):
|
| 28 |
+
m = nn.Sequential(
|
| 29 |
+
nn.Conv2d(1, 1, 3, padding=1), nn.BatchNorm2d(1), nn.ReLU(inplace=True)
|
| 30 |
+
)
|
| 31 |
+
ans.append(m)
|
| 32 |
+
return ans
|
| 33 |
+
|
| 34 |
+
def _make_output(self, planes, readout=1):
|
| 35 |
+
return nn.Sequential(
|
| 36 |
+
nn.Conv2d(planes, readout, 3, stride=1, padding=1),
|
| 37 |
+
nn.BatchNorm2d(readout),
|
| 38 |
+
nn.Sigmoid(),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
img_feat, pla_feat = x
|
| 43 |
+
feat = []
|
| 44 |
+
|
| 45 |
+
for a, b in zip(img_feat, self.img_model):
|
| 46 |
+
f = F.interpolate(b(a), self.shape)
|
| 47 |
+
feat.append(f)
|
| 48 |
+
|
| 49 |
+
for a, b in zip(pla_feat, self.pla_model):
|
| 50 |
+
f = F.interpolate(b(a), self.shape)
|
| 51 |
+
feat.append(f)
|
| 52 |
+
|
| 53 |
+
feat = torch.cat(feat, dim=1)
|
| 54 |
+
feat = self.combined(feat)
|
| 55 |
+
return feat
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def build_decoder(model_path, *args):
|
| 59 |
+
decoder = Decoder(*args)
|
| 60 |
+
loaded = torch.load(model_path, weights_only=True)["state_dict"]
|
| 61 |
+
decoder.load_state_dict(loaded)
|
| 62 |
+
return decoder
|
utils/saliency/resnet.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 9 |
+
conv = nn.Conv2d(
|
| 10 |
+
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
|
| 11 |
+
)
|
| 12 |
+
return conv
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BasicBlock(nn.Module):
|
| 16 |
+
expansion = 1
|
| 17 |
+
|
| 18 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 19 |
+
super(BasicBlock, self).__init__()
|
| 20 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 21 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 22 |
+
self.relu = nn.ReLU(inplace=True)
|
| 23 |
+
self.conv2 = conv3x3(planes, planes)
|
| 24 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 25 |
+
self.downsample = downsample
|
| 26 |
+
self.stride = stride
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
residual = x
|
| 30 |
+
out = self.conv1(x)
|
| 31 |
+
out = self.bn1(out)
|
| 32 |
+
out = self.relu(out)
|
| 33 |
+
out = self.conv2(out)
|
| 34 |
+
out = self.bn2(out)
|
| 35 |
+
if self.downsample is not None:
|
| 36 |
+
residual = self.downsample(x)
|
| 37 |
+
out += residual
|
| 38 |
+
out = self.relu(out)
|
| 39 |
+
return out
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Bottleneck(nn.Module):
|
| 43 |
+
expansion = 4
|
| 44 |
+
|
| 45 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 46 |
+
super(Bottleneck, self).__init__()
|
| 47 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 48 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 49 |
+
self.conv2 = nn.Conv2d(
|
| 50 |
+
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
| 51 |
+
)
|
| 52 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 53 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
| 54 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
| 55 |
+
self.relu = nn.ReLU(inplace=True)
|
| 56 |
+
self.downsample = downsample
|
| 57 |
+
self.stride = stride
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
residual = x
|
| 61 |
+
out = self.conv1(x)
|
| 62 |
+
out = self.bn1(out)
|
| 63 |
+
out = self.relu(out)
|
| 64 |
+
out = self.conv2(out)
|
| 65 |
+
out = self.bn2(out)
|
| 66 |
+
out = self.relu(out)
|
| 67 |
+
out = self.conv3(out)
|
| 68 |
+
out = self.bn3(out)
|
| 69 |
+
if self.downsample is not None:
|
| 70 |
+
residual = self.downsample(x)
|
| 71 |
+
out += residual
|
| 72 |
+
out = self.relu(out)
|
| 73 |
+
return out
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class ResNet(nn.Module):
|
| 77 |
+
def __init__(self, block, layers):
|
| 78 |
+
self.inplanes = 64
|
| 79 |
+
super(ResNet, self).__init__()
|
| 80 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
| 81 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 82 |
+
self.relu = nn.ReLU(inplace=True)
|
| 83 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 84 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 85 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 86 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 87 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 88 |
+
self.out_channels = 1
|
| 89 |
+
self.output0 = self._make_output(64, readout=self.out_channels)
|
| 90 |
+
self.output1 = self._make_output(256, readout=self.out_channels)
|
| 91 |
+
self.output2 = self._make_output(512, readout=self.out_channels)
|
| 92 |
+
self.output3 = self._make_output(1024, readout=self.out_channels)
|
| 93 |
+
self.output4 = self._make_output(2048, readout=self.out_channels)
|
| 94 |
+
self.combined = self._make_output(5, sigmoid=True)
|
| 95 |
+
for m in self.modules():
|
| 96 |
+
if isinstance(m, nn.Conv2d):
|
| 97 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 98 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / n))
|
| 99 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 100 |
+
m.weight.data.fill_(1)
|
| 101 |
+
m.bias.data.zero_()
|
| 102 |
+
|
| 103 |
+
def _make_output(self, planes, readout=1, sigmoid=False):
|
| 104 |
+
layers = [
|
| 105 |
+
nn.Conv2d(planes, readout, kernel_size=3, padding=1),
|
| 106 |
+
nn.BatchNorm2d(readout),
|
| 107 |
+
]
|
| 108 |
+
if sigmoid:
|
| 109 |
+
layers.append(nn.Sigmoid())
|
| 110 |
+
else:
|
| 111 |
+
layers.append(nn.ReLU(inplace=True))
|
| 112 |
+
return nn.Sequential(*layers)
|
| 113 |
+
|
| 114 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 115 |
+
downsample = None
|
| 116 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 117 |
+
downsample = nn.Sequential(
|
| 118 |
+
nn.Conv2d(
|
| 119 |
+
self.inplanes,
|
| 120 |
+
planes * block.expansion,
|
| 121 |
+
kernel_size=1,
|
| 122 |
+
stride=stride,
|
| 123 |
+
bias=False,
|
| 124 |
+
),
|
| 125 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 126 |
+
)
|
| 127 |
+
layers = []
|
| 128 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 129 |
+
self.inplanes = planes * block.expansion
|
| 130 |
+
for _ in range(1, blocks):
|
| 131 |
+
layers.append(block(self.inplanes, planes))
|
| 132 |
+
return nn.Sequential(*layers)
|
| 133 |
+
|
| 134 |
+
def forward(self, x, decode=False):
|
| 135 |
+
h, w = x.size(2), x.size(3)
|
| 136 |
+
x = self.conv1(x)
|
| 137 |
+
x = self.bn1(x)
|
| 138 |
+
out0 = self.relu(x)
|
| 139 |
+
x = self.maxpool(out0)
|
| 140 |
+
out1 = self.layer1(x)
|
| 141 |
+
out2 = self.layer2(out1)
|
| 142 |
+
out3 = self.layer3(out2)
|
| 143 |
+
out4 = self.layer4(out3)
|
| 144 |
+
out0 = self.output0(out0)
|
| 145 |
+
r, c = out0.size(2), out0.size(3)
|
| 146 |
+
out1 = self.output1(out1)
|
| 147 |
+
out2 = self.output2(out2)
|
| 148 |
+
out3 = self.output3(out3)
|
| 149 |
+
out4 = self.output4(out4)
|
| 150 |
+
if decode:
|
| 151 |
+
return [out0, out1, out2, out3, out4]
|
| 152 |
+
out1 = F.interpolate(out1, (r, c))
|
| 153 |
+
out2 = F.interpolate(out2, (r, c))
|
| 154 |
+
out3 = F.interpolate(out3, (r, c))
|
| 155 |
+
out4 = F.interpolate(out4, (r, c))
|
| 156 |
+
x = torch.cat([out0, out1, out2, out3, out4], dim=1)
|
| 157 |
+
x = self.combined(x)
|
| 158 |
+
x = F.interpolate(x, (h, w))
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def resnet50(model_path, **kwargs):
|
| 163 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
| 164 |
+
if model_path is not None:
|
| 165 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
| 166 |
+
model_state = model.state_dict()
|
| 167 |
+
loaded_model = torch.load(model_path, weights_only=True)
|
| 168 |
+
if "state_dict" in loaded_model:
|
| 169 |
+
loaded_model = loaded_model["state_dict"]
|
| 170 |
+
pretrained = {k[7:]: v for k, v in loaded_model.items() if k[7:] in model_state}
|
| 171 |
+
if len(pretrained) == 0:
|
| 172 |
+
pretrained = {k: v for k, v in loaded_model.items() if k in model_state}
|
| 173 |
+
model_state.update(pretrained)
|
| 174 |
+
model.load_state_dict(model_state)
|
| 175 |
+
return model
|
utils/saliency_utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from skimage import filters
|
| 4 |
+
from torchvision.transforms.functional import resize
|
| 5 |
+
|
| 6 |
+
from utils.saliency import decoder, resnet
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_smap(image, path, filter_size=15):
|
| 10 |
+
"""
|
| 11 |
+
Compute the saliency map of the target image using EMLNet.
|
| 12 |
+
Reference: https://arxiv.org/abs/1805.01047
|
| 13 |
+
Reference: https://github.com/SenJia/EML-NET-Saliency
|
| 14 |
+
"""
|
| 15 |
+
if image.shape[0] != 3:
|
| 16 |
+
raise ValueError("Saliency prediction only supports RGB images")
|
| 17 |
+
sod_res = (480, 640)
|
| 18 |
+
imagenet_model = resnet.resnet50(f"{path}/emlnet/res_imagenet.pth").cuda().eval()
|
| 19 |
+
places_model = resnet.resnet50(f"{path}/emlnet/res_places.pth").cuda().eval()
|
| 20 |
+
decoder_model = (
|
| 21 |
+
decoder.build_decoder(f"{path}/emlnet/res_decoder.pth", sod_res, 5, 5)
|
| 22 |
+
.cuda()
|
| 23 |
+
.eval()
|
| 24 |
+
)
|
| 25 |
+
image_sod = resize(image, sod_res).unsqueeze(0)
|
| 26 |
+
with torch.no_grad():
|
| 27 |
+
imagenet_feat = imagenet_model(image_sod, decode=True)
|
| 28 |
+
places_feat = places_model(image_sod, decode=True)
|
| 29 |
+
smap = decoder_model([imagenet_feat, places_feat])
|
| 30 |
+
smap = resize(smap.squeeze(0).detach().cpu(), image.shape[1:]).squeeze(0)
|
| 31 |
+
|
| 32 |
+
def post_process(smap):
|
| 33 |
+
smap = filters.gaussian(smap, filter_size)
|
| 34 |
+
smap -= smap.min()
|
| 35 |
+
smap /= smap.max()
|
| 36 |
+
return smap
|
| 37 |
+
|
| 38 |
+
return post_process(smap.numpy()).astype(np.float32)
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|