IAT β€” Illumination Adaptive Transformer (ONNX)

First public ONNX export of the Illumination Adaptive Transformer (IAT) by Cui et al.

This repo contains three ONNX variants exported from the official PyTorch checkpoints, covering both low-light enhancement and exposure correction tasks.

Variants

File Checkpoint Training Data Use Case
onnx/iat_exposure.onnx best_Epoch_exposure.pth Exposure Errors Over/under-exposure correction
onnx/iat_lol_v1.onnx best_Epoch_lol_v1.pth LOL-V1 Low-light enhancement
onnx/iat_lol_v2.onnx best_Epoch_lol.pth LOL-V2 Low-light enhancement (improved)

Model Specs

Property Value
Parameters ~90K
File size ~0.1 MB per variant
Input shape (1, 3, H, W) float32, values in [0, 1]
Normalization None β€” just rescale to [0,1], no ImageNet mean/std
Output names mul, add, enhanced
Which output to use enhanced (index 2)
Dynamic axes batch, height, width
ONNX opset 17

Preprocessing

import numpy as np
from PIL import Image

img = Image.open("dark_photo.jpg").convert("RGB")
img_np = np.array(img).astype(np.float32) / 255.0   # [0, 1]
# Transpose to CHW and add batch dim
input_tensor = img_np.transpose(2, 0, 1)[np.newaxis, ...]  # (1, 3, H, W)

Important: Do NOT apply ImageNet normalization. The model expects raw [0, 1] pixel values.

Usage with ONNX Runtime

import numpy as np
import onnxruntime as ort
from PIL import Image

# Load model
session = ort.InferenceSession("onnx/iat_lol_v2.onnx", providers=["CPUExecutionProvider"])

# Preprocess
img = Image.open("dark_photo.jpg").convert("RGB")
img_np = np.array(img).astype(np.float32) / 255.0
input_tensor = img_np.transpose(2, 0, 1)[np.newaxis, ...]  # (1, 3, H, W)

# Run inference β€” use "enhanced" (index 2)
mul, add, enhanced = session.run(None, {"input": input_tensor})

# Post-process
enhanced = np.clip(enhanced[0], 0, 1)               # (3, H, W)
enhanced = (enhanced.transpose(1, 2, 0) * 255).astype(np.uint8)  # (H, W, 3)
result = Image.fromarray(enhanced)
result.save("enhanced.jpg")

ONNX Export Fixes

The original PyTorch code required three monkey-patches for clean ONNX tracing:

  1. IAT.apply_color: Replaced torch.tensordot(image, ccm, dims=[[-1], [-1]]) with torch.matmul(image, ccm.T) β€” tensordot with negative dimension indices is not supported by the ONNX exporter.

  2. IAT.forward: Replaced Python for-loop over the batch dimension (for i in range(b)) with vectorized torch.bmm for the color matrix multiply and broadcast ** for gamma correction. Python loops produce unrollable static graphs that break with dynamic batch sizes.

  3. Aff_channel.forward: Same tensordot to matmul fix as patch 1, applied to the channel affinity block in the local branch.

See export_iat_onnx.py in this repo for the full export script with patches.

Architecture

IAT is a lightweight image enhancement model with two branches:

  • Local branch: Learns per-pixel multiplicative (mul) and additive (add) adjustment maps via a shallow transformer. enhanced_local = input * mul + add
  • Global branch: Learns a 3x3 color correction matrix (CCM) and a scalar gamma value. Applied after local enhancement: enhanced = (enhanced_local @ CCM^T) ^ gamma

The combination of local pixel-wise adjustments and global color/tone correction makes it effective for both low-light enhancement and exposure correction, while keeping the model extremely small (~90K parameters).

Benchmark Results

Results from the original paper:

Dataset PSNR SSIM
LOL-V1 23.38 0.809
LOL-V2 23.50 0.824

Citation

@InProceedings{Cui_2022_BMVC,
    title     = {Illumination Adaptive Transformer},
    author    = {Cui, Ziteng and Li, Kunchang and Gu, Lin and Su, Shenghan and Gao, Peng and Jiang, Zhengkai and Qiao, Yu and Harada, Tatsuya},
    booktitle = {British Machine Vision Conference (BMVC)},
    year      = {2022}
}

License

Apache-2.0 β€” same as the original IAT repository.

Acknowledgments

Downloads last month
10
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support