OrlandoHugBot commited on
Commit
1cb89a8
·
verified ·
1 Parent(s): 6b0f5ff

Upload 4 files

Browse files
Files changed (4) hide show
  1. SETUP.md +52 -0
  2. app.py +313 -131
  3. pipeline_qwenimage_edit.py +910 -0
  4. requirements.txt +21 -6
SETUP.md ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Space Setup Guide
2
+
3
+ ## Quick Start
4
+
5
+ 1. **Copy necessary files to this directory:**
6
+
7
+ ```bash
8
+ # Copy pipeline file (if QwenImageEditPipeline is not in diffusers yet)
9
+ cp ../UniPic/UniPic-3/qwen_image_edit_fast/pipeline_qwenimage_edit.py .
10
+ ```
11
+
12
+ 2. **Upload to Hugging Face Space:**
13
+
14
+ ```bash
15
+ # Install huggingface_hub if needed
16
+ pip install huggingface_hub
17
+
18
+ # Login
19
+ huggingface-cli login
20
+
21
+ # Create a new Space
22
+ huggingface-cli repo create your-space-name --type space
23
+
24
+ # Upload files
25
+ cd /path/to/gradio
26
+ huggingface-cli upload your-space-name app.py requirements.txt README.md
27
+ # If needed, also upload pipeline_qwenimage_edit.py
28
+ ```
29
+
30
+ ## Files Structure
31
+
32
+ ```
33
+ gradio_demo/
34
+ ├── app.py # Main Gradio application (required)
35
+ ├── requirements.txt # Python dependencies (required)
36
+ ├── README.md # Space description (required)
37
+ └── pipeline_qwenimage_edit.py # Pipeline code (if not in diffusers)
38
+ ```
39
+
40
+ ## Environment Variables (Optional)
41
+
42
+ You can set these in HF Space settings:
43
+
44
+ - `MODEL_NAME`: Base model name (default: "Qwen-Image-Edit")
45
+ - `TRANSFORMER_PATH`: DMD model path (default: "Skywork/Unipic3-DMD/ema_transformer")
46
+
47
+ ## Notes
48
+
49
+ - The pipeline will try to import from `diffusers` first
50
+ - If not available, it will fallback to local `pipeline_qwenimage_edit.py`
51
+ - Make sure to copy the pipeline file if QwenImageEditPipeline is not in your diffusers version
52
+
app.py CHANGED
@@ -1,154 +1,336 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
 
 
 
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
 
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
 
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
 
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
 
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
 
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
 
 
126
  )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
 
 
 
 
130
  minimum=1,
131
- maximum=50,
 
132
  step=1,
133
- value=2, # Replace with defaults that work for your model
 
134
  )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
 
 
153
  if __name__ == "__main__":
154
  demo.launch()
 
 
1
+ """
2
+ Gradio Demo for UniPic-3 DMD Multi-Image Composition
3
+ Hugging Face Space compatible version
4
+
5
+ Upload up to 6 images and generate a composed result using DMD model with 4-step inference.
6
+ """
7
 
8
+ import gradio as gr
 
9
  import torch
10
+ from PIL import Image
11
+ import os
12
 
