LewisYao commited on
Commit
9d7a4dd
·
verified ·
1 Parent(s): fa82be2

Sync README from GitHub repo

Browse files
Files changed (1) hide show
  1. README.md +397 -10
README.md CHANGED
@@ -1,19 +1,406 @@
1
- # Prompt Reinjection Rotations
2
 
3
- Released Procrustes rotation files for the ICML 2026 paper
4
- `Prompt Reinjection: Alleviating Prompt Forgetting in Multimodal Diffusion Transformers`.
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- ## Files
7
 
8
- - `sd3_coco5k_o1.pt`: SD3 rotation computed on COCO 5k, with origin layer 1 and target layers 2-23.
9
- - `flux_coco5k_o2.pt`: FLUX rotation computed on COCO 5k, with origin layer 2 and target layers 3-56.
10
 
11
- ## Usage
12
 
13
- Download the files and place them under `prompt_reinjection/rotations/` in the main repository:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  ```bash
16
- hf download LewisYao/PromptReinjection sd3_coco5k_o1.pt flux_coco5k_o2.pt --local-dir prompt_reinjection/rotations
 
 
 
 
 
 
 
 
17
  ```
18
 
19
- The default helper script in the main repository will pick them up automatically.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align='center'>[ICML 2026] Prompt Reinjection: Alleviating Prompt Forgetting in Multimodal Diffusion Transformers</h1>
2
 
3
+ <div align='center'>
4
+ <a href='https://yuxuan0919.github.io/' target='_blank'>Yuxuan Yao</a><sup>1,2,*</sup> &emsp;
5
+ <a href='https://scholar.google.com/scholar?q=Yuxuan+Chen' target='_blank'>Yuxuan Chen</a><sup>1,*</sup> &emsp;
6
+ <a href='https://openreview.net/profile?id=%7EHui_Li43' target='_blank'>Hui Li</a><sup>1</sup> &emsp;
7
+ <a href='https://scholar.google.com/scholar?q=Kaihui+Cheng' target='_blank'>Kaihui Cheng</a><sup>1</sup> &emsp;
8
+ <a href='https://scholar.google.com/scholar?q=Qipeng+Guo' target='_blank'>Qipeng Guo</a><sup>3</sup> &emsp;
9
+ <a href='https://scholar.google.com/scholar?q=Yuwei+Sun' target='_blank'>Yuwei Sun</a><sup>4</sup> &emsp;
10
+ </div>
11
+ <div align='center'>
12
+ <a href='https://scholar.google.com/scholar?q=Zilong+Dong' target='_blank'>Zilong Dong</a><sup>5</sup> &emsp;
13
+ <a href='https://scholar.google.cz/citations?user=z5SPCmgAAAAJ&hl=zh-CN' target='_blank'>Jingdong Wang</a><sup>6</sup> &emsp;
14
+ <a href='https://sites.google.com/site/zhusiyucs/home' target='_blank'>Siyu Zhu</a><sup>1,2</sup> &emsp;
15
+ </div>
16
 
 
17
 
 
 
18
 
19
+ <br>
20
 
