File size: 10,711 Bytes
1978b1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
---
pipeline_tag: image-feature-extraction
---
# Image Tokenizer Needs Post-Training
[](https://qiuk2.github.io/works/RobusTok/index.html)
[](https://arxiv.org/abs/2509.12474)
[](https://huggingface.co/qiuk6/RobusTok)
This repository contains the official implementation of the paper [Image Tokenizer Needs Post-Training](https://huggingface.co/papers/2509.12474).
Project Page: https://qiuk2.github.io/works/RobusTok/index.html
GitHub Repository: https://github.com/qiuk2/RobusTok
## Abstract
Recent image generative models typically capture the image distribution in a pre-constructed latent space, relying on a frozen image tokenizer. However, there exists a significant discrepancy between the reconstruction and generation distribution, where current tokenizers only prioritize the reconstruction task that happens before generative training without considering the generation errors during sampling. In this paper, we comprehensively analyze the reason for this discrepancy in a discrete latent space, and, from which, we propose a novel tokenizer training scheme including both main-training and post-training, focusing on improving latent space construction and decoding respectively. During the main training, a latent perturbation strategy is proposed to simulate sampling noises, \ie, the unexpected tokens generated in generative inference. Specifically, we propose a plug-and-play tokenizer training scheme, which significantly enhances the robustness of tokenizer, thus boosting the generation quality and convergence speed, and a novel tokenizer evaluation metric, \ie, pFID, which successfully correlates the tokenizer performance to generation quality. During post-training, we further optimize the tokenizer decoder regarding a well-trained generative model to mitigate the distribution difference between generated and reconstructed tokens. With a $\sim$400M generator, a discrete tokenizer trained with our proposed main training achieves a notable 1.60 gFID and further obtains 1.36 gFID with the additional post-training. Further experiments are conducted to broadly validate the effectiveness of our post-training strategy on off-the-shelf discrete and continuous tokenizers, coupled with autoregressive and diffusion-based generators.
## TL;DR
We present RobusTok, a new image tokenizer with a two-stage training scheme:
Main training β constructs a robust latent space.
Post-training β aligns the generatorβs latent distribution with its image space.
## Key highlights of Post-Training
- π **Better generative quality**: gFID 1.60 β 1.36.
- π **Generalizability**: applicable to both autoregressive & diffusion models.
- β‘ **Efficiency**: strong results with only ~400M generative models.
## Model Zoo
| Generator \ Tokenizer | RobusTok w/o. P.T([weights](https://huggingface.co/qiuk6/RobusTok/resolve/main/main-train.pt?download=true)) | RobusTok w/. P.T ([weights](https://huggingface.co/qiuk6/RobusTok/resolve/main/post-train.pt?download=true)) |
|---|---:|---:|
| Base ([weights](https://huggingface.co/qiuk6/RobusTok/resolve/main/rar_b.bin?download=true)) | gFID = 1.83 | gFID = 1.60 |
| Large ([weights](https://huggingface.co/qiuk6/RobusTok/resolve/main/rar_l.bin?download=true)) | gFID = 1.60 | gFID = 1.36 |
## Updates
- (2025.09.16) Paper released in Arxiv.
- (2025.09.18) Code and checkpoint are released. Preparing for PFID calculation
## Installation
Install all packages as
```bash
conda env create -f environment.yml
```
## Dataset
We download the ImageNet2012 from the website and collect it as
```
ImageNet2012
βββ train
βββ val
```
If you want to train or finetune on other datasets, collect them in the format that ImageFolder (pytorch's [ImageFolder](https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html)) can recognize.
```
Dataset
βββ train
β βββ Class1
β β βββ 1.png
β β βββ 2.png
βββ Class2
β β βββ 1.png
β β βββ 2.png
βββ val
```
## Main Train for tokenizer
Please login to Wandb first using
```bash
wandb login
```
rFID will be automatically evaluated and reported on Wandb. The checkpoint with the best rFID on the val set will be saved. We provide basic configurations in the "configs" folder.
WarningβοΈ: You may want to modify the metric to save models as rFID is not closely correlated to gFID. PSNR and SSIM are also good choices.
```bash
torchrun --nproc_per_node=8 tokenizer/tokenizer_image/main_train.py --config configs/main-train.yaml
```
Please modify the configuration file as needed for your specific dataset. We list some important ones here.
```
vq_ckpt: ckpt_best.pt # resume
cloud_save_path: output/exp-xx # output dir
data_path: ImageNet2012/train # training set dir
val_data_path: ImageNet2012/val # val set dir
enc_tuning_method: 'full' # ['full', 'lora', 'frozen']
dec_tuning_method: 'full' # ['full', 'lora', 'frozen']
codebook_embed_dim: 32 # codebook dim
codebook_size: 4096 # codebook size
product_quant: 1 # vanilla VQ
v_patch_nums: [16,] # latent resolution for RQ ([16,] is equivalent to vanilla VQ)
codebook_drop: 0.1 # quantizer dropout rate if RQ is applied
semantic_guide: dinov2 # ['none', 'dinov2', 'clip']
disc_epoch_start: 56 # epoch that discriminator starts
disc_type: dinodisc # discriminator type
disc_adaptive_weight: true # adaptive weight for discriminator loss
ema: true # use ema to update the model
num_latent_code: 256 # latent token number (must equals to the v_patch_nums[-1] ** 2οΌ
```
## Training code for Generator
We follow [RAR](https://github.com/bytedance/1d-tokenizer) to pretokenize the whole dataset for speed-up the training process. We have uploaded [it](https://huggingface.co/qiuk6/RobustTok/resolve/main/RobustTok-half-pretokenized.jsonl?download=true) so you can train RobusTok-RAR directly.
```bash
# training code for rar-b
accelerate launch scripts/train_rar.py experiment.project="rar" experiment.name="rar_b" experiment.output_dir="rar_b" model.generator.hidden_size=768 model.generator.num_hidden_layers=24 model.generator.num_attention_heads=16 model.generator.intermediate_size=3072 config=configs/generator/rar.yaml dataset.params.pretokenization=/path/to/pretokenized.jsonl model.vq_ckpt=/path/to/RobustTok.pt
# training code for rar-l
accelerate launch scripts/train_rar.py experiment.project="rar" experiment.name="rar_l" experiment.output_dir="rar_l" model.generator.hidden_size=1024 model.generator.num_hidden_layers=24 model.generator.num_attention_heads=16 model.generator.intermediate_size=4096 config=configs/generator/rar.yaml dataset.params.pretokenization=/path/to/pretokenized.jsonl model.vq_ckpt=/path/to/RobustTok.pt
```
## Post-Training for Tokenizer
For post-training, we need to (1) prepare paired dataset and (2) post-train our decoder to align with generated latent space
### Prepare data
You can follow our code with your desired dataset / σ / number to generate data
```bash
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 post_train_data.py config=configs/generator/rar.yaml \
experiment.output_dir="/path/to/data-folder" \
experiment.generator_checkpoint="rar_b.bin" \
model.vq_ckpt=/path/to/RobustTok.pt \
model.generator.hidden_size=768 \
model.generator.num_hidden_layers=24 \
model.generator.num_attention_heads=16 \
model.generator.intermediate_size=3072 \
model.generator.randomize_temperature=1.02 \
model.generator.guidance_scale=6.0 \
model.generator.guidance_scale_pow=1.15 \
--sigma 0.7 --data-path /path/to/imagenet --num_samples /number/of/generate
```
### Post-Training
```bash
torchrun --nproc_per_node=8 tokenizer/tokenizer_image/xqgan_post_train.py --config configs/post-train.yaml --data-path /path/to/data-folder --pair-set /path/to/imagenet --vq-ckpt /path/to/main-train/ckpt
```
## Inference Code
```bash
# Reproducing RAR-B
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_rar.py config=configs/generator/rar.yaml \
experiment.output_dir="rar_b" \
experiment.generator_checkpoint="rar_b.bin" \
model.vq_ckpt=/path/to/RobustTok.pt \
model.generator.hidden_size=768 \
model.generator.num_hidden_layers=24 \
model.generator.num_attention_heads=16 \
model.generator.intermediate_size=3072 \
model.generator.randomize_temperature=1.02 \
model.generator.guidance_scale=6.0 \
model.generator.guidance_scale_pow=1.15
# Run eval script. The result FID should be ~1.83 before post-training and ~1.60 after post-training
python3 evaluator.py VIRTUAL_imagenet256_labeled.npz rar_b.npz
# Reproducing RAR-L
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_rar.py config=configs/generator/rar.yaml \
experiment.output_dir="rar_l" \
experiment.generator_checkpoint="rar_l.bin" \
model.vq_ckpt=/path/to/RobustTok.pt \
model.generator.hidden_size=1024 \
model.generator.num_hidden_layers=24 \
model.generator.num_attention_heads=16 \
model.generator.intermediate_size=4096 \
model.generator.randomize_temperature=1.04 \
model.generator.guidance_scale=6.75 \
model.generator.guidance_scale_pow=1.01
# Run eval script. The result FID should be ~1.60 before post-training and ~1.36 after post-training
python3 evaluator.py VIRTUAL_imagenet256_labeled.npz rar_l.npz
```
## Visualization
<div align="center">
<img src="assets/ft-diff.png" alt="vis" width="95%">
<p>
visualization of 256×256 image generation before (top) and after (bottom) post-training. Three improvements are observed: (a) OOD mitigation, (b) Color fidelity, (c) detail refinement.
</p>
</div>
## Citation
If our work assists your research, feel free to give us a star β or cite us using
```bibtex
@misc{qiu2025imagetokenizerneedsposttraining,
title={Image Tokenizer Needs Post-Training},
author={Kai Qiu and Xiang Li and Hao Chen and Jason Kuen and Xiaohao Xu and Jiuxiang Gu and Yinyi Luo and Bhiksha Raj and Zhe Lin and Marios Savvides},
year={2025},
eprint={2509.12474},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2509.12474},
}
``` |