13
+ # Use local pipeline to ensure compatibility
14
+ import sys
15
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16
+ try:
17
+ from pipeline_qwenimage_edit import QwenImageEditPipeline
18
+ except ImportError:
19
+ # Fallback to diffusers if local not available
20
+ try:
21
+ from diffusers import QwenImageEditPipeline
22
+ except ImportError:
23
+ raise ImportError(
24
+ "QwenImageEditPipeline not found. Please ensure pipeline_qwenimage_edit.py "
25
+ "is in the same directory or diffusers is installed."
26
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ from diffusers import FlowMatchEulerDiscreteScheduler, QwenImageTransformer2DModel, AutoencoderKLQwenImage
29
+ from transformers import AutoModel, AutoTokenizer, Qwen2VLProcessor
 
30
 
 
 
 
 
 
 
 
 
31
 
32
+ # Global pipeline
33
+ pipe = None
34
 
35
+ # Model paths (can be set via environment variables)
36
+ # export HF_ENDPOINT=https://hf-mirror.com
37
 
38
+ MODEL_NAME = os.environ.get("MODEL_NAME", "/data_genie/genie/chris/Qwen-Image-Edit")
39
+ # Default to local path if exists, otherwise use HuggingFace
40
+ default_transformer = "/data_genie/genie/chris/unipic3_ckpt/dmd/ema_transformer" if os.path.exists("/data_genie/genie/chris/unipic3_ckpt/dmd/ema_transformer") else "Skywork/Unipic3-DMD"
41
+ TRANSFORMER_PATH = os.environ.get("TRANSFORMER_PATH", default_transformer)
 
 
 
42
 
43
+
44
+ def load_model():
45
+ """Load the DMD model and pipeline"""
46
+ global pipe
47
+
48
+ if pipe is not None:
49
+ return pipe
50
+
51
+ print(f"Loading model from {TRANSFORMER_PATH}...")
52
+
53
+ # Load scheduler
54
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
55
+ pretrained_model_name_or_path=MODEL_NAME, subfolder='scheduler'
56
+ )
57
+
58
+ # Load text encoder
59
+ text_encoder = AutoModel.from_pretrained(
60
+ pretrained_model_name_or_path=MODEL_NAME, subfolder='text_encoder',
61
+ device_map='auto', torch_dtype=torch.bfloat16
62
+ )
63
+
64
+ # Load tokenizer and processor
65
+ tokenizer = AutoTokenizer.from_pretrained(
66
+ pretrained_model_name_or_path=MODEL_NAME, subfolder='tokenizer',
67
+ )
68
+ processor = Qwen2VLProcessor.from_pretrained(
69
+ pretrained_model_name_or_path=MODEL_NAME, subfolder='processor',
70
+ )
71
+
72
+ # Load transformer (DMD model)
73
+ # Handle both local paths and HuggingFace repo paths
74
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
+
76
+ if os.path.exists(TRANSFORMER_PATH):
77
+ # Local path - load directly to device (avoid device_map issues with .bin files)
78
+ if os.path.isdir(TRANSFORMER_PATH):
79
+ # Check if it's a direct transformer directory or has subfolder
80
+ if os.path.exists(os.path.join(TRANSFORMER_PATH, "config.json")):
81
+ transformer = QwenImageTransformer2DModel.from_pretrained(
82
+ pretrained_model_name_or_path=TRANSFORMER_PATH,
83
+ torch_dtype=torch.bfloat16,
84
+ use_safetensors=False # Use .bin file
85
+ ).to(device)
86
+ else:
87
+ transformer = QwenImageTransformer2DModel.from_pretrained(
88
+ pretrained_model_name_or_path=TRANSFORMER_PATH,
89
+ subfolder='transformer',
90
+ torch_dtype=torch.bfloat16,
91
+ use_safetensors=False
92
+ ).to(device)
93
+ else:
94
+ raise ValueError(f"Transformer path does not exist: {TRANSFORMER_PATH}")
95
+ else:
96
+ # HuggingFace repo path
97
+ # Handle paths like "Skywork/Unipic3-DMD/ema_transformer"
98
+ path_parts = TRANSFORMER_PATH.split('/')
99
+ if len(path_parts) >= 3:
100
+ # Has subfolder: "Skywork/Unipic3-DMD/ema_transformer"
101
+ repo_id = '/'.join(path_parts[:2]) # "Skywork/Unipic3-DMD"
102
+ subfolder = path_parts[2] # "ema_transformer"
103
+ transformer = QwenImageTransformer2DModel.from_pretrained(
104
+ pretrained_model_name_or_path=repo_id,
105
+ subfolder=subfolder,
106
+ device_map='auto',
107
+ torch_dtype=torch.bfloat16
108
+ )
109
+ elif len(path_parts) == 2:
110
+ # Just repo: "Skywork/Unipic3-DMD"
111
+ transformer = QwenImageTransformer2DModel.from_pretrained(
112
+ pretrained_model_name_or_path=TRANSFORMER_PATH,
113
+ subfolder='transformer',
114
+ device_map='auto',
115
+ torch_dtype=torch.bfloat16
116
  )
117
+ else:
118
+ # Single name, assume it's a repo ID
119
+ transformer = QwenImageTransformer2DModel.from_pretrained(
120
+ pretrained_model_name_or_path=TRANSFORMER_PATH,
121
+ subfolder='transformer',
122
+ device_map='auto',
123
+ torch_dtype=torch.bfloat16
124
+ )
125
+
126
+ # Load VAE
127
+ # Get device from transformer (handle both .device and device_map cases)
128
+ if hasattr(transformer, 'device'):
129
+ vae_device = transformer.device
130
+ elif hasattr(transformer, 'hf_device_map'):
131
+ # If using device_map, get the first device
132
+ vae_device = device
133
+ else:
134
+ vae_device = device
135
+
136
+ vae = AutoencoderKLQwenImage.from_pretrained(
137
+ pretrained_model_name_or_path=MODEL_NAME,
138
+ subfolder='vae',
139
+ torch_dtype=torch.bfloat16,
140
+ ).to(vae_device)
141
+
142
+ # Create pipeline
143
+ pipe = QwenImageEditPipeline(
144
+ scheduler=scheduler,
145
+ vae=vae,
146
+ text_encoder=text_encoder,
147
+ tokenizer=tokenizer,
148
+ processor=processor,
149
+ transformer=transformer
150
+ )
151
+
152
+ print("Model loaded successfully!")
153
+ # Don't return pipe for demo.load() - it expects no return value
154
 
 
155
 