21
+ <div align='center'>
22
+ <sup>1</sup>Fudan University &emsp;
23
+ <sup>2</sup>Shanghai Innovation Institute
24
+ </div>
25
+ <div align='center'>
26
+ <sup>3</sup>Shanghai AI Laboratory &emsp;
27
+ <sup>4</sup>Shanghai Academy of AI for Science &emsp;
28
+ <sup>5</sup>Alibaba Group &emsp;
29
+ <sup>6</sup>Baidu
30
+ </div>
31
+
32
+ <br>
33
+
34
+ <div align="center">
35
+
36
+ [![Paper](https://img.shields.io/badge/arXiv-2604.23632-b31b1b.svg)](https://arxiv.org/abs/2602.06886)
37
+ [![arXiv](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Model-yellow)](https://huggingface.co/fudan-generative-ai/PromptReinjection)
38
+ [![License](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE)
39
+
40
+ </div>
41
+
42
+
43
+ ## 📖 Introduction
44
+
45
+ Prompt Reinjection is a training-free inference method for multimodal diffusion transformers that mitigates prompt forgetting by reinjecting early-layer prompt features into deeper text layers, improving GenEval overall scores by `6.48%` on SD3.5-large and `7.75%` on HunyuanImage-2.1, while adding only about `0.00002x` block-level FLOPs for base reinjection and `0.088x` for the full aligned variant.
46
+
47
+ ## 📘 Overview
48
+
49
+ Prompt Reinjection starts from a simple observation: in multimodal diffusion transformers (MMDiTs) such as SD3-medium, SD3.5-large, FLUX.1, HunyuanImage-2.1, and Qwen-Image, prompt information fades as depth increases. That makes instruction following weaker, especially on position, attributes, counting, and long prompts.
50
+
51
+ ### Prompt Forgetting
52
+
53
+ Unlike traditional DiTs, where text serves as a relatively stable conditioning signal, MMDiTs jointly update text and image tokens throughout denoising, even though the text tokens receive no direct supervision. The paper shows that deeper text features gradually lose fine-grained prompt semantics, a phenomenon we call `prompt forgetting`.
54
+
55
+ <p align="center">
56
+ <img src="assets/prompt-forgetting.png" alt="Prompt Forgetting" width="320">
57
+ </p>
58
+
59
+ The figure above captures the core trend: for SD3, SD3.5, and FLUX, prompt information becomes less recoverable in deeper layers. This helps explain why base MMDiT models often miss spatial relations, attributes, and numeracy constraints in generation.
60
+
61
+ ### Prompt Reinjection
62
+
63
+ Prompt Reinjection fixes this at inference time. It takes semantically stronger text features from an early layer, aligns them to the deeper feature space, and reinjects them into later blocks so prompt constraints stay active through the full denoising stack.
64
+
65
+ <p align="center">
66
+ <img src="assets/prompt-reinjection.png" alt="Prompt Reinjection" width="320">
67
+ </p>
68
+
69
+ The method is training-free, lightweight, and easy to plug into the original MMDiT forward process. No retraining is required.
70
+
71
+ <p align="center">
72
+ <img src="assets/qualitative-compare.png" alt="Qualitative Comparison" width="640">
73
+ </p>
74
+
75
+ As shown above, Prompt Reinjection makes SD3.5, FLUX, and Qwen-Image follow prompt constraints more consistently across position, attribute, counting, and complex prompts. The paper reports consistent gains on GenEval, DPG-Bench, and T2I-CompBench++ while preserving overall generation quality.
76
+
77
+
78
+ ## 🚀 Quick Start
79
+
80
+ ### 1. Create an environment
81
+
82
+ We recommend Python `3.10+` and one environment per model family.
83
+
84
+ ```bash
85
+ python3.10 -m venv .venv
86
+ source .venv/bin/activate
87
+ pip install --upgrade pip
88
+ ```
89
+
90
+ Install dependencies with either an editable package install or the pinned root requirements:
91
+
92
+ ```bash
93
+ pip install -e .
94
+ ```
95
+
96
+ ```bash
97
+ pip install -r requirements.txt
98
+ ```
99
+
100
+ If you want a narrower per-model environment, install one model-specific requirement file instead:
101
+
102
+ ```bash
103
+ pip install -r requirements/sd3.txt
104
+ pip install -r requirements/sd3.5.txt
105
+ pip install -r requirements/flux.txt
106
+ pip install -r requirements/qwen.txt
107
+ pip install -r requirements/hunyuanimage.txt
108
+ ```
109
+
110
+ If you need a specific CUDA build, install `torch` and `torchvision` first from the official PyTorch channel, then rerun one of the commands above.
111
+
112
+ For HunyuanImage, install Tencent's official runtime before `requirements/hunyuanimage.txt`:
113
+
114
+ ```bash
115
+ git clone https://github.com/Tencent-Hunyuan/HunyuanImage-2.1.git
116
+ pip install -r HunyuanImage-2.1/requirements.txt
117
+ pip install flash-attn==2.7.3 --no-build-isolation
118
+ pip install -r requirements/hunyuanimage.txt
119
+ ```
120
+
121
+ ### 2. Prepare checkpoints
122
+
123
+ Pass model paths explicitly with `--model-path` in open-source usage.
124
+
125
+ - `sd3`, `sd3.5`, `flux`, `qwen`: `--model-path /path/to/model`
126
+ - `hunyuanimage`: `--model-path /path/to/HunyuanImage-2.1` and optional `--model-name`
127
+
128
+ For `hunyuanimage`, `--model-path` can point either to the HunyuanImage runtime root or to its `ckpts` directory.
129
+
130
+ ### 3. Run one prompt
131
+
132
+ The default helper script reads the released per-model inference and Prompt Reinjection settings from [prompt_reinjection/reinjection_configs.json](prompt_reinjection/reinjection_configs.json), so standard inference does not need manual residual arguments.
133
+
134
+ By default, all models now run without memory-saving inference shortcuts such as CPU offload, Hunyuan runtime offload, VAE slicing, VAE tiling, or attention slicing. This keeps the default path as close as possible to the original plain inference flow. These options are enabled only when you pass them explicitly.
135
+
136
+ ```bash
137
+ bash prompt_reinjection/test_reinjection.sh \
138
+ --model sd3 \
139
+ --model-path /path/to/SD3 \
140
+ --prompt "A photo of a couch below a potted plant."
141
+ ```
142
+
143
+ Supported `--model` values: `sd3`, `sd3.5`, `flux`, `qwen`, `hunyuanimage`.
144
+
145
+ To change the default released settings for a model, edit [prompt_reinjection/reinjection_configs.json](prompt_reinjection/reinjection_configs.json). The helper script and the benchmark entrypoints all read from the same file.
146
+
147
+ To run the plain base model instead of Prompt Reinjection:
148
+
149
+ ```bash
150
+ bash prompt_reinjection/test_reinjection.sh \
151
+ --model sd3 \
152
+ --model-path /path/to/SD3 \
153
+ --prompt "A photo of a couch below a potted plant." \
154
+ --reinjection off
155
+ ```
156
+
157
+ To enable memory-saving options explicitly:
158
+
159
+ - `flux` and `qwen`: pass `--cpu-offload model` or `--cpu-offload sequential`
160
+ - `flux` and `qwen`: optionally add `--vae-slicing`, `--vae-tiling`, or `--attention-slicing auto`
161
+ - `hunyuanimage`: pass `--enable-offload`
162
+
163
+ Example:
164
+
165
+ ```bash
166
+ bash prompt_reinjection/test_reinjection.sh \
167
+ --model flux \
168
+ --model-path /path/to/FLUX.1-dev \
169
+ --prompt "A photo of a couch below a potted plant." \
170
+ --cpu-offload model \
171
+ --vae-slicing
172
+ ```
173
+
174
+ ## 🧭 Manual SD3 Example
175
+
176
+ If you want to bypass the helper script and set the reinjection parameters manually, an SD3 example is:
177
+
178
+ ```bash
179
+ python -m prompt_reinjection.run_sample \
180
+ --model sd3 \
181
+ --model-path /path/to/SD3 \
182
+ --prompt "A photo of a couch below a potted plant." \
183
+ --output outputs/manual_sd3.png \
184
+ --steps 28 \
185
+ --cfg 7.0 \
186
+ --residual_origin_layer 1 \
187
+ --residual_target_layers $(seq 2 23) \
188
+ --residual_weights 0.025 \
189
+ --residual_use_anchoring 1 \
190
+ --residual_procrustes_path prompt_reinjection/rotations/sd3_coco5k_o1.pt
191
+ ```
192
+
193
+ The same default memory policy also applies to the Python entrypoints. If you do not pass a memory-saving flag explicitly, the run stays on the plain inference path. For example, FLUX and Qwen only enable Diffusers offload when you pass `--cpu-offload model` or `--cpu-offload sequential`, and HunyuanImage only enables its runtime offload when you pass `--enable-offload`.
194
+
195
+ ## 🧮 Procrustes Precomputation
196
+
197
+ We use COCO 5k for the released Procrustes statistics.
198
+
199
+ For open-source usage, we recommend:
200
+
201
+ - `sd3` and `flux`: use the released Procrustes-aligned Prompt Reinjection settings.
202
+ - `sd3.5`, `qwen`, and `hunyuanimage`: use the most basic Prompt Reinjection variant without anchoring and without rotation. It already works well and adds almost zero inference cost.
203
+
204
+ ### SD3
205
 
206
  ```bash
207
+ python SD3/compute.py \
208
+ --model /path/to/SD3 \
209
+ --dataset coco \
210
+ --datadir data \
211
+ --num-samples 5000 \
212
+ --origin-layer 1 \
213
+ --target-layer-start 2 \
214
+ --col-center \
215
+ --output outputs/procrustes_rotations/sd3_coco5k_o1.pt
216
  ```
217
 
218
+ ### FLUX
219
+
220
+ ```bash
221
+ python FLUX/compute.py \
222
+ --model /path/to/FLUX.1-dev \
223
+ --dataset coco \
224
+ --datadir data \
225
+ --num-samples 5000 \
226
+ --origin-layer 2 \
227
+ --target-layer-start 3 \
228
+ --col-center \
229
+ --output outputs/procrustes_rotations/flux_coco5k_o2.pt
230
+ ```
231
+
232
+ ## 🤗 Released Rotations
233
+
234
+ The released Procrustes rotations are hosted at [LewisYao/PromptReinjection](https://huggingface.co/fudan-generative-ai/PromptReinjection):
235
+
236
+ - [`sd3_coco5k_o1.pt`](https://huggingface.co/fudan-generative-ai/PromptReinjection/blob/main/sd3_coco5k_o1.pt)
237
+ - [`flux_coco5k_o2.pt`](https://huggingface.co/fudan-generative-ai/PromptReinjection/blob/main/flux_coco5k_o2.pt)
238
+
239
+ Download them to `prompt_reinjection/rotations/` with:
240
+
241
+ ```bash
242
+ hf download LewisYao/PromptReinjection \
243
+ sd3_coco5k_o1.pt \
244
+ flux_coco5k_o2.pt \
245
+ --local-dir prompt_reinjection/rotations
246
+ ```
247
+
248
+ The default helper script will pick them up automatically once they are placed under `prompt_reinjection/rotations/`.
249
+
250
+ ## 📊 Evaluation
251
+
252
+ The benchmark scripts below also read the released per-model defaults from [prompt_reinjection/reinjection_configs.json](prompt_reinjection/reinjection_configs.json). By default, they run with `--reinjection on`. Use `--reinjection off` for the plain base model, or edit the config file if you want to change the released settings globally.
253
+
254
+ Like the helper script, these benchmark entrypoints do not enable CPU offload, Hunyuan runtime offload, VAE slicing, VAE tiling, or attention slicing unless you pass those flags explicitly.
255
+
256
+ ### GenEval
257
+
258
+ ```bash
259
+ python -m prompt_reinjection.run_geneval \
260
+ --model sd3 \
261
+ --model-path /path/to/SD3 \
262
+ --metadata_file /path/to/geneval/metadata.jsonl \
263
+ --outdir outputs/geneval_sd3
264
+ ```
265
+
266
+ Base-model variant:
267
+
268
+ ```bash
269
+ python -m prompt_reinjection.run_geneval \
270
+ --model sd3 \
271
+ --model-path /path/to/SD3 \
272
+ --metadata_file /path/to/geneval/metadata.jsonl \
273
+ --outdir outputs/geneval_sd3_base \
274
+ --reinjection off
275
+ ```
276
+
277
+ ### DPG-Bench
278
+
279
+ ```bash
280
+ python -m prompt_reinjection.run_dpg \
281
+ --model sd3.5 \
282
+ --model-path /path/to/SD3.5-large \
283
+ --prompt_dir /path/to/dpg/prompts \
284
+ --save_dir outputs/dpg_sd35
285
+ ```
286
+
287
+ ### T2I-CompBench++
288
+
289
+ ```bash
290
+ python -m prompt_reinjection.run_t2i \
291
+ --model qwen \
292
+ --model-path /path/to/Qwen-Image \
293
+ --dataset_dir /path/to/t2i-compbench/prompts \
294
+ --outdir_base outputs/t2i_qwen
295
+ ```
296
+
297
+ These scripts generate benchmark-format images. Final scoring should still be done with the official benchmark evaluators.
298
+
299
+ ## 🧩 Apply to New Models
300
+
301
+ Prompt Reinjection is designed for MMDiT-style open-source models where text features evolve inside the denoising transformer together with visual features. If your new model follows this pattern, you can usually add a basic reinjection version with only a small amount of integration work.
302
+
303
+ To plug a new model into this framework, the minimum steps are:
304
+
305
+ - Add a new model folder with an `adapter.py` that implements the adapter interface used in [prompt_reinjection/adapter_api.py](prompt_reinjection/adapter_api.py) and follows the existing examples in [SD3/adapter.py](SD3/adapter.py), [FLUX/adapter.py](FLUX/adapter.py), and [Qwen/adapter.py](Qwen/adapter.py).
306
+ - Make the model pipeline or transformer expose `set_residual_config(...)` so it can receive `residual_origin_layer`, `residual_target_layers`, `residual_weights`, `residual_use_anchoring`, and `residual_rotation_matrices`, as shown in [SD3/pipeline.py](SD3/pipeline.py), [FLUX/pipeline.py](FLUX/pipeline.py), and [Qwen/pipeline.py](Qwen/pipeline.py).
307
+ - Register the new adapter in [prompt_reinjection/registry.py](prompt_reinjection/registry.py) so it becomes available to `run_sample`, `run_geneval`, `run_dpg`, `run_t2i`, and `compute_procrustes`.
308
+ - If you want rotation-based alignment later, also provide a model-specific `compute.py` and expose it as the adapter’s `compute_script`, so `python -m prompt_reinjection.compute_procrustes --model YOUR_MODEL ...` can dispatch correctly.
309
+
310
+ For a new model, we recommend starting from the most basic reinjection setting first:
311
+
312
+ - `origin = 1`
313
+ - `target = 2-last`
314
+ - `weight = 0.025`
315
+ - `no anchoring`
316
+ - `no rotation`
317
+
318
+ In practice, that means using the shallowest stable MMDiT block as the source, reinjecting into all later blocks, and keeping the setup as lightweight as possible. If your model has `L` blocks indexed from `0` to `L-1`, the default starting rule is:
319
+
320
+ ```text
321
+ residual_origin_layer = 1
322
+ residual_target_layers = [2, 3, ..., L-1]
323
+ residual_weights = 0.025
324
+ residual_use_anchoring = 0
325
+ residual_procrustes_path = ""
326
+ ```
327
+
328
+ A typical manual run looks like this after the model has been integrated into the registry:
329
+
330
+ ```bash
331
+ python -m prompt_reinjection.run_sample \
332
+ --model your_model \
333
+ --model-path /path/to/your/model \
334
+ --prompt "A photo of a couch below a potted plant." \
335
+ --output outputs/your_model_base_reinjection.png \
336
+ --residual_origin_layer 1 \
337
+ --residual_target_layers $(seq 2 LAST_LAYER) \
338
+ --residual_weights 0.025 \
339
+ --residual_use_anchoring 0 \
340
+ --residual_procrustes_path ""
341
+ ```
342
+
343
+ Replace `LAST_LAYER` with the last text-processing block index of your model. For example, if the model has 24 blocks indexed from `0` to `23`, use `$(seq 2 23)`.
344
+
345
+ Once this basic version runs and already improves instruction following, the next recommended upgrades are:
346
+
347
+ - Turn on anchoring first by setting `--residual_use_anchoring 1`. This is usually the safest first upgrade when the model shows cross-layer scale or shift mismatch.
348
+ - If you want further gains, add rotation-based alignment by computing a Procrustes file on a prompt set such as COCO-5K and passing it through `--residual_procrustes_path`. This is useful when shallow and deep text features differ not only in scale, but also in feature geometry.
349
+
350
+ In short, the recommended order for a new model is: first make basic reinjection work, then try anchoring, and only then add rotation if you want the strongest alignment.
351
+
352
+ ## 🔐 Security
353
+
354
+ Run a secret scan before any public push. `gitleaks` does not ship on PyPI, so install it from the official release or your system package manager instead of `pip install gitleaks`.
355
+
356
+ For a full Git history scan:
357
+
358
+ ```bash
359
+ gitleaks detect --source . -v
360
+ ```
361
+
362
+ For a working-tree-only scan:
363
+
364
+ ```bash
365
+ gitleaks detect --source . --no-git -v
366
+ ```
367
+
368
+ This repository also includes a `.pre-commit-config.yaml` entry for `gitleaks`:
369
+
370
+ ```bash
371
+ pip install pre-commit
372
+ pre-commit install
373
+ pre-commit run --all-files
374
+ ```
375
+
376
+ In addition to automated secret scanning, avoid committing local absolute paths, internal hostnames, or private checkpoint locations.
377
+
378
+ ## ⚖️ License
379
+
380
+ The code in this repository is released under the [MIT License](LICENSE).
381
+
382
+ Prompt Reinjection is an inference-time code intervention that plugs into external model pipelines. This repository does not include, redistribute, or relicense any third-party model weights. Upstream weights and runtimes remain subject to their original licenses, including:
383
+
384
+ - SD3 and SD3.5: Stability AI Community License
385
+ - FLUX.1-dev: FLUX.1 `[dev]` Non-Commercial License
386
+ - HunyuanImage-2.1: Tencent Hunyuan Community License Agreement
387
+ - Qwen-Image: Apache License 2.0
388
+
389
+ Users are responsible for ensuring that their checkpoint download, local use, fine-tuning, serving, and redistribution flows comply with the corresponding upstream terms. The MIT License in this repository applies only to this repository's code and included assets.
390
+
391
+ ## 📝 Citation
392
+
393
+ If you find this project useful, please cite the ICML 2026 paper:
394
+
395
+ ```bibtex
396
+ @inproceedings{yao2026prompt,
397
+ title={Prompt Reinjection: Alleviating Prompt Forgetting in Multimodal Diffusion Transformers},
398
+ author={Yao, Yuxuan and Chen, Yuxuan and Li, Hui and Cheng, Kaihui and Guo, Qipeng and Sun, Yuwei and Dong, Zilong and Wang, Jingdong and Zhu, Siyu},
399
+ booktitle={International Conference on Machine Learning (ICML)},
400
+ year={2026}
401
+ }
402
+ ```
403
+
404
+ ## 🤗 Acknowledgements
405
+
406
+ We thank the open-source communities behind SD3, SD3.5, FLUX, Qwen-Image, and HunyuanImage. This release builds on their public model and runtime ecosystems to make Prompt Reinjection reproducible in open source.