RunyuZhu commited on
Commit
07c970e
·
verified ·
1 Parent(s): 7842505

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +182 -181
  2. naka_color_correction.py +1096 -0
  3. phototransduction.py +240 -0
  4. requirements.txt +5 -0
README.md CHANGED
@@ -1,269 +1,270 @@
1
- # NAKA-GS
2
- This pipeline was bulided base on [VGGT](https://github.com/facebookresearch/vggt) and [gsplat](https://github.com/nerfstudio-project/gsplat), thanks for their excellent works.
 
 
 
 
 
 
 
 
 
 
3
 
4
- The Paper can be found at: https://arxiv.org/abs/2604.11142; or view the .pdf file at: https://arxiv.org/pdf/2604.11142
5
 
6
- NAKA-GS is an end-to-end pipeline for low-light 3D scene reconstruction and novel-view synthesis:
7
 
8
- 1. `Naka` enhances low-light training images.
9
- 2. `VGGT` reconstructs sparse cameras and geometry from the enhanced images.
10
- 3. `gsplat` performs Gaussian Splatting training, with optional `PPM` dense-point preprocessing.
11
 
12
- The qualitative result (visual comparison on RealX3D) can be found at folder"asset"
13
 
14
- ## 1. What The Pipeline Expects
 
 
15
 
16
- Each scene directory should look like this before the first run:
17
 
18
- ```text
19
- data/
20
- └── Scene1/
21
- ├── train/ # low-light training images
22
- ├── transforms_train.json # training camera poses
23
- ├── transforms_test.json # render trajectory / test poses
24
- └── test/ # optional GT test images for metrics
25
- ```
26
 
27
- After the pipeline runs, it will automatically create:
 
 
 
 
28
 
29
- ```text
30
- data/
31
- └── Scene/
32
- ├── images/ # Naka-enhanced images
33
- ├── sparse/ # VGGT reconstruction outputs
34
- │ ├── cameras.bin
35
- │ ├── images.bin
36
- │ ├── points3D.bin
37
- │ └── points.ply
38
- └── gsplat_results/ # rendering results, stats, checkpoints
39
- ```
40
 
41
- Notes:
42
 
43
- - `images/`, `sparse/`, and `gsplat_results/` do not need to exist before the first run.
44
- - `sparse/points.ply` is produced by the VGGT stage and then reused by the PPM stage.
45
- - If a scene does not contain ground-truth test images, the pipeline still renders novel views but skips reference-image metrics.
46
 
47
- ## 2. System Requirements
48
 
49
- - Linux
50
- - NVIDIA GPU
51
- - CUDA-compatible PyTorch environment
52
- - A working CUDA toolkit / `nvcc` visible to the environment for `gsplat` extension compilation
53
 
54
- All experiments and internal validation for this repository were tested on an NVIDIA RTX A6000 GPU.
 
55
 
56
- ## 3. Install The Environment
57
 
58
- We recommend Conda for reproducibility.
 
 
 
59
 
60
- If the unified environment in this README does not solve cleanly on your machine, use the original environment setup procedures from the two upstream components instead:
61
 
62
- - `vggt/README.md`
63
- - `gsplat/README.md`
64
 
65
- In that fallback workflow, configure the `VGGT` and `gsplat` environments separately first, then return to this repository and run the unified pipeline script.
 
66
 
67
- ### Option A: Conda
68
 
69
- From the repository root:
70
 
71
- ```bash
72
- conda env create -f environment.yaml
73
- conda activate naka-gs
74
- pip install git+https://github.com/rahul-goel/fused-ssim@328dc9836f513d00c4b5bc38fe30478b4435cbb5
75
- pip install git+https://github.com/harry7557558/fused-bilagrid@90f9788e57d3545e3a033c1038bb9986549632fe
76
- pip install git+https://github.com/nerfstudio-project/nerfview@4538024fe0d15fd1a0e4d760f3695fc44ca72787
77
- pip install ppisp @ git+https://github.com/nv-tlabs/ppisp@v1.0.0
78
- ```
79
 
80
- If your Conda solver is slow, you can use:
81
 
82
- ```bash
83
- conda env create -f environment.yaml --solver=libmamba
84
- ```
85
 
86
- ### Option B: Pip
87
 
88
- If you already have a matching CUDA PyTorch installation:
89
 
90
- ```bash
91
- pip install -r requirements.txt
 
 
 
 
 
 
 
 
 
92
  ```
93
 
94
- ## 4. Download The VGGT Checkpoint
95
-
96
- The repository does not include the `VGGT` model weight. Download the official checkpoint and place it at:
97
 
98
- ```text
99
- vggt/checkpoint/model.pt
 
 
100
  ```
101
 
102
- Official model page:
103
 
104
- - https://huggingface.co/facebook/VGGT-1B
105
 
106
- Direct checkpoint URL:
 
 
 
 
107
 
108
- - https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt
109
 
110
- Example:
111
 
112
- ```bash
113
- mkdir -p vggt/checkpoint
114
- wget -O vggt/checkpoint/model.pt \
115
- https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt
116
- ```
117
 
118
- ## 5. Naka Checkpoint
 
 
119
 
120
- By default, the pipeline looks for the Naka checkpoint at:
121
 
122
  ```text
123
- outputs/naka/checkpoints/latest.pth
124
  ```
125
 
126
- ## 6. Prepare The Scene
127
 
128
- Put your scene under `data/` or any other location you prefer. The important part is that `--scene_dir` points to the scene root.
129
 
130
- Example:
131
-
132
- ```text
133
- /path/to/naka-gs/data/Scene/
134
- ├── train/
135
- ├── transforms_train.json
136
- ├── transforms_test.json
137
- └── test/ # optional
138
  ```
139
 
140
- `train/` is required.
141
- `transforms_train.json` is required when using `--pose-source replace`.
142
- `transforms_test.json` is required when using `--render-traj-path testjson`.
143
 
144
- ## 7. Reproduce The Unified Pipeline Command
145
-
146
- From the repository root, run:
147
 
148
  ```bash
149
- python run_lowlight_reconstruction.py \
150
- --scene_dir /path/to/naka-gs/data/Your_Scene \
151
- --pose-source replace \
152
- --render-traj-path testjson \
153
- --disable-viewer \
154
- --ppm-enable \
155
- --ppm-dense-points-path sparse/points.ply \
156
- --ppm-align-mode none \
157
- --ppm-voxel-size 0.01 \
158
- --ppm-tau0 0.005 \
159
- --ppm-beta 0.01 \
160
- --ppm-iters 6
161
  ```
162
 
163
- This command runs the full pipeline:
164
 
165
- 1. Low-light `train/` images are enhanced into `images/`.
166
- 2. `VGGT` reconstructs the scene and writes `sparse/` plus `sparse/points.ply`.
167
- 3. `gsplat` uses `PPM` to preprocess `sparse/points.ply`, then trains and renders the target trajectory from `transforms_test.json`.
168
 
169
- ## 8. Example With A Local Conda Python Path
170
 
171
- If you want to use a specific Python interpreter inside a Conda environment, the command is equivalent to:
172
 
173
- ```bash
174
- /path/to/conda/env/bin/python /path/to/naka-gs/run_lowlight_reconstruction.py \
175
- --scene_dir /path/to/naka-gs/data/Your_Scene \
176
- --pose-source replace \
177
- --render-traj-path testjson \
178
- --disable-viewer \
179
- --ppm-enable \
180
- --ppm-dense-points-path sparse/points.ply \
181
- --ppm-align-mode none \
182
- --ppm-voxel-size 0.01 \
183
- --ppm-tau0 0.005 \
184
- --ppm-beta 0.01 \
185
- --ppm-iters 6
186
  ```
187
 
188
- ## 9. Main Outputs
189
-
190
- After a successful run, check:
191
-
192
- - `data/Laboratory/images/` for enhanced images
193
- - `data/Laboratory/sparse/` for the VGGT sparse reconstruction
194
- - `data/Laboratory/gsplat_results/` for rendered views, metrics, checkpoints, and logs
195
- - `data/Laboratory/gsplat_results/pipeline_summary.json` for a stage-by-stage summary
196
 
197
- ## 10. Useful Variants
198
-
199
- ### Reuse Existing Enhanced Images
200
 
201
  ```bash
202
- python run_lowlight_reconstruction.py \
203
- --scene_dir /path/to/scene \
204
- --skip_naka
 
 
 
 
 
 
 
 
 
205
  ```
206
 
207
- ### Reuse Existing Sparse Reconstruction
208
 
209
  ```bash
210
- python run_lowlight_reconstruction.py \
211
- --scene_dir /path/to/scene \
212
- --skip_naka \
213
- --skip_vggt
 
 
214
  ```
215
 
216
- ### Disable PPM
217
 
218
  ```bash
219
- python run_lowlight_reconstruction.py \
220
- --scene_dir /path/to/scene \
221
- --ppm-enable false
 
 
 
222
  ```
223
 
224
- ## 11. Common Issues
225
 
226
- ### `FileNotFoundError: Naka checkpoint is required`
227
 
228
- Provide `--naka_ckpt /path/to/latest.pth`, or place the checkpoint at the default path shown above.
229
 
230
- ### `No enhanced images found`
 
 
 
 
 
 
 
231
 
232
- Make sure `train/` contains valid image files and the Naka stage finished successfully.
233
 
234
- ### `PPM dense point cloud is missing: .../sparse/points.ply`
235
 
236
- This usually means the VGGT stage did not finish successfully, so `sparse/points.ply` was not generated.
 
 
237
 
238
- ### `torch.cuda.is_available() is False`
239
 
240
- The `gsplat` stage requires a visible CUDA GPU.
 
 
 
 
 
 
241
 
242
- ### `gsplat` spends a long time on the first run
243
 
244
- This is expected when the CUDA extension is compiled for the first time.
245
 
246
- ## 12. Minimal Checklist Before Running
 
 
 
 
 
247
 
248
- - Environment created successfully
249
- - `vggt/checkpoint/model.pt` downloaded
250
- - Naka checkpoint available, either at the default path or via `--naka_ckpt`
251
- - Scene directory contains `train/`
252
- - `transforms_train.json` exists for `--pose-source replace`
253
- - `transforms_test.json` exists for `--render-traj-path testjson`
254
 
255
- ## 13. Citation
256
- If you find this code useful for your research, please use the following BibTeX entry.
257
 
258
- ```text
259
- @misc{zhu2026nakagsbionicsinspireddualbranchnaka,
260
- title={Naka-GS: A Bionics-inspired Dual-Branch Naka Correction and Progressive Point Pruning for Low-Light 3DGS},
261
- author={Runyu Zhu and SiXun Dong and Zhiqiang Zhang and Qingxia Ye and Zhihua Xu},
262
- year={2026},
263
- eprint={2604.11142},
264
- archivePrefix={arXiv},
265
- primaryClass={cs.CV},
266
- url={https://arxiv.org/abs/2604.11142},
267
  }
268
  ```
269
 
 
 
1
+ ---
2
+ library_name: pytorch
3
+ pipeline_tag: image-to-image
4
+ tags:
5
+ - low-light-image-enhancement
6
+ - image-enhancement
7
+ - image-to-image
8
+ - gaussian-splatting
9
+ - 3d-reconstruction
10
+ - custom-code
11
+ - pytorch
12
+ ---
13
 
14
+ # Naka-guided Chroma-Correction Model
15
 
16
+ This repository hosts the **Naka-guided Chroma-correction model** used in the **Naka-GS** pipeline.
17
 
18
+ The model is designed to refine a Naka-enhanced low-light image by suppressing color distortion in bright regions while preserving edge and texture details. In the released implementation, the network predicts a **single-channel multiplicative correction map** and a **three-channel additive correction map**, and applies them only to the **low-frequency component** of the Naka-enhanced image before adding the preserved high-frequency details back to the final output. The model input is an 18-channel representation built from the low-light image, the Naka-enhanced image, their residual, and standardized counterparts. fileciteturn4file0
 
 
19
 
20
+ ## Associated resources
21
 
22
+ - **Project page / code**: `https://github.com/RunyuZhu/Naka-GS`
23
+ - **Paper page**: `https://huggingface.co/papers/2604.11142`
24
+ - **ArXiv**: `https://arxiv.org/abs/2604.11142`
25
 
26
+ ## What this model does
27
 
28
+ Given a low-light RGB image:
 
 
 
 
 
 
 
29
 
30
+ 1. a Naka phototransduction transform is applied,
31
+ 2. the correction network predicts `mul_map` and `add_map`,
32
+ 3. the low-frequency component of the Naka image is corrected,
33
+ 4. the high-frequency component is added back,
34
+ 5. the final enhanced image is saved.
35
 
36
+ In the provided code, inference saves the corrected result as `<image_name>_enhanced.JPG`. fileciteturn4file1
 
 
 
 
 
 
 
 
 
 
37
 
38
+ ## Model details
39
 
40
+ ### Architecture
 
 
41
 
42
+ The released model is a U-Net-style encoder-decoder with residual blocks and SE attention. The core model class is `ChromaGuidedUNet`. Its forward pass takes `low` and `naka` tensors as input, constructs an 18-channel feature tensor, predicts `mul_map` and `add_map`, and performs frequency-decoupled correction on the Naka image. fileciteturn4file0
43
 
44
+ ### Input
 
 
 
45
 
46
+ - RGB low-light image
47
+ - Automatically generated Naka-enhanced intermediate image
48
 
49
+ ### Output
50
 
51
+ - corrected RGB image: `enhanced`
52
+ - optional intermediate maps:
53
+ - `mul_map`
54
+ - `add_map`
55
 
56
+ ### Checkpoints
57
 
58
+ Recommended checkpoint filenames:
 
59
 
60
+ - `best.pth`: recommended for inference
61
+ - `latest.pth`: latest training state
62
 
63
+ The training script saves `latest.pth` every epoch and updates `best.pth` whenever validation PSNR improves. fileciteturn4file1
64
 
65
+ ## Intended use
66
 
67
+ This model is intended to be used as the **color-correction / enhancement stage** in the Naka-GS low-light 3D reconstruction pipeline, or as a standalone low-light image refinement module when a Naka-style phototransduction preprocessing step is available.
 
 
 
 
 
 
 
68
 
69
+ ## Limitations
70
 
71
+ - This repository contains **custom PyTorch code** and is **not** a Transformers-native model.
72
+ - The script depends on a custom `Phototransduction` implementation and tries to import it from either `retina.phototransduction` or `phototransduction`. For a standalone release, place `phototransduction.py` next to `naka_color_correction.py`, or preserve the original package layout. fileciteturn4file0
73
+ - The model card does not claim broad robustness outside the training setting used by the original project.
74
 
75
+ ## Repository layout
76
 
77
+ A minimal Hugging Face release layout is:
78
 
79
+ ```text
80
+ .
81
+ ├── README.md
82
+ ├── requirements.txt
83
+ ├── naka_color_correction.py
84
+ ├── phototransduction.py
85
+ ├── best.pth
86
+ ├── latest.pth # optional
87
+ └── assets/
88
+ ├── teaser.png # optional
89
+ └── results.png # optional
90
  ```
91
 
92
+ ## Installation
 
 
93
 
94
+ ```bash
95
+ git clone https://huggingface.co/<your-username-or-org>/<your-model-repo>
96
+ cd <your-model-repo>
97
+ pip install -r requirements.txt
98
  ```
99
 
100
+ ## Requirements
101
 
102
+ Core dependencies used directly in the provided script:
103
 
104
+ - `torch`
105
+ - `torchvision`
106
+ - `numpy`
107
+ - `opencv-python`
108
+ - `Pillow`
109
 
110
+ The script also uses `torchvision.models.vgg19` for the perceptual loss branch during training. fileciteturn4file0
111
 
112
+ ## Quick start: inference
113
 
114
+ ### 1. Prepare files
115
+
116
+ Place the following files in the same directory:
 
 
117
 
118
+ - `naka_color_correction.py`
119
+ - `phototransduction.py`
120
+ - `best.pth`
121
 
122
+ Create an input folder such as:
123
 
124
  ```text
125
+ ./test_images/
126
  ```
127
 
128
+ and put your test images inside.
129
 
130
+ ### 2. Run inference
131
 
132
+ ```bash
133
+ python naka_color_correction.py \
134
+ --mode infer \
135
+ --input_dir ./test_images \
136
+ --output_dir ./outputs/infer_results \
137
+ --ckpt ./best.pth
 
 
138
  ```
139
 
140
+ ### 3. Inference on large images
 
 
141
 
142
+ The code supports tiled forwarding for large inputs:
 
 
143
 
144
  ```bash
145
+ python naka_color_correction.py \
146
+ --mode infer \
147
+ --input_dir ./test_images \
148
+ --output_dir ./outputs/infer_results \
149
+ --ckpt ./best.pth \
150
+ --tile_size 512 \
151
+ --tile_overlap 32
 
 
 
 
 
152
  ```
153
 
154
+ The command-line parser exposes `--mode`, `--input_dir`, `--output_dir`, `--ckpt`, `--tile_size`, and `--tile_overlap` for inference. fileciteturn4file1
155
 
156
+ ## Training
 
 
157
 
158
+ ### Dataset format
159
 
160
+ The training and validation data must follow this layout:
161
 
162
+ ```text
163
+ datasets/LOLv1/
164
+ ├── train/
165
+ │ ├── low/
166
+ │ └── normal/
167
+ └── val/
168
+ ├── low/
169
+ └── normal/
 
 
 
 
 
170
  ```
171
 
172
+ Files are paired by identical filename between `low/` and `normal/`. fileciteturn4file0
 
 
 
 
 
 
 
173
 
174
+ ### Basic training command
 
 
175
 
176
  ```bash
177
+ python naka_color_correction.py \
178
+ --mode train \
179
+ --data_root ./datasets/LOLv1 \
180
+ --output_dir ./outputs/naka_color_correction_v2 \
181
+ --epochs 200 \
182
+ --batch_size 8 \
183
+ --num_workers 4 \
184
+ --crop_size 256 \
185
+ --lr 2e-4 \
186
+ --weight_decay 1e-4 \
187
+ --base_ch 32 \
188
+ --amp
189
  ```
190
 
191
+ ### Resume training
192
 
193
  ```bash
194
+ python naka_color_correction.py \
195
+ --mode train \
196
+ --data_root ./datasets/LOLv1 \
197
+ --output_dir ./outputs/naka_color_correction_v2 \
198
+ --resume_ckpt ./outputs/naka_color_correction_v2/checkpoints/latest.pth \
199
+ --amp
200
  ```
201
 
202
+ ### Initialize from a checkpoint
203
 
204
  ```bash
205
+ python naka_color_correction.py \
206
+ --mode train \
207
+ --data_root ./datasets/LOLv1 \
208
+ --output_dir ./outputs/naka_color_correction_v2 \
209
+ --init_ckpt ./best.pth \
210
+ --amp
211
  ```
212
 
213
+ The parser defaults include `epochs=200`, `batch_size=8`, `crop_size=256`, `lr=2e-4`, `weight_decay=1e-4`, `base_ch=32`, `mul_range=0.6`, `add_range=0.25`, `hf_kernel_size=5`, and `hf_sigma=1.0`. fileciteturn4file1
214
 
215
+ ## Training objective
216
 
217
+ The provided implementation combines:
218
 
219
+ - RGB reconstruction loss
220
+ - YCbCr chroma/luma consistency loss
221
+ - SSIM loss
222
+ - edge loss
223
+ - VGG perceptual loss
224
+ - map regularization
225
+ - gray-edge masked loss
226
+ - bright-region masked loss
227
 
228
+ These are implemented through `NakaCorrectionLoss` and `NakaCorrectionLossWithMasks`. fileciteturn4file1
229
 
230
+ ## Notes on reproducibility
231
 
232
+ - Validation uses full-resolution images with `batch_size=1`. fileciteturn4file1
233
+ - Mixed precision is enabled with `--amp` on CUDA. fileciteturn4file1
234
+ - Checkpoint loading is backward-compatible with older 3-channel `mul_head` weights via `adapt_mul_head_to_single_channel()`. fileciteturn4file0
235
 
236
+ ## Suggested `requirements.txt`
237
 
238
+ ```text
239
+ torch>=2.1.0
240
+ torchvision>=0.16.0
241
+ numpy>=1.24.0
242
+ opencv-python>=4.8.0
243
+ Pillow>=10.0.0
244
+ ```
245
 
246
+ ## Example release contents
247
 
248
+ For a clean Hugging Face release, upload:
249
 
250
+ - `README.md`
251
+ - `requirements.txt`
252
+ - `naka_color_correction.py`
253
+ - `phototransduction.py`
254
+ - `best.pth`
255
+ - optional visual examples in `assets/`
256
 
257
+ ## Citation
 
 
 
 
 
258
 
259
+ If you use this model, please cite the associated Naka-GS paper.
 
260
 
261
+ ```bibtex
262
+ @article{zhu2026nakags,
263
+ title={Naka-GS},
264
+ author={Zhu, Runyu and others},
265
+ journal={arXiv preprint arXiv:2604.11142},
266
+ year={2026}
 
 
 
267
  }
268
  ```
269
 
270
+ If your final BibTeX entry differs, replace the placeholder entry above with the official version from your paper page.
naka_color_correction.py ADDED
@@ -0,0 +1,1096 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import math
4
+ import glob
5
+ import random
6
+ import argparse
7
+ from typing import Dict, List, Optional, Tuple
8
+
9
+ import cv2
10
+ import numpy as np
11
+ from PIL import Image
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.utils.data import Dataset, DataLoader
17
+
18
+ try:
19
+ from torchvision.models import vgg19, VGG19_Weights
20
+ except Exception:
21
+ vgg19 = None
22
+ VGG19_Weights = None
23
+
24
+ # -----------------------------------------------------------------------------
25
+ # Try importing the provided Naka function.
26
+ # Supports either:
27
+ # 1) retina.phototransduction.Phototransduction
28
+ # 2) phototransduction.Phototransduction
29
+ # -----------------------------------------------------------------------------
30
+ try:
31
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
32
+ from retina.phototransduction import Phototransduction
33
+ except Exception:
34
+ from phototransduction import Phototransduction
35
+
36
+
37
+ # -----------------------------------------------------------------------------
38
+ # Utilities
39
+ # -----------------------------------------------------------------------------
40
+ def seed_everything(seed: int = 42) -> None:
41
+ random.seed(seed)
42
+ np.random.seed(seed)
43
+ torch.manual_seed(seed)
44
+ torch.cuda.manual_seed_all(seed)
45
+
46
+
47
+ def load_rgb(path: str) -> np.ndarray:
48
+ img = Image.open(path).convert("RGB")
49
+ return np.array(img)
50
+
51
+
52
+ def save_rgb_tensor(tensor: torch.Tensor, path: str) -> None:
53
+ arr = tensor.detach().cpu().clamp(0, 1)
54
+ if arr.ndim != 3:
55
+ raise ValueError(f"Expected CHW tensor, got shape: {tuple(arr.shape)}")
56
+
57
+ if arr.shape[0] == 1:
58
+ arr = (arr.squeeze(0).numpy() * 255.0).round().astype(np.uint8)
59
+ Image.fromarray(arr, mode="L").save(path)
60
+ elif arr.shape[0] == 3:
61
+ arr = (arr.permute(1, 2, 0).numpy() * 255.0).round().astype(np.uint8)
62
+ Image.fromarray(arr).save(path)
63
+ else:
64
+ raise ValueError(f"save_rgb_tensor only supports 1 or 3 channels, got: {arr.shape[0]}")
65
+
66
+
67
+ def load_torch_checkpoint(path: str, map_location) -> Dict:
68
+ try:
69
+ return torch.load(path, map_location=map_location, weights_only=True)
70
+ except TypeError:
71
+ return torch.load(path, map_location=map_location)
72
+
73
+
74
+ def to_tensor(img: np.ndarray) -> torch.Tensor:
75
+ if img.dtype != np.float32:
76
+ img = img.astype(np.float32) / 255.0
77
+ if img.max() > 1.0:
78
+ img = img / 255.0
79
+ img = np.ascontiguousarray(img)
80
+ return torch.from_numpy(img).permute(2, 0, 1).contiguous().float()
81
+
82
+
83
+ def list_image_files(folder: str) -> List[str]:
84
+ exts = ["*.png", "*.jpg", "*.jpeg", "*.bmp", "*.tif", "*.tiff", "*.PNG", "*.JPG", "*.JPEG"]
85
+ files: List[str] = []
86
+ for ext in exts:
87
+ files.extend(glob.glob(os.path.join(folder, ext)))
88
+ return sorted(files)
89
+
90
+
91
+ def paired_paths(low_dir: str, normal_dir: str) -> List[Tuple[str, str]]:
92
+ low_files = list_image_files(low_dir)
93
+ normal_files = list_image_files(normal_dir)
94
+ normal_map = {os.path.basename(p): p for p in normal_files}
95
+ pairs = []
96
+ for low_path in low_files:
97
+ name = os.path.basename(low_path)
98
+ if name in normal_map:
99
+ pairs.append((low_path, normal_map[name]))
100
+ if not pairs:
101
+ raise RuntimeError(f"No paired files found between {low_dir} and {normal_dir}")
102
+ return pairs
103
+
104
+
105
+ def ensure_min_size_pair(a: np.ndarray, b: np.ndarray, min_size: int) -> Tuple[np.ndarray, np.ndarray]:
106
+ h, w = a.shape[:2]
107
+ if h >= min_size and w >= min_size:
108
+ return a, b
109
+ scale = max(min_size / max(h, 1), min_size / max(w, 1))
110
+ nh, nw = int(math.ceil(h * scale)), int(math.ceil(w * scale))
111
+ a = cv2.resize(a, (nw, nh), interpolation=cv2.INTER_LINEAR)
112
+ b = cv2.resize(b, (nw, nh), interpolation=cv2.INTER_LINEAR)
113
+ return a, b
114
+
115
+
116
+ def random_rescale_pair(
117
+ a: np.ndarray,
118
+ b: np.ndarray,
119
+ min_scale: float = 0.7,
120
+ max_scale: float = 1.4,
121
+ min_after_scale: int = 32,
122
+ ) -> Tuple[np.ndarray, np.ndarray]:
123
+ scale = random.uniform(min_scale, max_scale)
124
+ h, w = a.shape[:2]
125
+ nh = max(min_after_scale, int(round(h * scale)))
126
+ nw = max(min_after_scale, int(round(w * scale)))
127
+ a = cv2.resize(a, (nw, nh), interpolation=cv2.INTER_LINEAR)
128
+ b = cv2.resize(b, (nw, nh), interpolation=cv2.INTER_LINEAR)
129
+ return a, b
130
+
131
+
132
+ def random_crop_pair(a: np.ndarray, b: np.ndarray, crop_size: int) -> Tuple[np.ndarray, np.ndarray]:
133
+ a, b = ensure_min_size_pair(a, b, crop_size)
134
+ h, w = a.shape[:2]
135
+ top = random.randint(0, h - crop_size)
136
+ left = random.randint(0, w - crop_size)
137
+ a = a[top:top + crop_size, left:left + crop_size]
138
+ b = b[top:top + crop_size, left:left + crop_size]
139
+ return a, b
140
+
141
+
142
+ def rgb_to_ycbcr(x: torch.Tensor) -> torch.Tensor:
143
+ r, g, b = x[:, 0:1], x[:, 1:2], x[:, 2:3]
144
+ y = 0.299 * r + 0.587 * g + 0.114 * b
145
+ cb = -0.168736 * r - 0.331264 * g + 0.5 * b + 0.5
146
+ cr = 0.5 * r - 0.418688 * g - 0.081312 * b + 0.5
147
+ return torch.cat([y, cb, cr], dim=1)
148
+
149
+
150
+ def charbonnier_loss(pred: torch.Tensor, target: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
151
+ diff = pred - target
152
+ return torch.mean(torch.sqrt(diff * diff + eps * eps))
153
+
154
+
155
+ def edge_map(x: torch.Tensor) -> torch.Tensor:
156
+ c = x.shape[1]
157
+ sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=x.dtype, device=x.device).view(1, 1, 3, 3)
158
+ sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=x.dtype, device=x.device).view(1, 1, 3, 3)
159
+ sobel_x = sobel_x.repeat(c, 1, 1, 1)
160
+ sobel_y = sobel_y.repeat(c, 1, 1, 1)
161
+ gx = F.conv2d(x, sobel_x, padding=1, groups=c)
162
+ gy = F.conv2d(x, sobel_y, padding=1, groups=c)
163
+ return torch.sqrt(gx * gx + gy * gy + 1e-6)
164
+
165
+
166
+ def gaussian_window(
167
+ window_size: int = 11,
168
+ sigma: float = 1.5,
169
+ channels: int = 3,
170
+ device: Optional[torch.device] = None,
171
+ dtype: torch.dtype = torch.float32,
172
+ ) -> torch.Tensor:
173
+ coords = torch.arange(window_size, dtype=dtype, device=device) - window_size // 2
174
+ g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
175
+ g = g / g.sum()
176
+ w = torch.outer(g, g)
177
+ w = w.view(1, 1, window_size, window_size)
178
+ return w.repeat(channels, 1, 1, 1)
179
+
180
+
181
+
182
+ def gaussian_blur_tensor(x: torch.Tensor, kernel_size: int = 5, sigma: float = 1.0) -> torch.Tensor:
183
+ if kernel_size % 2 == 0:
184
+ raise ValueError(f"kernel_size must be odd, got {kernel_size}")
185
+ c = x.shape[1]
186
+ window = gaussian_window(kernel_size, sigma, c, x.device, x.dtype)
187
+ return F.conv2d(x, window, padding=kernel_size // 2, groups=c)
188
+
189
+
190
+ def ssim_loss(x: torch.Tensor, y: torch.Tensor, window_size: int = 11) -> torch.Tensor:
191
+ c = x.shape[1]
192
+ window = gaussian_window(window_size, 1.5, c, x.device, x.dtype)
193
+ mu_x = F.conv2d(x, window, padding=window_size // 2, groups=c)
194
+ mu_y = F.conv2d(y, window, padding=window_size // 2, groups=c)
195
+
196
+ mu_x2 = mu_x * mu_x
197
+ mu_y2 = mu_y * mu_y
198
+ mu_xy = mu_x * mu_y
199
+
200
+ sigma_x2 = F.conv2d(x * x, window, padding=window_size // 2, groups=c) - mu_x2
201
+ sigma_y2 = F.conv2d(y * y, window, padding=window_size // 2, groups=c) - mu_y2
202
+ sigma_xy = F.conv2d(x * y, window, padding=window_size // 2, groups=c) - mu_xy
203
+
204
+ c1 = 0.01 ** 2
205
+ c2 = 0.03 ** 2
206
+ ssim_n = (2 * mu_xy + c1) * (2 * sigma_xy + c2)
207
+ ssim_d = (mu_x2 + mu_y2 + c1) * (sigma_x2 + sigma_y2 + c2)
208
+ ssim_map = ssim_n / (ssim_d + 1e-8)
209
+ return 1.0 - ssim_map.mean()
210
+
211
+
212
+ # -----------------------------------------------------------------------------
213
+ # Dataset
214
+ # -----------------------------------------------------------------------------
215
+ class NakaPairDataset(Dataset):
216
+ """
217
+ Directory layout:
218
+ root/
219
+ train/
220
+ low/
221
+ normal/
222
+ val/
223
+ low/
224
+ normal/
225
+
226
+ Files are paired by the same filename.
227
+ """
228
+ def __init__(
229
+ self,
230
+ root: str,
231
+ split: str = "train",
232
+ crop_size: int = 256,
233
+ is_train: bool = True,
234
+ cache_naka: bool = False,
235
+ min_scale: float = 0.7,
236
+ max_scale: float = 1.4,
237
+ ) -> None:
238
+ super().__init__()
239
+ low_dir = os.path.join(root, split, "low")
240
+ normal_dir = os.path.join(root, split, "normal")
241
+ self.pairs = paired_paths(low_dir, normal_dir)
242
+ self.crop_size = crop_size
243
+ self.is_train = is_train
244
+ self.cache_naka = cache_naka and (not is_train)
245
+ self.min_scale = min_scale
246
+ self.max_scale = max_scale
247
+ self.naka_cache: Dict[str, np.ndarray] = {}
248
+
249
+ self.naka_processor = Phototransduction(
250
+ mode="naka",
251
+ per_channel=True,
252
+ naka_sigma=0.05,
253
+ clip_percentile=99.9,
254
+ out_mode="0_1",
255
+ out_method="linear",
256
+ )
257
+
258
+ def __len__(self) -> int:
259
+ return len(self.pairs)
260
+
261
+ def _apply_naka(self, low_rgb: np.ndarray, key: str) -> np.ndarray:
262
+ if self.cache_naka and key in self.naka_cache:
263
+ return self.naka_cache[key]
264
+
265
+ low_bgr = cv2.cvtColor(low_rgb, cv2.COLOR_RGB2BGR)
266
+ naka_bgr = self.naka_processor(low_bgr)
267
+ naka_rgb = cv2.cvtColor(naka_bgr.astype(np.float32), cv2.COLOR_BGR2RGB)
268
+ naka_rgb = np.clip(naka_rgb, 0.0, 1.0).astype(np.float32)
269
+
270
+ if self.cache_naka:
271
+ self.naka_cache[key] = naka_rgb
272
+ return naka_rgb
273
+
274
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
275
+ low_path, gt_path = self.pairs[idx]
276
+ low = load_rgb(low_path)
277
+ gt = load_rgb(gt_path)
278
+
279
+ if self.is_train:
280
+ low, gt = random_rescale_pair(low, gt, self.min_scale, self.max_scale, min_after_scale=self.crop_size)
281
+ low, gt = random_crop_pair(low, gt, self.crop_size)
282
+ if random.random() < 0.5:
283
+ low = np.ascontiguousarray(np.fliplr(low))
284
+ gt = np.ascontiguousarray(np.fliplr(gt))
285
+ if random.random() < 0.5:
286
+ low = np.ascontiguousarray(np.flipud(low))
287
+ gt = np.ascontiguousarray(np.flipud(gt))
288
+ if random.random() < 0.5:
289
+ low = np.ascontiguousarray(np.rot90(low))
290
+ gt = np.ascontiguousarray(np.rot90(gt))
291
+ cache_key = f"{low_path}_train_no_cache"
292
+ else:
293
+ cache_key = low_path
294
+
295
+ naka = self._apply_naka(low, cache_key)
296
+ low_t = to_tensor(low)
297
+ gt_t = to_tensor(gt)
298
+ naka_t = to_tensor(naka)
299
+
300
+ return {
301
+ "low": low_t,
302
+ "naka": naka_t,
303
+ "gt": gt_t,
304
+ "name": os.path.basename(low_path),
305
+ "hw": torch.tensor([low_t.shape[1], low_t.shape[2]], dtype=torch.int32),
306
+ }
307
+
308
+
309
+ # -----------------------------------------------------------------------------
310
+ # Model blocks
311
+ # -----------------------------------------------------------------------------
312
+ class InputStandardizer(nn.Module):
313
+ def __init__(self, eps: float = 1e-4):
314
+ super().__init__()
315
+ self.eps = eps
316
+
317
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
318
+ mean = x.mean(dim=(2, 3), keepdim=True)
319
+ std = x.std(dim=(2, 3), keepdim=True, unbiased=False).clamp_min(self.eps)
320
+ return (x - mean) / std
321
+
322
+
323
+ class ConvAct(nn.Module):
324
+ def __init__(
325
+ self,
326
+ in_ch: int,
327
+ out_ch: int,
328
+ k: int = 3,
329
+ s: int = 1,
330
+ p: Optional[int] = None,
331
+ act: bool = True,
332
+ use_norm: bool = True,
333
+ ):
334
+ super().__init__()
335
+ if p is None:
336
+ p = k // 2
337
+
338
+ layers = [nn.Conv2d(in_ch, out_ch, k, s, p, bias=not use_norm)]
339
+
340
+ if use_norm:
341
+ groups = min(8, out_ch)
342
+ while groups > 1 and out_ch % groups != 0:
343
+ groups -= 1
344
+ layers.append(nn.GroupNorm(groups, out_ch))
345
+
346
+ if act:
347
+ layers.append(nn.GELU())
348
+
349
+ self.block = nn.Sequential(*layers)
350
+
351
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
352
+ return self.block(x)
353
+
354
+
355
+ class ResidualBlock(nn.Module):
356
+ def __init__(self, ch: int):
357
+ super().__init__()
358
+ self.conv1 = ConvAct(ch, ch, 3, 1)
359
+ self.conv2 = ConvAct(ch, ch, 3, 1, act=False)
360
+ self.act = nn.GELU()
361
+
362
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
363
+ out = self.conv2(self.conv1(x))
364
+ return self.act(out + x)
365
+
366
+
367
+ class SEBlock(nn.Module):
368
+ def __init__(self, ch: int, r: int = 8):
369
+ super().__init__()
370
+ mid = max(8, ch // r)
371
+ self.pool = nn.AdaptiveAvgPool2d(1)
372
+ self.fc = nn.Sequential(
373
+ nn.Conv2d(ch, mid, 1),
374
+ nn.GELU(),
375
+ nn.Conv2d(mid, ch, 1),
376
+ nn.Sigmoid(),
377
+ )
378
+
379
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
380
+ w = self.fc(self.pool(x))
381
+ return x * w
382
+
383
+
384
+ class DownBlock(nn.Module):
385
+ def __init__(self, in_ch: int, out_ch: int, num_res: int = 2):
386
+ super().__init__()
387
+ blocks = [ConvAct(in_ch, out_ch, 3, 1)]
388
+ for _ in range(num_res):
389
+ blocks.append(ResidualBlock(out_ch))
390
+ blocks.append(SEBlock(out_ch))
391
+ self.block = nn.Sequential(*blocks)
392
+ self.down = ConvAct(out_ch, out_ch, 3, 2)
393
+
394
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
395
+ feat = self.block(x)
396
+ down = self.down(feat)
397
+ return feat, down
398
+
399
+
400
+ class UpBlock(nn.Module):
401
+ def __init__(self, in_ch: int, skip_ch: int, out_ch: int, num_res: int = 2):
402
+ super().__init__()
403
+ self.reduce = ConvAct(in_ch + skip_ch, out_ch, 3, 1)
404
+ blocks = []
405
+ for _ in range(num_res):
406
+ blocks.append(ResidualBlock(out_ch))
407
+ blocks.append(SEBlock(out_ch))
408
+ self.block = nn.Sequential(*blocks)
409
+
410
+ def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
411
+ x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
412
+ x = torch.cat([x, skip], dim=1)
413
+ x = self.reduce(x)
414
+ return self.block(x)
415
+
416
+
417
+ class ChromaGuidedUNet(nn.Module):
418
+ """
419
+ Input features:
420
+ raw branch: [low(3), naka(3), delta(3)]
421
+ norm branch: [low_norm(3), naka_norm(3), delta_norm(3)]
422
+ total input channels = 18
423
+
424
+ Output:
425
+ mul_map: [B,1,H,W], single-channel multiplicative correction
426
+ add_map: [B,3,H,W], additive correction
427
+ naka is decomposed into low/high frequency parts:
428
+ naka = naka_lf + naka_hf
429
+ correction is applied only on low-frequency content:
430
+ base = naka_lf * mul_map + add_map
431
+ enhanced = clamp(base + naka_hf, 0, 1)
432
+ """
433
+ def __init__(
434
+ self,
435
+ base_ch: int = 32,
436
+ mul_range: float = 0.6,
437
+ add_range: float = 0.25,
438
+ hf_kernel_size: int = 5,
439
+ hf_sigma: float = 1.0,
440
+ ):
441
+ super().__init__()
442
+ self.mul_range = mul_range
443
+ self.add_range = add_range
444
+ self.hf_kernel_size = hf_kernel_size
445
+ self.hf_sigma = hf_sigma
446
+ self.input_std = InputStandardizer()
447
+
448
+ self.stem = ConvAct(18, base_ch, 3, 1)
449
+ self.down1 = DownBlock(base_ch, base_ch, num_res=2)
450
+ self.down2 = DownBlock(base_ch, base_ch * 2, num_res=2)
451
+ self.down3 = DownBlock(base_ch * 2, base_ch * 4, num_res=3)
452
+
453
+ self.bottleneck = nn.Sequential(
454
+ ConvAct(base_ch * 4, base_ch * 8, 3, 1),
455
+ ResidualBlock(base_ch * 8),
456
+ ResidualBlock(base_ch * 8),
457
+ SEBlock(base_ch * 8),
458
+ )
459
+
460
+ self.up3 = UpBlock(base_ch * 8, base_ch * 4, base_ch * 4, num_res=2)
461
+ self.up2 = UpBlock(base_ch * 4, base_ch * 2, base_ch * 2, num_res=2)
462
+ self.up1 = UpBlock(base_ch * 2, base_ch, base_ch, num_res=2)
463
+
464
+ self.fuse = nn.Sequential(
465
+ ConvAct(base_ch, base_ch, 3, 1),
466
+ ResidualBlock(base_ch),
467
+ )
468
+
469
+ self.mul_head = nn.Conv2d(base_ch, 1, 3, 1, 1)
470
+ self.add_head = nn.Conv2d(base_ch, 3, 3, 1, 1)
471
+
472
+ def forward(self, low: torch.Tensor, naka: torch.Tensor) -> Dict[str, torch.Tensor]:
473
+ delta = naka - low
474
+
475
+ low_n = self.input_std(low)
476
+ naka_n = self.input_std(naka)
477
+ delta_n = self.input_std(delta)
478
+ x = torch.cat([low, naka, delta, low_n, naka_n, delta_n], dim=1)
479
+
480
+ x0 = self.stem(x)
481
+ s1, d1 = self.down1(x0)
482
+ s2, d2 = self.down2(d1)
483
+ s3, d3 = self.down3(d2)
484
+
485
+ b = self.bottleneck(d3)
486
+ u3 = self.up3(b, s3)
487
+ u2 = self.up2(u3, s2)
488
+ u1 = self.up1(u2, s1)
489
+ feat = self.fuse(u1)
490
+
491
+ mul_res = torch.tanh(self.mul_head(feat)) * self.mul_range
492
+ add_map = torch.tanh(self.add_head(feat)) * self.add_range
493
+ mul_map = 1.0 + mul_res
494
+
495
+ naka_lf = gaussian_blur_tensor(naka, kernel_size=self.hf_kernel_size, sigma=self.hf_sigma)
496
+ naka_hf = naka - naka_lf
497
+
498
+ base = naka_lf * mul_map + add_map
499
+ enhanced = torch.clamp(base + naka_hf, 0.0, 1.0)
500
+ return {
501
+ "enhanced": enhanced,
502
+ "mul_map": mul_map,
503
+ "add_map": add_map,
504
+ }
505
+
506
+
507
+ def adapt_mul_head_to_single_channel(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
508
+ """
509
+ Backward compatibility:
510
+ convert an old 3-channel mul_head checkpoint to the new single-channel mul_head
511
+ by averaging the three output filters/biases.
512
+ """
513
+ adapted = dict(state_dict)
514
+
515
+ if "mul_head.weight" in adapted and adapted["mul_head.weight"].ndim == 4 and adapted["mul_head.weight"].shape[0] == 3:
516
+ adapted["mul_head.weight"] = adapted["mul_head.weight"].mean(dim=0, keepdim=True)
517
+
518
+ if "mul_head.bias" in adapted and adapted["mul_head.bias"].ndim == 1 and adapted["mul_head.bias"].shape[0] == 3:
519
+ adapted["mul_head.bias"] = adapted["mul_head.bias"].mean(dim=0, keepdim=True)
520
+
521
+ return adapted
522
+
523
+
524
+ def load_model_state_flexible(model: nn.Module, ckpt_obj: Dict[str, torch.Tensor]) -> None:
525
+ state_dict = ckpt_obj["model"] if "model" in ckpt_obj else ckpt_obj
526
+ state_dict = adapt_mul_head_to_single_channel(state_dict)
527
+ model.load_state_dict(state_dict, strict=True)
528
+
529
+
530
+ # -----------------------------------------------------------------------------
531
+ # Perceptual feature extractor
532
+ # -----------------------------------------------------------------------------
533
+ class VGGFeatureExtractor(nn.Module):
534
+ def __init__(self, layer_ids: Tuple[int, ...] = (3, 8, 17, 26)):
535
+ super().__init__()
536
+ self.enabled = vgg19 is not None
537
+ self.layer_ids = layer_ids
538
+ if not self.enabled:
539
+ self.features = None
540
+ self.mean = None
541
+ self.std = None
542
+ return
543
+
544
+ try:
545
+ model = vgg19(weights=VGG19_Weights.IMAGENET1K_V1)
546
+ except Exception:
547
+ model = vgg19(weights=None)
548
+ self.features = model.features.eval()
549
+ for p in self.features.parameters():
550
+ p.requires_grad = False
551
+
552
+ self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
553
+ self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
554
+
555
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
556
+ if self.features is None:
557
+ return []
558
+ x = (x - self.mean) / self.std
559
+ feats = []
560
+ for i, layer in enumerate(self.features):
561
+ x = layer(x)
562
+ if i in self.layer_ids:
563
+ feats.append(x)
564
+ return feats
565
+
566
+
567
+ # -----------------------------------------------------------------------------
568
+ # Losses
569
+ # -----------------------------------------------------------------------------
570
+ class NakaCorrectionLoss(nn.Module):
571
+ def __init__(
572
+ self,
573
+ lambda_rgb: float = 1.0,
574
+ lambda_chroma: float = 0.5,
575
+ lambda_ssim: float = 0.3,
576
+ lambda_edge: float = 0.2,
577
+ lambda_feat: float = 0.15,
578
+ lambda_reg: float = 0.02,
579
+ lambda_mse: float = 0.0,
580
+ mse_on: str = "rgb",
581
+ ):
582
+ super().__init__()
583
+ self.lambda_rgb = lambda_rgb
584
+ self.lambda_chroma = lambda_chroma
585
+ self.lambda_ssim = lambda_ssim
586
+ self.lambda_edge = lambda_edge
587
+ self.lambda_feat = lambda_feat
588
+ self.lambda_reg = lambda_reg
589
+ self.lambda_mse = lambda_mse
590
+ self.mse_on = mse_on
591
+ self.vgg = VGGFeatureExtractor()
592
+
593
+ def forward(self, pred_dict: Dict[str, torch.Tensor], gt: torch.Tensor, naka: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, float]]:
594
+ pred = pred_dict["enhanced"]
595
+ mul_map = pred_dict["mul_map"]
596
+ add_map = pred_dict["add_map"]
597
+
598
+ loss_rgb = charbonnier_loss(pred, gt) + 0.5 * F.l1_loss(pred, gt)
599
+
600
+ pred_ycc = rgb_to_ycbcr(pred)
601
+ gt_ycc = rgb_to_ycbcr(gt)
602
+ loss_chroma = F.l1_loss(pred_ycc[:, 1:], gt_ycc[:, 1:]) + 0.2 * F.l1_loss(pred_ycc[:, :1], gt_ycc[:, :1])
603
+
604
+ loss_ssim = ssim_loss(pred, gt)
605
+ loss_edge = F.l1_loss(edge_map(pred), edge_map(gt))
606
+
607
+ loss_feat = pred.new_tensor(0.0)
608
+ pred_feats = self.vgg(pred)
609
+ gt_feats = self.vgg(gt)
610
+ if len(pred_feats) == len(gt_feats) and len(pred_feats) > 0:
611
+ for pf, gf in zip(pred_feats, gt_feats):
612
+ loss_feat = loss_feat + F.l1_loss(pf, gf)
613
+ loss_feat = loss_feat / len(pred_feats)
614
+
615
+ id_mul = F.l1_loss(mul_map, torch.ones_like(mul_map))
616
+ id_add = F.l1_loss(add_map, torch.zeros_like(add_map))
617
+ smooth_mul = F.l1_loss(mul_map[:, :, :, 1:], mul_map[:, :, :, :-1]) + F.l1_loss(mul_map[:, :, 1:, :], mul_map[:, :, :-1, :])
618
+ smooth_add = F.l1_loss(add_map[:, :, :, 1:], add_map[:, :, :, :-1]) + F.l1_loss(add_map[:, :, 1:, :], add_map[:, :, :-1, :])
619
+ improve_consistency = 0.1 * torch.relu(F.l1_loss(pred, gt) - F.l1_loss(naka, gt))
620
+ loss_reg = id_mul + id_add + 0.5 * (smooth_mul + smooth_add) + improve_consistency
621
+
622
+ if self.mse_on == "rgb":
623
+ loss_mse = F.mse_loss(pred, gt)
624
+ elif self.mse_on == "chroma":
625
+ loss_mse = F.mse_loss(pred_ycc[:, 1:], gt_ycc[:, 1:])
626
+ elif self.mse_on == "y":
627
+ loss_mse = F.mse_loss(pred_ycc[:, :1], gt_ycc[:, :1])
628
+ else:
629
+ raise ValueError(f"Unsupported mse_on: {self.mse_on}")
630
+
631
+ total = (
632
+ self.lambda_rgb * loss_rgb
633
+ + self.lambda_chroma * loss_chroma
634
+ + self.lambda_ssim * loss_ssim
635
+ + self.lambda_edge * loss_edge
636
+ + self.lambda_feat * loss_feat
637
+ + self.lambda_reg * loss_reg
638
+ + self.lambda_mse * loss_mse
639
+ )
640
+
641
+ metrics = {
642
+ "loss": float(total.detach().item()),
643
+ "rgb": float(loss_rgb.detach().item()),
644
+ "chroma": float(loss_chroma.detach().item()),
645
+ "ssim": float(loss_ssim.detach().item()),
646
+ "edge": float(loss_edge.detach().item()),
647
+ "feat": float(loss_feat.detach().item()),
648
+ "reg": float(loss_reg.detach().item()),
649
+ "mse": float(loss_mse.detach().item()),
650
+ }
651
+ return total, metrics
652
+
653
+
654
+ # -----------------------------------------------------------------------------
655
+ # Validation / inference helpers
656
+ # -----------------------------------------------------------------------------
657
+ def psnr(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
658
+ mse = F.mse_loss(pred, target)
659
+ return 10.0 * torch.log10(1.0 / (mse + 1e-8))
660
+
661
+
662
+ def make_naka_processor() -> Phototransduction:
663
+ return Phototransduction(
664
+ mode="naka",
665
+ per_channel=True,
666
+ naka_sigma=0.05,
667
+ clip_percentile=99.9,
668
+ out_mode="0_1",
669
+ out_method="linear",
670
+ )
671
+
672
+
673
+ @torch.no_grad()
674
+ def forward_full_or_tiled(
675
+ model: nn.Module,
676
+ low: torch.Tensor,
677
+ naka: torch.Tensor,
678
+ tile_size: int = 0,
679
+ tile_overlap: int = 32,
680
+ ) -> Dict[str, torch.Tensor]:
681
+ _, _, h, w = low.shape
682
+ if tile_size <= 0 or (h <= tile_size and w <= tile_size):
683
+ return model(low, naka)
684
+
685
+ step = max(tile_size - tile_overlap, 1)
686
+ enhanced_acc = torch.zeros_like(naka)
687
+ add_acc = torch.zeros_like(naka)
688
+ weight_acc = torch.zeros_like(naka)
689
+
690
+ b = low.shape[0]
691
+ mul_acc = low.new_zeros((b, 1, h, w))
692
+ mul_weight_acc = low.new_zeros((b, 1, h, w))
693
+
694
+ for top in range(0, h, step):
695
+ for left in range(0, w, step):
696
+ bottom = min(top + tile_size, h)
697
+ right = min(left + tile_size, w)
698
+ top = max(0, bottom - tile_size)
699
+ left = max(0, right - tile_size)
700
+
701
+ low_tile = low[:, :, top:bottom, left:right]
702
+ naka_tile = naka[:, :, top:bottom, left:right]
703
+ pred = model(low_tile, naka_tile)
704
+
705
+ weight = torch.ones_like(pred["enhanced"])
706
+ mul_weight = torch.ones_like(pred["mul_map"])
707
+
708
+ enhanced_acc[:, :, top:bottom, left:right] += pred["enhanced"] * weight
709
+ mul_acc[:, :, top:bottom, left:right] += pred["mul_map"] * mul_weight
710
+ add_acc[:, :, top:bottom, left:right] += pred["add_map"] * weight
711
+ weight_acc[:, :, top:bottom, left:right] += weight
712
+ mul_weight_acc[:, :, top:bottom, left:right] += mul_weight
713
+
714
+ enhanced = enhanced_acc / weight_acc.clamp_min(1e-6)
715
+ mul_map = mul_acc / mul_weight_acc.clamp_min(1e-6)
716
+ add_map = add_acc / weight_acc.clamp_min(1e-6)
717
+ enhanced = torch.clamp(enhanced, 0.0, 1.0)
718
+ return {"enhanced": enhanced, "mul_map": mul_map, "add_map": add_map}
719
+
720
+
721
+ @torch.no_grad()
722
+ def validate(
723
+ model: nn.Module,
724
+ criterion: NakaCorrectionLoss,
725
+ loader: DataLoader,
726
+ device: torch.device,
727
+ save_dir: Optional[str] = None,
728
+ max_save: int = 8,
729
+ tile_size: int = 0,
730
+ tile_overlap: int = 32,
731
+ ) -> Dict[str, float]:
732
+ model.eval()
733
+ loss_sum = 0.0
734
+ psnr_sum = 0.0
735
+ count = 0
736
+ saved = 0
737
+
738
+ if save_dir is not None:
739
+ os.makedirs(save_dir, exist_ok=True)
740
+
741
+ for batch in loader:
742
+ low = batch["low"].to(device, non_blocking=True)
743
+ naka = batch["naka"].to(device, non_blocking=True)
744
+ gt = batch["gt"].to(device, non_blocking=True)
745
+ names = batch["name"]
746
+
747
+ pred_dict = forward_full_or_tiled(model, low, naka, tile_size=tile_size, tile_overlap=tile_overlap)
748
+ loss, _ = criterion(pred_dict, gt, naka)
749
+
750
+ bs = low.size(0)
751
+ loss_sum += float(loss.item()) * bs
752
+ psnr_sum += float(psnr(pred_dict["enhanced"], gt).item()) * bs
753
+ count += bs
754
+
755
+ if save_dir is not None and saved < max_save:
756
+ for i in range(bs):
757
+ if saved >= max_save:
758
+ break
759
+ stem, _ = os.path.splitext(names[i])
760
+ sample_dir = os.path.join(save_dir, stem)
761
+ os.makedirs(sample_dir, exist_ok=True)
762
+ save_rgb_tensor(low[i], os.path.join(sample_dir, f"{stem}_low.JPG"))
763
+ save_rgb_tensor(naka[i], os.path.join(sample_dir, f"{stem}_naka.JPG"))
764
+ save_rgb_tensor(pred_dict["enhanced"][i], os.path.join(sample_dir, f"{stem}_enhanced.JPG"))
765
+ save_rgb_tensor(gt[i], os.path.join(sample_dir, f"{stem}_gt.JPG"))
766
+ save_rgb_tensor(pred_dict["mul_map"][i].clamp(0, 2) / 2.0, os.path.join(sample_dir, f"{stem}_mul_map_vis.JPG"))
767
+ save_rgb_tensor((pred_dict["add_map"][i] + 0.25) / 0.5, os.path.join(sample_dir, f"{stem}_add_map_vis.JPG"))
768
+ saved += 1
769
+
770
+ return {
771
+ "val_loss": loss_sum / max(count, 1),
772
+ "val_psnr": psnr_sum / max(count, 1),
773
+ }
774
+ class NakaCorrectionLossWithMasks(nn.Module):
775
+ def __init__(self, base_loss: nn.Module, lambda_gray_edge: float = 0.5, lambda_bright: float = 0.8):
776
+ """
777
+ base_loss: 原始 NakaCorrectionLoss
778
+ lambda_gray_edge: 灰度边缘 mask 权重
779
+ lambda_bright: 亮区 mask 权重
780
+ """
781
+ super().__init__()
782
+ self.base_loss = base_loss
783
+ self.lambda_gray_edge = lambda_gray_edge
784
+ self.lambda_bright = lambda_bright
785
+
786
+ @staticmethod
787
+ def compute_gray_laplacian_mask(img: torch.Tensor) -> torch.Tensor:
788
+ """B x C x H x W -> gray edge mask B x 1 x H x W"""
789
+ img_np = img.permute(0, 2, 3, 1).cpu().numpy()
790
+ lap_masks = []
791
+ for i in range(img.shape[0]):
792
+ gray = 0.299*img_np[i,:,:,0] + 0.587*img_np[i,:,:,1] + 0.114*img_np[i,:,:,2]
793
+ lap = cv2.Laplacian(gray, cv2.CV_32F, ksize=3)
794
+ lap = np.abs(lap)
795
+ lap /= (lap.max() + 1e-8)
796
+ lap = np.sqrt(lap) # 压缩极端值
797
+ lap_masks.append(lap)
798
+ lap_masks = np.stack(lap_masks, axis=0)
799
+ lap_masks = torch.from_numpy(lap_masks).float().unsqueeze(1).to(img.device)
800
+ return lap_masks
801
+
802
+ @staticmethod
803
+ def compute_bright_mask(img: torch.Tensor, percentile: float = 0.85) -> torch.Tensor:
804
+ """B x C x H x W -> bright mask B x 1 x H x W"""
805
+ img_gray = 0.299*img[:,0:1] + 0.587*img[:,1:2] + 0.114*img[:,2:3]
806
+ threshold = torch.quantile(img_gray.view(img.shape[0], -1), percentile, dim=1).view(-1,1,1,1)
807
+ mask = (img_gray >= threshold).float()
808
+ return mask
809
+
810
+ def forward(self, pred_dict: Dict[str, torch.Tensor], gt: torch.Tensor, naka: torch.Tensor):
811
+ # 原始 base loss
812
+ total_loss, metrics = self.base_loss(pred_dict, gt, naka)
813
+
814
+ pred = pred_dict["enhanced"]
815
+
816
+ # 灰度边缘 mask
817
+ gray_mask = self.compute_gray_laplacian_mask(gt)
818
+ loss_gray = (gray_mask * torch.abs(pred - gt)).mean()
819
+
820
+ # 亮区 mask
821
+ bright_mask = self.compute_bright_mask(pred)
822
+ loss_bright = (bright_mask * torch.abs(pred - gt)).mean()
823
+
824
+ # 总 loss
825
+ total_loss = total_loss + self.lambda_gray_edge * loss_gray + self.lambda_bright * loss_bright
826
+
827
+ # 更新 metrics
828
+ metrics["gray_edge"] = float(loss_gray.detach().item())
829
+ metrics["bright_mask"] = float(loss_bright.detach().item())
830
+ metrics["loss"] = float(total_loss.detach().item())
831
+
832
+ return total_loss, metrics
833
+
834
+ # -----------------------------------------------------------------------------
835
+ # Training / inference
836
+ # -----------------------------------------------------------------------------
837
+ def train(args: argparse.Namespace) -> None:
838
+ seed_everything(args.seed)
839
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
840
+ os.makedirs(args.output_dir, exist_ok=True)
841
+ ckpt_dir = os.path.join(args.output_dir, "checkpoints")
842
+ vis_dir = os.path.join(args.output_dir, "val_vis")
843
+ os.makedirs(ckpt_dir, exist_ok=True)
844
+ os.makedirs(vis_dir, exist_ok=True)
845
+
846
+ train_set = NakaPairDataset(
847
+ root=args.data_root,
848
+ split="train",
849
+ crop_size=args.crop_size,
850
+ is_train=True,
851
+ cache_naka=False,
852
+ min_scale=args.train_min_scale,
853
+ max_scale=args.train_max_scale,
854
+ )
855
+ val_set = NakaPairDataset(
856
+ root=args.data_root,
857
+ split="val",
858
+ crop_size=args.crop_size,
859
+ is_train=False,
860
+ cache_naka=args.cache_naka,
861
+ )
862
+
863
+ train_loader = DataLoader(
864
+ train_set,
865
+ batch_size=args.batch_size,
866
+ shuffle=True,
867
+ num_workers=args.num_workers,
868
+ pin_memory=True,
869
+ drop_last=True,
870
+ )
871
+ # Validation uses full-resolution images, so batch_size must stay at 1.
872
+ val_loader = DataLoader(
873
+ val_set,
874
+ batch_size=1,
875
+ shuffle=False,
876
+ num_workers=args.num_workers,
877
+ pin_memory=True,
878
+ )
879
+
880
+ model = ChromaGuidedUNet(base_ch=args.base_ch, mul_range=args.mul_range, add_range=args.add_range, hf_kernel_size=args.hf_kernel_size, hf_sigma=args.hf_sigma).to(device)
881
+ base_loss = NakaCorrectionLoss(
882
+ lambda_rgb=args.lambda_rgb,
883
+ lambda_chroma=args.lambda_chroma,
884
+ lambda_ssim=args.lambda_ssim,
885
+ lambda_edge=args.lambda_edge,
886
+ lambda_feat=args.lambda_feat,
887
+ lambda_reg=args.lambda_reg,
888
+ lambda_mse=args.lambda_mse,
889
+ mse_on=args.mse_on,
890
+ ).to(device)
891
+
892
+ criterion = NakaCorrectionLossWithMasks(
893
+ base_loss=base_loss,
894
+ lambda_gray_edge=1, # 可调
895
+ lambda_bright=0.8 # 可调
896
+ ).to(device)
897
+
898
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
899
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
900
+ scaler = torch.amp.GradScaler("cuda", enabled=args.amp and device.type == "cuda")
901
+
902
+ start_epoch = 1
903
+ best_psnr = -1e9
904
+
905
+ if args.resume_ckpt:
906
+ ckpt = load_torch_checkpoint(args.resume_ckpt, map_location=device)
907
+ load_model_state_flexible(model, ckpt)
908
+ if not args.reset_optimizer:
909
+ if "optimizer" in ckpt:
910
+ optimizer.load_state_dict(ckpt["optimizer"])
911
+ if "scheduler" in ckpt:
912
+ try:
913
+ scheduler.load_state_dict(ckpt["scheduler"])
914
+ except Exception as e:
915
+ print(f"[Warning] Failed to load scheduler state: {e}. Scheduler will be reinitialized.")
916
+ start_epoch = int(ckpt.get("epoch", 0)) + 1
917
+ best_psnr = float(ckpt.get("best_psnr", -1e9))
918
+ print(f"Loaded resume checkpoint: {args.resume_ckpt}")
919
+ elif args.init_ckpt:
920
+ ckpt = load_torch_checkpoint(args.init_ckpt, map_location=device)
921
+ load_model_state_flexible(model, ckpt)
922
+ print(f"Loaded init checkpoint: {args.init_ckpt}")
923
+
924
+ end_epoch = start_epoch + args.epochs - 1
925
+ for epoch in range(start_epoch, end_epoch + 1):
926
+ model.train()
927
+ running_loss = 0.0
928
+ running_psnr = 0.0
929
+ count = 0
930
+
931
+ for batch in train_loader:
932
+ low = batch["low"].to(device, non_blocking=True)
933
+ naka = batch["naka"].to(device, non_blocking=True)
934
+ gt = batch["gt"].to(device, non_blocking=True)
935
+
936
+ optimizer.zero_grad(set_to_none=True)
937
+ with torch.amp.autocast("cuda", enabled=args.amp and device.type == "cuda"):
938
+ pred_dict = model(low, naka)
939
+ loss, _ = criterion(pred_dict, gt, naka)
940
+
941
+ scaler.scale(loss).backward()
942
+ scaler.unscale_(optimizer)
943
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
944
+ scaler.step(optimizer)
945
+ scaler.update()
946
+
947
+ batch_psnr = psnr(pred_dict["enhanced"], gt)
948
+ running_loss += float(loss.item()) * low.size(0)
949
+ running_psnr += float(batch_psnr.item()) * low.size(0)
950
+ count += low.size(0)
951
+
952
+ scheduler.step()
953
+
954
+ train_log = {
955
+ "train_loss": running_loss / max(count, 1),
956
+ "train_psnr": running_psnr / max(count, 1),
957
+ }
958
+ val_log = validate(
959
+ model,
960
+ criterion,
961
+ val_loader,
962
+ device,
963
+ save_dir=os.path.join(vis_dir, f"epoch_{epoch:03d}"),
964
+ max_save=4,
965
+ tile_size=args.val_tile_size,
966
+ tile_overlap=args.tile_overlap,
967
+ )
968
+
969
+ print(
970
+ f"Epoch [{epoch:03d}/{end_epoch:03d}] "
971
+ f"train_loss={train_log['train_loss']:.4f} "
972
+ f"train_psnr={train_log['train_psnr']:.2f} "
973
+ f"val_loss={val_log['val_loss']:.4f} "
974
+ f"val_psnr={val_log['val_psnr']:.2f}"
975
+ )
976
+
977
+ latest_path = os.path.join(ckpt_dir, "latest.pth")
978
+ torch.save(
979
+ {
980
+ "epoch": epoch,
981
+ "model": model.state_dict(),
982
+ "optimizer": optimizer.state_dict(),
983
+ "scheduler": scheduler.state_dict(),
984
+ "args": vars(args),
985
+ "best_psnr": best_psnr,
986
+ },
987
+ latest_path,
988
+ )
989
+
990
+ if val_log["val_psnr"] > best_psnr:
991
+ best_psnr = val_log["val_psnr"]
992
+ best_path = os.path.join(ckpt_dir, "best.pth")
993
+ torch.save(
994
+ {
995
+ "epoch": epoch,
996
+ "model": model.state_dict(),
997
+ "optimizer": optimizer.state_dict(),
998
+ "scheduler": scheduler.state_dict(),
999
+ "args": vars(args),
1000
+ "best_psnr": best_psnr,
1001
+ },
1002
+ best_path,
1003
+ )
1004
+ print(f"Saved best checkpoint to: {best_path}")
1005
+
1006
+
1007
+ @torch.no_grad()
1008
+ def inference(args: argparse.Namespace) -> None:
1009
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1010
+ os.makedirs(args.output_dir, exist_ok=True)
1011
+
1012
+ model = ChromaGuidedUNet(base_ch=args.base_ch, mul_range=args.mul_range, add_range=args.add_range, hf_kernel_size=args.hf_kernel_size, hf_sigma=args.hf_sigma).to(device)
1013
+ ckpt = load_torch_checkpoint(args.ckpt, map_location=device)
1014
+ load_model_state_flexible(model, ckpt)
1015
+ model.eval()
1016
+
1017
+ naka_processor = make_naka_processor()
1018
+ paths = list_image_files(args.input_dir)
1019
+
1020
+ for path in paths:
1021
+ low_rgb = load_rgb(path)
1022
+ low_float = low_rgb.astype(np.float32) / 255.0
1023
+ low_bgr = cv2.cvtColor(low_rgb, cv2.COLOR_RGB2BGR)
1024
+ naka_bgr = naka_processor(low_bgr)
1025
+ naka_rgb = cv2.cvtColor(naka_bgr.astype(np.float32), cv2.COLOR_BGR2RGB)
1026
+ naka_rgb = np.clip(naka_rgb, 0.0, 1.0).astype(np.float32)
1027
+
1028
+ low_t = torch.from_numpy(np.ascontiguousarray(low_float)).permute(2, 0, 1).unsqueeze(0).float().to(device)
1029
+ naka_t = torch.from_numpy(np.ascontiguousarray(naka_rgb)).permute(2, 0, 1).unsqueeze(0).float().to(device)
1030
+ pred_dict = forward_full_or_tiled(
1031
+ model,
1032
+ low_t,
1033
+ naka_t,
1034
+ tile_size=args.tile_size,
1035
+ tile_overlap=args.tile_overlap,
1036
+ )
1037
+
1038
+ name = os.path.splitext(os.path.basename(path))[0]
1039
+ save_rgb_tensor(pred_dict["enhanced"][0], os.path.join(args.output_dir, f"{name}_enhanced.JPG"))
1040
+ #save_rgb_tensor(pred_dict["mul_map"][0].clamp(0, 2) / 2.0, os.path.join(args.output_dir, f"{name}_mul_vis.JPG"))
1041
+ #save_rgb_tensor((pred_dict["add_map"][0] + 0.25) / 0.5, os.path.join(args.output_dir, f"{name}_add_vis.JPG"))
1042
+
1043
+
1044
+ # -----------------------------------------------------------------------------
1045
+ # Main
1046
+ # -----------------------------------------------------------------------------
1047
+ def build_parser() -> argparse.ArgumentParser:
1048
+ parser = argparse.ArgumentParser("Naka-guided color-correction network (multi-scale + adaptive input standardization)")
1049
+ parser.add_argument("--mode", type=str, default="train", choices=["train", "infer"])
1050
+ parser.add_argument("--data_root", type=str, default="./datasets/LOLv1")
1051
+ parser.add_argument("--input_dir", type=str, default="./test_images")
1052
+ parser.add_argument("--output_dir", type=str, default="./outputs/naka_color_correction_v2")
1053
+ parser.add_argument("--ckpt", type=str, default="./outputs/naka_color_correction_v2/checkpoints/best.pth")
1054
+ parser.add_argument("--resume_ckpt", type=str, default="", help="Resume training from a saved checkpoint and continue epoch count.")
1055
+ parser.add_argument("--init_ckpt", type=str, default="", help="Initialize model weights from a checkpoint and start a fresh optimization run.")
1056
+ parser.add_argument("--reset_optimizer", action="store_true", help="When used with --resume_ckpt, only load model weights and reset optimizer/scheduler.")
1057
+
1058
+ parser.add_argument("--epochs", type=int, default=200)
1059
+ parser.add_argument("--batch_size", type=int, default=8)
1060
+ parser.add_argument("--num_workers", type=int, default=4)
1061
+ parser.add_argument("--crop_size", type=int, default=256)
1062
+ parser.add_argument("--lr", type=float, default=2e-4)
1063
+ parser.add_argument("--weight_decay", type=float, default=1e-4)
1064
+ parser.add_argument("--base_ch", type=int, default=32)
1065
+ parser.add_argument("--seed", type=int, default=42)
1066
+ parser.add_argument("--cache_naka", action="store_true")
1067
+ parser.add_argument("--amp", action="store_true")
1068
+
1069
+ parser.add_argument("--mul_range", type=float, default=0.6)
1070
+ parser.add_argument("--add_range", type=float, default=0.25)
1071
+ parser.add_argument("--hf_kernel_size", type=int, default=5, help="Odd Gaussian kernel size for low/high-frequency decomposition.")
1072
+ parser.add_argument("--hf_sigma", type=float, default=1.0, help="Gaussian sigma for low/high-frequency decomposition.")
1073
+ parser.add_argument("--train_min_scale", type=float, default=0.7)
1074
+ parser.add_argument("--train_max_scale", type=float, default=1.4)
1075
+
1076
+ parser.add_argument("--val_tile_size", type=int, default=0, help="0 means full-resolution validation without tiling")
1077
+ parser.add_argument("--tile_size", type=int, default=0, help="0 means full-resolution inference without tiling")
1078
+ parser.add_argument("--tile_overlap", type=int, default=32)
1079
+
1080
+ parser.add_argument("--lambda_rgb", type=float, default=1.0)
1081
+ parser.add_argument("--lambda_chroma", type=float, default=0.5)
1082
+ parser.add_argument("--lambda_ssim", type=float, default=0.3)
1083
+ parser.add_argument("--lambda_edge", type=float, default=0.2)
1084
+ parser.add_argument("--lambda_feat", type=float, default=0.15)
1085
+ parser.add_argument("--lambda_reg", type=float, default=0.02)
1086
+ parser.add_argument("--lambda_mse", type=float, default=0.0, help="Weight for extra MSE loss term. Keep small to avoid oversmoothing.")
1087
+ parser.add_argument("--mse_on", type=str, default="rgb", choices=["rgb", "chroma", "y"], help="Where to apply the extra MSE term.")
1088
+ return parser
1089
+
1090
+
1091
+ if __name__ == "__main__":
1092
+ args = build_parser().parse_args()
1093
+ if args.mode == "train":
1094
+ train(args)
1095
+ else:
1096
+ inference(args)
phototransduction.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Literal, Optional
3
+ import cv2
4
+
5
+
6
+ Array = np.ndarray
7
+ _Mode = Literal["log", "naka"]
8
+ _Out = Literal["zero_mean", "0_1"]
9
+
10
+
11
+ class Phototransduction:
12
+ def __init__(
13
+ self,
14
+ log_sigma: Optional[float] = None,
15
+ mode: _Mode = "log",
16
+ naka_n: float = 1.0,
17
+ naka_sigma: Optional[float] = None,
18
+ clip_percentile: Optional[float] = 99.9,
19
+ local_radius: int = 5,
20
+ per_channel: bool = True,
21
+ eps: float = 1e-3,
22
+ out_mode: _Out = "zero_mean",
23
+ sym_clip_tau: float = 3,
24
+ out_method: str = "symmetric",
25
+ out_dtype = np.float32,
26
+ ):
27
+ self.log_sigma = log_sigma
28
+ self.mode = mode
29
+ self.naka_n = float(naka_n)
30
+ self.naka_sigma = naka_sigma
31
+ self.clip_percentile = clip_percentile
32
+ self.local_radius = int(local_radius)
33
+ self.per_channel = bool(per_channel)
34
+ self.eps = float(eps)
35
+ self.out_mode = out_mode
36
+ self.sym_clip_tau = float(sym_clip_tau)
37
+ self.out_method = out_method
38
+ self.out_dtype = out_dtype
39
+
40
+ # ---------- public API ----------
41
+ def __call__(self, I: Array) -> Array:
42
+ x = self._to_float01(I)
43
+
44
+ if self.mode == "log":
45
+ effective_log_sigma = self._auto_log_sigma(x) if self.log_sigma is None else self.log_sigma
46
+ x = self._log_compress(x, effective_log_sigma)
47
+ elif self.mode == "naka":
48
+ effective_naka_sigma = self._auto_naka_sigma(x) if self.naka_sigma is None else self.naka_sigma
49
+ x = self._naka_rushton(x, n=self.naka_n, sigma=effective_naka_sigma)
50
+ else:
51
+ raise ValueError(f"Unknown mode: {self.mode}")
52
+
53
+ if self.out_mode == "0_1":
54
+ if self.out_method == "symmetric":
55
+ x = self._to_01_symmetric(x, tau=self.sym_clip_tau)
56
+ elif self.out_method == "percentile":
57
+ x = self._to_01_percentile(x, lower_pct=2.5, upper_pct=97.5)
58
+ elif self.out_method == "linear":
59
+ x = self._to_01_linear(x)
60
+ elif self.out_method == "histogram":
61
+ x = self._to_01_histogram(x)
62
+ else:
63
+ raise ValueError(f"Unknown out_method: {self.out_method}")
64
+ elif self.out_mode == "zero_mean":
65
+ pass
66
+ else:
67
+ raise ValueError(f"Unknown out_mode: {self.out_mode}")
68
+
69
+ return x.astype(self.out_dtype, copy=False)
70
+
71
+ @staticmethod
72
+ def _to_01_symmetric(x: Array, tau: float = 3.0) -> Array:
73
+ x_clip = np.clip(x, -tau, tau)
74
+ return (x_clip + tau) / (2.0 * tau)
75
+
76
+ @staticmethod
77
+ def _to_01_percentile(x: Array, lower_pct: float = 1.0, upper_pct: float = 99.0) -> Array:
78
+ lower = np.percentile(x, lower_pct)
79
+ upper = np.percentile(x, upper_pct)
80
+
81
+ x_clip = np.clip(x, lower, upper)
82
+ return (x_clip - lower) / (upper - lower + 1e-12)
83
+
84
+ @staticmethod
85
+ def _to_01_linear(x: Array) -> Array:
86
+ x_min = x.min()
87
+ x_max = x.max()
88
+
89
+ if x_max - x_min < 1e-6:
90
+ return np.zeros_like(x)
91
+
92
+ return (x - x_min) / (x_max - x_min)
93
+
94
+ @staticmethod
95
+ def _to_01_histogram(x: Array) -> Array:
96
+ x_min = x.min()
97
+ x_max = x.max()
98
+
99
+ if x_max - x_min < 1e-6:
100
+ return np.zeros_like(x)
101
+
102
+ x_norm = (x - x_min) / (x_max - x_min)
103
+ x_uint8 = (x_norm * 255).astype(np.uint8)
104
+
105
+ if len(x.shape) == 3:
106
+ x_yuv = cv2.cvtColor(x_uint8, cv2.COLOR_RGB2YUV)
107
+ x_yuv[:,:,0] = cv2.equalizeHist(x_yuv[:,:,0])
108
+ x_eq = cv2.cvtColor(x_yuv, cv2.COLOR_YUV2RGB)
109
+ else:
110
+ x_eq = cv2.equalizeHist(x_uint8)
111
+
112
+ return x_eq.astype(np.float32) / 255.0
113
+
114
+ def _to_float01(self, I: Array) -> Array:
115
+ if np.issubdtype(I.dtype, np.integer):
116
+ maxv = np.iinfo(I.dtype).max
117
+ x = I.astype(np.float32) / float(maxv)
118
+ return np.clip(x, 0.0, 1.0)
119
+ x = I.astype(np.float32, copy=False)
120
+ if self.clip_percentile is None:
121
+ maxv = float(np.max(x)) if x.size else 1.0
122
+ if maxv <= 1.0 + 1e-6:
123
+ return np.clip(x, 0.0, 1.0)
124
+ return np.clip(x / (maxv + 1e-12), 0.0, 1.0)
125
+
126
+ hi = np.percentile(x, self.clip_percentile)
127
+ if hi <= 1e-12:
128
+ return np.zeros_like(x, dtype=np.float32)
129
+ return np.clip(x / hi, 0.0, 1.0)
130
+
131
+ def _auto_log_sigma(self, x: Array) -> float:
132
+ if x.ndim == 3:
133
+ brightness = np.mean(x, axis=2)
134
+ else:
135
+ brightness = x
136
+
137
+ median_brightness = np.median(brightness)
138
+
139
+ median_brightness = np.clip(median_brightness, 0.05, 0.95)
140
+
141
+ auto_sigma = median_brightness * 0.4
142
+
143
+ auto_sigma = np.clip(auto_sigma, 0.02, 0.5)
144
+
145
+ return float(auto_sigma)
146
+
147
+ def _auto_naka_sigma(self, x: Array) -> float:
148
+ if x.ndim == 3:
149
+ brightness = np.mean(x, axis=2)
150
+ else:
151
+ brightness = x
152
+
153
+ median_brightness = np.median(brightness)
154
+
155
+ auto_sigma = median_brightness * 0.25
156
+
157
+ auto_sigma = np.clip(auto_sigma, 0.01, 0.8)
158
+
159
+ if median_brightness < 0.05:
160
+ auto_sigma = max(auto_sigma, 0.05)
161
+
162
+ return float(auto_sigma)
163
+
164
+ @staticmethod
165
+ def _log_compress(x: Array, sigma: float) -> Array:
166
+ denom = np.log1p(1.0 / (sigma + 1e-12))
167
+ return np.log1p(x / (sigma + 1e-12)) / (denom + 1e-12)
168
+
169
+ @staticmethod
170
+ def _naka_rushton(x: Array, n: float, sigma: float) -> Array:
171
+ xn = np.power(np.clip(x, 0.0, None), n)
172
+ sig = np.power(max(sigma, 1e-8), n)
173
+ return xn / (xn + sig)
174
+
175
+ @staticmethod
176
+ def _zero_center(x: Array) -> Array:
177
+ if x.ndim == 3:
178
+ mu = np.mean(x, axis=(0, 1), keepdims=True)
179
+ else:
180
+ mu = np.mean(x, keepdims=True)
181
+ return x - mu
182
+
183
+ @staticmethod
184
+ def _to_01_from_zero_mean(x: Array, tau: float = 3.0) -> Array:
185
+ x_clip = np.clip(x, -tau, tau)
186
+ return (x_clip + tau) / (2.0 * tau)
187
+
188
+ def _gaussian_blur(self, x: Array, radius: int, per_channel: bool) -> Array:
189
+ if radius <= 0:
190
+ return x.copy()
191
+
192
+ if x.ndim == 2:
193
+ xx = x[..., None]
194
+ else:
195
+ xx = x
196
+
197
+ if not per_channel and xx.shape[2] > 1:
198
+ mean_ch = np.mean(xx, axis=2, keepdims=True)
199
+ sm = self._gauss_sep(mean_ch, radius)
200
+ sm = np.repeat(sm, xx.shape[2], axis=2)
201
+ return sm.squeeze() if x.ndim == 2 else sm
202
+
203
+ sm = self._gauss_sep(xx, radius)
204
+ return sm.squeeze() if x.ndim == 2 else sm
205
+
206
+ @staticmethod
207
+ def _gauss_kernel1d(radius: int) -> Array:
208
+ sigma = max(radius / 3.0, 1e-6)
209
+ ax = np.arange(-radius, radius + 1, dtype=np.float32)
210
+ k = np.exp(-0.5 * (ax / sigma) ** 2)
211
+ k /= np.sum(k)
212
+ return k.astype(np.float32)
213
+
214
+ def _gauss_sep(self, x: Array, radius: int) -> Array:
215
+ k = self._gauss_kernel1d(radius)
216
+ y = self._conv1d_h(x, k)
217
+ y = self._conv1d_v(y, k)
218
+ return y
219
+
220
+ @staticmethod
221
+ def _pad_reflect(x: Array, pad: int, axis: int) -> Array:
222
+ pad_width = [(0, 0)] * x.ndim
223
+ pad_width[axis] = (pad, pad)
224
+ return np.pad(x, pad_width, mode="reflect")
225
+
226
+ def _conv1d_h(self, x: Array, k: Array) -> Array:
227
+ pad = k.size // 2
228
+ xp = self._pad_reflect(x, pad, axis=1)
229
+ out = np.empty_like(xp[:, pad:-pad, :])
230
+ for c in range(x.shape[2]):
231
+ out[..., c] = np.apply_along_axis(lambda r: np.convolve(r, k, mode="valid"), 1, xp[..., c])
232
+ return out
233
+
234
+ def _conv1d_v(self, x: Array, k: Array) -> Array:
235
+ pad = k.size // 2
236
+ xp = self._pad_reflect(x, pad, axis=0)
237
+ out = np.empty_like(xp[pad:-pad, :, :])
238
+ for c in range(x.shape[2]):
239
+ out[..., c] = np.apply_along_axis(lambda r: np.convolve(r, k, mode="valid"), 0, xp[..., c])
240
+ return out
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy>=1.23
2
+ opencv-python>=4.8
3
+ Pillow>=9.0
4
+ torch>=2.1
5
+ torchvision>=0.16