|
|
--- |
|
|
language: |
|
|
- en |
|
|
license: mit |
|
|
tags: |
|
|
- image-generation |
|
|
- autoregressive |
|
|
- next-scale-prediction |
|
|
- exposure-bias |
|
|
- post-training |
|
|
- pytorch |
|
|
- imagenet |
|
|
library_name: pytorch |
|
|
inference: false |
|
|
model-index: |
|
|
- name: ZGZzz/SAR |
|
|
results: |
|
|
- task: |
|
|
type: image-generation |
|
|
name: Image Generation |
|
|
dataset: |
|
|
name: ImageNet 256×256 |
|
|
type: imagenet-1k |
|
|
config: 256x256 |
|
|
split: validation |
|
|
metrics: |
|
|
- type: fid |
|
|
name: FID (FlexVAR-d16, +SAR) |
|
|
value: 2.89 |
|
|
higher_is_better: false |
|
|
- type: fid |
|
|
name: FID (FlexVAR-d20, +SAR) |
|
|
value: 2.35 |
|
|
higher_is_better: false |
|
|
- type: fid |
|
|
name: FID (FlexVAR-d24, +SAR) |
|
|
value: 2.14 |
|
|
higher_is_better: false |
|
|
datasets: |
|
|
- ILSVRC/imagenet-1k |
|
|
base_model: |
|
|
- jiaosiyu1999/FlexVAR |
|
|
pipeline_tag: text-to-image |
|
|
--- |
|
|
|
|
|
<div align="center"> |
|
|
<h1>Rethinking Training Dynamics in Scale-wise Autoregressive Generation</h1> |
|
|
|
|
|
<a href="https://gengzezhou.github.io/" target="_blank">Gengze Zhou</a><sup>1*</sup>, |
|
|
<a href="https://chongjiange.github.io/" target="_blank">Chongjian Ge</a><sup>2</sup>, |
|
|
<a href="https://www.cs.unc.edu/~airsplay/" target="_blank">Hao Tan</a><sup>2</sup>, |
|
|
<a href="https://pages.cs.wisc.edu/~fliu/" target="_blank">Feng Liu</a><sup>2</sup>, |
|
|
<a href="https://yiconghong.me" target="_blank">Yicong Hong</a><sup>2</sup> |
|
|
|
|
|
<sup>1</sup>Australian Institute for Machine Learning, Adelaide University |
|
|
<sup>2</sup>Adobe Research |
|
|
|
|
|
[](https://arxiv.org/abs/2512.06421) |
|
|
[](https://huggingface.co/ZGZzz/SAR) |
|
|
[](https://gengzezhou.github.io/SAR) |
|
|
[](https://opensource.org/licenses/MIT) |
|
|
|
|
|
</div> |
|
|
|
|
|
## Model Description |
|
|
|
|
|
**Self-Autoregressive Refinement (SAR)** is a lightweight *post-training* algorithm for **scale-wise autoregressive (AR)** image generation (next-scale prediction). SAR mitigates **exposure bias** by addressing (1) train–test mismatch (teacher forcing vs. student forcing) and (2) imbalance in scale-wise learning difficulty. |
|
|
|
|
|
SAR consists of: |
|
|
- **Stagger-Scale Rollout (SSR):** a two-step rollout (teacher-forcing → student-forcing) with minimal compute overhead (one extra forward pass). |
|
|
- **Contrastive Student-Forcing Loss (CSFL):** stabilizes student-forced training by aligning predictions with a teacher trajectory under self-generated contexts. |
|
|
|
|
|
## Key Features |
|
|
|
|
|
- **Minimal overhead:** SSR adds only a lightweight additional forward pass to train on self-generated content. |
|
|
- **General post-training recipe:** applies on top of pretrained scale-wise AR models. |
|
|
- **Empirical gains:** e.g., reported **5.2% FID reduction** on FlexVAR-d16 with 10 SAR epochs. |
|
|
|
|
|
## Model Zoo (ImageNet 256×256) |
|
|
|
|
|
| Model | Params | Base FID ↓ | SAR FID ↓ | SAR Weights | |
|
|
|---|---:|---:|---:|---| |
|
|
| SAR-d16 | 310M | 3.05 | **2.89** | `pretrained/SARd16-epo179.pth` | |
|
|
| SAR-d20 | 600M | 2.41 | **2.35** | `pretrained/SARd20-epo249.pth` | |
|
|
| SAR-d24 | 1.0B | 2.21 | **2.14** | `pretrained/SARd24-epo349.pth` | |
|
|
|
|
|
## How to Use |
|
|
|
|
|
### Installation |
|
|
|
|
|
```bash |
|
|
git clone https://github.com/GengzeZhou/SAR.git |
|
|
conda create -n sar python=3.10 -y |
|
|
conda activate sar |
|
|
pip install -r requirements.txt |
|
|
# optional |
|
|
pip install flash-attn xformers |
|
|
``` |
|
|
|
|
|
### Sampling / Inference (Example) |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from models import build_vae_var |
|
|
from torchvision.utils import save_image |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
# Build VAE + VAR backbone (example: depth=16) |
|
|
vae, model = build_vae_var( |
|
|
V=8912, Cvae=32, device=device, |
|
|
num_classes=1000, depth=16, |
|
|
vae_ckpt="pretrained/FlexVAE.pth", |
|
|
) |
|
|
|
|
|
# Load SAR checkpoint |
|
|
ckpt = torch.load("pretrained/SARd16-epo179.pth", map_location="cpu") |
|
|
if "trainer" in ckpt: |
|
|
ckpt = ckpt["trainer"]["var_wo_ddp"] |
|
|
model.load_state_dict(ckpt, strict=False) |
|
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
labels = torch.tensor([207, 88, 360, 387], device=device) # example ImageNet classes |
|
|
images = model.autoregressive_infer_cfg( |
|
|
vqvae=vae, |
|
|
B=4, |
|
|
label_B=labels, |
|
|
cfg=2.5, |
|
|
top_k=900, |
|
|
top_p=0.95, |
|
|
) |
|
|
|
|
|
save_image(images, "samples.png", normalize=True, value_range=(-1, 1), nrow=4) |
|
|
``` |
|
|
|
|
|
## Training (SAR Post-Training) |
|
|
|
|
|
```bash |
|
|
bash scripts/train_SAR_d16.sh |
|
|
bash scripts/train_SAR_d20.sh |
|
|
bash scripts/train_SAR_d24.sh |
|
|
``` |
|
|
|
|
|
## Evaluation |
|
|
|
|
|
```bash |
|
|
bash scripts/setup_eval.sh |
|
|
bash scripts/eval_SAR_d16.sh |
|
|
bash scripts/eval_SAR_d20.sh |
|
|
bash scripts/eval_SAR_d24.sh |
|
|
``` |
|
|
|
|
|
## Acknowledgements |
|
|
|
|
|
This codebase builds upon **VAR** and **FlexVAR**. |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@article{zhou2025rethinking, |
|
|
title={Rethinking Training Dynamics in Scale-wise Autoregressive Generation}, |
|
|
author={Zhou, Gengze and Ge, Chongjian and Tan, Hao and Liu, Feng and Hong, Yicong}, |
|
|
journal={arXiv preprint arXiv:2512.06421}, |
|
|
year={2025} |
|
|
} |
|
|
``` |