README.md CHANGED
@@ -1,3 +1,115 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TAG-MoE: Task-Aware Gating for Unified Generative Mixture-of-Experts
2
+
3
+ > **TAG-MoE: Task-Aware Gating for Unified Generative Mixture-of-Experts**<br>
4
+ > Yu Xu<sup>1,2†</sup>, Hongbin Yan<sup>1</sup>, Juan Cao<sup>1</sup>, Yiji Cheng<sup>2</sup>, Tiankai Hang<sup>2</sup>, Runze He<sup>2</sup>, Zijin Yin<sup>2</sup>, Shiyi Zhang<sup>2</sup>, Yuxin Zhang<sup>1</sup>, Jintao Li<sup>1</sup>, Chunyu Wang<sup>2‡</sup>, Qinglin Lu<sup>2</sup>, Tong-Yee Lee<sup>3</sup>, Fan Tang<sup>1§</sup><br>
5
+ > <sup>1</sup>University of Chinese Academy of Sciences, <sup>2</sup>Tencent Hunyuan, <sup>3</sup>National Cheng-Kung University
6
+
7
+ <a href='https://arxiv.org/abs/2601.08881'><img src='https://img.shields.io/badge/ArXiv-2505.20271-red'></a>
8
+ <a href='https://yuci-gpt.github.io/TAG-MoE/'><img src='https://img.shields.io/badge/Project%20Page-Homepage-green'></a>
9
+ <a href='https://github.com/ICTMCG/TAG-MoE'><img src='https://img.shields.io/badge/github-repo-blue?logo=github'></a>
10
+
11
+ ![](https://raw.githubusercontent.com/yuci-gpt/TAG-MoE/refs/heads/master/static/images/teaser.png)
12
+
13
+ > **Abstract**:<br>
14
+ > Unified image generation and editing models suffer from severe task interference in dense diffusion transformers architectures, where a shared parameter space must compromise between conflicting objectives (e.g., local editing v.s. subject-driven generation). While the sparse Mixture-of-Experts (MoE) paradigm is a promising solution, its gating networks remain task-agnostic, operating based on local features, unaware of global task intent. This task-agnostic nature prevents meaningful specialization and fails to resolve the underlying task interference. In this paper, we propose a novel framework to inject semantic intent into MoE routing. We introduce a Hierarchical Task Semantic Annotation scheme to create structured task descriptors (e.g., scope, type, preservation). We then design Predictive Alignment Regularization to align internal routing decisions with the task's high-level semantics. This regularization evolves the gating network from a task-agnostic executor to a dispatch center. Our model effectively mitigates task interference, outperforming dense baselines in fidelity and quality, and our analysis shows that experts naturally develop clear and semantically correlated specializations.
15
+
16
+ ---
17
+
18
+ ## 🔧 Environment Setup
19
+
20
+ We recommend using [uv](https://docs.astral.sh/uv/) with the provided `pyproject.toml` / `uv.lock`.
21
+
22
+ ### 1. Install uv
23
+
24
+ ```bash
25
+ curl -LsSf https://astral.sh/uv/install.sh | sh
26
+ ```
27
+
28
+ ### 2. Create and activate virtual environment
29
+
30
+ ```bash
31
+ git clone https://github.com/ICTMCG/TAG-MoE.git && cd TAG-MoE
32
+ uv venv --python 3.12
33
+ source .venv/bin/activate
34
+ ```
35
+
36
+ ### 3. Install dependencies
37
+
38
+ ```bash
39
+ UV_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cu126 uv sync
40
+ ```
41
+
42
+ ---
43
+
44
+ ## 📦 Model Weights
45
+
46
+ - **Base model**: [Qwen/Qwen-Image](https://huggingface.co/Qwen/Qwen-Image)
47
+ - **TAG-MoE weights**: [YUXU915/TAG-MoE](https://huggingface.co/YUXU915/TAG-MoE)
48
+ -->
49
+
50
+ ---
51
+
52
+ ## 🚀 Inference
53
+
54
+ > **Note**: TAG-MoE inference requires **60GB+ available VRAM**.
55
+
56
+ ### Image Generation
57
+
58
+ ```bash
59
+ uv run python infer.py \
60
+ --pretrained_model_path Qwen/Qwen-Image \
61
+ --transformer_model_path YUXU915/TAG-MoE \
62
+ --device 0,1 \
63
+ --prompt "A cinematic portrait of a futuristic astronaut, ultra-detailed." \
64
+ --width 1024 \
65
+ --height 1024 \
66
+ --output result.png
67
+ ```
68
+
69
+ ### Image Editing
70
+
71
+ ```bash
72
+ uv run python infer.py \
73
+ --pretrained_model_path Qwen/Qwen-Image \
74
+ --transformer_model_path YUXU915/TAG-MoE \
75
+ --device 0,1 \
76
+ --image input.jpg \
77
+ --prompt "Add the red and teal-colored text 'TAG-MoE' onto the airship." \
78
+ --output result.png
79
+ ```
80
+
81
+ ### WebUI
82
+
83
+ ```bash
84
+ uv run python run_gradio.py \
85
+ --pretrained_model_path Qwen/Qwen-Image \
86
+ --transformer_model_path YUXU915/TAG-MoE \
87
+ --device 0,1
88
+ ```
89
+
90
+ ---
91
+
92
+ ## 📄 Citation
93
+
94
+ If you find this work useful, please consider citing:
95
+
96
+ ```bibtex
97
+ @misc{xu2026tagmoetaskawaregatingunified,
98
+ title={TAG-MoE: Task-Aware Gating for Unified Generative Mixture-of-Experts},
99
+ author={Yu Xu and Hongbin Yan and Juan Cao and Yiji Cheng and Tiankai Hang and Runze He and Zijin Yin and Shiyi Zhang and Yuxin Zhang and Jintao Li and Chunyu Wang and Qinglin Lu and Tong-Yee Lee and Fan Tang},
100
+ year={2026},
101
+ eprint={2601.08881},
102
+ archivePrefix={arXiv},
103
+ primaryClass={cs.CV},
104
+ url={https://arxiv.org/abs/2601.08881},
105
+ }
106
+ ```
107
+
108
+ ---
109
+
110
+ ## 🙏 Acknowledgements
111
+
112
+ This project builds upon the following excellent open-source works:
113
+
114
+ - **Diffusers** — https://github.com/huggingface/diffusers
115
+ - **MegaBlocks** — https://github.com/databricks/megablocks
transformer/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "QwenImageTransformer2DModel",
3
+ "_diffusers_version": "0.35.1",
4
+ "attention_head_dim": 128,
5
+ "axes_dims_rope": [
6
+ 16,
7
+ 56,
8
+ 56
9
+ ],
10
+ "guidance_embeds": false,
11
+ "in_channels": 64,
12
+ "joint_attention_dim": 3584,
13
+ "num_attention_heads": 24,
14
+ "num_layers": 60,
15
+ "out_channels": 16,
16
+ "patch_size": 2,
17
+ "pooled_projection_dim": 768,
18
+ "scale_rope": true,
19
+ "tag_embedding_dim": 512,
20
+ "tag_vocab_size": 18,
21
+ "router_hidden_dim": 256,
22
+ "moe_num_experts": 4,
23
+ "runtime_moe_replace_from_layer": 50,
24
+ "runtime_moe_replace_num_layers": 10,
25
+ "runtime_moe_target": "transformer_blocks.*.img_mlp"
26
+ }
transformer/diffusion_pytorch_model-00001-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fd12225bb30fd0cdd813d18b649bf7027ee4bfeee471d31b7bf74afcdc1d491
3
+ size 4817101176
transformer/diffusion_pytorch_model-00002-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4dd2be866d18a5c6271bef5a9df88284979bb251a76941264f101bf5933d2d21
3
+ size 4814308040
transformer/diffusion_pytorch_model-00003-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24e87d5ca9c5ba65b47bc4d6f4a1cea9bbee0cd0792436eb3f5d689e32beb096
3
+ size 4814306584
transformer/diffusion_pytorch_model-00004-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fc559053efbfc5e8308d38ff4071a99d67f48c1870342e7a4ae9de31546b308
3
+ size 4795444264
transformer/diffusion_pytorch_model-00005-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:624b9c59e47d75188b10cc028e38dbca4aab4ab95cf0648404268493f01ed010
3
+ size 4757664352
transformer/diffusion_pytorch_model-00006-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1275dad160dbd5d8cc29b10773cddf5f39c26bd8f7c7a5c54af1da3364e35f5a
3
+ size 4757664384
transformer/diffusion_pytorch_model-00007-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae292c74b963876bc9ba73d615ac8eb1bdf4003cb6f9928a5e0ea4a8e746185e
3
+ size 4719649200
transformer/diffusion_pytorch_model-00008-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14a5eb07f6232c340a13cf0c83fbc059c14d66820bd39d9b98a7886be31b82a5
3
+ size 4794904752
transformer/diffusion_pytorch_model-00009-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eae85d1ca09ecf0288a6d7e546ef848be7059b0ef867c177a5767b624f5b06f7
3
+ size 4794932600
transformer/diffusion_pytorch_model-00010-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d20439b2a51b1177689708432bbfcb81e84e6ea0723dab0c8ea3d9b59a1b1f9e
3
+ size 2325317296
transformer/diffusion_pytorch_model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff