RunyuZhu's picture
Update README.md
dec231c verified
---
library_name: pytorch
pipeline_tag: image-to-image
tags:
- low-light-image-enhancement
- image-enhancement
- image-to-image
- gaussian-splatting
- 3d-reconstruction
- custom-code
- pytorch
---
# Naka-guided Chroma-Correction Model
This repository hosts the **Naka-guided Chroma-correction model** used in the **Naka-GS** pipeline.
The model is designed to refine a Naka-enhanced low-light image by suppressing color distortion in bright regions while preserving edge and texture details. In the released implementation, the network predicts a single-channel multiplicative correction map and a three-channel additive correction map, and applies them only to the low-frequency component of the Naka-enhanced image before adding the preserved high-frequency details back to the final output. The model input is an 18-channel representation built from the low-light image, the Naka-enhanced image, their residual, and standardized counterparts.
## Associated resources
- **Project page / code**: `https://github.com/RunyuZhu/Naka-GS`
- **Paper page**: `https://huggingface.co/papers/2604.11142`
- **ArXiv**: `https://arxiv.org/abs/2604.11142`
## What this model does
Given a low-light RGB image:
1. a Naka phototransduction transform is applied,
2. the correction network predicts `mul_map` and `add_map`,
3. the low-frequency component of the Naka image is corrected,
4. the high-frequency component is added back,
5. the final enhanced image is saved.
In the provided code, inference saves the corrected result as `<image_name>_enhanced.JPG`.
## Model details
### Architecture
The released model is a U-Net-style encoder-decoder with residual blocks and SE attention. The core model class is `ChromaGuidedUNet`. Its forward pass takes `low` and `naka` tensors as input, constructs an 18-channel feature tensor, predicts `mul_map` and `add_map`, and performs frequency-decoupled correction on the Naka image.
### Input
- RGB low-light image
- Automatically generated Naka-enhanced intermediate image
### Output
- corrected RGB image: `enhanced`
- optional intermediate maps:
- `mul_map`
- `add_map`
### Checkpoints
Recommended checkpoint filenames:
- `best.pth`: the best one during eval
- `latest.pth`: latest training state. We usually use this one during inference
The training script saves `latest.pth` every epoch and updates `best.pth` whenever validation PSNR improves.
## Intended use
This model is intended to be used as the **color-correction / enhancement stage** in the Naka-GS low-light 3D reconstruction pipeline, or as a standalone low-light image refinement module when a Naka-style phototransduction preprocessing step is available.
## Limitations
- This repository contains **custom PyTorch code** and is **not** a Transformers-native model.
- The script depends on a custom `Phototransduction` implementation and tries to import it from either `retina.phototransduction` or `phototransduction`. For a standalone release, place `phototransduction.py` next to `naka_color_correction.py`, or preserve the original package layout.
- The model card does not claim broad robustness outside the training setting used by the original project.
## Repository layout
A minimal Hugging Face release layout is:
```text
.
β”œβ”€β”€ README.md
β”œβ”€β”€ requirements.txt
β”œβ”€β”€ naka_color_correction.py
β”œβ”€β”€ phototransduction.py
β”œβ”€β”€ best.pth
β”œβ”€β”€ latest.pth
└── assets/
β”œβ”€β”€ *** # optional
└── results.png # optional
```
## Installation
```bash
git clone https://huggingface.co/<your-username-or-org>/<your-model-repo>
cd <your-model-repo>
pip install -r requirements.txt
```
## Requirements
Core dependencies used directly in the provided script:
- `torch`
- `torchvision`
- `numpy`
- `opencv-python`
- `Pillow`
The script also uses `torchvision.models.vgg19` for the perceptual loss branch during training.
## Quick start: inference
### 1. Prepare files
Place the following files in the same directory:
- `naka_color_correction.py`
- `phototransduction.py`
- `latest.pth`/`best.pth`
Create an input folder such as:
```text
./test_images/
```
and put your test images inside.
### 2. Run inference
```bash
python naka_color_correction.py \
--mode infer \
--input_dir ./test_images \
--output_dir ./outputs/infer_results \
--ckpt ./latest.pth
```
### 3. Inference on large images
The code supports tiled forwarding for large inputs:
```bash
python naka_color_correction.py \
--mode infer \
--input_dir ./test_images \
--output_dir ./outputs/infer_results \
--ckpt ./best.pth \
--tile_size 512 \
--tile_overlap 32
```
The command-line parser exposes `--mode`, `--input_dir`, `--output_dir`, `--ckpt`, `--tile_size`, and `--tile_overlap` for inference.
## Training
### Dataset format
The training and validation data must follow this layout:
```text
datasets/
β”œβ”€β”€ train/
β”‚ β”œβ”€β”€ low/
β”‚ └── normal/
└── val/
β”œβ”€β”€ low/
└── normal/
```
Files are paired by identical filename between `low/` and `normal/`.
### Basic training command
```bash
python naka_color_correction.py \
--mode train \
--data_root ./datasets \
--output_dir ./outputs/naka_color_correction_v2 \
--epochs 200 \
--batch_size 8 \
--num_workers 4 \
--crop_size 256 \
--lr 2e-4 \
--weight_decay 1e-4 \
--base_ch 32 \
--amp
```
### Resume training
```bash
python naka_color_correction.py \
--mode train \
--data_root ./datasets \
--output_dir ./outputs/naka_color_correction_v2 \
--resume_ckpt ./outputs/naka_color_correction_v2/checkpoints/latest.pth \
--amp
```
### Initialize from a checkpoint
```bash
python naka_color_correction.py \
--mode train \
--data_root ./datasets \
--output_dir ./outputs/naka_color_correction_v2 \
--init_ckpt ./best.pth \
--amp
```
The parser defaults include `epochs=200`, `batch_size=8`, `crop_size=256`, `lr=2e-4`, `weight_decay=1e-4`, `base_ch=32`, `mul_range=0.6`, `add_range=0.25`, `hf_kernel_size=5`, and `hf_sigma=1.0`.
## Training objective
The provided implementation combines:
- RGB reconstruction loss
- YCbCr chroma/luma consistency loss
- SSIM loss
- edge loss
- VGG perceptual loss
- map regularization
- gray-edge masked loss
- bright-region masked loss
These are implemented through `NakaCorrectionLoss` and `NakaCorrectionLossWithMasks`.
## Notes on reproducibility
- Validation uses full-resolution images with `batch_size=1`.
- Mixed precision is enabled with `--amp` on CUDA.
- Checkpoint loading is backward-compatible with older 3-channel `mul_head` weights via `adapt_mul_head_to_single_channel()`.
## Suggested `requirements.txt`
```text
torch>=2.1.0
torchvision>=0.16.0
numpy>=1.24.0
opencv-python>=4.8.0
Pillow>=10.0.0
```