156
+ def process_images(
157
+ img1, img2, img3, img4, img5, img6,
158
+ prompt: str,
159
+ true_cfg_scale: float = 4.0,
160
+ seed: int = 42,
161
+ num_steps: int = 4
162
+ ) -> tuple:
163
+ """Process multiple images and generate composed result"""
164
+ global pipe
165
+
166
+ # Ensure model is loaded (should be loaded by demo.load() on startup)
167
+ if pipe is None:
168
+ return None, "⏳ Model is still loading, please wait a moment and try again..."
169
+
170
+ # Filter out None images
171
+ images = [img for img in [img1, img2, img3, img4, img5, img6] if img is not None]
172
+
173
+ # Validate inputs
174
+ if len(images) == 0:
175
+ return None, "❌ Error: Please upload at least one image."
176
+
177
+ if len(images) > 6:
178
+ return None, f"❌ Error: Maximum 6 images allowed. You uploaded {len(images)} images."
179
+
180
+ if not prompt or prompt.strip() == "":
181
+ return None, "❌ Error: Please enter an editing instruction."
182
+
183
+ try:
184
+ # Convert to RGB
185
+ images = [img.convert("RGB") for img in images]
186
+
187
+ print(f"Processing {len(images)} images with prompt: '{prompt}'")
188
+ print(f"Steps: {num_steps}, CFG Scale: {true_cfg_scale}, Seed: {seed}")
189
+
190
+ # Generate image
191
+ # Note: images can be passed as first positional argument or as keyword argument
192
+ with torch.no_grad():
193
+ # Try positional argument first (as shown in pipeline examples)
194
+ if len(images) == 1:
195
+ # Single image: pass as first positional argument
196
+ result = pipe(
197
+ images[0],
198
+ prompt=prompt,
199
+ height=1024,
200
+ width=1024,
201
+ negative_prompt=' ',
202
+ num_inference_steps=num_steps,
203
+ true_cfg_scale=true_cfg_scale,
204
+ generator=torch.manual_seed(int(seed))
205
+ ).images[0]
206
+ else:
207
+ # Multiple images: pass as keyword argument
208
+ result = pipe(
209
+ images=images,
210
+ prompt=prompt,
211
+ height=1024,
212
+ width=1024,
213
+ negative_prompt=' ',
214
+ num_inference_steps=num_steps,
215
+ true_cfg_scale=true_cfg_scale,
216
+ generator=torch.manual_seed(int(seed))
217
+ ).images[0]
218
+
219
+ return result, f"✅ Success! Generated from {len(images)} image(s) in {num_steps} steps."
220
+
221
+ except Exception as e:
222
+ error_msg = f"❌ Error: {str(e)}"
223
+ print(error_msg)
224
+ import traceback
225
+ traceback.print_exc()
226
+ return None, error_msg
227
 
 
 
 
 
 
 
 
228
 
229
+ # Create Gradio interface
230
+ with gr.Blocks(title="UniPic-3 DMD Multi-Image Composition", theme=gr.themes.Soft()) as demo:
231
+ gr.Markdown("""
232
+ # 🔥 UniPic-3 DMD Multi-Image Composition
233
+
234
+ Upload up to **6 images** and provide an editing instruction to generate a composed result.
235
+
236
+ **Model**: DMD (Distribution-Matching Distillation) - **4-step fast inference (12.5× speedup)**
237
+
238
+ **Features**:
239
+ - Support 1-6 input images
240
+ - Fast 4-step inference
241
+ - High-quality multi-image composition
242
+ """)
243
+
244
+ with gr.Row():
245
+ with gr.Column(scale=1):
246
+ gr.Markdown("### 📸 Upload Images (1-6 images)")
247
+ image_inputs = [
248
+ gr.Image(type="pil", label=f"Image {i+1}", visible=(i < 2))
249
+ for i in range(6)
250
+ ]
251
+
252
+ num_images = gr.Slider(
253
+ minimum=1,
254
+ maximum=6,
255
+ value=2,
256
+ step=1,
257
+ label="Number of Images",
258
+ info="Select how many images you want to upload"
259
+ )
260
+
261
+ def update_image_visibility(num):
262
+ return [gr.update(visible=(i < num)) for i in range(6)]
263
+
264
+ num_images.change(
265
+ fn=update_image_visibility,
266
+ inputs=num_images,
267
+ outputs=image_inputs
268
+ )
269
+
270
+ gr.Markdown("### ✍️ Editing Instruction")
271
+ prompt_input = gr.Textbox(
272
+ label="Prompt",
273
+ placeholder="e.g., A man from Image1 is standing on a surfboard from Image2, riding the ocean waves under a bright blue sky.",
274
+ lines=3,
275
+ value="Combine the reference images to generate the final result."
276
+ )
277
+
278
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
279
+ cfg_scale = gr.Slider(
280
+ minimum=1.0,
281
  maximum=10.0,
282
+ value=4.0,
283
+ step=0.5,
284
+ label="CFG Scale",
285
+ info="Higher values make the output more aligned with the prompt"
286
  )
287
+ seed = gr.Number(
288
+ value=42,
289
+ label="Seed",
290
+ info="Random seed for reproducibility",
291
+ precision=0
292
+ )
293
+ num_steps = gr.Slider(
294
  minimum=1,
295
+ maximum=8,
296
+ value=8,
297
  step=1,
298
+ label="Inference Steps",
299
+ info="Number of denoising steps (8 is recommended for DMD)"
300
  )
301
+
302
+ generate_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
303
+
304
+ with gr.Column(scale=1):
305
+ gr.Markdown("### 🎨 Generated Result")
306
+ output_image = gr.Image(type="pil", label="Output Image")
307
+ status_text = gr.Textbox(
308
+ label="Status",
309
+ value="Ready. Upload images and enter a prompt, then click Generate.",
310
+ interactive=False
311
+ )
312
+
313
+ # Load model on startup
314
+ def load_model_wrapper():
315
+ """Wrapper to load model without returning value"""
316
+ load_model()
317
+ return None
318
+
319
+ demo.load(
320
+ fn=load_model_wrapper,
321
+ inputs=[],
322
+ outputs=[],
323
+ show_progress=True
324
+ )
325
+
326
+ # Generate button
327
+ generate_btn.click(
328
+ fn=process_images,
329
+ inputs=[*image_inputs, prompt_input, cfg_scale, seed, num_steps],
330
+ outputs=[output_image, status_text]
331
  )
332
 
333
+
334
  if __name__ == "__main__":
335
  demo.launch()
