UniAR-RL / README.md
wjpoom's picture
Update README.md
8f70df6 verified
|
Raw
History Blame Contribute Delete
4.98 kB
---
license: apache-2.0
language:
- en
tags:
- image-generation
- image-understanding
- image-editing
- multimodal
- autoregressive
- text-to-image
- unified-model
pipeline_tag: image-to-text
base_model: ShareLab-SII/UniAR-SFT
---
# UniAR: Unified Multimodal Autoregressive Modeling with Shared Context--Visual Tokenizer is Key to Unification (ICML2026)
**UniAR** is a unified autoregressive multimodal model for **image understanding**, **image generation**, and **image editing** in a single Transformer. UniAR-RL is obtained by reinforcement learning (GRPO) on top of [UniAR-SFT](https://huggingface.co/ShareLab-SII/UniAR-SFT), achieving state-of-the-art text rendering and instruction-following performance among unified models.
[![arXiv](https://img.shields.io/badge/arXiv-2606.18249-b31b1b.svg)](https://arxiv.org/abs/2606.18249)
[![Project Page](https://img.shields.io/badge/Project-Page-blue.svg)](https://sharelab-sii.github.io/uniar-web)
[![Code](https://img.shields.io/badge/GitHub-Code-black.svg)](https://github.com/ShareLab-SII/UniAR)
## Model Description
UniAR uses a single discrete visual tokenizer (BSQ) as the key bridge between understanding and generation, enabling a shared context where the model can directly interpret its own generated visual tokens. Key components:
- **Backbone:** Qwen3-8B
- **Visual Tokenizer:** BSQ-quantized SigLiP2-So400M ViT with DeepStack connections
- **Visual Decoder:** SD3.5-Medium DiT with SigLIP feature injection
- **Training:** Pre-training (1T tokens) → SFT → RL (GRPO with multi-reward stack)
This checkpoint (`UniAR-RL`) is the final RL-finetuned model with improved generation quality.
## Checkpoint Contents
This is a self-contained checkpoint with all components needed for both understanding and generation:
| Component | Path | Description |
|-----------|------|-------------|
| AR model | `*.safetensors` | Unified autoregressive model weights |
| BSQ encoder | `bsq_encoder/` | BSQ quantized image tokenizer |
| SD3 transformer | `sd3_transformer/` | SD3 transformer with visual feature injection |
| SD3 pipeline | `sd3_pipeline/` | SD3 VAE + text encoders |
## Usage
### Installation
```bash
conda create -n uniar python=3.12 -y
conda activate uniar
git clone https://github.com/ShareLab-SII/UniAR.git
cd UniAR
pip install -e . # inference dependencies
```
### Image Understanding
```python
import torch
from transformers import AutoProcessor
from uniar import UniARForConditionalGeneration
model_path = "ShareLab-SII/UniAR-RL"
model = UniARForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
).cuda().eval()
processor = AutoProcessor.from_pretrained(model_path)
messages = [{"role": "user", "content": [
{"type": "image", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"},
{"type": "text", "text": "Describe this image in detail."},
]}]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
inputs.pop("mm_token_type_ids", None)
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = model.generate(**inputs, max_new_tokens=1024, do_sample=False)
output_ids = [o[len(i):] for i, o in zip(inputs.input_ids, output_ids)]
print(processor.batch_decode(output_ids, skip_special_tokens=True)[0])
```
### Image Generation
```python
import torch
from transformers import AutoProcessor
from uniar import UniARForConditionalGeneration, UniARVisualDecoder
from inference.visual_inputs import prepare_visual_inputs
model_path = "ShareLab-SII/UniAR-RL"
device = torch.device("cuda")
ar_model = UniARForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
).to(device).eval()
processor = AutoProcessor.from_pretrained(model_path, padding_side="left")
visual_decoder = UniARVisualDecoder.from_pretrained(model_path, device=device)
# prepare inputs
visual_inputs = prepare_visual_inputs(
["A cute anime girl."],
ar_model,
processor,
ar_height=960,
ar_width=960,
)
# autogressively generate visual indices
indices = ar_model.generate_visual(
**visual_inputs,
temperature=1.0,
cfg=1.5,
show_progress=True,
)
# decode visual indices into image
images = visual_decoder.decode(
indices,
ar_height=960,
ar_width=960,
upsampling_ratio=1.067,
)
images[0].save("output.png")
```
## Citation
```bibtex
@inproceedings{peng2026uniar,
title={Unified Multimodal Autoregressive Modeling with Shared Context --- Visual Tokenizer is Key to Unification},
author={Peng, Wujian and Meng, Lingchen and Cai, Yuxuan and Zhuang, Xianwei and Yang, Yuhuan and Fang, Rongyao and Wu, Chenfei and Lin, Junyang and Wu, Zuxuan and Bai, Shuai},
booktitle={ICML},
year={2026}
}
```