--- language: - en license: mit tags: - image-generation - autoregressive - next-scale-prediction - exposure-bias - post-training - pytorch - imagenet library_name: pytorch inference: false model-index: - name: ZGZzz/SAR results: - task: type: image-generation name: Image Generation dataset: name: ImageNet 256×256 type: imagenet-1k config: 256x256 split: validation metrics: - type: fid name: FID (FlexVAR-d16, +SAR) value: 2.89 higher_is_better: false - type: fid name: FID (FlexVAR-d20, +SAR) value: 2.35 higher_is_better: false - type: fid name: FID (FlexVAR-d24, +SAR) value: 2.14 higher_is_better: false datasets: - ILSVRC/imagenet-1k base_model: - jiaosiyu1999/FlexVAR pipeline_tag: text-to-image ---

Rethinking Training Dynamics in Scale-wise Autoregressive Generation

Gengze Zhou1*, Chongjian Ge2, Hao Tan2, Feng Liu2, Yicong Hong2 1Australian Institute for Machine Learning, Adelaide University     2Adobe Research [![arXiv](https://img.shields.io/badge/arXiv-2512.06421-b31b1b.svg)](https://arxiv.org/abs/2512.06421)  [![huggingface weights](https://img.shields.io/badge/%F0%9F%A4%97%20Weights-SAR--ckpts-yellow)](https://huggingface.co/ZGZzz/SAR)  [![project page](https://img.shields.io/badge/Project%20Page-SAR-blue)](https://gengzezhou.github.io/SAR)  [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)
## Model Description **Self-Autoregressive Refinement (SAR)** is a lightweight *post-training* algorithm for **scale-wise autoregressive (AR)** image generation (next-scale prediction). SAR mitigates **exposure bias** by addressing (1) train–test mismatch (teacher forcing vs. student forcing) and (2) imbalance in scale-wise learning difficulty. SAR consists of: - **Stagger-Scale Rollout (SSR):** a two-step rollout (teacher-forcing → student-forcing) with minimal compute overhead (one extra forward pass). - **Contrastive Student-Forcing Loss (CSFL):** stabilizes student-forced training by aligning predictions with a teacher trajectory under self-generated contexts. ## Key Features - **Minimal overhead:** SSR adds only a lightweight additional forward pass to train on self-generated content. - **General post-training recipe:** applies on top of pretrained scale-wise AR models. - **Empirical gains:** e.g., reported **5.2% FID reduction** on FlexVAR-d16 with 10 SAR epochs. ## Model Zoo (ImageNet 256×256) | Model | Params | Base FID ↓ | SAR FID ↓ | SAR Weights | |---|---:|---:|---:|---| | SAR-d16 | 310M | 3.05 | **2.89** | `pretrained/SARd16-epo179.pth` | | SAR-d20 | 600M | 2.41 | **2.35** | `pretrained/SARd20-epo249.pth` | | SAR-d24 | 1.0B | 2.21 | **2.14** | `pretrained/SARd24-epo349.pth` | ## How to Use ### Installation ```bash git clone https://github.com/GengzeZhou/SAR.git conda create -n sar python=3.10 -y conda activate sar pip install -r requirements.txt # optional pip install flash-attn xformers ``` ### Sampling / Inference (Example) ```python import torch from models import build_vae_var from torchvision.utils import save_image device = "cuda" if torch.cuda.is_available() else "cpu" # Build VAE + VAR backbone (example: depth=16) vae, model = build_vae_var( V=8912, Cvae=32, device=device, num_classes=1000, depth=16, vae_ckpt="pretrained/FlexVAE.pth", ) # Load SAR checkpoint ckpt = torch.load("pretrained/SARd16-epo179.pth", map_location="cpu") if "trainer" in ckpt: ckpt = ckpt["trainer"]["var_wo_ddp"] model.load_state_dict(ckpt, strict=False) model.eval() with torch.no_grad(): labels = torch.tensor([207, 88, 360, 387], device=device) # example ImageNet classes images = model.autoregressive_infer_cfg( vqvae=vae, B=4, label_B=labels, cfg=2.5, top_k=900, top_p=0.95, ) save_image(images, "samples.png", normalize=True, value_range=(-1, 1), nrow=4) ``` ## Training (SAR Post-Training) ```bash bash scripts/train_SAR_d16.sh bash scripts/train_SAR_d20.sh bash scripts/train_SAR_d24.sh ``` ## Evaluation ```bash bash scripts/setup_eval.sh bash scripts/eval_SAR_d16.sh bash scripts/eval_SAR_d20.sh bash scripts/eval_SAR_d24.sh ``` ## Acknowledgements This codebase builds upon **VAR** and **FlexVAR**. ## Citation ```bibtex @article{zhou2025rethinking, title={Rethinking Training Dynamics in Scale-wise Autoregressive Generation}, author={Zhou, Gengze and Ge, Chongjian and Tan, Hao and Liu, Feng and Hong, Yicong}, journal={arXiv preprint arXiv:2512.06421}, year={2025} } ```