336
+
pipeline_qwenimage_edit.py ADDED
@@ -0,0 +1,910 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import math
17
+ from typing import Any, Callable, Dict, List, Optional, Union
18
+
19
+ from PIL import Image
20
+ import numpy as np
21
+ import torch
22
+ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
23
+
24
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
25
+ from diffusers.loaders import QwenImageLoraLoaderMixin
26
+ from diffusers.models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
27
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
28
+ from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
31
+ from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput
32
+
33
+
34
+ if is_torch_xla_available():
35
+ import torch_xla.core.xla_model as xm
36
+
37
+ XLA_AVAILABLE = True
38
+ else:
39
+ XLA_AVAILABLE = False
40
+
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+ EXAMPLE_DOC_STRING = """
45
+ Examples:
46
+ ```py
47
+ >>> import torch
48
+ >>> from PIL import Image
49
+ >>> from diffusers import QwenImageEditPipeline
50
+ >>> from diffusers.utils import load_image
51
+
52
+ >>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16)
53
+ >>> pipe.to("cuda")
54
+ >>> image = load_image(
55
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
56
+ ... ).convert("RGB")
57
+ >>> prompt = (
58
+ ... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors"
59
+ ... )
60
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
61
+ >>> # Refer to the pipeline documentation for more details.
62
+ >>> image = pipe(image, prompt, num_inference_steps=50).images[0]
63
+ >>> image.save("qwenimage_edit.png")
64
+ ```
65
+ """
66
+
67
+
68
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
69
+ def calculate_shift(
70
+ image_seq_len,
71
+ base_seq_len: int = 256,
72
+ max_seq_len: int = 4096,
73
+ base_shift: float = 0.5,
74
+ max_shift: float = 1.15,
75
+ ):
76
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
77
+ b = base_shift - m * base_seq_len
78
+ mu = image_seq_len * m + b
79
+ return mu
80
+
81
+
82
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
83
+ def retrieve_timesteps(
84
+ scheduler,
85
+ num_inference_steps: Optional[int] = None,
86
+ device: Optional[Union[str, torch.device]] = None,
87
+ timesteps: Optional[List[int]] = None,
88
+ sigmas: Optional[List[float]] = None,
89
+ **kwargs,
90
+ ):
91
+ r"""
92
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
93
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
94
+
95
+ Args:
96
+ scheduler (`SchedulerMixin`):
97
+ The scheduler to get timesteps from.
98
+ num_inference_steps (`int`):
99
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
100
+ must be `None`.
101
+ device (`str` or `torch.device`, *optional*):
102
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
103
+ timesteps (`List[int]`, *optional*):
104
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
105
+ `num_inference_steps` and `sigmas` must be `None`.
106
+ sigmas (`List[float]`, *optional*):
107
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
108
+ `num_inference_steps` and `timesteps` must be `None`.
109
+
110
+ Returns:
111
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
112
+ second element is the number of inference steps.
113
+ """
114
+ if timesteps is not None and sigmas is not None:
115
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
116
+ if timesteps is not None:
117
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
118
+ if not accepts_timesteps:
119
+ raise ValueError(
120
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
121
+ f" timestep schedules. Please check whether you are using the correct scheduler."
122
+ )
123
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
124
+ timesteps = scheduler.timesteps
125
+ num_inference_steps = len(timesteps)
126
+ elif sigmas is not None:
127
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
128
+ if not accept_sigmas:
129
+ raise ValueError(
130
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
131
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
132
+ )
133
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
134
+ timesteps = scheduler.timesteps
135
+ num_inference_steps = len(timesteps)
136
+ else:
137
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
138
+ timesteps = scheduler.timesteps
139
+ return timesteps, num_inference_steps
140
+
141
+
142
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
143
+ def retrieve_latents(
144
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
145
+ ):
146
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
147
+ return encoder_output.latent_dist.sample(generator)
148
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
149
+ return encoder_output.latent_dist.mode()
150
+ elif hasattr(encoder_output, "latents"):
151
+ return encoder_output.latents
152
+ else:
153
+ raise AttributeError("Could not access latents of provided encoder_output")
154
+
155
+
156
+ def calculate_dimensions(target_area, ratio):
157
+ width = math.sqrt(target_area * ratio)
158
+ height = width / ratio
159
+
160
+ width = round(width / 32) * 32
161
+ height = round(height / 32) * 32
162
+
163
+ return width, height, None
164
+
165
+
166
+ def resize_to_multiple_of(image, multiple_of=32):
167
+ width, height = image.size
168
+ width = round(width / multiple_of) * multiple_of
169
+ height = round(height / multiple_of) * multiple_of
170
+
171
+ image = image.resize((width, height))
172
+
173
+ return image
174
+
175
+
176
+
177
+ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
178
+ r"""
179
+ The Qwen-Image-Edit pipeline for image editing.
180
+
181
+ Args:
182
+ transformer ([`QwenImageTransformer2DModel`]):
183
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
184
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
185
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
186
+ vae ([`AutoencoderKL`]):
187
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
188
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
189
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
190
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
191
+ tokenizer (`QwenTokenizer`):
192
+ Tokenizer of class
193
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
194
+ """
195
+
196
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
197
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
198
+
199
+ def __init__(
200
+ self,
201
+ scheduler: FlowMatchEulerDiscreteScheduler,
202
+ vae: AutoencoderKLQwenImage,
203
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
204
+ tokenizer: Qwen2Tokenizer,
205
+ processor: Qwen2VLProcessor,
206
+ transformer: QwenImageTransformer2DModel,
207
+ ):
208
+ super().__init__()
209
+
210
+ self.register_modules(
211
+ vae=vae,
212
+ text_encoder=text_encoder,
213
+ tokenizer=tokenizer,
214
+ processor=processor,
215
+ transformer=transformer,
216
+ scheduler=scheduler,
217
+ )
218
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
219
+ self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
220
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
221
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
222
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
223
+ self.vl_processor = processor
224
+ self.tokenizer_max_length = 1024
225
+
226
+ self.system_message = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
227
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
228
+ self.prompt_template_encode_start_idx = 64
229
+ self.default_sample_size = 128
230
+
231
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
232
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
233
+ bool_mask = mask.bool()
234
+ valid_lengths = bool_mask.sum(dim=1)
235
+ selected = hidden_states[bool_mask]
236
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
237
+
238
+ return split_result
239
+
240
+ def _get_qwen_prompt_embeds(
241
+ self,
242
+ prompts: Union[str, List[str]] = None,
243
+ images: List[List[Image.Image]] = None,
244
+ device: Optional[torch.device] = None,
245
+ dtype: Optional[torch.dtype] = None,
246
+ ):
247
+ device = device or self._execution_device
248
+ dtype = dtype or self.text_encoder.dtype
249
+
250
+ prompts = [prompts] if isinstance(prompts, str) else prompts
251
+
252
+ if isinstance(images, Image.Image):
253
+ images = [[images], ]
254
+ elif isinstance(images[0], Image.Image):
255
+ images = [images, ]
256
+ assert len(prompts) == len(images)
257
+
258
+ texts = []
259
+
260
+ for prompt, image_list in zip(prompts, images):
261
+ messages = [
262
+ {
263
+ "role": "system",
264
+ "content": self.system_message,},
265
+ {
266
+ "role": "user",
267
+ "content": [{"type": "image", "image": image} for image in image_list]
268
+ + [{"type": "text", "text": prompt}, ],
269
+ },
270
+ ]
271
+
272
+ # Apply chat template
273
+ text = self.processor.apply_chat_template(
274
+ messages,
275
+ tokenize=False,
276
+ add_generation_prompt=True
277
+ )
278
+ texts.append(text)
279
+
280
+ # Process inputs
281
+ model_inputs = self.processor(
282
+ text=texts,
283
+ images=images,
284
+ do_resize=False, # already resized
285
+ padding=True,
286
+ return_tensors="pt"
287
+ ).to(self.device)
288
+
289
+ # template = self.prompt_template_encode
290
+ drop_idx = self.prompt_template_encode_start_idx
291
+ # txt = [template.format(e) for e in prompt]
292
+
293
+ # model_inputs = self.processor(
294
+ # text=txt,
295
+ # images=image,
296
+ # padding=True,
297
+ # return_tensors="pt",
298
+ # ).to(device)
299
+
300
+ outputs = self.text_encoder(
301
+ input_ids=model_inputs.input_ids,
302
+ attention_mask=model_inputs.attention_mask,
303
+ pixel_values=model_inputs.pixel_values,
304
+ image_grid_thw=model_inputs.image_grid_thw,
305
+ output_hidden_states=True,
306
+ )
307
+ # import pdb; pdb.set_trace()
308
+
309
+ hidden_states = outputs.hidden_states[-1]
310
+ split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
311
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
312
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
313
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
314
+ prompt_embeds = torch.stack(
315
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
316
+ )
317
+ encoder_attention_mask = torch.stack(
318
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
319
+ )
320
+
321
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
322
+
323
+ return prompt_embeds, encoder_attention_mask
324
+
325
+ def encode_prompt(
326
+ self,
327
+ prompt: Union[str, List[str]],
328
+ images: List[Image.Image] = None,
329
+ device: Optional[torch.device] = None,
330
+ num_images_per_prompt: int = 1,
331
+ prompt_embeds: Optional[torch.Tensor] = None,
332
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
333
+ max_sequence_length: int = 1024,
334
+ ):
335
+ r"""
336
+
337
+ Args:
338
+ prompt (`str` or `List[str]`, *optional*):
339
+ prompt to be encoded
340
+ images (`List[Image.Image]`, *optional*):
341
+ images to be encoded
342
+ device: (`torch.device`):
343
+ torch device
344
+ num_images_per_prompt (`int`):
345
+ number of images that should be generated per prompt
346
+ prompt_embeds (`torch.Tensor`, *optional*):
347
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
348
+ provided, text embeddings will be generated from `prompt` input argument.
349
+ """
350
+ device = device or self._execution_device
351
+
352
+ prompt = [prompt] if isinstance(prompt, str) else prompt
353
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
354
+
355
+ if prompt_embeds is None:
356
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, images, device)
357
+
358
+ _, seq_len, _ = prompt_embeds.shape
359
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
360
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
361
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
362
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
363
+
364
+ return prompt_embeds, prompt_embeds_mask
365
+
366
+ def check_inputs(
367
+ self,
368
+ prompt,
369
+ height,
370
+ width,
371
+ negative_prompt=None,
372
+ prompt_embeds=None,
373
+ negative_prompt_embeds=None,
374
+ prompt_embeds_mask=None,
375
+ negative_prompt_embeds_mask=None,
376
+ callback_on_step_end_tensor_inputs=None,
377
+ max_sequence_length=None,
378
+ ):
379
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
380
+ logger.warning(
381
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
382
+ )
383
+
384
+ if callback_on_step_end_tensor_inputs is not None and not all(
385
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
386
+ ):
387
+ raise ValueError(
388
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
389
+ )
390
+
391
+ if prompt is not None and prompt_embeds is not None:
392
+ raise ValueError(
393
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
394
+ " only forward one of the two."
395
+ )
396
+ elif prompt is None and prompt_embeds is None:
397
+ raise ValueError(
398
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
399
+ )
400
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
401
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
402
+
403
+ if negative_prompt is not None and negative_prompt_embeds is not None:
404
+ raise ValueError(
405
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
406
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
407
+ )
408
+
409
+ if prompt_embeds is not None and prompt_embeds_mask is None:
410
+ raise ValueError(
411
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
412
+ )
413
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
414
+ raise ValueError(
415
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
416
+ )
417
+
418
+ if max_sequence_length is not None and max_sequence_length > 1024:
419
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
420
+
421
+ @staticmethod
422
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
423
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
424
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
425
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
426
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
427
+
428
+ return latents
429
+
430
+ @staticmethod
431
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
432
+ def _unpack_latents(latents, height, width, vae_scale_factor):
433
+ batch_size, num_patches, channels = latents.shape
434
+
435
+ # VAE applies 8x compression on images but we must also account for packing which requires
436
+ # latent height and width to be divisible by 2.
437
+ height = 2 * (int(height) // (vae_scale_factor * 2))
438
+ width = 2 * (int(width) // (vae_scale_factor * 2))
439
+
440
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
441
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
442
+
443
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
444
+
445
+ return latents
446
+
447
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
448
+ if isinstance(generator, list):
449
+ image_latents = [
450
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
451
+ for i in range(image.shape[0])
452
+ ]
453
+ image_latents = torch.cat(image_latents, dim=0)
454
+ else:
455
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
456
+ latents_mean = (
457
+ torch.tensor(self.vae.config.latents_mean)
458
+ .view(1, self.latent_channels, 1, 1, 1)
459
+ .to(image_latents.device, image_latents.dtype)
460
+ )
461
+ latents_std = (
462
+ torch.tensor(self.vae.config.latents_std)
463
+ .view(1, self.latent_channels, 1, 1, 1)
464
+ .to(image_latents.device, image_latents.dtype)
465
+ )
466
+ image_latents = (image_latents - latents_mean) / latents_std
467
+
468
+ return image_latents
469
+
470
+ def enable_vae_slicing(self):
471
+ r"""
472
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
473
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
474
+ """
475
+ self.vae.enable_slicing()
476
+
477
+ def disable_vae_slicing(self):
478
+ r"""
479
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
480
+ computing decoding in one step.
481
+ """
482
+ self.vae.disable_slicing()
483
+
484
+ def enable_vae_tiling(self):
485
+ r"""
486
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
487
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
488
+ processing larger images.
489
+ """
490
+ self.vae.enable_tiling()
491
+
492
+ def disable_vae_tiling(self):
493
+ r"""
494
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
495
+ computing decoding in one step.
496
+ """
497
+ self.vae.disable_tiling()
498
+
499
+ def prepare_latents(
500
+ self,
501
+ images,
502
+ batch_size,
503
+ num_channels_latents,
504
+ height,
505
+ width,
506
+ dtype,
507
+ device,
508
+ generator,
509
+ latents=None,
510
+ ):
511
+ # VAE applies 8x compression on images but we must also account for packing which requires
512
+ # latent height and width to be divisible by 2.
513
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
514
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
515
+
516
+ shape = (batch_size, 1, num_channels_latents, height, width)
517
+
518
+ image_latents_list = []
519
+ for image in images:
520
+ image = image.to(device=device, dtype=dtype)
521
+ if image.shape[1] != self.latent_channels:
522
+ image_latents = self._encode_vae_image(image=image, generator=generator)
523
+ else:
524
+ image_latents = image
525
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
526
+ # expand init_latents for batch_size
527
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
528
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
529
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
530
+ raise ValueError(
531
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
532
+ )
533
+ else:
534
+ image_latents = torch.cat([image_latents], dim=0)
535
+
536
+ image_latent_height, image_latent_width = image_latents.shape[3:]
537
+ image_latents = self._pack_latents(
538
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
539
+ )
540
+ image_latents_list.append(image_latents)
541
+
542
+ if isinstance(generator, list) and len(generator) != batch_size:
543
+ raise ValueError(
544
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
545
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
546
+ )
547
+ if latents is None:
548
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
549
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
550
+ else:
551
+ latents = latents.to(device=device, dtype=dtype)
552
+
553
+ return latents, image_latents_list
554
+
555
+ @property
556
+ def guidance_scale(self):
557
+ return self._guidance_scale
558
+
559
+ @property
560
+ def attention_kwargs(self):
561
+ return self._attention_kwargs
562
+
563
+ @property
564
+ def num_timesteps(self):
565
+ return self._num_timesteps
566
+
567
+ @property
568
+ def current_timestep(self):
569
+ return self._current_timestep
570
+
571
+ @property
572
+ def interrupt(self):
573
+ return self._interrupt
574
+
575
+ @torch.no_grad()
576
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
577
+ def __call__(
578
+ self,
579
+ images: List[PipelineImageInput] = None,
580
+ prompt: Union[str, List[str]] = None,
581
+ negative_prompt: Union[str, List[str]] = None,
582
+ true_cfg_scale: float = 4.0,
583
+ height: Optional[int] = None,
584
+ width: Optional[int] = None,
585
+ num_inference_steps: int = 50,
586
+ sigmas: Optional[List[float]] = None,
587
+ guidance_scale: float = 1.0,
588
+ num_images_per_prompt: int = 1,
589
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
590
+ latents: Optional[torch.Tensor] = None,
591
+ prompt_embeds: Optional[torch.Tensor] = None,
592
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
593
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
594
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
595
+ output_type: Optional[str] = "pil",
596
+ return_dict: bool = True,
597
+ attention_kwargs: Optional[Dict[str, Any]] = None,
598
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
599
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
600
+ max_sequence_length: int = 512,
601
+ ):
602
+ r"""
603
+ Function invoked when calling the pipeline for generation.
604
+
605
+ Args:
606
+ prompt (`str` or `List[str]`, *optional*):
607
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
608
+ instead.
609
+ negative_prompt (`str` or `List[str]`, *optional*):
610
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
611
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
612
+ not greater than `1`).
613
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
614
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
615
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
616
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
617
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
618
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
619
+ num_inference_steps (`int`, *optional*, defaults to 50):
620
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
621
+ expense of slower inference.
622
+ sigmas (`List[float]`, *optional*):
623
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
624
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
625
+ will be used.
626
+ guidance_scale (`float`, *optional*, defaults to 3.5):
627
+ Guidance scale as defined in [Classifier-Free Diffusion
628
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
629
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
630
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
631
+ the text `prompt`, usually at the expense of lower image quality.
632
+
633
+ This parameter in the pipeline is there to support future guidance-distilled models when they come up.
634
+ Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
635
+ please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
636
+ enable classifier-free guidance computations.
637
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
638
+ The number of images to generate per prompt.
639
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
640
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
641
+ to make generation deterministic.
642
+ latents (`torch.Tensor`, *optional*):
643
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
644
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
645
+ tensor will be generated by sampling using the supplied random `generator`.
646
+ prompt_embeds (`torch.Tensor`, *optional*):
647
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
648
+ provided, text embeddings will be generated from `prompt` input argument.
649
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
650
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
651
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
652
+ argument.
653
+ output_type (`str`, *optional*, defaults to `"pil"`):
654
+ The output format of the generate image. Choose between
655
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
656
+ return_dict (`bool`, *optional*, defaults to `True`):
657
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
658
+ attention_kwargs (`dict`, *optional*):
659
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
660
+ `self.processor` in
661
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
662
+ callback_on_step_end (`Callable`, *optional*):
663
+ A function that calls at the end of each denoising steps during the inference. The function is called
664
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
665
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
666
+ `callback_on_step_end_tensor_inputs`.
667
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
668
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
669
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
670
+ `._callback_tensor_inputs` attribute of your pipeline class.
671
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
672
+
673
+ Examples:
674
+
675
+ Returns:
676
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
677
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
678
+ returning a tuple, the first element is a list with the generated images.
679
+ """
680
+ if not isinstance(images, (list, tuple)):
681
+ images = [images]
682
+
683
+ # prepare multiple images
684
+ total_number_of_pixels = sum([math.prod(image.size) for image in images])
685
+ ratio = 1024 / total_number_of_pixels ** 0.5
686
+ images = [image.resize(size=(round(image.width*ratio), round(image.height*ratio))) for image in images]
687
+ images = [resize_to_multiple_of(image=image, multiple_of=32) for image in images]
688
+
689
+ if height is None or width is None:
690
+ width, height = images[0].size
691
+
692
+ multiple_of = self.vae_scale_factor * 2
693
+ width = width // multiple_of * multiple_of
694
+ height = height // multiple_of * multiple_of
695
+
696
+ # 1. Check inputs. Raise error if not correct
697
+ self.check_inputs(
698
+ prompt,
699
+ height,
700
+ width,
701
+ negative_prompt=negative_prompt,
702
+ prompt_embeds=prompt_embeds,
703
+ negative_prompt_embeds=negative_prompt_embeds,
704
+ prompt_embeds_mask=prompt_embeds_mask,
705
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
706
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
707
+ max_sequence_length=max_sequence_length,
708
+ )
709
+
710
+ self._guidance_scale = guidance_scale
711
+ self._attention_kwargs = attention_kwargs
712
+ self._current_timestep = None
713
+ self._interrupt = False
714
+
715
+ # 2. Define call parameters
716
+ if prompt is not None and isinstance(prompt, str):
717
+ batch_size = 1
718
+ elif prompt is not None and isinstance(prompt, list):
719
+ batch_size = len(prompt)
720
+ else:
721
+ batch_size = prompt_embeds.shape[0]
722
+
723
+ device = self._execution_device
724
+ # 3. Preprocess image
725
+
726
+ prompt_images = [image.resize((round(image.width * 28 / 32), round(image.height * 28 / 32))) for image in images]
727
+ images = [self.image_processor.preprocess(image, image.height, image.width).unsqueeze(2) for image in images]
728
+ # import pdb; pdb.set_trace()
729
+
730
+ has_neg_prompt = negative_prompt is not None or (
731
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
732
+ )
733
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
734
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
735
+ images=prompt_images,
736
+ prompt=prompt,
737
+ prompt_embeds=prompt_embeds,
738
+ prompt_embeds_mask=prompt_embeds_mask,
739
+ device=device,
740
+ num_images_per_prompt=num_images_per_prompt,
741
+ max_sequence_length=max_sequence_length,
742
+ )
743
+ if do_true_cfg:
744
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
745
+ images=prompt_images,
746
+ prompt=negative_prompt,
747
+ prompt_embeds=negative_prompt_embeds,
748
+ prompt_embeds_mask=negative_prompt_embeds_mask,
749
+ device=device,
750
+ num_images_per_prompt=num_images_per_prompt,
751
+ max_sequence_length=max_sequence_length,
752
+ )
753
+
754
+ # 4. Prepare latent variables
755
+ num_channels_latents = self.transformer.config.in_channels // 4
756
+ latents, image_latents = self.prepare_latents(
757
+ images,
758
+ batch_size * num_images_per_prompt,
759
+ num_channels_latents,
760
+ height,
761
+ width,
762
+ prompt_embeds.dtype,
763
+ device,
764
+ generator,
765
+ latents,
766
+ )
767
+ img_shapes = [
768
+ [
769
+ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
770
+ ] +
771
+ [
772
+ (1, image.shape[-2] // self.vae_scale_factor // 2, image.shape[-1] // self.vae_scale_factor // 2)
773
+ for image in images
774
+ ]
775
+ ] * batch_size
776
+
777
+ # import pdb; pdb.set_trace()
778
+
779
+ # 5. Prepare timesteps
780
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
781
+ image_seq_len = latents.shape[1]
782
+ mu = calculate_shift(
783
+ image_seq_len,
784
+ self.scheduler.config.get("base_image_seq_len", 256),
785
+ self.scheduler.config.get("max_image_seq_len", 4096),
786
+ self.scheduler.config.get("base_shift", 0.5),
787
+ self.scheduler.config.get("max_shift", 1.15),
788
+ )
789
+ timesteps, num_inference_steps = retrieve_timesteps(
790
+ self.scheduler,
791
+ num_inference_steps,
792
+ device,
793
+ sigmas=sigmas,
794
+ mu=mu,
795
+ )
796
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
797
+ self._num_timesteps = len(timesteps)
798
+
799
+ # handle guidance
800
+ if self.transformer.config.guidance_embeds:
801
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
802
+ guidance = guidance.expand(latents.shape[0])
803
+ else:
804
+ guidance = None
805
+
806
+ if self.attention_kwargs is None:
807
+ self._attention_kwargs = {}
808
+
809
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
810
+ negative_txt_seq_lens = (
811
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
812
+ )
813
+
814
+ # 6. Denoising loop
815
+ self.scheduler.set_begin_index(0)
816
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
817
+ for i, t in enumerate(timesteps):
818
+ if self.interrupt:
819
+ continue
820
+
821
+ self._current_timestep = t
822
+
823
+ latent_model_input = torch.cat([latents] + image_latents, dim=1)
824
+
825
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
826
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
827
+ with self.transformer.cache_context("cond"):
828
+ noise_pred = self.transformer(
829
+ hidden_states=latent_model_input,
830
+ timestep=timestep / 1000,
831
+ guidance=guidance,
832
+ encoder_hidden_states_mask=prompt_embeds_mask,
833
+ encoder_hidden_states=prompt_embeds,
834
+ img_shapes=img_shapes,
835
+ txt_seq_lens=txt_seq_lens,
836
+ attention_kwargs=self.attention_kwargs,
837
+ return_dict=False,
838
+ )[0]
839
+ noise_pred = noise_pred[:, : latents.size(1)]
840
+
841
+ if do_true_cfg:
842
+ with self.transformer.cache_context("uncond"):
843
+ neg_noise_pred = self.transformer(
844
+ hidden_states=latent_model_input,
845
+ timestep=timestep / 1000,
846
+ guidance=guidance,
847
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
848
+ encoder_hidden_states=negative_prompt_embeds,
849
+ img_shapes=img_shapes,
850
+ txt_seq_lens=negative_txt_seq_lens,
851
+ attention_kwargs=self.attention_kwargs,
852
+ return_dict=False,
853
+ )[0]
854
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
855
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
856
+
857
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
858
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
859
+ noise_pred = comb_pred * (cond_norm / noise_norm)
860
+
861
+ # compute the previous noisy sample x_t -> x_t-1
862
+ latents_dtype = latents.dtype
863
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
864
+
865
+ if latents.dtype != latents_dtype:
866
+ if torch.backends.mps.is_available():
867
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
868
+ latents = latents.to(latents_dtype)
869
+
870
+ if callback_on_step_end is not None:
871
+ callback_kwargs = {}
872
+ for k in callback_on_step_end_tensor_inputs:
873
+ callback_kwargs[k] = locals()[k]
874
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
875
+
876
+ latents = callback_outputs.pop("latents", latents)
877
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
878
+
879
+ # call the callback, if provided
880
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
881
+ progress_bar.update()
882
+
883
+ if XLA_AVAILABLE:
884
+ xm.mark_step()
885
+
886
+ self._current_timestep = None
887
+ if output_type == "latent":
888
+ image = latents
889
+ else:
890
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
891
+ latents = latents.to(self.vae.dtype)
892
+ latents_mean = (
893
+ torch.tensor(self.vae.config.latents_mean)
894
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
895
+ .to(latents.device, latents.dtype)
896
+ )
897
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
898
+ latents.device, latents.dtype
899
+ )
900
+ latents = latents / latents_std + latents_mean
901
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
902
+ image = self.image_processor.postprocess(image, output_type=output_type)
903
+
904
+ # Offload all models
905
+ self.maybe_free_model_hooks()
906
+
907
+ if not return_dict:
908
+ return (image,)
909
+
910
+ return QwenImagePipelineOutput(images=image)
requirements.txt CHANGED
@@ -1,6 +1,21 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio and UI
2
+ gradio>=4.0.0
3
+
4
+ # PyTorch
5
+ torch>=2.0.0
6
+ torchvision
7
+ torchaudio
8
+
9
+ # Transformers and Diffusers
10
+ transformers>=4.56.0
11
+ diffusers>=0.36.0
12
+ accelerate>=1.10.0
13
+
14
+ # Image processing
15
+ pillow>=10.0.0
16
+ numpy
17
+
18
+ # Utilities
19
+ huggingface-hub
20
+ safetensors
21
+