wjpoom commited on
Commit
7149f77
·
verified ·
1 Parent(s): 05d7bab

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +151 -0
README.md CHANGED
@@ -1,3 +1,154 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - image-generation
7
+ - image-understanding
8
+ - image-editing
9
+ - multimodal
10
+ - autoregressive
11
+ - text-to-image
12
+ - unified-model
13
+ pipeline_tag: image-to-text
14
+ base_model: ShareLab-SII/UniAR-SFT
15
  ---
16
+
17
+ # UniAR: Unified Multimodal Autoregressive Modeling with Shared Context--Visual Tokenizer is Key to Unification (ICML2026)
18
+
19
+ **UniAR** is a unified autoregressive multimodal model for **image understanding**, **image generation**, and **image editing** in a single Transformer. UniAR-RL is obtained by reinforcement learning (GRPO) on top of [UniAR-SFT](https://huggingface.co/ShareLab-SII/UniAR-SFT), achieving state-of-the-art text rendering and instruction-following performance among unified models.
20
+
21
+ [![arXiv](https://img.shields.io/badge/arXiv-TODO-b31b1b.svg)](https://arxiv.org/abs/TODO)
22
+ [![Project Page](https://img.shields.io/badge/Project-Page-blue.svg)](https://sharelab-sii.github.io/uniar-web)
23
+ [![Code](https://img.shields.io/badge/GitHub-Code-black.svg)](https://github.com/ShareLab-SII/UniAR)
24
+
25
+ ## Model Description
26
+
27
+ UniAR uses a single discrete visual tokenizer (BSQ) as the key bridge between understanding and generation, enabling a shared context where the model can directly interpret its own generated visual tokens. Key components:
28
+
29
+ - **Backbone:** Qwen3-8B
30
+ - **Visual Tokenizer:** BSQ-quantized SigLiP2-So400M ViT with DeepStack connections
31
+ - **Visual Decoder:** SD3.5-Medium DiT with SigLIP feature injection
32
+ - **Training:** Pre-training (1T tokens) → SFT → RL (GRPO with multi-reward stack)
33
+
34
+ This checkpoint (`UniAR-RL`) is the final RL-finetuned model with improved generation quality.
35
+
36
+ ## Checkpoint Contents
37
+
38
+ This is a self-contained checkpoint with all components needed for both understanding and generation:
39
+
40
+ | Component | Path | Description |
41
+ |-----------|------|-------------|
42
+ | AR model | `*.safetensors` | Unified autoregressive model weights |
43
+ | BSQ encoder | `bsq_encoder/` | BSQ quantized image tokenizer |
44
+ | SD3 transformer | `sd3_transformer/` | SD3 transformer with visual feature injection |
45
+ | SD3 pipeline | `sd3_pipeline/` | SD3 VAE + text encoders |
46
+
47
+ ## Usage
48
+
49
+ ### Installation
50
+
51
+ ```bash
52
+ conda create -n uniar python=3.12 -y
53
+ conda activate uniar
54
+
55
+ git clone https://github.com/ShareLab-SII/UniAR.git
56
+ cd UniAR
57
+ pip install -e . # inference dependencies
58
+ ```
59
+
60
+ ### Image Understanding
61
+
62
+ ```python
63
+ import torch
64
+ from transformers import AutoProcessor
65
+ from uniar import UniARForConditionalGeneration
66
+
67
+ model_path = "ShareLab-SII/UniAR-RL"
68
+ model = UniARForConditionalGeneration.from_pretrained(
69
+ model_path,
70
+ torch_dtype=torch.bfloat16,
71
+ attn_implementation="flash_attention_2",
72
+ ).cuda().eval()
73
+ processor = AutoProcessor.from_pretrained(model_path)
74
+
75
+ messages = [{"role": "user", "content": [
76
+ {"type": "image", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"},
77
+ {"type": "text", "text": "Describe this image in detail."},
78
+ ]}]
79
+
80
+ inputs = processor.apply_chat_template(
81
+ messages,
82
+ tokenize=True,
83
+ add_generation_prompt=True,
84
+ return_dict=True,
85
+ return_tensors="pt",
86
+ ).to(model.device)
87
+ inputs.pop("mm_token_type_ids", None)
88
+
89
+ with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
90
+ output_ids = model.generate(**inputs, max_new_tokens=1024, do_sample=False)
91
+ output_ids = [o[len(i):] for i, o in zip(inputs.input_ids, output_ids)]
92
+
93
+ print(processor.batch_decode(output_ids, skip_special_tokens=True)[0])
94
+
95
+ ```
96
+
97
+ ### Image Generation
98
+
99
+ ```python
100
+ import torch
101
+ from transformers import AutoProcessor
102
+ from uniar import UniARForConditionalGeneration, UniARVisualDecoder
103
+ from inference.visual_inputs import prepare_visual_inputs
104
+
105
+ model_path = "ShareLab-SII/UniAR-RL"
106
+ device = torch.device("cuda")
107
+
108
+ ar_model = UniARForConditionalGeneration.from_pretrained(
109
+ model_path,
110
+ torch_dtype=torch.bfloat16,
111
+ attn_implementation="flash_attention_2",
112
+ ).to(device).eval()
113
+ processor = AutoProcessor.from_pretrained(model_path, padding_side="left")
114
+ visual_decoder = UniARVisualDecoder.from_pretrained(model_path, device=device)
115
+
116
+ # prepare inputs
117
+ visual_inputs = prepare_visual_inputs(
118
+ ["A cute anime girl."],
119
+ ar_model,
120
+ processor,
121
+ ar_height=960,
122
+ ar_width=960,
123
+ )
124
+
125
+ # autogressively generate visual indices
126
+ indices = ar_model.generate_visual(
127
+ **visual_inputs,
128
+ temperature=1.0,
129
+ cfg=1.5,
130
+ show_progress=True,
131
+ )
132
+
133
+ # decode visual indices into image
134
+ images = visual_decoder.decode(
135
+ indices,
136
+ ar_height=960,
137
+ ar_width=960,
138
+ upsampling_ratio=1.067,
139
+ )
140
+
141
+ images[0].save("output.png")
142
+
143
+ ```
144
+
145
+ ## Citation
146
+
147
+ ```bibtex
148
+ @inproceedings{peng2026uniar,
149
+ title={Unified Multimodal Autoregressive Modeling with Shared Context --- Visual Tokenizer is Key to Unification},
150
+ author={Peng, Wujian and Meng, Lingchen and Cai, Yuxuan and Zhuang, Xianwei and Yang, Yuhuan and Fang, Rongyao and Wu, Chenfei and Lin, Junyang and Wu, Zuxuan and Bai, Shuai},
151
+ booktitle={ICML},
152
+ year={2026}
153
+ }
154
+ ```