---
language:
- en
license: mit
tags:
- image-generation
- autoregressive
- next-scale-prediction
- exposure-bias
- post-training
- pytorch
- imagenet
library_name: pytorch
inference: false
model-index:
- name: ZGZzz/SAR
results:
- task:
type: image-generation
name: Image Generation
dataset:
name: ImageNet 256×256
type: imagenet-1k
config: 256x256
split: validation
metrics:
- type: fid
name: FID (FlexVAR-d16, +SAR)
value: 2.89
higher_is_better: false
- type: fid
name: FID (FlexVAR-d20, +SAR)
value: 2.35
higher_is_better: false
- type: fid
name: FID (FlexVAR-d24, +SAR)
value: 2.14
higher_is_better: false
datasets:
- ILSVRC/imagenet-1k
base_model:
- jiaosiyu1999/FlexVAR
pipeline_tag: text-to-image
---
Rethinking Training Dynamics in Scale-wise Autoregressive Generation
Gengze Zhou1*,
Chongjian Ge2,
Hao Tan2,
Feng Liu2,
Yicong Hong2
1Australian Institute for Machine Learning, Adelaide University
2Adobe Research
[](https://arxiv.org/abs/2512.06421)
[](https://huggingface.co/ZGZzz/SAR)
[](https://gengzezhou.github.io/SAR)
[](https://opensource.org/licenses/MIT)
## Model Description
**Self-Autoregressive Refinement (SAR)** is a lightweight *post-training* algorithm for **scale-wise autoregressive (AR)** image generation (next-scale prediction). SAR mitigates **exposure bias** by addressing (1) train–test mismatch (teacher forcing vs. student forcing) and (2) imbalance in scale-wise learning difficulty.
SAR consists of:
- **Stagger-Scale Rollout (SSR):** a two-step rollout (teacher-forcing → student-forcing) with minimal compute overhead (one extra forward pass).
- **Contrastive Student-Forcing Loss (CSFL):** stabilizes student-forced training by aligning predictions with a teacher trajectory under self-generated contexts.
## Key Features
- **Minimal overhead:** SSR adds only a lightweight additional forward pass to train on self-generated content.
- **General post-training recipe:** applies on top of pretrained scale-wise AR models.
- **Empirical gains:** e.g., reported **5.2% FID reduction** on FlexVAR-d16 with 10 SAR epochs.
## Model Zoo (ImageNet 256×256)
| Model | Params | Base FID ↓ | SAR FID ↓ | SAR Weights |
|---|---:|---:|---:|---|
| SAR-d16 | 310M | 3.05 | **2.89** | `pretrained/SARd16-epo179.pth` |
| SAR-d20 | 600M | 2.41 | **2.35** | `pretrained/SARd20-epo249.pth` |
| SAR-d24 | 1.0B | 2.21 | **2.14** | `pretrained/SARd24-epo349.pth` |
## How to Use
### Installation
```bash
git clone https://github.com/GengzeZhou/SAR.git
conda create -n sar python=3.10 -y
conda activate sar
pip install -r requirements.txt
# optional
pip install flash-attn xformers
```
### Sampling / Inference (Example)
```python
import torch
from models import build_vae_var
from torchvision.utils import save_image
device = "cuda" if torch.cuda.is_available() else "cpu"
# Build VAE + VAR backbone (example: depth=16)
vae, model = build_vae_var(
V=8912, Cvae=32, device=device,
num_classes=1000, depth=16,
vae_ckpt="pretrained/FlexVAE.pth",
)
# Load SAR checkpoint
ckpt = torch.load("pretrained/SARd16-epo179.pth", map_location="cpu")
if "trainer" in ckpt:
ckpt = ckpt["trainer"]["var_wo_ddp"]
model.load_state_dict(ckpt, strict=False)
model.eval()
with torch.no_grad():
labels = torch.tensor([207, 88, 360, 387], device=device) # example ImageNet classes
images = model.autoregressive_infer_cfg(
vqvae=vae,
B=4,
label_B=labels,
cfg=2.5,
top_k=900,
top_p=0.95,
)
save_image(images, "samples.png", normalize=True, value_range=(-1, 1), nrow=4)
```
## Training (SAR Post-Training)
```bash
bash scripts/train_SAR_d16.sh
bash scripts/train_SAR_d20.sh
bash scripts/train_SAR_d24.sh
```
## Evaluation
```bash
bash scripts/setup_eval.sh
bash scripts/eval_SAR_d16.sh
bash scripts/eval_SAR_d20.sh
bash scripts/eval_SAR_d24.sh
```
## Acknowledgements
This codebase builds upon **VAR** and **FlexVAR**.
## Citation
```bibtex
@article{zhou2025rethinking,
title={Rethinking Training Dynamics in Scale-wise Autoregressive Generation},
author={Zhou, Gengze and Ge, Chongjian and Tan, Hao and Liu, Feng and Hong, Yicong},
journal={arXiv preprint arXiv:2512.06421},
year={2025}
}
```