YUXU915 Yana-Hangabina commited on
Commit
201d119
·
1 Parent(s): 15ebffa

Release online demo (#1)

Browse files

- Add Space-only TAG-MoE demo files (5fd79e9e5e8c19cf4331bd7032513e54449736dd)
- Remove cached pyc from Space branch (e7ff53c37003024226f12a82aae43b181739dc64)
- Add Space README metadata header (6428707556418f3ebf89879804f2e61e892cf6f1)


Co-authored-by: Yana-Hangabina <Yana-Hangabina@users.noreply.huggingface.co>

README.md CHANGED
@@ -1,14 +1,15 @@
1
  ---
2
- title: TAG MoE
3
- emoji: 📚
4
  colorFrom: red
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 6.1.0
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: TAG-MoE:Task-Aware Gating for Unified Generative Mixture-of-
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: TAG-MoE
3
+ emoji: 🎨
4
  colorFrom: red
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 5.49.1
8
+ python_version: 3.10
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
12
+ short_description: Task-Aware Gating for Unified Generative Mixture-of-Experts
13
  ---
14
 
15
+ TAG-MoE Space demo.
app.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import threading
3
+
4
+ import gradio as gr
5
+
6
+ from src.utils.device_utils import resolve_device_ids
7
+ from src.utils.inference_config import (
8
+ DEFAULT_HEIGHT,
9
+ DEFAULT_NEGATIVE_PROMPT,
10
+ DEFAULT_NUM_INFERENCE_STEPS,
11
+ DEFAULT_SEED,
12
+ DEFAULT_TRUE_CFG_SCALE,
13
+ DEFAULT_WIDTH,
14
+ generate_random_seed,
15
+ )
16
+
17
+ try:
18
+ import spaces
19
+ except ImportError:
20
+ spaces = None
21
+
22
+
23
+ def _env_bool(name: str, default: bool = False) -> bool:
24
+ value = os.getenv(name)
25
+ if value is None:
26
+ return default
27
+ return value.strip().lower() in {"1", "true", "yes", "on"}
28
+
29
+
30
+ def _env_int(name: str, default: int) -> int:
31
+ value = os.getenv(name)
32
+ if value is None or not value.strip():
33
+ return default
34
+ return int(value.strip())
35
+
36
+
37
+ PRETRAINED_MODEL_PATH = os.getenv("PRETRAINED_MODEL_PATH", "Qwen/Qwen-Image")
38
+ TRANSFORMER_MODEL_PATH = os.getenv("TRANSFORMER_MODEL_PATH", "YUXU915/TAG-MoE")
39
+ TRANSFORMER_WEIGHT_NAME = os.getenv("TRANSFORMER_WEIGHT_NAME", "diffusion_pytorch_model.safetensors")
40
+ TRANSFORMER_SUBFOLDER = os.getenv("TRANSFORMER_SUBFOLDER", "transformer")
41
+ TRANSFORMER_REVISION = os.getenv("TRANSFORMER_REVISION", "").strip() or None
42
+ LOCAL_FILES_ONLY = _env_bool("LOCAL_FILES_ONLY", default=False)
43
+ TAGMOE_DEVICE = os.getenv("TAGMOE_DEVICE", "auto").strip().lower()
44
+ ZERO_GPU_DURATION = _env_int("ZERO_GPU_DURATION", default=300)
45
+
46
+ LINKS_HTML = """
47
+ <div class="tagmoe-links">
48
+ <a href="https://yuci-gpt.github.io/TAG-MoE/" target="_blank" rel="noopener noreferrer">Project Homepage</a>
49
+ <a href="https://arxiv.org/abs/2601.08881" target="_blank" rel="noopener noreferrer">Paper (arXiv)</a>
50
+ <a href="https://github.com/ICTMCG/TAG-MoE" target="_blank" rel="noopener noreferrer">GitHub Repo</a>
51
+ <a href="https://huggingface.co/YUXU915/TAG-MoE" target="_blank" rel="noopener noreferrer">Model Weights</a>
52
+ </div>
53
+ """
54
+
55
+ _RUNTIME_LOCK = threading.Lock()
56
+ _PIPELINE = None
57
+ _BASE64_TO_IMAGE_FN = None
58
+
59
+
60
+ def _resolve_runtime_device_ids():
61
+ if TAGMOE_DEVICE in {"", "auto", "default"}:
62
+ import torch
63
+
64
+ return [0] if torch.cuda.is_available() else []
65
+ if TAGMOE_DEVICE in {"none", "framework"}:
66
+ return None
67
+ return resolve_device_ids(TAGMOE_DEVICE)
68
+
69
+
70
+ def _ensure_runtime_loaded():
71
+ global _PIPELINE, _BASE64_TO_IMAGE_FN
72
+
73
+ if _PIPELINE is not None and _BASE64_TO_IMAGE_FN is not None:
74
+ return _PIPELINE, _BASE64_TO_IMAGE_FN
75
+
76
+ with _RUNTIME_LOCK:
77
+ if _PIPELINE is not None and _BASE64_TO_IMAGE_FN is not None:
78
+ return _PIPELINE, _BASE64_TO_IMAGE_FN
79
+
80
+ from src.infer_tagmoe import End2End, base64_to_image
81
+
82
+ device_ids = _resolve_runtime_device_ids()
83
+ _PIPELINE = End2End(
84
+ pretrained_model_path=PRETRAINED_MODEL_PATH,
85
+ transformer_model_path=TRANSFORMER_MODEL_PATH,
86
+ device_ids=device_ids,
87
+ transformer_weight_name=TRANSFORMER_WEIGHT_NAME,
88
+ transformer_subfolder=TRANSFORMER_SUBFOLDER,
89
+ transformer_revision=TRANSFORMER_REVISION,
90
+ local_files_only=LOCAL_FILES_ONLY,
91
+ )
92
+ _BASE64_TO_IMAGE_FN = base64_to_image
93
+ return _PIPELINE, _BASE64_TO_IMAGE_FN
94
+
95
+
96
+ class LazyPipelineProxy:
97
+ def predict(self, input_dict):
98
+ pipeline, _ = _ensure_runtime_loaded()
99
+ return pipeline.predict(input_dict)
100
+
101
+
102
+ def _lazy_base64_to_image(data):
103
+ _, base64_to_image_fn = _ensure_runtime_loaded()
104
+ return base64_to_image_fn(data)
105
+
106
+
107
+ def _infer_decorator():
108
+ if spaces is None:
109
+ return lambda fn: fn
110
+ return spaces.GPU(duration=ZERO_GPU_DURATION)
111
+
112
+
113
+ def build_demo(gr, pipeline, base64_to_image_fn):
114
+ def infer(
115
+ image,
116
+ prompt,
117
+ negative_prompt,
118
+ seed,
119
+ gen_width,
120
+ gen_height,
121
+ cfg_scale,
122
+ inference_steps,
123
+ ):
124
+ if prompt is None or not str(prompt).strip():
125
+ raise gr.Error("Prompt cannot be empty.")
126
+ if image is None:
127
+ raise gr.Error("Image is required.")
128
+
129
+ width_value = int(gen_width) if gen_width is not None else int(image.size[0])
130
+ height_value = int(gen_height) if gen_height is not None else int(image.size[1])
131
+ input_dict = {
132
+ "image": image.convert("RGB"),
133
+ "prompt": str(prompt).strip(),
134
+ "negative_prompt": str(negative_prompt or DEFAULT_NEGATIVE_PROMPT),
135
+ "seed": int(seed if seed is not None else DEFAULT_SEED),
136
+ "target_width": width_value,
137
+ "target_height": height_value,
138
+ "true_cfg_scale": float(cfg_scale),
139
+ "num_inference_steps": int(inference_steps),
140
+ "keep_original_size": False,
141
+ }
142
+ result = pipeline.predict(input_dict)
143
+ out_image = base64_to_image_fn(result["generate_imgs_buffer"][0])
144
+ return out_image, int(result["seed"])
145
+
146
+ def randomize_seed():
147
+ return generate_random_seed()
148
+
149
+ def on_image_upload(image):
150
+ if image is None:
151
+ return gr.update(), gr.update()
152
+ return int(image.size[0]), int(image.size[1])
153
+
154
+ title_html = """
155
+ <div class="tagmoe-header">
156
+ <picture>
157
+ <source srcset="https://raw.githubusercontent.com/yuci-gpt/TAG-MoE/refs/heads/master/static/images/logo_dark.png" media="(prefers-color-scheme: dark)">
158
+ <img src="https://raw.githubusercontent.com/yuci-gpt/TAG-MoE/refs/heads/master/static/images/logo_light.png" alt="TAG-MoE logo">
159
+ </picture>
160
+ <div>
161
+ <h1>TAG-MoE</h1>
162
+ <p>Task-Aware Gating for Unified Generative Mixture-of-Experts</p>
163
+ </div>
164
+ </div>
165
+ """
166
+
167
+ custom_css = """
168
+ .tagmoe-header {
169
+ display: flex;
170
+ align-items: center;
171
+ gap: 12px;
172
+ margin-bottom: 8px;
173
+ }
174
+ .tagmoe-header img {
175
+ width: 48px;
176
+ height: 48px;
177
+ object-fit: contain;
178
+ }
179
+ .tagmoe-header h1 {
180
+ margin: 0;
181
+ font-size: 1.8rem;
182
+ }
183
+ .tagmoe-header p {
184
+ margin: 0;
185
+ opacity: 0.85;
186
+ font-size: 0.95rem;
187
+ }
188
+ .param-card {
189
+ border: 1px solid var(--border-color-primary);
190
+ border-radius: 12px;
191
+ padding: 14px 14px 10px;
192
+ margin-bottom: 10px;
193
+ }
194
+ .param-card .gradio-textbox textarea {
195
+ min-height: 110px !important;
196
+ }
197
+ .run-btn button {
198
+ height: 46px !important;
199
+ font-weight: 600;
200
+ }
201
+ .image-panel {
202
+ border: 1px solid var(--border-color-primary);
203
+ border-radius: 12px;
204
+ padding: 10px;
205
+ }
206
+ .tool-btn {
207
+ margin-top: 28px !important;
208
+ min-width: 42px !important;
209
+ height: 42px !important;
210
+ padding: 0 !important;
211
+ display: flex;
212
+ align-items: center;
213
+ justify-content: center;
214
+ flex-shrink: 0;
215
+ }
216
+ .tagmoe-links {
217
+ margin: 6px 0 14px 0;
218
+ display: flex;
219
+ flex-wrap: wrap;
220
+ gap: 12px;
221
+ font-size: 0.95rem;
222
+ }
223
+ .tagmoe-links a {
224
+ text-decoration: none;
225
+ }
226
+ """
227
+
228
+ infer_fn = _infer_decorator()(infer)
229
+ with gr.Blocks(title="TAG-MoE Space Demo", css=custom_css) as demo:
230
+ gr.HTML(title_html)
231
+ gr.HTML(LINKS_HTML)
232
+
233
+ with gr.Row(equal_height=True):
234
+ with gr.Column(scale=1, elem_classes=["image-panel"]):
235
+ image_input = gr.Image(type="pil", label="Input Image", height=520)
236
+ with gr.Column(scale=1, elem_classes=["image-panel"]):
237
+ image_output = gr.Image(type="pil", label="Output Image", height=520)
238
+
239
+ with gr.Group(elem_classes=["param-card"]):
240
+ prompt_input = gr.Textbox(
241
+ label="Prompt",
242
+ placeholder="Describe the instruction",
243
+ lines=3,
244
+ )
245
+ negative_prompt_input = gr.Textbox(
246
+ label="Negative Prompt",
247
+ value=DEFAULT_NEGATIVE_PROMPT,
248
+ lines=2,
249
+ placeholder="Optional negative prompt",
250
+ )
251
+ with gr.Row():
252
+ gen_width_input = gr.Slider(minimum=64, maximum=4096, step=1, value=DEFAULT_WIDTH, label="Width")
253
+ gen_height_input = gr.Slider(minimum=64, maximum=4096, step=1, value=DEFAULT_HEIGHT, label="Height")
254
+ with gr.Row():
255
+ cfg_scale_input = gr.Slider(
256
+ minimum=1.0,
257
+ maximum=10.0,
258
+ step=0.1,
259
+ value=DEFAULT_TRUE_CFG_SCALE,
260
+ label="CFG Scale",
261
+ )
262
+ inference_steps_input = gr.Slider(
263
+ minimum=10,
264
+ maximum=100,
265
+ step=1,
266
+ value=DEFAULT_NUM_INFERENCE_STEPS,
267
+ label="Inference Steps",
268
+ )
269
+ with gr.Column(scale=1, min_width=200):
270
+ with gr.Row():
271
+ seed_input = gr.Number(
272
+ label="Seed",
273
+ value=generate_random_seed(),
274
+ precision=0,
275
+ scale=1,
276
+ )
277
+ random_seed_btn = gr.Button(
278
+ "🎲",
279
+ elem_classes=["tool-btn"],
280
+ scale=0,
281
+ min_width=42,
282
+ variant="secondary",
283
+ )
284
+ run_btn = gr.Button("Run Inference", variant="primary", elem_classes=["run-btn"])
285
+
286
+ run_btn.click(
287
+ fn=infer_fn,
288
+ inputs=[
289
+ image_input,
290
+ prompt_input,
291
+ negative_prompt_input,
292
+ seed_input,
293
+ gen_width_input,
294
+ gen_height_input,
295
+ cfg_scale_input,
296
+ inference_steps_input,
297
+ ],
298
+ outputs=[image_output, seed_input],
299
+ )
300
+ image_input.change(
301
+ fn=on_image_upload,
302
+ inputs=[image_input],
303
+ outputs=[gen_width_input, gen_height_input],
304
+ )
305
+ random_seed_btn.click(fn=randomize_seed, outputs=[seed_input])
306
+
307
+ return demo
308
+
309
+
310
+ demo = build_demo(gr, LazyPipelineProxy(), _lazy_base64_to_image)
311
+ demo.queue(default_concurrency_limit=1, max_size=8)
312
+
313
+
314
+ if __name__ == "__main__":
315
+ demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu126
2
+
3
+ accelerate==1.10.1
4
+ diffusers @ git+https://github.com/huggingface/diffusers.git@0e12ba74542c6ecb02719ec3e5c6e993b85556e3
5
+ gradio>=5.49.1,<6
6
+ grouped-gemm==0.3.0
7
+ loguru>=0.7.3
8
+ megablocks==0.10.0
9
+ numpy<2.1.0
10
+ pillow>=12.1.1
11
+ qwen-vl-utils>=0.0.14
12
+ safetensors>=0.7.0
13
+ spaces>=0.35.0
14
+ torch==2.7.0
15
+ torchvision==0.22.0
16
+ transformers==4.56.2
17
+ triton>=3.3.0
src/infer_tagmoe.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import os
4
+ import time
5
+ from functools import partial
6
+
7
+ from loguru import logger
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from PIL import Image
13
+
14
+ from src.utils.device_utils import build_accelerate_max_memory_map
15
+ from src.utils.inference_config import (
16
+ DEFAULT_NEGATIVE_PROMPT,
17
+ DEFAULT_NUM_INFERENCE_STEPS,
18
+ DEFAULT_SEED,
19
+ DEFAULT_TRUE_CFG_SCALE,
20
+ generate_random_seed,
21
+ normalize_negative_prompt,
22
+ )
23
+ from src.models.transformer_qwenimage_tagmoe import QwenImageTransformer2DModel, TRANSFORMER_NUM_LAYERS, MOE_NUM_EXPERTS
24
+ from src.pipelines.pipeline_qwenimage_tagmoe import QwenImagePipeline
25
+
26
+
27
+ def image_to_byte_array(image: Image) -> bytes:
28
+ imgByteArr = io.BytesIO()
29
+ image.save(imgByteArr, format="PNG")
30
+ imgByteArr = imgByteArr.getvalue()
31
+ return imgByteArr
32
+
33
+
34
+ def image_to_base64(image: Image) -> str:
35
+ return base64.b64encode(image_to_byte_array(image)).decode()
36
+
37
+
38
+ def base64_to_image(base64_str: str) -> Image:
39
+ return Image.open(io.BytesIO(base64.b64decode(base64_str))).convert("RGB")
40
+
41
+
42
+ PREFERRED_QWENIMAGE_RESOLUTIONS = [
43
+ (512, 2048),
44
+ (512, 1984),
45
+ (512, 1920),
46
+ (512, 1856),
47
+ (512, 1792),
48
+ (512, 1728),
49
+ (512, 1664),
50
+ (512, 1600),
51
+ (512, 1536),
52
+ (576, 1472),
53
+ (640, 1408),
54
+ (704, 1344),
55
+ (768, 1280),
56
+ (832, 1216),
57
+ (896, 1152),
58
+ (960, 1088),
59
+ (1024, 1024),
60
+ (1088, 960),
61
+ (1152, 896),
62
+ (1216, 832),
63
+ (1280, 768),
64
+ (1344, 704),
65
+ (1408, 640),
66
+ (1472, 576),
67
+ (1536, 512),
68
+ (1600, 512),
69
+ (1664, 512),
70
+ (1728, 512),
71
+ (1792, 512),
72
+ (1856, 512),
73
+ (1920, 512),
74
+ (1984, 512),
75
+ (2048, 512),
76
+ ]
77
+
78
+
79
+ QWEN_IMAGE_TRANSFORMER_BLOCK_DIM = 3072
80
+ SEMANTIC_DIM = 512
81
+ TAG_DICT = {
82
+ "local editing": 0,
83
+ "global editing": 1,
84
+ "multi region editing": 2,
85
+ "viewpoint editing": 3,
86
+ "content customization": 4,
87
+ "style customization": 5,
88
+ "object editing": 6,
89
+ "attribute editing": 7,
90
+ "style transfer": 8,
91
+ "pose editing": 9,
92
+ "background editing": 10,
93
+ "illumination editing": 11,
94
+ "structure preservation": 12,
95
+ "background preservation": 13,
96
+ "identity preservation": 14,
97
+ "face preservation": 15,
98
+ "style preservation": 16,
99
+ "image generation": 17,
100
+ }
101
+
102
+
103
+ class PredictionHead(nn.Module):
104
+ def __init__(self, gating_dim: int = 4, semantic_dim: int = 512, hidden_dim: int = 256):
105
+ super().__init__()
106
+ self.net = nn.Sequential(
107
+ nn.Linear(gating_dim, hidden_dim),
108
+ nn.GELU(),
109
+ nn.Linear(hidden_dim, hidden_dim),
110
+ nn.GELU(),
111
+ nn.Linear(hidden_dim, semantic_dim),
112
+ )
113
+
114
+ def forward(self, g: torch.Tensor) -> torch.Tensor:
115
+ return self.net(g)
116
+
117
+
118
+ class End2End:
119
+ def __init__(
120
+ self,
121
+ pretrained_model_path,
122
+ transformer_model_path=None,
123
+ rank=0,
124
+ device_ids=None,
125
+ transformer_weight_name: str = "diffusion_pytorch_model.safetensors",
126
+ transformer_subfolder: str | None = "transformer",
127
+ transformer_revision: str | None = None,
128
+ local_files_only: bool = False,
129
+ ):
130
+ self.device_ids = self._resolve_device_ids(rank, device_ids)
131
+ self.is_multi_gpu = len(self.device_ids) > 1
132
+
133
+ self.device, self.generator_device, torch_dtype = self._resolve_runtime_device()
134
+ transformer = self._build_runtime_transformer(pretrained_model_path, torch_dtype)
135
+
136
+ self.pipe = QwenImagePipeline.from_pretrained(
137
+ pretrained_model_path,
138
+ transformer=transformer,
139
+ torch_dtype=torch_dtype,
140
+ )
141
+
142
+ self.pipe.init_custom(
143
+ transformer_model_path,
144
+ weight_name=transformer_weight_name,
145
+ subfolder=transformer_subfolder,
146
+ revision=transformer_revision,
147
+ local_files_only=local_files_only,
148
+ )
149
+ if self.is_multi_gpu:
150
+ self._enable_multi_gpu_dispatch(torch_dtype=torch_dtype)
151
+ else:
152
+ self.pipe = self.pipe.to(self.device)
153
+
154
+ @staticmethod
155
+ def _resolve_device_ids(rank, device_ids):
156
+ if device_ids is None:
157
+ return [rank] if torch.cuda.is_available() else []
158
+ return list(device_ids)
159
+
160
+ def _resolve_runtime_device(self):
161
+ if len(self.device_ids) > 0 and torch.cuda.is_available():
162
+ primary_gpu = self.device_ids[0]
163
+ torch.cuda.set_device(primary_gpu)
164
+ device = f"cuda:{primary_gpu}"
165
+ return device, device, torch.bfloat16
166
+ return "cpu", "cpu", torch.float32
167
+
168
+ def _build_runtime_transformer(self, pretrained_model_path, torch_dtype):
169
+ transformer = QwenImageTransformer2DModel.from_pretrained(
170
+ pretrained_model_path,
171
+ subfolder="transformer",
172
+ torch_dtype=torch_dtype,
173
+ )
174
+ self._replace_mlp_with_runtime_moe(transformer)
175
+ self._attach_tag_modules(transformer)
176
+ return transformer
177
+
178
+ def _build_moe_args(self):
179
+ from megablocks.layers.arguments import Arguments
180
+
181
+ return Arguments(
182
+ hidden_size=QWEN_IMAGE_TRANSFORMER_BLOCK_DIM,
183
+ ffn_hidden_size=QWEN_IMAGE_TRANSFORMER_BLOCK_DIM * 4,
184
+ num_layers=TRANSFORMER_NUM_LAYERS,
185
+ bias=True,
186
+ activation_fn=partial(F.gelu, approximate="tanh"),
187
+ moe_num_experts=MOE_NUM_EXPERTS,
188
+ moe_top_k=1,
189
+ moe_loss_weight=0.01,
190
+ moe_capacity_factor=1.25,
191
+ mlp_type="mlp",
192
+ shared_expert=False,
193
+ mlp_impl="grouped",
194
+ init_method=nn.init.xavier_uniform_,
195
+ moe_expert_model_parallelism=False,
196
+ expert_parallel_group=None,
197
+ fp16=False,
198
+ bf16=True,
199
+ device=self.device,
200
+ )
201
+
202
+ def _replace_mlp_with_runtime_moe(self, transformer):
203
+ from megablocks.layers.dmoe import dMoE
204
+
205
+ moe_args = self._build_moe_args()
206
+ replace_from_layer = 60 - TRANSFORMER_NUM_LAYERS
207
+ replace_paths = []
208
+ for name, _ in transformer.named_modules():
209
+ if not name.startswith("transformer_blocks.") or not name.endswith("img_mlp"):
210
+ continue
211
+ block_idx = int(name.split(".")[1])
212
+ if block_idx >= replace_from_layer:
213
+ replace_paths.append(name)
214
+
215
+ for path in replace_paths:
216
+ parent_name, child_name = path.rsplit(".", 1)
217
+ parent_module = transformer.get_submodule(parent_name)
218
+ setattr(parent_module, child_name, dMoE(moe_args))
219
+
220
+ def _attach_tag_modules(self, transformer):
221
+ transformer.tag_embedding = nn.Embedding(len(TAG_DICT), SEMANTIC_DIM)
222
+ transformer.router_head = PredictionHead(
223
+ gating_dim=MOE_NUM_EXPERTS,
224
+ semantic_dim=SEMANTIC_DIM,
225
+ hidden_dim=256,
226
+ )
227
+
228
+ def _enable_multi_gpu_dispatch(self, torch_dtype):
229
+ from accelerate import dispatch_model, infer_auto_device_map
230
+
231
+ free_bytes_by_device = {}
232
+ for device_id in self.device_ids:
233
+ free_bytes, _ = torch.cuda.mem_get_info(device_id)
234
+ free_bytes_by_device[device_id] = free_bytes
235
+ max_memory = build_accelerate_max_memory_map(self.device_ids, free_bytes_by_device)
236
+
237
+ transformer_device_map = infer_auto_device_map(
238
+ self.pipe.transformer,
239
+ max_memory=max_memory,
240
+ no_split_module_classes=["QwenImageTransformerBlock"],
241
+ dtype=torch_dtype,
242
+ )
243
+
244
+ offload_dir = None
245
+ if any(device == "disk" for device in transformer_device_map.values()):
246
+ offload_dir = os.path.join("/tmp", "tag_moe_offload")
247
+ os.makedirs(offload_dir, exist_ok=True)
248
+
249
+ self.pipe.transformer = dispatch_model(
250
+ self.pipe.transformer,
251
+ device_map=transformer_device_map,
252
+ offload_dir=offload_dir,
253
+ )
254
+
255
+ text_encoder_device = f"cuda:{self.device_ids[-1]}"
256
+ self.pipe.text_encoder = self.pipe.text_encoder.to(text_encoder_device)
257
+ self.pipe.vae = self.pipe.vae.to(self.device)
258
+
259
+
260
+ def predict(self, input_dict):
261
+ out_dict = {}
262
+
263
+ start_time = time.time()
264
+ image = input_dict.get("image")
265
+ if image is None:
266
+ raise ValueError("Input image is required.")
267
+ seed = int(input_dict.get("seed", DEFAULT_SEED))
268
+ prompt = input_dict.get("prompt", "")
269
+ negative_prompt = normalize_negative_prompt(
270
+ input_dict.get("negative_prompt", DEFAULT_NEGATIVE_PROMPT)
271
+ )
272
+ num_inference_steps = int(
273
+ input_dict.get("num_inference_steps", DEFAULT_NUM_INFERENCE_STEPS)
274
+ )
275
+ true_cfg_scale = float(
276
+ input_dict.get("true_cfg_scale", DEFAULT_TRUE_CFG_SCALE)
277
+ )
278
+ target_height = input_dict.get("target_height", None)
279
+ target_width = input_dict.get("target_width", None)
280
+ keep_original_size = bool(input_dict.get("keep_original_size", False))
281
+ has_custom_target = target_height is not None or target_width is not None
282
+
283
+ if seed < 0:
284
+ seed = generate_random_seed()
285
+ out_dict["seed"] = seed
286
+
287
+ cond_image = image
288
+ w_ori, h_ori = cond_image.size
289
+ original_size = (w_ori, h_ori)
290
+
291
+ white_bg = Image.new("RGB", cond_image.size, (255, 255, 255))
292
+ if cond_image.mode == "RGBA":
293
+ result = Image.alpha_composite(white_bg.convert("RGBA"), cond_image)
294
+ cond_image = result.convert("RGB")
295
+ else:
296
+ cond_image = cond_image.convert("RGB")
297
+
298
+ aspect_ratio = w_ori / h_ori
299
+ _, snap_width, snap_height = min(
300
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS
301
+ )
302
+ cond_image = cond_image.resize((snap_width, snap_height), Image.LANCZOS)
303
+
304
+ if target_height is None:
305
+ target_height = snap_height
306
+ if target_width is None:
307
+ target_width = snap_width
308
+
309
+ out_image_pil = self.pipe(
310
+ prompt=prompt,
311
+ negative_prompt=negative_prompt,
312
+ width=target_width,
313
+ height=target_height,
314
+ num_inference_steps=num_inference_steps,
315
+ true_cfg_scale=true_cfg_scale,
316
+ generator=torch.Generator(device=self.generator_device).manual_seed(seed),
317
+ cond_image=cond_image,
318
+ ).images[0]
319
+
320
+ if keep_original_size and original_size is not None and not has_custom_target:
321
+ out_image_pil = out_image_pil.resize(original_size, Image.LANCZOS)
322
+
323
+ out_dict["generate_imgs_buffer"] = [image_to_base64(out_image_pil)]
324
+ logger.info(f"Generation time: {time.time()-start_time:.2f}s")
325
+ return out_dict
src/models/transformer_qwenimage_tagmoe.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+
17
+ import functools
18
+ import math
19
+ from typing import Any, Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from functools import partial
25
+
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
28
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
29
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
30
+ from diffusers.models.attention import FeedForward
31
+ from diffusers.models.attention_dispatch import dispatch_attention_fn
32
+ from diffusers.models.attention_processor import Attention
33
+ from diffusers.models.cache_utils import CacheMixin
34
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
35
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
36
+ from diffusers.models.modeling_utils import ModelMixin
37
+ from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
38
+
39
+ from megablocks.layers.moe import MoE
40
+ from megablocks.layers.dmoe import dMoE
41
+ from megablocks.layers.arguments import Arguments
42
+
43
+ from src.utils.device_utils import maybe_set_cuda_device_from_tensor
44
+
45
+
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+ TRANSFORMER_NUM_LAYERS = 10
50
+ TRANSFORMER_BLOCK_BAR = 60 - TRANSFORMER_NUM_LAYERS
51
+ MOE_NUM_EXPERTS = 4
52
+
53
+
54
+ def get_timestep_embedding(
55
+ timesteps: torch.Tensor,
56
+ embedding_dim: int,
57
+ flip_sin_to_cos: bool = False,
58
+ downscale_freq_shift: float = 1,
59
+ scale: float = 1,
60
+ max_period: int = 10000,
61
+ ) -> torch.Tensor:
62
+ """
63
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
64
+
65
+ Args
66
+ timesteps (torch.Tensor):
67
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
68
+ embedding_dim (int):
69
+ the dimension of the output.
70
+ flip_sin_to_cos (bool):
71
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
72
+ downscale_freq_shift (float):
73
+ Controls the delta between frequencies between dimensions
74
+ scale (float):
75
+ Scaling factor applied to the embeddings.
76
+ max_period (int):
77
+ Controls the maximum frequency of the embeddings
78
+ Returns
79
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
80
+ """
81
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
82
+
83
+ half_dim = embedding_dim // 2
84
+ exponent = -math.log(max_period) * torch.arange(
85
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
86
+ )
87
+ exponent = exponent / (half_dim - downscale_freq_shift)
88
+
89
+ emb = torch.exp(exponent).to(timesteps.dtype)
90
+ emb = timesteps[:, None].float() * emb[None, :]
91
+
92
+ # scale embeddings
93
+ emb = scale * emb
94
+
95
+ # concat sine and cosine embeddings
96
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
97
+
98
+ # flip sine and cosine embeddings
99
+ if flip_sin_to_cos:
100
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
101
+
102
+ # zero pad
103
+ if embedding_dim % 2 == 1:
104
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
105
+ return emb
106
+
107
+
108
+ def apply_rotary_emb_qwen(
109
+ x: torch.Tensor,
110
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
111
+ use_real: bool = True,
112
+ use_real_unbind_dim: int = -1,
113
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
114
+ """
115
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
116
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
117
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
118
+ tensors contain rotary embeddings and are returned as real tensors.
119
+
120
+ Args:
121
+ x (`torch.Tensor`):
122
+ Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
123
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
124
+
125
+ Returns:
126
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
127
+ """
128
+ if use_real:
129
+ cos, sin = freqs_cis # [S, D]
130
+ cos = cos[None, None]
131
+ sin = sin[None, None]
132
+ cos, sin = cos.to(x.device), sin.to(x.device)
133
+
134
+ if use_real_unbind_dim == -1:
135
+ # Used for flux, cogvideox, hunyuan-dit
136
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
137
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
138
+ elif use_real_unbind_dim == -2:
139
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
140
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
141
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
142
+ else:
143
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
144
+
145
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
146
+
147
+ return out
148
+ else:
149
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
150
+ freqs_cis = freqs_cis.unsqueeze(1)
151
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
152
+
153
+ return x_out.type_as(x)
154
+
155
+
156
+ class QwenTimestepProjEmbeddings(nn.Module):
157
+ def __init__(self, embedding_dim):
158
+ super().__init__()
159
+
160
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
161
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
162
+
163
+ def forward(self, timestep, hidden_states):
164
+ timesteps_proj = self.time_proj(timestep)
165
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
166
+
167
+ conditioning = timesteps_emb
168
+
169
+ return conditioning
170
+
171
+
172
+ class QwenEmbedRope(nn.Module):
173
+ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
174
+ super().__init__()
175
+ self.theta = theta
176
+ self.axes_dim = axes_dim
177
+ pos_index = torch.arange(4096)
178
+ neg_index = torch.arange(4096).flip(0) * -1 - 1
179
+ self.pos_freqs = torch.cat(
180
+ [
181
+ self.rope_params(pos_index, self.axes_dim[0], self.theta),
182
+ self.rope_params(pos_index, self.axes_dim[1], self.theta),
183
+ self.rope_params(pos_index, self.axes_dim[2], self.theta),
184
+ ],
185
+ dim=1,
186
+ )
187
+ self.neg_freqs = torch.cat(
188
+ [
189
+ self.rope_params(neg_index, self.axes_dim[0], self.theta),
190
+ self.rope_params(neg_index, self.axes_dim[1], self.theta),
191
+ self.rope_params(neg_index, self.axes_dim[2], self.theta),
192
+ ],
193
+ dim=1,
194
+ )
195
+ self.rope_cache = {}
196
+ self.cond_rope_cache = {}
197
+
198
+ # 是否使用 scale rope
199
+ self.scale_rope = scale_rope
200
+
201
+ def rope_params(self, index, dim, theta=10000):
202
+ """
203
+ Args:
204
+ index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
205
+ """
206
+ assert dim % 2 == 0
207
+ freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
208
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
209
+ return freqs
210
+
211
+ def forward(self, video_fhw, txt_seq_lens, device):
212
+ """
213
+ Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
214
+ txt_length: [bs] a list of 1 integers representing the length of the text
215
+ """
216
+ if self.pos_freqs.device != device:
217
+ self.pos_freqs = self.pos_freqs.to(device)
218
+ self.neg_freqs = self.neg_freqs.to(device)
219
+
220
+ if isinstance(video_fhw, list):
221
+ video_fhw = video_fhw[0]
222
+ frame, height, width = video_fhw
223
+ rope_key = f"{frame}_{height}_{width}"
224
+
225
+ if rope_key not in self.rope_cache:
226
+ seq_lens = frame * height * width
227
+ freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
228
+ freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
229
+ freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
230
+ if self.scale_rope:
231
+ freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
232
+ freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
233
+ freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
234
+ freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
235
+
236
+ else:
237
+ freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
238
+ freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
239
+
240
+ freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
241
+ self.rope_cache[rope_key] = freqs.clone().contiguous()
242
+ vid_freqs = self.rope_cache[rope_key]
243
+
244
+ if self.scale_rope:
245
+ max_vid_index = max(height // 2, width // 2)
246
+ else:
247
+ max_vid_index = max(height, width)
248
+
249
+ max_len = max(txt_seq_lens)
250
+ txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
251
+
252
+ return vid_freqs, txt_freqs
253
+
254
+ def get_img_rope(self, video_fhw, device, frame_idx=0):
255
+ if self.pos_freqs.device != device:
256
+ self.pos_freqs = self.pos_freqs.to(device)
257
+ self.neg_freqs = self.neg_freqs.to(device)
258
+
259
+ if isinstance(video_fhw, list):
260
+ video_fhw = video_fhw[0]
261
+ frame, height, width = video_fhw
262
+ rope_key = f"{frame}_{height}_{width}_{frame_idx}"
263
+
264
+ assert frame == 1
265
+
266
+ if rope_key not in self.cond_rope_cache:
267
+ seq_lens = frame * height * width
268
+ freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
269
+ freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
270
+ freqs_frame = freqs_pos[0][frame_idx:frame_idx+1].view(frame, 1, 1, -1).expand(frame, height, width, -1)
271
+ if self.scale_rope:
272
+ freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
273
+ freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
274
+ freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
275
+ freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
276
+
277
+ else:
278
+ freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
279
+ freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
280
+
281
+ freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
282
+ self.cond_rope_cache[rope_key] = freqs.clone().contiguous()
283
+ vid_freqs = self.cond_rope_cache[rope_key]
284
+
285
+ return vid_freqs
286
+
287
+ def get_img_rope_by_bbox(self, video_fhw, bbox, device):
288
+ if self.pos_freqs.device != device:
289
+ self.pos_freqs = self.pos_freqs.to(device)
290
+ self.neg_freqs = self.neg_freqs.to(device)
291
+
292
+ if isinstance(video_fhw, list):
293
+ video_fhw = video_fhw[0]
294
+ frame, height, width = video_fhw
295
+
296
+ x1, y1, x2, y2 = bbox
297
+
298
+ x0 = -(width - width // 2)
299
+ y0 = -(height - height // 2)
300
+
301
+ seq_lens = frame * ((y2-y1)+1) * ((x2-x1)+1)
302
+ freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
303
+ freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
304
+ freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, (y2-y1)+1, (x2-x1)+1, -1)
305
+
306
+ index_height_neg = [y + y0 for y in range(y1, y2 + 1, 1) if (y + y0) < 0]
307
+ index_height_pos = [y + y0 for y in range(y1, y2 + 1, 1) if (y + y0) >= 0]
308
+ freqs_height = torch.cat([freqs_neg[1][index_height_neg], freqs_pos[1][index_height_pos]], dim=0)
309
+ freqs_height = freqs_height.view(1, (y2-y1)+1, 1, -1).expand(frame, (y2-y1)+1, (x2-x1)+1, -1)
310
+
311
+ index_width_neg = [x + x0 for x in range(x1, x2 + 1, 1) if (x + x0) < 0]
312
+ index_width_pos = [x + x0 for x in range(x1, x2 + 1, 1) if (x + x0) >= 0]
313
+ freqs_width = torch.cat([freqs_neg[2][index_width_neg], freqs_pos[2][index_width_pos]], dim=0)
314
+ freqs_width = freqs_width.view(1, 1, (x2-x1)+1, -1).expand(frame, (y2-y1)+1, (x2-x1)+1, -1)
315
+
316
+ freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
317
+ vid_freqs = freqs
318
+
319
+ return vid_freqs
320
+
321
+
322
+ class QwenDoubleStreamAttnProcessor2_0:
323
+ """
324
+ Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
325
+ implements joint attention computation where text and image streams are processed together.
326
+ """
327
+
328
+ _attention_backend = None
329
+
330
+ def __init__(self):
331
+ if not hasattr(F, "scaled_dot_product_attention"):
332
+ raise ImportError(
333
+ "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
334
+ )
335
+
336
+ def __call__(
337
+ self,
338
+ attn: Attention,
339
+ hidden_states: torch.FloatTensor, # Image stream
340
+ encoder_hidden_states: torch.FloatTensor = None, # Text stream
341
+ encoder_hidden_states_mask: torch.FloatTensor = None,
342
+ attention_mask: Optional[torch.FloatTensor] = None,
343
+ image_rotary_emb: Optional[torch.Tensor] = None,
344
+ ) -> torch.FloatTensor:
345
+ if encoder_hidden_states is None:
346
+ raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
347
+
348
+ seq_txt = encoder_hidden_states.shape[1]
349
+
350
+ # Compute QKV for image stream (sample projections)
351
+ img_query = attn.to_q(hidden_states)
352
+ img_key = attn.to_k(hidden_states)
353
+ img_value = attn.to_v(hidden_states)
354
+
355
+ # Compute QKV for text stream (context projections)
356
+ txt_query = attn.add_q_proj(encoder_hidden_states)
357
+ txt_key = attn.add_k_proj(encoder_hidden_states)
358
+ txt_value = attn.add_v_proj(encoder_hidden_states)
359
+
360
+ # Reshape for multi-head attention
361
+ img_query = img_query.unflatten(-1, (attn.heads, -1))
362
+ img_key = img_key.unflatten(-1, (attn.heads, -1))
363
+ img_value = img_value.unflatten(-1, (attn.heads, -1))
364
+
365
+ txt_query = txt_query.unflatten(-1, (attn.heads, -1))
366
+ txt_key = txt_key.unflatten(-1, (attn.heads, -1))
367
+ txt_value = txt_value.unflatten(-1, (attn.heads, -1))
368
+
369
+ # Apply QK normalization
370
+ if attn.norm_q is not None:
371
+ img_query = attn.norm_q(img_query)
372
+ if attn.norm_k is not None:
373
+ img_key = attn.norm_k(img_key)
374
+ if attn.norm_added_q is not None:
375
+ txt_query = attn.norm_added_q(txt_query)
376
+ if attn.norm_added_k is not None:
377
+ txt_key = attn.norm_added_k(txt_key)
378
+
379
+ # Apply RoPE
380
+ if image_rotary_emb is not None:
381
+ img_freqs, txt_freqs = image_rotary_emb
382
+ img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
383
+ img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
384
+ txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
385
+ txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
386
+
387
+ # Concatenate for joint attention
388
+ # Order: [text, image]
389
+ joint_query = torch.cat([txt_query, img_query], dim=1)
390
+ joint_key = torch.cat([txt_key, img_key], dim=1)
391
+ joint_value = torch.cat([txt_value, img_value], dim=1)
392
+
393
+ # Compute joint attention
394
+ joint_hidden_states = dispatch_attention_fn(
395
+ joint_query,
396
+ joint_key,
397
+ joint_value,
398
+ attn_mask=attention_mask,
399
+ dropout_p=0.0,
400
+ is_causal=False,
401
+ backend=self._attention_backend,
402
+ )
403
+
404
+ # Reshape back
405
+ joint_hidden_states = joint_hidden_states.flatten(2, 3)
406
+ joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
407
+
408
+ # Split attention outputs back
409
+ txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
410
+ img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
411
+
412
+ # Apply output projections
413
+ img_attn_output = attn.to_out[0](img_attn_output)
414
+ if len(attn.to_out) > 1:
415
+ img_attn_output = attn.to_out[1](img_attn_output) # dropout
416
+
417
+ txt_attn_output = attn.to_add_out(txt_attn_output)
418
+
419
+ return img_attn_output, txt_attn_output
420
+
421
+
422
+ @maybe_allow_in_graph
423
+ class QwenImageTransformerBlock(nn.Module):
424
+ def __init__(
425
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, block_index: int, qk_norm: str = "rms_norm", eps: float = 1e-6
426
+ ):
427
+ super().__init__()
428
+
429
+ self.dim = dim
430
+ self.num_attention_heads = num_attention_heads
431
+ self.attention_head_dim = attention_head_dim
432
+
433
+ # Image processing modules
434
+ self.img_mod = nn.Sequential(
435
+ nn.SiLU(),
436
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
437
+ )
438
+ self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
439
+ self.attn = Attention(
440
+ query_dim=dim,
441
+ cross_attention_dim=None, # Enable cross attention for joint computation
442
+ added_kv_proj_dim=dim, # Enable added KV projections for text stream
443
+ dim_head=attention_head_dim,
444
+ heads=num_attention_heads,
445
+ out_dim=dim,
446
+ context_pre_only=False,
447
+ bias=True,
448
+ processor=QwenDoubleStreamAttnProcessor2_0(),
449
+ qk_norm=qk_norm,
450
+ eps=eps,
451
+ )
452
+ self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
453
+ self.block_index = block_index
454
+ if block_index < TRANSFORMER_BLOCK_BAR: # Replace last part of layers with MoE
455
+ self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
456
+ else:
457
+ self.moe_args = Arguments(
458
+ hidden_size=dim,
459
+ ffn_hidden_size=dim*4, # Keep ffn_hidden_size consistent with FeedForward mult=4
460
+ num_layers=TRANSFORMER_NUM_LAYERS, # Number of MoE layers
461
+ bias=True,
462
+ activation_fn=partial(F.gelu, approximate='tanh'), # Keep consistent with FeedForward
463
+ moe_num_experts=MOE_NUM_EXPERTS, # Number of experts; adjust as needed
464
+ moe_top_k=1, # Top-k experts per token (1 means top-1)
465
+ moe_loss_weight=0.01, # Load balancing loss weight
466
+ moe_capacity_factor=1.25, # Capacity factor for handling load imbalance
467
+ mlp_type="mlp",
468
+ shared_expert=False, # Do not use shared experts
469
+ mlp_impl="grouped", # Use 'grouped' implementation
470
+ init_method=nn.init.xavier_uniform_,
471
+ memory_optimized_mlp=True, # Optimize MLP activation memory
472
+
473
+ moe_expert_model_parallelism=False,
474
+ expert_parallel_group=None,
475
+
476
+ fp16=False,
477
+ bf16=True,
478
+ )
479
+ self.img_mlp = dMoE(self.moe_args)
480
+
481
+ # Text processing modules
482
+ self.txt_mod = nn.Sequential(
483
+ nn.SiLU(),
484
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
485
+ )
486
+ self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
487
+ # Text doesn't need separate attention - it's handled by img_attn joint computation
488
+ self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
489
+ self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
490
+
491
+ def _modulate(self, x, mod_params):
492
+ """Apply modulation to input tensor"""
493
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
494
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
495
+
496
+ def forward(
497
+ self,
498
+ hidden_states: torch.Tensor,
499
+ encoder_hidden_states: torch.Tensor,
500
+ encoder_hidden_states_mask: torch.Tensor,
501
+ temb: torch.Tensor,
502
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
503
+ img_shapes=None,
504
+ timestep=None,
505
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
506
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
507
+ maybe_set_cuda_device_from_tensor(hidden_states)
508
+
509
+ # Get modulation parameters for both streams
510
+ img_mod_params = self.img_mod(temb) # [B, 6*dim]
511
+ txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
512
+
513
+ # Split modulation parameters for norm1 and norm2
514
+ img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
515
+ txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
516
+
517
+ # Process image stream - norm1 + modulation
518
+ img_normed = self.img_norm1(hidden_states)
519
+ img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
520
+
521
+ # Process text stream - norm1 + modulation
522
+ txt_normed = self.txt_norm1(encoder_hidden_states)
523
+ txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
524
+
525
+ # Use QwenAttnProcessor2_0 for joint attention computation
526
+ # This directly implements the DoubleStreamLayerMegatron logic:
527
+ # 1. Computes QKV for both streams
528
+ # 2. Applies QK normalization and RoPE
529
+ # 3. Concatenates and runs joint attention
530
+ # 4. Splits results back to separate streams
531
+ joint_attention_kwargs = joint_attention_kwargs or {}
532
+ attn_output = self.attn(
533
+ hidden_states=img_modulated, # Image stream (will be processed as "sample")
534
+ encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
535
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
536
+ image_rotary_emb=image_rotary_emb,
537
+ **joint_attention_kwargs,
538
+ )
539
+
540
+ # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
541
+ img_attn_output, txt_attn_output = attn_output
542
+
543
+ # Apply attention gates and add residual (like in Megatron)
544
+ hidden_states = hidden_states + img_gate1 * img_attn_output
545
+ encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
546
+
547
+ # Process image stream - norm2 + MLP
548
+ img_normed2 = self.img_norm2(hidden_states)
549
+ img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
550
+
551
+ if self.block_index < TRANSFORMER_BLOCK_BAR:
552
+ img_mlp_output = self.img_mlp(img_modulated2)
553
+ else:
554
+ # dMoE.forward returns (output, bias) due to return_bias=True default
555
+ img_mlp_output = self.img_mlp(img_modulated2)[0]
556
+ hidden_states = hidden_states + img_gate2 * img_mlp_output
557
+
558
+ # Process text stream - norm2 + MLP
559
+ txt_normed2 = self.txt_norm2(encoder_hidden_states)
560
+ txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
561
+ txt_mlp_output = self.txt_mlp(txt_modulated2)
562
+ encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
563
+
564
+ # Clip to prevent overflow for fp16
565
+ if encoder_hidden_states.dtype == torch.float16:
566
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
567
+ if hidden_states.dtype == torch.float16:
568
+ hidden_states = hidden_states.clip(-65504, 65504)
569
+
570
+ return encoder_hidden_states, hidden_states
571
+
572
+
573
+ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
574
+ """
575
+ The Transformer model introduced in Qwen.
576
+
577
+ Args:
578
+ patch_size (`int`, defaults to `2`):
579
+ Patch size to turn the input data into small patches.
580
+ in_channels (`int`, defaults to `64`):
581
+ The number of channels in the input.
582
+ out_channels (`int`, *optional*, defaults to `None`):
583
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
584
+ num_layers (`int`, defaults to `60`):
585
+ The number of layers of dual stream DiT blocks to use.
586
+ attention_head_dim (`int`, defaults to `128`):
587
+ The number of dimensions to use for each attention head.
588
+ num_attention_heads (`int`, defaults to `24`):
589
+ The number of attention heads to use.
590
+ joint_attention_dim (`int`, defaults to `3584`):
591
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
592
+ `encoder_hidden_states`).
593
+ guidance_embeds (`bool`, defaults to `False`):
594
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
595
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
596
+ The dimensions to use for the rotary positional embeddings.
597
+ """
598
+
599
+ _supports_gradient_checkpointing = True
600
+ _no_split_modules = ["QwenImageTransformerBlock"]
601
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
602
+
603
+ @register_to_config
604
+ def __init__(
605
+ self,
606
+ patch_size: int = 2,
607
+ in_channels: int = 64,
608
+ out_channels: Optional[int] = 16,
609
+ num_layers: int = 60,
610
+ attention_head_dim: int = 128,
611
+ num_attention_heads: int = 24,
612
+ joint_attention_dim: int = 3584,
613
+ guidance_embeds: bool = False, # TODO: this should probably be removed
614
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
615
+ ):
616
+ super().__init__()
617
+ self.out_channels = out_channels or in_channels
618
+ self.inner_dim = num_attention_heads * attention_head_dim
619
+
620
+ self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
621
+
622
+ self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
623
+
624
+ self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
625
+
626
+ self.img_in = nn.Linear(in_channels, self.inner_dim)
627
+ self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
628
+
629
+ self.transformer_blocks = nn.ModuleList(
630
+ [
631
+ QwenImageTransformerBlock(
632
+ dim=self.inner_dim,
633
+ num_attention_heads=num_attention_heads,
634
+ attention_head_dim=attention_head_dim,
635
+ block_index=block_index,
636
+ )
637
+ for block_index in range(num_layers)
638
+ ]
639
+ )
640
+
641
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
642
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
643
+
644
+ self.gradient_checkpointing = False
645
+
646
+ def forward(
647
+ self,
648
+ hidden_states: torch.Tensor,
649
+ encoder_hidden_states: torch.Tensor = None,
650
+ encoder_hidden_states_mask: torch.Tensor = None,
651
+ timestep: torch.LongTensor = None,
652
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
653
+ txt_seq_lens: Optional[List[int]] = None,
654
+ guidance: torch.Tensor = None, # TODO: this should probably be removed
655
+ attention_kwargs: Optional[Dict[str, Any]] = None,
656
+ return_dict: bool = True,
657
+ cond_hidden_states = None,
658
+ cond_rope = None,
659
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
660
+ """
661
+ The [`QwenTransformer2DModel`] forward method.
662
+
663
+ Args:
664
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
665
+ Input `hidden_states`.
666
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
667
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
668
+ encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
669
+ Mask of the input conditions.
670
+ timestep ( `torch.LongTensor`):
671
+ Used to indicate denoising step.
672
+ attention_kwargs (`dict`, *optional*):
673
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
674
+ `self.processor` in
675
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
676
+ return_dict (`bool`, *optional*, defaults to `True`):
677
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
678
+ tuple.
679
+
680
+ Returns:
681
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
682
+ `tuple` where the first element is the sample tensor.
683
+ """
684
+ if attention_kwargs is not None:
685
+ attention_kwargs = attention_kwargs.copy()
686
+ lora_scale = attention_kwargs.pop("scale", 1.0)
687
+ else:
688
+ lora_scale = 1.0
689
+
690
+ if USE_PEFT_BACKEND:
691
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
692
+ scale_lora_layers(self, lora_scale)
693
+ else:
694
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
695
+ logger.warning(
696
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
697
+ )
698
+
699
+ if cond_hidden_states is not None:
700
+ length_raw_hidden_states = hidden_states.shape[1]
701
+ hidden_states = torch.cat([hidden_states, cond_hidden_states], dim=1)
702
+
703
+ hidden_states = self.img_in(hidden_states)
704
+
705
+ timestep = timestep.to(hidden_states.dtype)
706
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
707
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
708
+
709
+ if guidance is not None:
710
+ guidance = guidance.to(hidden_states.dtype) * 1000
711
+
712
+ temb = (
713
+ self.time_text_embed(timestep, hidden_states)
714
+ if guidance is None
715
+ else self.time_text_embed(timestep, guidance, hidden_states)
716
+ )
717
+
718
+ image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
719
+ if cond_rope is not None:
720
+ img_freqs, txt_freqs = image_rotary_emb
721
+ img_freqs = torch.cat([img_freqs, cond_rope], dim=0)
722
+ image_rotary_emb = img_freqs, txt_freqs
723
+
724
+ for index_block, block in enumerate(self.transformer_blocks):
725
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
726
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
727
+ block,
728
+ hidden_states,
729
+ encoder_hidden_states,
730
+ encoder_hidden_states_mask,
731
+ temb,
732
+ image_rotary_emb,
733
+ attention_kwargs,
734
+ )
735
+ else:
736
+ encoder_hidden_states, hidden_states = block(
737
+ hidden_states=hidden_states,
738
+ encoder_hidden_states=encoder_hidden_states,
739
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
740
+ temb=temb,
741
+ image_rotary_emb=image_rotary_emb,
742
+ img_shapes=img_shapes,
743
+ timestep=timestep,
744
+ joint_attention_kwargs=attention_kwargs,
745
+ )
746
+
747
+ if cond_hidden_states is not None:
748
+ hidden_states = hidden_states[:, :length_raw_hidden_states]
749
+
750
+ # Use only the image part (hidden_states) from the dual-stream blocks
751
+ hidden_states = self.norm_out(hidden_states, temb)
752
+ output = self.proj_out(hidden_states)
753
+
754
+ if USE_PEFT_BACKEND:
755
+ # remove `lora_scale` from each PEFT layer
756
+ unscale_lora_layers(self, lora_scale)
757
+
758
+ if not return_dict:
759
+ return (output,)
760
+
761
+ return Transformer2DModelOutput(sample=output)
src/pipelines/pipeline_qwenimage_tagmoe.py ADDED
@@ -0,0 +1,1068 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import os
17
+ from typing import Any, Callable, Dict, List, Optional, Union
18
+ from PIL import Image
19
+
20
+ import numpy as np
21
+ import torch
22
+ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, AutoProcessor
23
+
24
+ from diffusers.image_processor import VaeImageProcessor
25
+ from diffusers.loaders import QwenImageLoraLoaderMixin
26
+ from diffusers.models import AutoencoderKLQwenImage
27
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
28
+ from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
31
+ from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput
32
+ from qwen_vl_utils import process_vision_info
33
+
34
+ from src.models.transformer_qwenimage_tagmoe import QwenImageTransformer2DModel
35
+
36
+ if is_torch_xla_available():
37
+ import torch_xla.core.xla_model as xm
38
+
39
+ XLA_AVAILABLE = True
40
+ else:
41
+ XLA_AVAILABLE = False
42
+
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+ EXAMPLE_DOC_STRING = """
47
+ Examples:
48
+ ```py
49
+ >>> import torch
50
+ >>> from diffusers import QwenImagePipeline
51
+
52
+ >>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
53
+ >>> pipe.to("cuda")
54
+ >>> prompt = "A cat holding a sign that says hello world"
55
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
56
+ >>> # Refer to the pipeline documentation for more details.
57
+ >>> image = pipe(prompt, num_inference_steps=50).images[0]
58
+ >>> image.save("qwenimage.png")
59
+ ```
60
+ """
61
+
62
+
63
+ def calculate_shift(
64
+ image_seq_len,
65
+ base_seq_len: int = 256,
66
+ max_seq_len: int = 4096,
67
+ base_shift: float = 0.5,
68
+ max_shift: float = 1.15,
69
+ ):
70
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
71
+ b = base_shift - m * base_seq_len
72
+ mu = image_seq_len * m + b
73
+ return mu
74
+
75
+
76
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
77
+ def retrieve_timesteps(
78
+ scheduler,
79
+ num_inference_steps: Optional[int] = None,
80
+ device: Optional[Union[str, torch.device]] = None,
81
+ timesteps: Optional[List[int]] = None,
82
+ sigmas: Optional[List[float]] = None,
83
+ **kwargs,
84
+ ):
85
+ r"""
86
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
87
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
88
+
89
+ Args:
90
+ scheduler (`SchedulerMixin`):
91
+ The scheduler to get timesteps from.
92
+ num_inference_steps (`int`):
93
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
94
+ must be `None`.
95
+ device (`str` or `torch.device`, *optional*):
96
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
97
+ timesteps (`List[int]`, *optional*):
98
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
99
+ `num_inference_steps` and `sigmas` must be `None`.
100
+ sigmas (`List[float]`, *optional*):
101
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
102
+ `num_inference_steps` and `timesteps` must be `None`.
103
+
104
+ Returns:
105
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
106
+ second element is the number of inference steps.
107
+ """
108
+ if timesteps is not None and sigmas is not None:
109
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
110
+ if timesteps is not None:
111
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
112
+ if not accepts_timesteps:
113
+ raise ValueError(
114
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
115
+ f" timestep schedules. Please check whether you are using the correct scheduler."
116
+ )
117
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
118
+ timesteps = scheduler.timesteps
119
+ num_inference_steps = len(timesteps)
120
+ elif sigmas is not None:
121
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
122
+ if not accept_sigmas:
123
+ raise ValueError(
124
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
125
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
126
+ )
127
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
128
+ timesteps = scheduler.timesteps
129
+ num_inference_steps = len(timesteps)
130
+ else:
131
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ return timesteps, num_inference_steps
134
+
135
+
136
+ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
137
+ r"""
138
+ The QwenImage pipeline for text-to-image generation.
139
+
140
+ Args:
141
+ transformer ([`QwenImageTransformer2DModel`]):
142
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
143
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
144
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
145
+ vae ([`AutoencoderKL`]):
146
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
147
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
148
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
149
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
150
+ tokenizer (`QwenTokenizer`):
151
+ Tokenizer of class
152
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
153
+ """
154
+
155
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
156
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
157
+
158
+ def __init__(
159
+ self,
160
+ scheduler: FlowMatchEulerDiscreteScheduler,
161
+ vae: AutoencoderKLQwenImage,
162
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
163
+ tokenizer: Qwen2Tokenizer,
164
+ transformer: QwenImageTransformer2DModel,
165
+ vlm_processor: AutoProcessor = None,
166
+ ):
167
+ super().__init__()
168
+
169
+ self.register_modules(
170
+ vae=vae,
171
+ text_encoder=text_encoder,
172
+ tokenizer=tokenizer,
173
+ transformer=transformer,
174
+ scheduler=scheduler,
175
+ vlm_processor=vlm_processor,
176
+ )
177
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
178
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
179
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
180
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
181
+ self.tokenizer_max_length = 1024
182
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
183
+ self.prompt_template_encode_start_idx = 34
184
+ self.default_sample_size = 128
185
+
186
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
187
+ bool_mask = mask.bool()
188
+ valid_lengths = bool_mask.sum(dim=1)
189
+ selected = hidden_states[bool_mask]
190
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
191
+
192
+ return split_result
193
+
194
+ @staticmethod
195
+ def _get_module_input_device(module):
196
+ device_map = getattr(module, "hf_device_map", None)
197
+ if device_map is not None:
198
+ for mapped_device in device_map.values():
199
+ if mapped_device in ("cpu", "disk"):
200
+ continue
201
+ if isinstance(mapped_device, int):
202
+ return torch.device("cuda", mapped_device)
203
+ return torch.device(mapped_device)
204
+ return next(module.parameters()).device
205
+
206
+ def _get_qwen_prompt_embeds(
207
+ self,
208
+ prompt: Union[str, List[str]] = None,
209
+ device: Optional[torch.device] = None,
210
+ dtype: Optional[torch.dtype] = None,
211
+ ):
212
+ device = device or self._execution_device
213
+ dtype = dtype or self.text_encoder.dtype
214
+
215
+ prompt = [prompt] if isinstance(prompt, str) else prompt
216
+
217
+ template = self.prompt_template_encode
218
+ drop_idx = self.prompt_template_encode_start_idx
219
+ txt = [template.format(e) for e in prompt]
220
+ text_encoder_device = self._get_module_input_device(self.text_encoder)
221
+ txt_tokens = self.tokenizer(
222
+ txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
223
+ ).to(text_encoder_device)
224
+ encoder_hidden_states = self.text_encoder(
225
+ input_ids=txt_tokens.input_ids,
226
+ attention_mask=txt_tokens.attention_mask,
227
+ output_hidden_states=True,
228
+ )
229
+ hidden_states = encoder_hidden_states.hidden_states[-1]
230
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
231
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
232
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
233
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
234
+ prompt_embeds = torch.stack(
235
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
236
+ )
237
+ encoder_attention_mask = torch.stack(
238
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
239
+ )
240
+
241
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
242
+
243
+ return prompt_embeds, encoder_attention_mask
244
+
245
+ def _get_qwenvl_prompt_embeds(
246
+ self,
247
+ prompt: Union[str, List[str]] = None,
248
+ device: Optional[torch.device] = None,
249
+ dtype: Optional[torch.dtype] = None,
250
+ image: Optional[Image.Image] = None,
251
+ ):
252
+ device = device or self._execution_device
253
+ dtype = dtype or self.text_encoder.dtype
254
+
255
+ prompt = [prompt] if isinstance(prompt, str) else prompt
256
+ assert len(prompt) == 1
257
+
258
+ template = self.prompt_template_encode
259
+ drop_idx = self.prompt_template_encode_start_idx
260
+
261
+ messages = [
262
+ {
263
+ "role": "system",
264
+ "content": [{"type": "text", "text": f"{template}"}],
265
+ },
266
+ {
267
+ "role": "user",
268
+ "content": []
269
+ }
270
+ ]
271
+
272
+ # 先添加所有的 image
273
+ # messages[0]["content"].extend([{"type": "image", "image": img} for img in image_list])
274
+ messages[1]['content'].append({"type": "image", "image": image})
275
+ # print(text)
276
+ # 再添加 text
277
+ messages[1]["content"].append({"type": "text", "text": f"{prompt[0]}"})
278
+
279
+ # Preparation for inference
280
+ text = self.vlm_processor.apply_chat_template(
281
+ messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
282
+ )
283
+
284
+ image_inputs, video_inputs = process_vision_info(messages)
285
+
286
+ kwargs = dict(truncation=True, padding=True, max_length=self.tokenizer_max_length + drop_idx + 374, return_tensors="pt")
287
+ txt_tokens = self.vlm_processor(
288
+ text=[text],
289
+ images=image_inputs,
290
+ **kwargs,
291
+ )
292
+
293
+ text_encoder_device = self._get_module_input_device(self.text_encoder)
294
+ encoder_hidden_states = self.text_encoder(
295
+ input_ids=txt_tokens.input_ids.to(text_encoder_device),
296
+ attention_mask=txt_tokens.attention_mask.to(text_encoder_device),
297
+ pixel_values=txt_tokens.pixel_values.to(text_encoder_device),
298
+ image_grid_thw=txt_tokens.image_grid_thw.to(text_encoder_device),
299
+ output_hidden_states=True
300
+ )
301
+
302
+
303
+
304
+ hidden_states = encoder_hidden_states.hidden_states[-1]
305
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
306
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
307
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
308
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
309
+ prompt_embeds = torch.stack(
310
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
311
+ )
312
+ encoder_attention_mask = torch.stack(
313
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
314
+ )
315
+
316
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
317
+
318
+ return prompt_embeds, encoder_attention_mask
319
+
320
+ def encode_prompt(
321
+ self,
322
+ prompt: Union[str, List[str]],
323
+ device: Optional[torch.device] = None,
324
+ num_images_per_prompt: int = 1,
325
+ prompt_embeds: Optional[torch.Tensor] = None,
326
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
327
+ max_sequence_length: int = 1024,
328
+ image=None,
329
+ ):
330
+ r"""
331
+
332
+ Args:
333
+ prompt (`str` or `List[str]`, *optional*):
334
+ prompt to be encoded
335
+ device: (`torch.device`):
336
+ torch device
337
+ num_images_per_prompt (`int`):
338
+ number of images that should be generated per prompt
339
+ prompt_embeds (`torch.Tensor`, *optional*):
340
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
341
+ provided, text embeddings will be generated from `prompt` input argument.
342
+ """
343
+ device = device or self._execution_device
344
+
345
+ prompt = [prompt] if isinstance(prompt, str) else prompt
346
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
347
+
348
+ if image is not None:
349
+ if self.vlm_processor is None:
350
+ raise ValueError(
351
+ "VLM processor is not initialized. Please make sure to pass a valid VLM processor to the pipeline."
352
+ )
353
+ prompt_embeds, prompt_embeds_mask = self._get_qwenvl_prompt_embeds(
354
+ prompt=prompt, device=device, dtype=self.text_encoder.dtype, image=image
355
+ )
356
+ elif prompt_embeds is None:
357
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
358
+
359
+ _, seq_len, _ = prompt_embeds.shape
360
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
361
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
362
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
363
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
364
+
365
+ return prompt_embeds, prompt_embeds_mask
366
+
367
+ def check_inputs(
368
+ self,
369
+ prompt,
370
+ height,
371
+ width,
372
+ negative_prompt=None,
373
+ prompt_embeds=None,
374
+ negative_prompt_embeds=None,
375
+ prompt_embeds_mask=None,
376
+ negative_prompt_embeds_mask=None,
377
+ callback_on_step_end_tensor_inputs=None,
378
+ max_sequence_length=None,
379
+ ):
380
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
381
+ logger.warning(
382
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
383
+ )
384
+
385
+ if callback_on_step_end_tensor_inputs is not None and not all(
386
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
387
+ ):
388
+ raise ValueError(
389
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
390
+ )
391
+
392
+ if prompt is not None and prompt_embeds is not None:
393
+ raise ValueError(
394
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
395
+ " only forward one of the two."
396
+ )
397
+ elif prompt is None and prompt_embeds is None:
398
+ raise ValueError(
399
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
400
+ )
401
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
402
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
403
+
404
+ if negative_prompt is not None and negative_prompt_embeds is not None:
405
+ raise ValueError(
406
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
407
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
408
+ )
409
+
410
+ if prompt_embeds is not None and prompt_embeds_mask is None:
411
+ raise ValueError(
412
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
413
+ )
414
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
415
+ raise ValueError(
416
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
417
+ )
418
+
419
+ if max_sequence_length is not None and max_sequence_length > 1024:
420
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
421
+
422
+ @staticmethod
423
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
424
+ latent_image_ids = torch.zeros(height, width, 3)
425
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
426
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
427
+
428
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
429
+
430
+ latent_image_ids = latent_image_ids.reshape(
431
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
432
+ )
433
+
434
+ return latent_image_ids.to(device=device, dtype=dtype)
435
+
436
+ @staticmethod
437
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
438
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
439
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
440
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
441
+
442
+ return latents
443
+
444
+ @staticmethod
445
+ def _unpack_latents(latents, height, width, vae_scale_factor):
446
+ batch_size, num_patches, channels = latents.shape
447
+
448
+ # VAE applies 8x compression on images but we must also account for packing which requires
449
+ # latent height and width to be divisible by 2.
450
+ height = 2 * (int(height) // (vae_scale_factor * 2))
451
+ width = 2 * (int(width) // (vae_scale_factor * 2))
452
+
453
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
454
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
455
+
456
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
457
+
458
+ return latents
459
+
460
+ def enable_vae_slicing(self):
461
+ r"""
462
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
463
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
464
+ """
465
+ self.vae.enable_slicing()
466
+
467
+ def disable_vae_slicing(self):
468
+ r"""
469
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
470
+ computing decoding in one step.
471
+ """
472
+ self.vae.disable_slicing()
473
+
474
+ def enable_vae_tiling(self):
475
+ r"""
476
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
477
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
478
+ processing larger images.
479
+ """
480
+ self.vae.enable_tiling()
481
+
482
+ def disable_vae_tiling(self):
483
+ r"""
484
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
485
+ computing decoding in one step.
486
+ """
487
+ self.vae.disable_tiling()
488
+
489
+ def prepare_latents(
490
+ self,
491
+ batch_size,
492
+ num_channels_latents,
493
+ height,
494
+ width,
495
+ dtype,
496
+ device,
497
+ generator,
498
+ latents=None,
499
+ ):
500
+ # VAE applies 8x compression on images but we must also account for packing which requires
501
+ # latent height and width to be divisible by 2.
502
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
503
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
504
+
505
+ shape = (batch_size, 1, num_channels_latents, height, width)
506
+
507
+ if latents is not None:
508
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
509
+ return latents.to(device=device, dtype=dtype), latent_image_ids
510
+
511
+ if isinstance(generator, list) and len(generator) != batch_size:
512
+ raise ValueError(
513
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
514
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
515
+ )
516
+
517
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
518
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
519
+
520
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
521
+
522
+ return latents, latent_image_ids
523
+
524
+ @staticmethod
525
+ def _candidate_index_names(weight_name: Optional[str]) -> List[str]:
526
+ candidate_names = []
527
+ if weight_name:
528
+ if weight_name.endswith(".index.json"):
529
+ candidate_names.append(weight_name)
530
+ else:
531
+ candidate_names.append(f"{weight_name}.index.json")
532
+
533
+ for default_name in (
534
+ "diffusion_pytorch_model.safetensors.index.json",
535
+ "diffusion_pytorch_model.bin.index.json",
536
+ ):
537
+ if default_name not in candidate_names:
538
+ candidate_names.append(default_name)
539
+ return candidate_names
540
+
541
+ @staticmethod
542
+ def _dedupe_paths(paths: List[str]) -> List[str]:
543
+ deduped_paths = []
544
+ seen = set()
545
+ for path in paths:
546
+ normalized = os.path.normpath(path)
547
+ if normalized in seen:
548
+ continue
549
+ deduped_paths.append(path)
550
+ seen.add(normalized)
551
+ return deduped_paths
552
+
553
+ def _resolve_custom_weights_files(
554
+ self,
555
+ weight_source: str,
556
+ weight_name: str = "diffusion_pytorch_model.safetensors",
557
+ subfolder: Optional[str] = "transformer",
558
+ cache_dir: Optional[str] = None,
559
+ revision: Optional[str] = None,
560
+ local_files_only: bool = False,
561
+ ) -> tuple[List[str], Optional[str]]:
562
+ from diffusers.utils.hub_utils import _get_checkpoint_shard_files, _get_model_file
563
+
564
+ if os.path.isfile(weight_source):
565
+ return [weight_source], None
566
+
567
+ index_name_candidates = self._candidate_index_names(weight_name)
568
+ normalized_subfolder = subfolder or ""
569
+
570
+ if os.path.isdir(weight_source):
571
+ candidate_paths: List[str] = []
572
+ if weight_name:
573
+ candidate_paths.append(os.path.join(weight_source, weight_name))
574
+ if subfolder and weight_name:
575
+ candidate_paths.append(os.path.join(weight_source, subfolder, weight_name))
576
+ candidate_paths = self._dedupe_paths(candidate_paths)
577
+ for candidate in candidate_paths:
578
+ if os.path.isfile(candidate):
579
+ return [candidate], None
580
+
581
+ candidate_index_paths: List[str] = []
582
+ for index_name in index_name_candidates:
583
+ candidate_index_paths.append(os.path.join(weight_source, index_name))
584
+ if subfolder:
585
+ candidate_index_paths.append(os.path.join(weight_source, subfolder, index_name))
586
+ candidate_index_paths = self._dedupe_paths(candidate_index_paths)
587
+
588
+ for index_path in candidate_index_paths:
589
+ if not os.path.isfile(index_path):
590
+ continue
591
+ shard_subfolder = os.path.relpath(os.path.dirname(index_path), weight_source)
592
+ if shard_subfolder == ".":
593
+ shard_subfolder = ""
594
+ shard_files, _ = _get_checkpoint_shard_files(
595
+ pretrained_model_name_or_path=weight_source,
596
+ index_filename=index_path,
597
+ subfolder=shard_subfolder,
598
+ local_files_only=True,
599
+ )
600
+ return shard_files, index_path
601
+
602
+ raise FileNotFoundError(
603
+ f"Cannot find transformer weights under directory '{weight_source}'. "
604
+ f"Tried files: {candidate_paths}. Tried index files: {candidate_index_paths}"
605
+ )
606
+
607
+ try:
608
+ resolved_file = _get_model_file(
609
+ pretrained_model_name_or_path=weight_source,
610
+ weights_name=weight_name,
611
+ subfolder=subfolder,
612
+ cache_dir=cache_dir,
613
+ local_files_only=local_files_only,
614
+ revision=revision,
615
+ )
616
+ return [resolved_file], None
617
+ except EnvironmentError as single_file_error:
618
+ for index_name in index_name_candidates:
619
+ try:
620
+ index_file = _get_model_file(
621
+ pretrained_model_name_or_path=weight_source,
622
+ weights_name=index_name,
623
+ subfolder=subfolder,
624
+ cache_dir=cache_dir,
625
+ local_files_only=local_files_only,
626
+ revision=revision,
627
+ )
628
+ except EnvironmentError:
629
+ continue
630
+
631
+ shard_files, _ = _get_checkpoint_shard_files(
632
+ pretrained_model_name_or_path=weight_source,
633
+ index_filename=index_file,
634
+ cache_dir=cache_dir,
635
+ local_files_only=local_files_only,
636
+ revision=revision,
637
+ subfolder=normalized_subfolder,
638
+ )
639
+ return shard_files, index_file
640
+
641
+ raise single_file_error
642
+
643
+ @staticmethod
644
+ def _unwrap_state_dict(checkpoint: Any) -> Dict[str, torch.Tensor]:
645
+ if not isinstance(checkpoint, dict):
646
+ return checkpoint
647
+
648
+ for key in ("model", "state_dict", "transformer"):
649
+ value = checkpoint.get(key)
650
+ if isinstance(value, dict):
651
+ return value
652
+ return checkpoint
653
+
654
+ def init_custom(
655
+ self,
656
+ weight_source: Optional[str],
657
+ weight_name: str = "diffusion_pytorch_model.safetensors",
658
+ subfolder: Optional[str] = "transformer",
659
+ cache_dir: Optional[str] = None,
660
+ revision: Optional[str] = None,
661
+ local_files_only: bool = False,
662
+ ):
663
+ if weight_source is None:
664
+ return
665
+
666
+ weights_files, index_file = self._resolve_custom_weights_files(
667
+ weight_source=weight_source,
668
+ weight_name=weight_name,
669
+ subfolder=subfolder,
670
+ cache_dir=cache_dir,
671
+ revision=revision,
672
+ local_files_only=local_files_only,
673
+ )
674
+
675
+ from safetensors.torch import load_file
676
+
677
+ all_unexpected_keys = []
678
+ for weights_file in weights_files:
679
+ if weights_file.endswith(".safetensors"):
680
+ model_weights = load_file(weights_file)
681
+ else:
682
+ try:
683
+ checkpoint = torch.load(weights_file, weights_only=True, map_location="cpu")
684
+ except TypeError:
685
+ checkpoint = torch.load(weights_file, map_location="cpu")
686
+ model_weights = self._unwrap_state_dict(checkpoint)
687
+
688
+ load_result = self.transformer.load_state_dict(model_weights, strict=False, assign=True)
689
+ if len(load_result.unexpected_keys) > 0:
690
+ all_unexpected_keys.extend(load_result.unexpected_keys)
691
+ del model_weights
692
+
693
+ if index_file is not None:
694
+ logger.info(f"Loaded transformer weights from {len(weights_files)} shards via index: {index_file}")
695
+
696
+ if len(all_unexpected_keys) > 0:
697
+ unique_unexpected_keys = list(dict.fromkeys(all_unexpected_keys))
698
+ logger.warning(f"Unexpected keys while loading transformer weights: {unique_unexpected_keys[:20]}")
699
+
700
+ @property
701
+ def guidance_scale(self):
702
+ return self._guidance_scale
703
+
704
+ @property
705
+ def attention_kwargs(self):
706
+ return self._attention_kwargs
707
+
708
+ @property
709
+ def num_timesteps(self):
710
+ return self._num_timesteps
711
+
712
+ @property
713
+ def current_timestep(self):
714
+ return self._current_timestep
715
+
716
+ @property
717
+ def interrupt(self):
718
+ return self._interrupt
719
+
720
+ @torch.no_grad()
721
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
722
+ def __call__(
723
+ self,
724
+ prompt: Union[str, List[str]] = None,
725
+ negative_prompt: Union[str, List[str]] = None,
726
+ true_cfg_scale: float = 4.0,
727
+ height: Optional[int] = None,
728
+ width: Optional[int] = None,
729
+ num_inference_steps: int = 50,
730
+ sigmas: Optional[List[float]] = None,
731
+ guidance_scale: float = 1.0,
732
+ num_images_per_prompt: int = 1,
733
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
734
+ latents: Optional[torch.Tensor] = None,
735
+ prompt_embeds: Optional[torch.Tensor] = None,
736
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
737
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
738
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
739
+ output_type: Optional[str] = "pil",
740
+ return_dict: bool = True,
741
+ attention_kwargs: Optional[Dict[str, Any]] = None,
742
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
743
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
744
+ max_sequence_length: int = 512,
745
+ cond_image = None,
746
+ cond_bbox = None,
747
+ use_vlm = False,
748
+ tag_embedding = None,
749
+ ):
750
+ r"""
751
+ Function invoked when calling the pipeline for generation.
752
+
753
+ Args:
754
+ prompt (`str` or `List[str]`, *optional*):
755
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
756
+ instead.
757
+ negative_prompt (`str` or `List[str]`, *optional*):
758
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
759
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
760
+ not greater than `1`).
761
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
762
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
763
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
764
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
765
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
766
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
767
+ num_inference_steps (`int`, *optional*, defaults to 50):
768
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
769
+ expense of slower inference.
770
+ sigmas (`List[float]`, *optional*):
771
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
772
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
773
+ will be used.
774
+ guidance_scale (`float`, *optional*, defaults to 3.5):
775
+ Guidance scale as defined in [Classifier-Free Diffusion
776
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
777
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
778
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
779
+ the text `prompt`, usually at the expense of lower image quality.
780
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
781
+ The number of images to generate per prompt.
782
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
783
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
784
+ to make generation deterministic.
785
+ latents (`torch.Tensor`, *optional*):
786
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
787
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
788
+ tensor will be generated by sampling using the supplied random `generator`.
789
+ prompt_embeds (`torch.Tensor`, *optional*):
790
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
791
+ provided, text embeddings will be generated from `prompt` input argument.
792
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
793
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
794
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
795
+ argument.
796
+ output_type (`str`, *optional*, defaults to `"pil"`):
797
+ The output format of the generate image. Choose between
798
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
799
+ return_dict (`bool`, *optional*, defaults to `True`):
800
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
801
+ attention_kwargs (`dict`, *optional*):
802
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
803
+ `self.processor` in
804
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
805
+ callback_on_step_end (`Callable`, *optional*):
806
+ A function that calls at the end of each denoising steps during the inference. The function is called
807
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
808
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
809
+ `callback_on_step_end_tensor_inputs`.
810
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
811
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
812
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
813
+ `._callback_tensor_inputs` attribute of your pipeline class.
814
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
815
+
816
+ Examples:
817
+
818
+ Returns:
819
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
820
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
821
+ returning a tuple, the first element is a list with the generated images.
822
+ """
823
+
824
+ height = height or self.default_sample_size * self.vae_scale_factor
825
+ width = width or self.default_sample_size * self.vae_scale_factor
826
+
827
+ # 1. Check inputs. Raise error if not correct
828
+ self.check_inputs(
829
+ prompt,
830
+ height,
831
+ width,
832
+ negative_prompt=negative_prompt,
833
+ prompt_embeds=prompt_embeds,
834
+ negative_prompt_embeds=negative_prompt_embeds,
835
+ prompt_embeds_mask=prompt_embeds_mask,
836
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
837
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
838
+ max_sequence_length=max_sequence_length,
839
+ )
840
+
841
+ self._guidance_scale = guidance_scale
842
+ self._attention_kwargs = attention_kwargs
843
+ self._current_timestep = None
844
+ self._interrupt = False
845
+
846
+ # 2. Define call parameters
847
+ if prompt is not None and isinstance(prompt, str):
848
+ batch_size = 1
849
+ elif prompt is not None and isinstance(prompt, list):
850
+ batch_size = len(prompt)
851
+ else:
852
+ batch_size = prompt_embeds.shape[0]
853
+
854
+ device = self._get_module_input_device(self.transformer)
855
+ dtype = self.transformer.dtype
856
+
857
+ has_neg_prompt = negative_prompt is not None or (
858
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
859
+ )
860
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
861
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
862
+ prompt=prompt,
863
+ prompt_embeds=prompt_embeds,
864
+ prompt_embeds_mask=prompt_embeds_mask,
865
+ device=device,
866
+ num_images_per_prompt=num_images_per_prompt,
867
+ max_sequence_length=max_sequence_length,
868
+ image=cond_image if use_vlm else None,
869
+ )
870
+ if do_true_cfg:
871
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
872
+ prompt=negative_prompt,
873
+ prompt_embeds=negative_prompt_embeds,
874
+ prompt_embeds_mask=negative_prompt_embeds_mask,
875
+ device=device,
876
+ num_images_per_prompt=num_images_per_prompt,
877
+ max_sequence_length=max_sequence_length,
878
+ image=cond_image if use_vlm else None,
879
+ )
880
+
881
+ # 4. Prepare latent variables
882
+ num_channels_latents = self.transformer.config.in_channels // 4
883
+ latents, latent_image_ids = self.prepare_latents(
884
+ batch_size * num_images_per_prompt,
885
+ num_channels_latents,
886
+ height,
887
+ width,
888
+ prompt_embeds.dtype,
889
+ device,
890
+ generator,
891
+ latents,
892
+ )
893
+ # print("============")
894
+ # print(height)
895
+ # print("============")
896
+ # print(width)
897
+ img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
898
+
899
+ # 4.1 cond_image
900
+ if cond_image is not None:
901
+ cond_image_latent = self.image_processor.preprocess(cond_image, height, width)
902
+ cond_image_latent = cond_image_latent.to(device, dtype=dtype)
903
+
904
+ cond_image_latent = self.vae.encode(cond_image_latent.to(dtype=self.vae.dtype)[:, :, None]).latent_dist.sample()
905
+ latents_mean = (
906
+ torch.tensor(self.vae.config.latents_mean)
907
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
908
+ .to(latents.device, latents.dtype)
909
+ )
910
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
911
+ latents.device, latents.dtype
912
+ )
913
+ cond_image_latent = (cond_image_latent - latents_mean) * latents_std
914
+ cond_image_latent = cond_image_latent.to(dtype=dtype)
915
+ height_cond_image_latent, width_cond_image_latent = cond_image_latent.shape[-2:]
916
+ if cond_bbox is None:
917
+ cond_image_latent = self._pack_latents(cond_image_latent, 1, 16, height_cond_image_latent, width_cond_image_latent)
918
+ else:
919
+ cond_image_latent = cond_image_latent.view(1, 16, height_cond_image_latent // 2, 2, width_cond_image_latent // 2, 2)
920
+ cond_image_latent = cond_image_latent.permute(0, 2, 4, 1, 3, 5)
921
+ x1, y1, x2, y2 = cond_bbox
922
+ cond_image_latent = cond_image_latent[:, y1:y2+1, x1:x2+1]
923
+ cond_image_latent = cond_image_latent.reshape(1, -1, 64)
924
+ else:
925
+ cond_image_latent = None
926
+
927
+ # 5. Prepare timesteps
928
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
929
+ image_seq_len = latents.shape[1]
930
+ mu = calculate_shift(
931
+ image_seq_len,
932
+ self.scheduler.config.get("base_image_seq_len", 256),
933
+ self.scheduler.config.get("max_image_seq_len", 4096),
934
+ self.scheduler.config.get("base_shift", 0.5),
935
+ self.scheduler.config.get("max_shift", 1.15),
936
+ )
937
+ timesteps, num_inference_steps = retrieve_timesteps(
938
+ self.scheduler,
939
+ num_inference_steps,
940
+ device,
941
+ sigmas=sigmas,
942
+ mu=mu,
943
+ )
944
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
945
+ self._num_timesteps = len(timesteps)
946
+
947
+ # handle guidance
948
+ if self.transformer.config.guidance_embeds:
949
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
950
+ guidance = guidance.expand(latents.shape[0])
951
+ else:
952
+ guidance = None
953
+
954
+ if self.attention_kwargs is None:
955
+ self._attention_kwargs = {}
956
+
957
+ # 6. Denoising loop
958
+ self.scheduler.set_begin_index(0)
959
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
960
+ for i, t in enumerate(timesteps):
961
+ if self.interrupt:
962
+ continue
963
+
964
+ self._current_timestep = t
965
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
966
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
967
+ condition_rotary_emb = None
968
+
969
+ if cond_image is not None:
970
+ if cond_bbox is None:
971
+ condition_rotary_emb = self.transformer.pos_embed.get_img_rope(
972
+ [(1, height_cond_image_latent // 2, width_cond_image_latent // 2)],
973
+ device=device,
974
+ frame_idx=1,
975
+ )
976
+
977
+ else:
978
+ condition_rotary_emb = self.transformer.pos_embed.get_img_rope_by_bbox([(1, height // 16, width // 16)], cond_bbox, device)
979
+
980
+ joint_attention_kwargs = dict()
981
+ else:
982
+ joint_attention_kwargs = dict()
983
+
984
+ with self.transformer.cache_context("cond"):
985
+ noise_pred = self.transformer(
986
+ hidden_states=latents,
987
+ timestep=timestep / 1000,
988
+ guidance=guidance,
989
+ encoder_hidden_states_mask=prompt_embeds_mask,
990
+ encoder_hidden_states=prompt_embeds,
991
+ img_shapes=img_shapes,
992
+ txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
993
+ attention_kwargs=joint_attention_kwargs,
994
+ cond_hidden_states=cond_image_latent,
995
+ cond_rope=condition_rotary_emb,
996
+ ).sample
997
+
998
+ if do_true_cfg:
999
+ with self.transformer.cache_context("uncond"):
1000
+ neg_noise_pred = self.transformer(
1001
+ hidden_states=latents,
1002
+ timestep=timestep / 1000,
1003
+ guidance=guidance,
1004
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
1005
+ encoder_hidden_states=negative_prompt_embeds,
1006
+ img_shapes=img_shapes,
1007
+ txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
1008
+ attention_kwargs=joint_attention_kwargs,
1009
+ cond_hidden_states=cond_image_latent,
1010
+ cond_rope=condition_rotary_emb,
1011
+ ).sample
1012
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
1013
+
1014
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
1015
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
1016
+ noise_pred = comb_pred * (cond_norm / noise_norm)
1017
+
1018
+ # compute the previous noisy sample x_t -> x_t-1
1019
+ latents_dtype = latents.dtype
1020
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1021
+
1022
+ if latents.dtype != latents_dtype:
1023
+ if torch.backends.mps.is_available():
1024
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1025
+ latents = latents.to(latents_dtype)
1026
+
1027
+ if callback_on_step_end is not None:
1028
+ callback_kwargs = {}
1029
+ for k in callback_on_step_end_tensor_inputs:
1030
+ callback_kwargs[k] = locals()[k]
1031
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1032
+
1033
+ latents = callback_outputs.pop("latents", latents)
1034
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1035
+
1036
+ # call the callback, if provided
1037
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1038
+ progress_bar.update()
1039
+
1040
+ if XLA_AVAILABLE:
1041
+ xm.mark_step()
1042
+
1043
+
1044
+ self._current_timestep = None
1045
+ if output_type == "latent":
1046
+ image = latents
1047
+ else:
1048
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1049
+ latents = latents.to(self.vae.dtype)
1050
+ latents_mean = (
1051
+ torch.tensor(self.vae.config.latents_mean)
1052
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
1053
+ .to(latents.device, latents.dtype)
1054
+ )
1055
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
1056
+ latents.device, latents.dtype
1057
+ )
1058
+ latents = latents / latents_std + latents_mean
1059
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
1060
+ image = self.image_processor.postprocess(image, output_type=output_type)
1061
+
1062
+ # Offload all models
1063
+ self.maybe_free_model_hooks()
1064
+
1065
+ if not return_dict:
1066
+ return (image,)
1067
+
1068
+ return QwenImagePipelineOutput(images=image)
src/utils/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.device_utils import (
2
+ build_accelerate_max_memory_map,
3
+ maybe_set_cuda_device_from_tensor,
4
+ parse_device_ids,
5
+ resolve_device_ids,
6
+ )
7
+ from src.utils.inference_config import (
8
+ DEFAULT_HEIGHT,
9
+ DEFAULT_NEGATIVE_PROMPT,
10
+ DEFAULT_NUM_INFERENCE_STEPS,
11
+ DEFAULT_SEED,
12
+ DEFAULT_TRUE_CFG_SCALE,
13
+ DEFAULT_WIDTH,
14
+ generate_random_seed,
15
+ normalize_negative_prompt,
16
+ )
17
+
18
+ __all__ = [
19
+ "DEFAULT_HEIGHT",
20
+ "DEFAULT_NEGATIVE_PROMPT",
21
+ "DEFAULT_NUM_INFERENCE_STEPS",
22
+ "DEFAULT_SEED",
23
+ "DEFAULT_TRUE_CFG_SCALE",
24
+ "DEFAULT_WIDTH",
25
+ "build_accelerate_max_memory_map",
26
+ "generate_random_seed",
27
+ "maybe_set_cuda_device_from_tensor",
28
+ "normalize_negative_prompt",
29
+ "parse_device_ids",
30
+ "resolve_device_ids",
31
+ ]
src/utils/device_utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, Iterable, List, Mapping
4
+
5
+ import torch
6
+
7
+
8
+ def resolve_device_ids(device_arg: str | None) -> list[int] | None:
9
+ """Validate a user-provided device spec and return a list of GPU ids.
10
+
11
+ Returns ``None`` when *device_arg* is ``None`` (meaning "use framework
12
+ default"), an empty list for CPU, or a list of validated GPU indices.
13
+ """
14
+ if device_arg is None:
15
+ return None
16
+
17
+ device_ids = parse_device_ids(device_arg)
18
+
19
+ import torch as _torch
20
+
21
+ if len(device_ids) > 0 and not _torch.cuda.is_available():
22
+ raise ValueError("CUDA is not available, but GPU device ids were provided.")
23
+ if len(device_ids) == 0:
24
+ return []
25
+
26
+ device_count = _torch.cuda.device_count()
27
+ invalid_ids = [idx for idx in device_ids if idx < 0 or idx >= device_count]
28
+ if invalid_ids:
29
+ raise ValueError(
30
+ f"Invalid GPU ids {invalid_ids}. Available GPU ids: 0..{device_count - 1}."
31
+ )
32
+ return device_ids
33
+
34
+
35
+ def parse_device_ids(device_arg: str) -> List[int]:
36
+ value = device_arg.strip().lower()
37
+ if not value:
38
+ raise ValueError("Device argument is empty.")
39
+ if value in {"cpu", "-1"}:
40
+ return []
41
+
42
+ device_ids = []
43
+ for part in value.split(","):
44
+ token = part.strip()
45
+ if not token:
46
+ raise ValueError(f"Invalid device list: {device_arg!r}")
47
+ device_ids.append(int(token))
48
+ return device_ids
49
+
50
+
51
+ def build_accelerate_max_memory_map(
52
+ device_ids: Iterable[int],
53
+ free_bytes_by_device: Mapping[int, int],
54
+ reserve_bytes: int = 2 * 1024**3,
55
+ ) -> Dict[int, str]:
56
+ max_memory: Dict[int, str] = {}
57
+ for device_id in device_ids:
58
+ if device_id not in free_bytes_by_device:
59
+ raise ValueError(f"Missing free memory info for device {device_id}.")
60
+ free_bytes = free_bytes_by_device[device_id]
61
+ usable_gib = max(int((free_bytes - reserve_bytes) / (1024**3)), 4)
62
+ max_memory[device_id] = f"{usable_gib}GiB"
63
+ return max_memory
64
+
65
+
66
+ def maybe_set_cuda_device_from_tensor(tensor) -> None:
67
+ if tensor is None:
68
+ return
69
+ if not torch.cuda.is_available():
70
+ return
71
+ if not getattr(tensor, "is_cuda", False):
72
+ return
73
+
74
+ device = getattr(tensor, "device", None)
75
+ device_index = getattr(device, "index", None)
76
+ if device_index is None:
77
+ return
78
+ if torch.cuda.current_device() == device_index:
79
+ return
80
+ torch.cuda.set_device(device_index)
src/utils/inference_config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+
4
+ DEFAULT_WIDTH = 1024
5
+ DEFAULT_HEIGHT = 1024
6
+ DEFAULT_SEED = -1
7
+ DEFAULT_TRUE_CFG_SCALE = 4.0
8
+ DEFAULT_NUM_INFERENCE_STEPS = 30
9
+ DEFAULT_NEGATIVE_PROMPT = ""
10
+
11
+
12
+ def normalize_negative_prompt(value: str | None) -> str:
13
+ if value is None or not str(value).strip():
14
+ return " "
15
+ return str(value)
16
+
17
+
18
+ def generate_random_seed() -> int:
19
+ return random.randint(0, 2**32 - 1)