RPiAE: A Representation-Pivoted Autoencoder Enhancing Both Image Generation and Editing
This repository contains the PyTorch implementation of RPiAE and the corresponding latent diffusion training pipeline.
RPiAE follows a two-stage pipeline:
- Stage 1 (image reconstruction): train the RPiAE model to learn high-quality latent representations and reconstruction.
- Stage 2 (class conditional image generation): train a diffusion transformer in the pretrained RPiAE latent space for image generation.
TODO
- Release the RPiAE model
- Release the training code for RPiAE
- Release pretrained weights
- Release the training code for T2I
Environment
Dependency Setup
- Create environment and install via
uv:conda create -n rpiae python=3.10 -y conda activate rpiae pip install uv # Install PyTorch 2.8.0 with CUDA 12.9 # or your own cuda version uv pip install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu129 # Install other dependencies uv pip install -r requirements.txt
Data & Model Preparation
Download Pre-trained Models
Pretrained checkpoints are available on Hugging Face.
Download the pretrain weights to ./model_weights:
hf download qihoo360/RPiAE \
--repo-type model \
--local-dir ./model_weights
Prepare Dataset
- Download ImageNet-1k.
- Point scripts to the training split or validation split via
--data-path.
Config-based Initialization
All training and sampling entrypoints are driven by OmegaConf YAML files. A single config describes the Stage 1 autoencoder, the Stage 2 diffusion model, and the solver used during training or inference. A minimal example looks like:
stage_1:
target: stage1.RPiAE / stage1.RPiAE_VB
params: { ... }
ckpt: <path_to_ckpt>
stage_2:
target: stage2.models.lightningDiT.LightningDiT
params: { ... }
ckpt: <path_to_ckpt>
transport:
params:
path_type: Linear
prediction: velocity
...
sampler:
mode: ODE
params:
num_steps: 50
...
guidance:
method: cfg/autoguidance
scale: 1.0
...
misc:
latent_size: [64, 16, 16]
num_classes: 1000
training:
...
eval:
...
stage_1defines the RPiAE training process (reconstruction-oriented training).stage_2defines the generation model (LightningDiT) in the RPiAE latent space.transport,sampler, andguidancecontrol ODE/SDE solving and guidance strategy.miscstores latent shape and shared constants.trainingandevalcontain optimization and online evaluation settings.- Stage 1 training configs additionally include a top-level
ganblock for discriminator and loss schedule.
Provided Configs
Stage1
- Pretrained:
configs/stage1/pretrained/DINOv2-B_decXL_RPiAE.yaml
- Training:
configs/stage1/training/DINOv2-B_decXL_RPiAE_stage1.yamlconfigs/stage1/training/DINOv2-B_decXL_RPiAE_stage2.yamlconfigs/stage1/training/DINOv2-B_decXL_RPiAE_stage3.yaml
Stage2
- Training:
configs/stage2/training/ImageNet256/LightingDiT-XL_f16d64rpiae-v2b_vitxl.yaml
- Sampling:
configs/stage2/sampling/ImageNet256/LightingDiT-XL_d64rpiae-v2b_vitxl.yamlconfigs/stage2/sampling/ImageNet256/LightingDiT-XL_d64rpiae-v2b_vitxl_AG.yaml
RPiAE Training & image reconstruction evaluation
Use the provided shell scripts with the corresponding configs.
If you use wandb logging, also configure:
EXPERIMENT_NAME=
ENTITY=
PROJECT=
Single-GPU Training
Train Stage 1
bash run_train_stage1_rpiae_s1.sh \
configs/stage1/training/DINOv2-B_decXL_RPiAE_stage1.yaml
Train Stage 2
bash run_train_stage1_rpiae_s23.sh \
configs/stage1/training/DINOv2-B_decXL_RPiAE_stage2.yaml
Train Stage 3
bash run_train_stage1_rpiae_s23.sh \
configs/stage1/training/DINOv2-B_decXL_RPiAE_stage3.yaml
Multi-GPU Training
For multi-GPU or multi-node training, please use the *_mult*.sh scripts instead.
Before launching, configure the distributed variables in the shell script:
RANK=
MASTER_ADDR=
GPUS_PER_NODE=
NNODES=
MASTER_PORT=
Then run the corresponding multi-GPU scripts.
Train Stage 1
bash run_train_stage1_mult_rpiae_s1.sh \
configs/stage1/training/DINOv2-B_decXL_RPiAE_stage1.yaml
Train Stage 2
bash run_train_stage1_mult_rpiae_s23.sh \
configs/stage1/training/DINOv2-B_decXL_RPiAE_stage2.yaml
Train Stage 3
bash run_train_stage1_mult_rpiae_s23.sh \
configs/stage1/training/DINOv2-B_decXL_RPiAE_stage3.yaml
Reconstruction Evaluation
bash run_sample_reconstruction_eval.sh \
configs/stage1/pretrained/DINOv2-B_decXL_RPiAE.yaml
Latent Diffusion Transformer Training & Class conditional Image Generation Evaluation
Training
torchrun --standalone --nnodes=1 --nproc_per_node=N \
src/train_diffusion_rpiae.py \
--config <training_config> \
--data-path <imagenet_train_split> \
--results-dir ckpts/diffusion \
--compile \
--precision fp32
For multi-GPU / multi-node training, use:
bash run_train_mult_diffusion.sh \
configs/stage2/training/ImageNet256/LightingDiT-XL_f16d64rpiae-v2b_vitxl.yaml
For multi-node launch, set:
export RANK=<node_rank>
export MASTER_ADDR=<master_node_ip_or_hostname>
Although bf16 is supported, we recommend using fp32 for more stable training.
Sampling
src/sample.py uses the same config schema to draw a small batch of images on a
single device and saves them to sample.png:
python src/sample.py \
--config <sample_config> \
--seed 42
Distributed sampling for evaluation
src/sample_ddp.py parallelises sampling across GPUs, producing PNGs and an
FID-ready .npz:
torchrun --standalone --nnodes=1 --nproc_per_node=N \
src/sample_ddp.py \
--config <sample_config> \
--sample-dir samples \
--precision fp32/bf16 \
--label-sampling equal
--label-sampling {equal,random}: equal uses exactly 50 images per class for FID-50k; random uniformly samples labels. We use equal by default. We recommend using fp32 when model FID is low.
Autoguidance and classifier-free guidance are controlled via the config’s guidance block.
Evaluation
ADM Suite FID setup
Use the ADM evaluation suite to score generated samples:
Clone the repo:
git clone https://github.com/openai/guided-diffusion.git cd guided-diffusion/evaluationCreate an environment and install dependencies:
conda create -n adm-fid python=3.10 conda activate adm-fid pip install 'tensorflow[and-cuda]'==2.19 scipy requests tqdmDownload ImageNet statistics (256×256 shown here):
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npzEvaluate:
python evaluator.py VIRTUAL_imagenet256_labeled.npz /path/to/samples.npz
Acknowledgement
This code is built upon the following repositories:
Citation
If you find this repository useful, please consider citing our paper:
@misc{RPiAE,
title={RPiAE: A Representation-Pivoted Autoencoder Enhancing Both Image Generation and Editing},
author={Yue Gong and Hongyu Li and Shanyuan Liu and Bo Cheng and Yuhang Ma and Liebucha Wu and Xiaoyu Wu and Manyuan Zhang and Dawei Leng and Yuhui Yin and Lijun Zhang},
year={2026},
eprint={2603.19206},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2603.19206},
}