ATD v2 XLarge SIDD Medium x4 Super-Resolution

JAX/Flax checkpoints for real-world x4 image super-resolution and restoration. The model is not TPU-only: inference can run on CPU, NVIDIA GPU, or TPU as long as JAX is installed with the matching backend.

Source code: https://github.com/BitIntx/sr-tpu

Model

  • Architecture: ATD v2-style restoration transformer
  • Preset: atd_v2_xlarge
  • Parameters: 57,364,563
  • Scale: x4
  • Input mode: LR RGB input, HR RGB x4 output
  • Training precision/backend: bfloat16 on Cloud TPU
  • Base training run: atd_v2_xlarge_real_v3_sidd_medium_long_x4
  • W&B run: https://wandb.ai/jwheo/sr-tpu/runs/s25zi8zw

Included Files

  • checkpoint_390000_best_lpips/: recommended checkpoint. It is the nearest saved checkpoint to the best fixed-sample LPIPS point, logged at step 389400 with LPIPS 0.3944.
  • checkpoint_400000_latest/: latest saved base checkpoint, also used as the frozen SR base for residual-refiner experiments.
  • infer.py: standalone x4 inference script for the base SR checkpoint.
  • infer_refiner.py: optional base + residual-refiner inference script.
  • sr_tpu/, train.py: runtime modules required by the inference scripts.
  • requirements-infer.txt: shared non-JAX inference dependencies.

The raw training run saw best random-eval PSNR gain vs bicubic at step 321000, but for practical visual restoration checkpoint_390000_best_lpips is the recommended first pick.

Install

Create an environment, install the correct JAX backend, then install shared dependencies.

NVIDIA GPU, CUDA 13

python3 -m venv ~/venvs/sr-tpu-infer
source ~/venvs/sr-tpu-infer/bin/activate
pip install --upgrade pip
pip install --upgrade "jax[cuda13]"
pip install -r requirements-infer.txt

CUDA 12 users can use pip install --upgrade "jax[cuda12]" instead. See the official JAX installation guide for current platform support and driver requirements: https://docs.jax.dev/en/latest/installation.html

TPU VM

python3 -m venv ~/venvs/sr-tpu-infer
source ~/venvs/sr-tpu-infer/bin/activate
pip install --upgrade pip
pip install --upgrade "jax[tpu]"
pip install -r requirements-infer.txt

CPU

CPU is useful for smoke tests or small images, but will be slow for this xlarge model.

python3 -m venv ~/venvs/sr-tpu-infer
source ~/venvs/sr-tpu-infer/bin/activate
pip install --upgrade pip
pip install --upgrade jax
pip install -r requirements-infer.txt

Check the backend:

python - <<'PY'
import jax
print(jax.default_backend())
print(jax.devices())
PY

Run Inference

Single image on GPU:

python infer.py \
  --checkpoint checkpoint_390000_best_lpips \
  --input path/to/input.jpg \
  --output samples/input_x4.png \
  --platform gpu \
  --save-bicubic \
  --compare

Folder with tiling, recommended for larger photos:

python infer.py \
  --checkpoint checkpoint_390000_best_lpips \
  --input path/to/images \
  --output samples/atd_v2_xlarge_x4 \
  --platform gpu \
  --tile-size 128 \
  --tile-overlap 16 \
  --save-bicubic \
  --compare

TPU VM inference:

python infer.py \
  --checkpoint checkpoint_390000_best_lpips \
  --input path/to/images \
  --output samples/atd_v2_xlarge_x4_tpu \
  --platform tpu \
  --tile-size 128 \
  --tile-overlap 16

CPU smoke test:

python infer.py \
  --checkpoint checkpoint_390000_best_lpips \
  --input path/to/small.jpg \
  --output samples/small_x4.png \
  --platform cpu \
  --max-side 128

infer.py writes restored images and, for folder runs, metrics.csv. With --compare, comparison sheets are saved under output/compare/.

Training Data

Prepared pair-manifest dataset:

  • SIDD Medium sRGB noisy/GT pairs
  • DIV2K HR clean self-pairs
  • OST HR clean self-pairs
  • Prepared dataset name: sr_real_x4_v3_sidd_medium
  • Train rows: 14,159
  • Validation rows: 505

The model was trained with mixed-denoise degradation, including phone-like noise/compression artifacts and clean self-pair synthetic degradation.

Notes

These checkpoints are experimental research artifacts, not a polished production model. They are useful for TPU/JAX restoration experiments, GPU/CPU inference experiments, and as a base for residual-refiner training.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support