ATD v2 XLarge SIDD Medium x4 Super-Resolution
JAX/Flax checkpoints for real-world x4 image super-resolution and restoration. The model is not TPU-only: inference can run on CPU, NVIDIA GPU, or TPU as long as JAX is installed with the matching backend.
Source code: https://github.com/BitIntx/sr-tpu
Model
- Architecture: ATD v2-style restoration transformer
- Preset:
atd_v2_xlarge - Parameters: 57,364,563
- Scale: x4
- Input mode: LR RGB input, HR RGB x4 output
- Training precision/backend: bfloat16 on Cloud TPU
- Base training run:
atd_v2_xlarge_real_v3_sidd_medium_long_x4 - W&B run: https://wandb.ai/jwheo/sr-tpu/runs/s25zi8zw
Included Files
checkpoint_390000_best_lpips/: recommended checkpoint. It is the nearest saved checkpoint to the best fixed-sample LPIPS point, logged at step 389400 with LPIPS0.3944.checkpoint_400000_latest/: latest saved base checkpoint, also used as the frozen SR base for residual-refiner experiments.infer.py: standalone x4 inference script for the base SR checkpoint.infer_refiner.py: optional base + residual-refiner inference script.sr_tpu/,train.py: runtime modules required by the inference scripts.requirements-infer.txt: shared non-JAX inference dependencies.
The raw training run saw best random-eval PSNR gain vs bicubic at step 321000, but for practical visual restoration checkpoint_390000_best_lpips is the recommended first pick.
Install
Create an environment, install the correct JAX backend, then install shared dependencies.
NVIDIA GPU, CUDA 13
python3 -m venv ~/venvs/sr-tpu-infer
source ~/venvs/sr-tpu-infer/bin/activate
pip install --upgrade pip
pip install --upgrade "jax[cuda13]"
pip install -r requirements-infer.txt
CUDA 12 users can use pip install --upgrade "jax[cuda12]" instead. See the official JAX installation guide for current platform support and driver requirements: https://docs.jax.dev/en/latest/installation.html
TPU VM
python3 -m venv ~/venvs/sr-tpu-infer
source ~/venvs/sr-tpu-infer/bin/activate
pip install --upgrade pip
pip install --upgrade "jax[tpu]"
pip install -r requirements-infer.txt
CPU
CPU is useful for smoke tests or small images, but will be slow for this xlarge model.
python3 -m venv ~/venvs/sr-tpu-infer
source ~/venvs/sr-tpu-infer/bin/activate
pip install --upgrade pip
pip install --upgrade jax
pip install -r requirements-infer.txt
Check the backend:
python - <<'PY'
import jax
print(jax.default_backend())
print(jax.devices())
PY
Run Inference
Single image on GPU:
python infer.py \
--checkpoint checkpoint_390000_best_lpips \
--input path/to/input.jpg \
--output samples/input_x4.png \
--platform gpu \
--save-bicubic \
--compare
Folder with tiling, recommended for larger photos:
python infer.py \
--checkpoint checkpoint_390000_best_lpips \
--input path/to/images \
--output samples/atd_v2_xlarge_x4 \
--platform gpu \
--tile-size 128 \
--tile-overlap 16 \
--save-bicubic \
--compare
TPU VM inference:
python infer.py \
--checkpoint checkpoint_390000_best_lpips \
--input path/to/images \
--output samples/atd_v2_xlarge_x4_tpu \
--platform tpu \
--tile-size 128 \
--tile-overlap 16
CPU smoke test:
python infer.py \
--checkpoint checkpoint_390000_best_lpips \
--input path/to/small.jpg \
--output samples/small_x4.png \
--platform cpu \
--max-side 128
infer.py writes restored images and, for folder runs, metrics.csv. With --compare, comparison sheets are saved under output/compare/.
Training Data
Prepared pair-manifest dataset:
- SIDD Medium sRGB noisy/GT pairs
- DIV2K HR clean self-pairs
- OST HR clean self-pairs
- Prepared dataset name:
sr_real_x4_v3_sidd_medium - Train rows: 14,159
- Validation rows: 505
The model was trained with mixed-denoise degradation, including phone-like noise/compression artifacts and clean self-pair synthetic degradation.
Notes
These checkpoints are experimental research artifacts, not a polished production model. They are useful for TPU/JAX restoration experiments, GPU/CPU inference experiments, and as a base for residual-refiner training.