sdfafdfsdf commited on
Commit
674481b
·
verified ·
1 Parent(s): 11b66b5

Upload app code and configuration files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/13.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: QIE-2511 Rapid-AIO LoRAs Fast (Experimental)
3
+ emoji: ⚡
4
+ colorFrom: red
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 6.2.0
8
+ app_file: app.py
9
+ pinned: true
10
+ license: apache-2.0
11
+ short_description: Demo of the Collection of Qwen Image Edit LoRAs
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,1380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ try:
3
+ import spaces
4
+ except ImportError:
5
+ class spaces:
6
+ @staticmethod
7
+ def GPU(f): return f
8
+ sys.modules["spaces"] = sys.modules.get("spaces", spaces)
9
+ import os
10
+ from camera_control_ui import CameraControl3D, build_camera_prompt, update_prompt_with_camera
11
+ import re
12
+ import gc
13
+ import traceback
14
+ import gradio as gr
15
+ import numpy as np
16
+ import spaces
17
+ import torch
18
+ import random
19
+ from PIL import Image, ImageDraw
20
+ from typing import Iterable, Optional
21
+
22
+ from transformers import (
23
+ AutoImageProcessor,
24
+ AutoModelForDepthEstimation,
25
+ )
26
+
27
+ from huggingface_hub import hf_hub_download
28
+ from safetensors.torch import load_file as safetensors_load_file
29
+
30
+ from gradio.themes import Soft
31
+ from gradio.themes.utils import colors, fonts, sizes
32
+
33
+ # ============================================================
34
+ # Theme
35
+ # ============================================================
36
+
37
+ colors.orange_red = colors.Color(
38
+ name="orange_red",
39
+ c50="#FFF0E5",
40
+ c100="#FFE0CC",
41
+ c200="#FFC299",
42
+ c300="#FFA366",
43
+ c400="#FF8533",
44
+ c500="#FF4500",
45
+ c600="#E63E00",
46
+ c700="#CC3700",
47
+ c800="#B33000",
48
+ c900="#992900",
49
+ c950="#802200",
50
+ )
51
+
52
+
53
+ class OrangeRedTheme(Soft):
54
+ def __init__(
55
+ self,
56
+ *,
57
+ primary_hue: colors.Color | str = colors.gray,
58
+ secondary_hue: colors.Color | str = colors.orange_red,
59
+ neutral_hue: colors.Color | str = colors.slate,
60
+ text_size: sizes.Size | str = sizes.text_lg,
61
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
62
+ fonts.GoogleFont("Outfit"),
63
+ "Arial",
64
+ "sans-serif",
65
+ ),
66
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
67
+ fonts.GoogleFont("IBM Plex Mono"),
68
+ "ui-monospace",
69
+ "monospace",
70
+ ),
71
+ ):
72
+ super().__init__(
73
+ primary_hue=primary_hue,
74
+ secondary_hue=secondary_hue,
75
+ neutral_hue=neutral_hue,
76
+ text_size=text_size,
77
+ font=font,
78
+ font_mono=font_mono,
79
+ )
80
+ super().set(
81
+ background_fill_primary="*primary_50",
82
+ background_fill_primary_dark="*primary_900",
83
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
84
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
85
+ button_primary_text_color="white",
86
+ button_primary_text_color_hover="white",
87
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
88
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
89
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
90
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
91
+ button_secondary_text_color="black",
92
+ button_secondary_text_color_hover="white",
93
+ button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
94
+ button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
95
+ button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
96
+ button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
97
+ slider_color="*secondary_500",
98
+ slider_color_dark="*secondary_600",
99
+ block_title_text_weight="600",
100
+ block_border_width="3px",
101
+ block_shadow="*shadow_drop_lg",
102
+ button_primary_shadow="*shadow_drop_lg",
103
+ button_large_padding="11px",
104
+ color_accent_soft="*primary_100",
105
+ block_label_background_fill="*primary_200",
106
+ )
107
+
108
+
109
+ orange_red_theme = OrangeRedTheme()
110
+
111
+ # ============================================================
112
+ # Device
113
+ # ============================================================
114
+
115
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
+
117
+ print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
118
+ print("torch.__version__ =", torch.__version__)
119
+ print("torch.version.cuda =", torch.version.cuda)
120
+ print("cuda available:", torch.cuda.is_available())
121
+ print("cuda device count:", torch.cuda.device_count())
122
+ if torch.cuda.is_available():
123
+ print("current device:", torch.cuda.current_device())
124
+ print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
125
+ print("Using device:", device)
126
+
127
+ # ============================================================
128
+ # AIO version (Space variable)
129
+ # ============================================================
130
+
131
+ AIO_REPO_ID = "Pr0f3ssi0n4ln00b/Phr00t-Qwen-Rapid-AIO"
132
+ DEFAULT_AIO_VERSION = "v19"
133
+
134
+ _VER_RE = re.compile(r"^v\d+$")
135
+ _DIGITS_RE = re.compile(r"^\d+$")
136
+
137
+
138
+ def _normalize_version(raw: str) -> Optional[str]:
139
+ if raw is None:
140
+ return None
141
+ s = str(raw).strip()
142
+ if not s:
143
+ return None
144
+ if _VER_RE.fullmatch(s):
145
+ return s
146
+ # forgiving: allow "21" -> "v21"
147
+ if _DIGITS_RE.fullmatch(s):
148
+ return f"v{s}"
149
+ return None
150
+
151
+
152
+ _AIO_ENV_RAW = os.environ.get("AIO_VERSION", "")
153
+ _AIO_ENV_NORM = _normalize_version(_AIO_ENV_RAW)
154
+
155
+ AIO_VERSION = _AIO_ENV_NORM or DEFAULT_AIO_VERSION
156
+ AIO_VERSION_SOURCE = "env" if _AIO_ENV_NORM else "default(v19)"
157
+
158
+ print(f"AIO_VERSION (env raw) = {_AIO_ENV_RAW!r}")
159
+ print(f"AIO_VERSION (normalized) = {_AIO_ENV_NORM!r}")
160
+ print(f"Using AIO_VERSION = {AIO_VERSION} ({AIO_VERSION_SOURCE})")
161
+
162
+ # ============================================================
163
+ # Pipeline
164
+ # ============================================================
165
+
166
+ from diffusers import FlowMatchEulerDiscreteScheduler # noqa: F401
167
+ from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
168
+ from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
169
+ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
170
+
171
+ dtype = torch.bfloat16
172
+
173
+
174
+ def _load_pipe_with_version(version: str) -> QwenImageEditPlusPipeline:
175
+ sub = f"{version}/transformer"
176
+ print(f"📦 Loading AIO transformer: {AIO_REPO_ID} / {sub}")
177
+ p = QwenImageEditPlusPipeline.from_pretrained(
178
+ "Qwen/Qwen-Image-Edit-2511",
179
+ transformer=QwenImageTransformer2DModel.from_pretrained(
180
+ AIO_REPO_ID,
181
+ subfolder=sub,
182
+ torch_dtype=dtype,
183
+ device_map="auto",
184
+ low_cpu_mem_usage=True,
185
+ ),
186
+ torch_dtype=dtype,
187
+ )
188
+ p.enable_model_cpu_offload()
189
+ return p
190
+
191
+
192
+ # Forgiving load: try env/default version, fallback to v19 if it fails
193
+ try:
194
+ pipe = _load_pipe_with_version(AIO_VERSION)
195
+ except Exception as e:
196
+ print("❌ Failed to load requested AIO_VERSION. Falling back to v19.")
197
+ print("---- exception ----")
198
+ print(traceback.format_exc())
199
+ print("-------------------")
200
+ AIO_VERSION = DEFAULT_AIO_VERSION
201
+ AIO_VERSION_SOURCE = "fallback_to_v19"
202
+ pipe = _load_pipe_with_version(AIO_VERSION)
203
+
204
+ # Apply FA3 Optimization
205
+ try:
206
+ print("Skipping FA3 optimization for stability.")
207
+ print("Flash Attention 3 Processor set successfully.")
208
+ except Exception as e:
209
+ print(f"Warning: Could not set FA3 processor: {e}")
210
+
211
+ MAX_SEED = np.iinfo(np.int32).max
212
+
213
+ # ============================================================
214
+ # VAE tiling toggle (UI-controlled; OFF by default)
215
+ # ============================================================
216
+
217
+ def _apply_vae_tiling(enabled: bool):
218
+ """
219
+ Toggle VAE tiling on the global pipeline.
220
+
221
+ This does NOT require a Space restart; it applies to the next pipe(...) call.
222
+ Note: this is global process state, so concurrent users could flip it between runs.
223
+ """
224
+ try:
225
+ if enabled:
226
+ if hasattr(pipe, "enable_vae_tiling"):
227
+ pipe.enable_vae_tiling()
228
+ print("✅ VAE tiling ENABLED (per UI).")
229
+ elif hasattr(pipe, "vae") and hasattr(pipe.vae, "enable_tiling"):
230
+ pipe.vae.enable_tiling()
231
+ print("✅ VAE tiling ENABLED via pipe.vae.enable_tiling() (per UI).")
232
+ else:
233
+ print("⚠️ No enable_vae_tiling()/vae.enable_tiling() found; cannot enable.")
234
+ else:
235
+ if hasattr(pipe, "disable_vae_tiling"):
236
+ pipe.disable_vae_tiling()
237
+ print("🛑 VAE tiling DISABLED (per UI).")
238
+ elif hasattr(pipe, "vae") and hasattr(pipe.vae, "disable_tiling"):
239
+ pipe.vae.disable_tiling()
240
+ print("🛑 VAE tiling DISABLED via pipe.vae.disable_tiling() (per UI).")
241
+ else:
242
+ # If no disable method exists, we leave current state unchanged.
243
+ print("⚠️ No disable_vae_tiling()/vae.disable_tiling() found; leaving current state unchanged.")
244
+ except Exception as e:
245
+ print(f"⚠️ VAE tiling toggle failed: {e}")
246
+
247
+ # ============================================================
248
+ # Derived conditioning (Transformers): Depth
249
+ # ============================================================
250
+ # Depth uses Depth Anything V2 Small (Transformers-compatible):
251
+ # https://huggingface.co/depth-anything/Depth-Anything-V2-Small-hf
252
+
253
+ DEPTH_MODEL_ID = "depth-anything/Depth-Anything-V2-Small-hf"
254
+
255
+ # Lazy cache keyed by device string ("cpu" / "cuda")
256
+ _DEPTH_CACHE = {}
257
+
258
+ def _derived_device(use_gpu: bool) -> torch.device:
259
+ return torch.device("cuda" if (use_gpu and torch.cuda.is_available()) else "cpu")
260
+
261
+ def _load_depth_models(dev: torch.device):
262
+ key = str(dev)
263
+ if key in _DEPTH_CACHE:
264
+ return _DEPTH_CACHE[key]
265
+ proc = AutoImageProcessor.from_pretrained(DEPTH_MODEL_ID)
266
+ model = AutoModelForDepthEstimation.from_pretrained(DEPTH_MODEL_ID).to(dev)
267
+ model.eval()
268
+ _DEPTH_CACHE[key] = (proc, model)
269
+ return _DEPTH_CACHE[key]
270
+
271
+ @torch.inference_mode()
272
+ def make_depth_map(img: Image.Image, *, use_gpu: bool) -> Image.Image:
273
+ dev = _derived_device(use_gpu)
274
+ proc, model = _load_depth_models(dev)
275
+
276
+ w, h = img.size
277
+ inputs = proc(images=img.convert("RGB"), return_tensors="pt").to(dev)
278
+ outputs = model(**inputs)
279
+ predicted = outputs.predicted_depth # [B, H, W]
280
+
281
+ depth = torch.nn.functional.interpolate(
282
+ predicted.unsqueeze(1),
283
+ size=(h, w),
284
+ mode="bicubic",
285
+ align_corners=False,
286
+ ).squeeze(1)[0]
287
+
288
+ depth = depth - depth.min()
289
+ depth = depth / (depth.max() + 1e-8)
290
+ depth = (depth * 255.0).clamp(0, 255).to(torch.uint8).cpu().numpy()
291
+ return Image.fromarray(depth).convert("RGB")
292
+
293
+ # ============================================================
294
+ # LoRA adapters + presets
295
+ # ============================================================
296
+
297
+ NONE_LORA = "None"
298
+
299
+ ADAPTER_SPECS = {
300
+ "3D-Camera": {
301
+ "type": "single",
302
+ "repo": "fal/Qwen-Image-Edit-2511-Multiple-Angles-LoRA",
303
+ "weights": "qwen-image-edit-2511-multiple-angles-lora.safetensors",
304
+ "adapter_name": "angles",
305
+ "strength": 1.0,
306
+ },
307
+
308
+ "Qwen-lora-nsfw": {
309
+ "type": "single",
310
+ "repo": "wiikoo/Qwen-lora-nsfw",
311
+ "weights": "loras/qwen_image_edit_remove-clothing_v1.0.safetensors",
312
+ "adapter_name": "qwen-lora-nsfw",
313
+ "strength": 1.0,
314
+ },
315
+
316
+ "Consistance": {
317
+ "type": "single",
318
+ "repo": "Pr0f3ssi0n4ln00b/QIE_2511_Consistency_Lora",
319
+ "weights": "qe2511_consis_alpha_patched.safetensors",
320
+ "adapter_name": "Consistency",
321
+ "strength": 0.6,
322
+ },
323
+ "Semirealistic-photo-detailer": {
324
+ "type": "single",
325
+ "repo": "rzgar/Qwen-Image-Edit-semi-realistic-detailer",
326
+ "weights": "Qwen-Image-Edit-Anime-Semi-Realistic-Detailer-v1.safetensors",
327
+ "adapter_name": "semirealistic",
328
+ "strength": 1.0,
329
+ },
330
+ "AnyPose": {
331
+ "type": "package",
332
+ "requires_two_images": True,
333
+ "image2_label": "Upload Pose Reference (Image 2)",
334
+ "parts": [
335
+ {
336
+ "repo": "lilylilith/AnyPose",
337
+ "weights": "2511-AnyPose-base-000006250.safetensors",
338
+ "adapter_name": "anypose-base",
339
+ "strength": 0.7,
340
+ },
341
+ {
342
+ "repo": "lilylilith/AnyPose",
343
+ "weights": "2511-AnyPose-helper-00006000.safetensors",
344
+ "adapter_name": "anypose-helper",
345
+ "strength": 0.7,
346
+ },
347
+ ],
348
+ },
349
+ "Any2Real_2601": {
350
+ "type": "single",
351
+ "repo": "lrzjason/Anything2Real_2601",
352
+ "weights": "anything2real_2601_A_final_patched.safetensors",
353
+ "adapter_name": "photoreal",
354
+ "strength": 1.0,
355
+ },
356
+ "Hyperrealistic-Portrait": {
357
+ "type": "single",
358
+ "repo": "prithivMLmods/Qwen-Image-Edit-2511-Hyper-Realistic-Portrait",
359
+ "weights": "HRP_20.safetensors",
360
+ "adapter_name": "HRPortrait",
361
+ "strength": 1.0,
362
+ },
363
+ "Ultrarealistic-Portrait": {
364
+ "type": "single",
365
+ "repo": "prithivMLmods/Qwen-Image-Edit-2511-Ultra-Realistic-Portrait",
366
+ "weights": "URP_20.safetensors",
367
+ "adapter_name": "URPortrait",
368
+ "strength": 1.0,
369
+ },
370
+ "BFS-Best-FaceSwap": {
371
+ "type": "single",
372
+ "requires_two_images": True,
373
+ "image2_label": "Upload Head/Face Donor (Image 2)",
374
+ "repo": "Alissonerdx/BFS-Best-Face-Swap",
375
+ "weights": "bfs_head_v5_2511_original.safetensors",
376
+ "adapter_name": "BFS-Best-Faceswap",
377
+ "strength": 1.0,
378
+ "needs_alpha_fix": True, # <-- fixes KeyError 'img_in.alpha'
379
+ },
380
+ "BFS-Best-FaceSwap-merge": {
381
+ "type": "single",
382
+ "requires_two_images": True,
383
+ "image2_label": "Upload Head/Face Donor (Image 2)",
384
+ "repo": "Alissonerdx/BFS-Best-Face-Swap",
385
+ "weights": "bfs_head_v5_2511_merged_version_rank_32_fp32.safetensors",
386
+ "adapter_name": "BFS-Best-Faceswap-merge",
387
+ "strength": 1.1,
388
+ "needs_alpha_fix": True, # <-- fixes KeyError 'img_in.alpha'
389
+ },
390
+ "F2P": {
391
+ "type": "single",
392
+ "repo": "DiffSynth-Studio/Qwen-Image-Edit-F2P",
393
+ "weights": "edit_0928_lora_step40000.safetensors",
394
+ "adapter_name": "F2P",
395
+ "strength": 1.0,
396
+ },
397
+ "Multiple-Angles": {
398
+ "type": "single",
399
+ "repo": "dx8152/Qwen-Edit-2509-Multiple-angles",
400
+ "weights": "镜头转换.safetensors",
401
+ "adapter_name": "multiple-angles",
402
+ "strength": 1.0,
403
+ },
404
+ "Light-Restoration": {
405
+ "type": "single",
406
+ "repo": "dx8152/Qwen-Image-Edit-2509-Light_restoration",
407
+ "weights": "移除光影.safetensors",
408
+ "adapter_name": "light-restoration",
409
+ "strength": 1.0,
410
+ },
411
+ "Relight": {
412
+ "type": "single",
413
+ "repo": "dx8152/Qwen-Image-Edit-2509-Relight",
414
+ "weights": "Qwen-Edit-Relight.safetensors",
415
+ "adapter_name": "relight",
416
+ "strength": 1.0,
417
+ },
418
+ "Multi-Angle-Lighting": {
419
+ "type": "single",
420
+ "repo": "dx8152/Qwen-Edit-2509-Multi-Angle-Lighting",
421
+ "weights": "多角度灯光-251116.safetensors",
422
+ "adapter_name": "multi-angle-lighting",
423
+ "strength": 1.0,
424
+ },
425
+ "Edit-Skin": {
426
+ "type": "single",
427
+ "repo": "tlennon-ie/qwen-edit-skin",
428
+ "weights": "qwen-edit-skin_1.1_000002750.safetensors",
429
+ "adapter_name": "edit-skin",
430
+ "strength": 1.0,
431
+ },
432
+ "Next-Scene": {
433
+ "type": "single",
434
+ "repo": "lovis93/next-scene-qwen-image-lora-2509",
435
+ "weights": "next-scene_lora-v2-3000.safetensors",
436
+ "adapter_name": "next-scene",
437
+ "strength": 1.0,
438
+ },
439
+ "Flat-Log": {
440
+ "type": "single",
441
+ "repo": "tlennon-ie/QwenEdit2509-FlatLogColor",
442
+ "weights": "QwenEdit2509-FlatLogColor.safetensors",
443
+ "adapter_name": "flat-log",
444
+ "strength": 1.0,
445
+ },
446
+ "Upscale-Image": {
447
+ "type": "single",
448
+ "repo": "vafipas663/Qwen-Edit-2509-Upscale-LoRA",
449
+ "weights": "qwen-edit-enhance_64-v3_000001000.safetensors",
450
+ "adapter_name": "upscale-image",
451
+ "strength": 1.0,
452
+ },
453
+ "Upscale2K": {
454
+ "type": "single",
455
+ "repo": "valiantcat/Qwen-Image-Edit-2509-Upscale2K",
456
+ "weights": "qwen_image_edit_2509_upscale.safetensors",
457
+ "adapter_name": "upscale-2k",
458
+ "strength": 1.0,
459
+ "target_long_edge": 2048,
460
+ },
461
+ }
462
+
463
+ LORA_PRESET_PROMPTS = {
464
+ "Any2Real_2601": "change the picture 1 to realistic photograph",
465
+ "Semirealistic-photo-detailer": "transform the image to semi-realistic image",
466
+ "AnyPose": "Make the person in image 1 do the exact same pose of the person in image 2. Changing the style and background of the image of the person in image 1 is undesirable, so don't do it. The new pose should be pixel accurate to the pose we are trying to copy. The position of the arms and head and legs should be the same as the pose we are trying to copy. Change the field of view and angle to match exactly image 2. Head tilt and eye gaze pose should match the person in image 2.",
467
+ "Hyperrealistic-Portrait": "Transform the image into an ultra-realistic photorealistic portrait with strict identity preservation, facing straight to the camera. Enhance pore-level skin textures, realistic moisture effects, and natural wet hair clumping against the skin. Apply cool-toned soft-box lighting with subtle highlights and shadows, maintain realistic green-hazel eye catchlights without synthetic gloss, and preserve soft natural lip texture. Use shallow depth of field with a clean bokeh background, an 85mm macro photographic look, and raw photo grading without retouching to maintain realism and original details.",
468
+ "Ultrarealistic-Portrait": "Transform the image into an ultra-realistic glamour portrait while strictly preserving the subject’s identity. Apply a close-up composition with a slight head tilt and a hand near the face, enhance cinematic directional lighting with dramatic fashion-style highlights, and refine makeup details including glowing skin, glossy lips, luminous highlighter, and defined eyes. Increase skin realism with detailed epidermal textures such as micropores, microhairs, subtle oil sheen, natural highlights, soft wrinkles, and subsurface scattering. Maintain a luxury fashion-magazine look in a 9:16 aspect ratio, preserving realism, facial structure, and original details without over-smoothing or retouching.",
469
+ "Upscale2K": "Upscale this picture to 4K resolution.",
470
+ "BFS-Best-FaceSwap": "head_swap: start with Picture 1 as the base image, keeping its lighting, environment, and background. remove the head from Picture 1 completely and replace it with the head from Picture 2, strictly preserving the hair, eye color, and nose structure of Picture 2. copy the eye direction, head rotation, and micro-expressions from Picture 1. high quality, sharp details, 4k",
471
+ "BFS-Best-FaceSwap-merge": "head_swap: start with Picture 1 as the base image, keeping its lighting, environment, and background. remove the head from Picture 1 completely and replace it with the head from Picture 2, strictly preserving the hair, eye color, and nose structure of Picture 2. copy the eye direction, head rotation, and micro-expressions from Picture 1. high quality, sharp details, 4k",
472
+ }
473
+
474
+ # Track what is currently loaded in memory (adapter_name values)
475
+ LOADED_ADAPTERS = set()
476
+
477
+ # ============================================================
478
+ # Helpers: resolution
479
+ # ============================================================
480
+
481
+ # We prefer *area-based* sizing (≈ megapixels) over long-edge sizing.
482
+ # This aligns better with Qwen-Image-Edit's internal assumptions and reduces FOV drift.
483
+
484
+ def _round_to_multiple(x: int, m: int) -> int:
485
+ return max(m, (int(x) // m) * m)
486
+
487
+ def compute_canvas_dimensions_from_area(
488
+ image: Image.Image,
489
+ target_area: int,
490
+ multiple_of: int,
491
+ ) -> tuple[int, int]:
492
+ """Compute (width, height) that matches image aspect ratio and approximates target_area.
493
+
494
+ The result is floored to be divisible by multiple_of (typically vae_scale_factor*2).
495
+ """
496
+ w, h = image.size
497
+ aspect = w / h if h else 1.0
498
+
499
+ # Use the pipeline's own area->(w,h) helper for consistency.
500
+ from qwenimage.pipeline_qwenimage_edit_plus import calculate_dimensions
501
+
502
+ width, height = calculate_dimensions(int(target_area), float(aspect))
503
+ width = _round_to_multiple(int(width), int(multiple_of))
504
+ height = _round_to_multiple(int(height), int(multiple_of))
505
+ return width, height
506
+
507
+ def get_target_area_for_lora(
508
+ image: Image.Image,
509
+ lora_adapter: str,
510
+ user_target_megapixels: float,
511
+ ) -> int:
512
+ """Return target pixel area for the canvas.
513
+
514
+ Priority:
515
+ 1) Adapter spec: target_area (pixels) or target_megapixels
516
+ 2) Adapter spec: target_long_edge (legacy) -> converted to area using image aspect
517
+ 3) User slider target megapixels
518
+ """
519
+ spec = ADAPTER_SPECS.get(lora_adapter, {})
520
+
521
+ if "target_area" in spec:
522
+ try:
523
+ return int(spec["target_area"])
524
+ except Exception:
525
+ pass
526
+
527
+ if "target_megapixels" in spec:
528
+ try:
529
+ mp = float(spec["target_megapixels"])
530
+ return int(mp * 1024 * 1024)
531
+ except Exception:
532
+ pass
533
+
534
+ # Legacy support (e.g. Upscale2K)
535
+ if "target_long_edge" in spec:
536
+ try:
537
+ long_edge = int(spec["target_long_edge"])
538
+ w, h = image.size
539
+ if w >= h:
540
+ new_w = long_edge
541
+ new_h = int(round(long_edge * (h / w)))
542
+ else:
543
+ new_h = long_edge
544
+ new_w = int(round(long_edge * (w / h)))
545
+ return int(new_w * new_h)
546
+ except Exception:
547
+ pass
548
+
549
+ # User default
550
+ try:
551
+ mp = float(user_target_megapixels)
552
+ except Exception:
553
+ mp = 1.0
554
+
555
+ # Treat 0 MP as "match input area"
556
+ if mp <= 0:
557
+ w, h = image.size
558
+ return int(w * h)
559
+
560
+ return int(mp * 1024 * 1024)
561
+
562
+ # ============================================================
563
+ # Helpers: multi-input routing + gallery normalization
564
+ # ============================================================
565
+
566
+
567
+ def lora_requires_two_images(lora_adapter: str) -> bool:
568
+ return bool(ADAPTER_SPECS.get(lora_adapter, {}).get("requires_two_images", False))
569
+
570
+
571
+ def image2_label_for_lora(lora_adapter: str) -> str:
572
+ return str(ADAPTER_SPECS.get(lora_adapter, {}).get("image2_label", "Upload Reference (Image 2)"))
573
+
574
+
575
+ def _to_pil_rgb(x) -> Optional[Image.Image]:
576
+ """
577
+ Accepts PIL / numpy / (image, caption) tuples from gr.Gallery and returns PIL RGB.
578
+ Gradio Gallery commonly yields tuples like (image, caption).
579
+ """
580
+ if x is None:
581
+ return None
582
+
583
+ # Gallery often returns (image, caption)
584
+ if isinstance(x, tuple) and len(x) >= 1:
585
+ x = x[0]
586
+ if x is None:
587
+ return None
588
+
589
+ if isinstance(x, Image.Image):
590
+ return x.convert("RGB")
591
+
592
+ if isinstance(x, np.ndarray):
593
+ return Image.fromarray(x).convert("RGB")
594
+
595
+ # Best-effort fallback
596
+ try:
597
+ return Image.fromarray(np.array(x)).convert("RGB")
598
+ except Exception:
599
+ return None
600
+
601
+
602
+ def build_labeled_images(
603
+ img1: Image.Image,
604
+ img2: Optional[Image.Image],
605
+ extra_imgs: Optional[list[Image.Image]],
606
+ ) -> dict[str, Image.Image]:
607
+ """
608
+ Creates labels image_1, image_2, image_3... based on what is actually uploaded:
609
+ - img1 is always image_1
610
+ - img2 becomes image_2 only if present
611
+ - extras start immediately after the last present base box
612
+ The pipeline receives images in this exact order.
613
+ """
614
+ labeled: dict[str, Image.Image] = {}
615
+ idx = 1
616
+
617
+ labeled[f"image_{idx}"] = img1
618
+ idx += 1
619
+
620
+ if img2 is not None:
621
+ labeled[f"image_{idx}"] = img2
622
+ idx += 1
623
+
624
+ if extra_imgs:
625
+ for im in extra_imgs:
626
+ if im is None:
627
+ continue
628
+ labeled[f"image_{idx}"] = im
629
+ idx += 1
630
+
631
+ return labeled
632
+
633
+
634
+ # ============================================================
635
+ # Helpers: BFS alpha key fix
636
+ # ============================================================
637
+
638
+
639
+ def _inject_missing_alpha_keys(state_dict: dict) -> dict:
640
+ """
641
+ Diffusers' Qwen LoRA converter expects '<module>.alpha' keys.
642
+ BFS safetensors omits them. We inject alpha = rank (neutral scaling).
643
+
644
+ IMPORTANT: diffusers may strip 'diffusion_model.' before lookup, so we
645
+ inject BOTH:
646
+ - diffusion_model.xxx.alpha
647
+ - xxx.alpha
648
+ """
649
+ bases = {}
650
+
651
+ for k, v in state_dict.items():
652
+ if not isinstance(v, torch.Tensor):
653
+ continue
654
+ if k.endswith(".lora_down.weight") and v.ndim >= 1:
655
+ base = k[: -len(".lora_down.weight")]
656
+ rank = int(v.shape[0])
657
+ bases[base] = rank
658
+
659
+ for base, rank in bases.items():
660
+ alpha_tensor = torch.tensor(float(rank), dtype=torch.float32)
661
+
662
+ full_alpha = f"{base}.alpha"
663
+ if full_alpha not in state_dict:
664
+ state_dict[full_alpha] = alpha_tensor
665
+
666
+ if base.startswith("diffusion_model."):
667
+ stripped_base = base[len("diffusion_model.") :]
668
+ stripped_alpha = f"{stripped_base}.alpha"
669
+ if stripped_alpha not in state_dict:
670
+ state_dict[stripped_alpha] = alpha_tensor
671
+
672
+ return state_dict
673
+
674
+
675
+ def _filter_to_diffusers_lora_keys(state_dict: dict) -> tuple[dict, dict]:
676
+ """Return (filtered_state_dict, stats).
677
+
678
+ Some ComfyUI/Qwen safetensors (especially "merged" variants) include non-LoRA
679
+ delta/patch keys like `*.diff` and `*.diff_b` alongside real LoRA tensors.
680
+ Diffusers' internal Qwen LoRA converter is strict: any leftover keys cause an
681
+ error (`state_dict should be empty...`).
682
+
683
+ This helper keeps only the keys Diffusers can consume as a LoRA:
684
+ - `*.lora_up.weight`
685
+ - `*.lora_down.weight`
686
+ - (rare) `*.lora_mid.weight`
687
+ - alpha keys: `*.alpha` (or `*.lora_alpha` which we normalize to `*.alpha`)
688
+
689
+ It also drops known patch keys (`*.diff`, `*.diff_b`) and everything else.
690
+ """
691
+
692
+ keep_suffixes = (
693
+ ".lora_up.weight",
694
+ ".lora_down.weight",
695
+ ".lora_mid.weight",
696
+ ".alpha",
697
+ ".lora_alpha",
698
+ )
699
+
700
+ dropped_patch = 0
701
+ dropped_other = 0
702
+ kept = 0
703
+ normalized_alpha = 0
704
+
705
+ out: dict[str, torch.Tensor] = {}
706
+ for k, v in state_dict.items():
707
+ if not isinstance(v, torch.Tensor):
708
+ # Ignore non-tensor entries if any.
709
+ dropped_other += 1
710
+ continue
711
+
712
+ # Drop ComfyUI "delta" keys that Diffusers' LoRA loader will never consume.
713
+ if k.endswith(".diff") or k.endswith(".diff_b"):
714
+ dropped_patch += 1
715
+ continue
716
+
717
+ if not k.endswith(keep_suffixes):
718
+ dropped_other += 1
719
+ continue
720
+
721
+ if k.endswith(".lora_alpha"):
722
+ # Normalize common alt name to what Diffusers expects.
723
+ base = k[: -len(".lora_alpha")]
724
+ k2 = f"{base}.alpha"
725
+ out[k2] = v.float() if v.dtype != torch.float32 else v
726
+ normalized_alpha += 1
727
+ kept += 1
728
+ continue
729
+
730
+ out[k] = v
731
+ kept += 1
732
+
733
+ stats = {
734
+ "kept": kept,
735
+ "dropped_patch": dropped_patch,
736
+ "dropped_other": dropped_other,
737
+ "normalized_alpha": normalized_alpha,
738
+ }
739
+ return out, stats
740
+
741
+
742
+ def _duplicate_stripped_prefix_keys(state_dict: dict, prefix: str = "diffusion_model.") -> dict:
743
+ """Ensure both prefixed and unprefixed variants exist for LoRA-related keys.
744
+
745
+ Diffusers' Qwen LoRA conversion may strip `diffusion_model.` when looking up
746
+ modules. Some exports only include prefixed keys. To be maximally compatible,
747
+ we duplicate LoRA keys (and alpha) in stripped form when missing.
748
+ """
749
+
750
+ out = dict(state_dict)
751
+ for k, v in list(state_dict.items()):
752
+ if not k.startswith(prefix):
753
+ continue
754
+ stripped = k[len(prefix) :]
755
+ if stripped not in out:
756
+ out[stripped] = v
757
+ return out
758
+
759
+
760
+ def _load_lora_weights_with_fallback(repo: str, weight_name: str, adapter_name: str, needs_alpha_fix: bool = False):
761
+ """
762
+ Normal path: pipe.load_lora_weights(repo, weight_name=..., adapter_name=...)
763
+ BFS fallback: download safetensors, inject missing alpha keys, then load from dict.
764
+ """
765
+ try:
766
+ pipe.load_lora_weights(repo, weight_name=weight_name, adapter_name=adapter_name)
767
+ return
768
+ except (KeyError, ValueError) as e:
769
+ # KeyError: missing required alpha keys (common in BFS)
770
+ # ValueError: Diffusers Qwen converter found leftover keys (e.g. .diff/.diff_b)
771
+ if not needs_alpha_fix:
772
+ raise
773
+
774
+ print(
775
+ "⚠️ LoRA load failed (will try safe dict fallback). "
776
+ f"Adapter={adapter_name!r} file={weight_name!r} error={type(e).__name__}: {e}"
777
+ )
778
+
779
+ local_path = hf_hub_download(repo_id=repo, filename=weight_name)
780
+ sd = safetensors_load_file(local_path)
781
+
782
+ # 1) Inject required `<module>.alpha` keys (neutral scaling alpha=rank).
783
+ sd = _inject_missing_alpha_keys(sd)
784
+
785
+ # 2) Keep only LoRA + alpha keys; drop ComfyUI patch/delta keys.
786
+ sd, stats = _filter_to_diffusers_lora_keys(sd)
787
+
788
+ # 3) Duplicate stripped keys (remove `diffusion_model.`) for compatibility.
789
+ sd = _duplicate_stripped_prefix_keys(sd)
790
+
791
+ print(
792
+ "🧹 LoRA dict cleanup stats: "
793
+ f"kept={stats['kept']} dropped_patch={stats['dropped_patch']} "
794
+ f"dropped_other={stats['dropped_other']} normalized_alpha={stats['normalized_alpha']}"
795
+ )
796
+
797
+ pipe.load_lora_weights(sd, adapter_name=adapter_name)
798
+ return
799
+
800
+
801
+ # ============================================================
802
+ # LoRA loader: single/package + strengths
803
+ # ============================================================
804
+
805
+
806
+ def _ensure_loaded_and_get_active_adapters(selected_lora: str):
807
+ spec = ADAPTER_SPECS.get(selected_lora)
808
+ if not spec:
809
+ raise gr.Error(f"Configuration not found for: {selected_lora}")
810
+
811
+ adapter_names = []
812
+ adapter_weights = []
813
+
814
+ if spec.get("type") == "package":
815
+ parts = spec.get("parts", [])
816
+ if not parts:
817
+ raise gr.Error(f"Package spec has no parts: {selected_lora}")
818
+
819
+ for part in parts:
820
+ repo = part["repo"]
821
+ weights = part["weights"]
822
+ adapter_name = part["adapter_name"]
823
+ strength = float(part.get("strength", 1.0))
824
+ needs_alpha_fix = bool(part.get("needs_alpha_fix", False))
825
+
826
+ if adapter_name not in LOADED_ADAPTERS:
827
+ print(f"--- Downloading and Loading Adapter Part: {selected_lora} / {adapter_name} ---")
828
+ try:
829
+ _load_lora_weights_with_fallback(
830
+ repo=repo,
831
+ weight_name=weights,
832
+ adapter_name=adapter_name,
833
+ needs_alpha_fix=needs_alpha_fix,
834
+ )
835
+ LOADED_ADAPTERS.add(adapter_name)
836
+ except Exception as e:
837
+ raise gr.Error(f"Failed to load adapter part {selected_lora}/{adapter_name}: {e}")
838
+ else:
839
+ print(f"--- Adapter part already loaded: {selected_lora} / {adapter_name} ---")
840
+
841
+ adapter_names.append(adapter_name)
842
+ adapter_weights.append(strength)
843
+
844
+ else:
845
+ repo = spec["repo"]
846
+ weights = spec["weights"]
847
+ adapter_name = spec["adapter_name"]
848
+ strength = float(spec.get("strength", 1.0))
849
+ needs_alpha_fix = bool(spec.get("needs_alpha_fix", False))
850
+
851
+ if adapter_name not in LOADED_ADAPTERS:
852
+ print(f"--- Downloading and Loading Adapter: {selected_lora} ---")
853
+ try:
854
+ _load_lora_weights_with_fallback(
855
+ repo=repo,
856
+ weight_name=weights,
857
+ adapter_name=adapter_name,
858
+ needs_alpha_fix=needs_alpha_fix,
859
+ )
860
+ LOADED_ADAPTERS.add(adapter_name)
861
+ except Exception as e:
862
+ raise gr.Error(f"Failed to load adapter {selected_lora}: {e}")
863
+ else:
864
+ print(f"--- Adapter {selected_lora} is already loaded. ---")
865
+
866
+ adapter_names = [adapter_name]
867
+ adapter_weights = [strength]
868
+
869
+ return adapter_names, adapter_weights
870
+
871
+
872
+ # ============================================================
873
+ # UI handlers
874
+ # ============================================================
875
+
876
+
877
+
878
+ def on_lora_change_ui(selected_lora, current_prompt, extras_condition_only):
879
+ prompt_val = current_prompt
880
+ if selected_lora != NONE_LORA:
881
+ preset = LORA_PRESET_PROMPTS.get(selected_lora, "")
882
+ if preset:
883
+ prompt_val = preset
884
+ else:
885
+ prompt_val = "" # CLEAR THE PROMPT IF ACTIVE BUT NO PRESET
886
+
887
+ prompt_update = gr.update(value=prompt_val)
888
+ camera_update = gr.update(visible=(selected_lora == "3D-Camera"))
889
+
890
+ # Image2 visibility/label
891
+ if lora_requires_two_images(selected_lora):
892
+ img2_update = gr.update(visible=True, label=image2_label_for_lora(selected_lora))
893
+ else:
894
+ img2_update = gr.update(visible=False, value=None, label='Upload Reference (Image 2)')
895
+
896
+ # Extra references routing default
897
+ if selected_lora in ('BFS-Best-FaceSwap', 'BFS-Best-FaceSwap-merge', 'AnyPose'):
898
+ extras_update = gr.update(value=True)
899
+ else:
900
+ extras_update = gr.update(value=extras_condition_only)
901
+
902
+ return prompt_update, img2_update, extras_update, camera_update
903
+ # ============================================================
904
+ # UI helpers: output routing + derived conditioning
905
+
906
+ def _append_to_gallery(existing_gallery, new_image):
907
+ if existing_gallery is None:
908
+ return [new_image]
909
+ if not isinstance(existing_gallery, list):
910
+ existing_gallery = [existing_gallery]
911
+ existing_gallery.append(new_image)
912
+ return existing_gallery
913
+
914
+ # ============================================================
915
+
916
+ def set_output_as_image1(last):
917
+ if last is None:
918
+ raise gr.Error("No output available yet.")
919
+ return gr.update(value=last)
920
+
921
+
922
+ def set_output_as_image2(last):
923
+ if last is None:
924
+ raise gr.Error("No output available yet.")
925
+ return gr.update(value=last)
926
+
927
+
928
+ def set_output_as_extra(last, existing_extra):
929
+ if last is None:
930
+ raise gr.Error("No output available yet.")
931
+ return _append_to_gallery(existing_extra, last)
932
+
933
+
934
+ @spaces.GPU
935
+ def add_derived_ref(img1, existing_extra, derived_type, derived_use_gpu):
936
+ if img1 is None:
937
+ raise gr.Error("Please upload Image 1 first.")
938
+
939
+ if derived_type == "None":
940
+ return gr.update(value=existing_extra), gr.update(visible=False, value=None)
941
+
942
+ base = img1.convert("RGB")
943
+
944
+ if derived_type == "Depth (Depth Anything V2 Small)":
945
+ derived = make_depth_map(base, use_gpu=bool(derived_use_gpu))
946
+ else:
947
+ raise gr.Error(f"Unknown derived type: {derived_type}")
948
+
949
+ new_gallery = _append_to_gallery(existing_extra, derived)
950
+ return gr.update(value=new_gallery), gr.update(visible=True, value=derived)
951
+
952
+
953
+ # ============================================================
954
+ # Inference
955
+ # ============================================================
956
+
957
+
958
+ @spaces.GPU
959
+ def infer(
960
+ input_image_1,
961
+ input_image_2,
962
+ input_images_extra, # gallery multi-image box
963
+ prompt,
964
+ lora_adapter,
965
+ seed,
966
+ randomize_seed,
967
+ guidance_scale,
968
+ steps,
969
+ target_megapixels,
970
+ extras_condition_only,
971
+ pad_to_canvas,
972
+ vae_tiling, # VAE tiling toggle
973
+ resolution_multiple,
974
+ vae_ref_megapixels,
975
+ decoder_vae,
976
+ keep_decoder_2x,
977
+ progress=gr.Progress(track_tqdm=True),
978
+ ):
979
+ gc.collect()
980
+ if torch.cuda.is_available():
981
+ torch.cuda.empty_cache()
982
+
983
+ if input_image_1 is None:
984
+ raise gr.Error("Please upload Image 1.")
985
+
986
+ # Handle "None"
987
+ if lora_adapter == NONE_LORA:
988
+ try:
989
+ pipe.set_adapters([], adapter_weights=[])
990
+ except Exception:
991
+ if LOADED_ADAPTERS:
992
+ pipe.set_adapters(list(LOADED_ADAPTERS), adapter_weights=[0.0] * len(LOADED_ADAPTERS))
993
+ else:
994
+ adapter_names, adapter_weights = _ensure_loaded_and_get_active_adapters(lora_adapter)
995
+ pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
996
+
997
+ if randomize_seed:
998
+ seed = random.randint(0, MAX_SEED)
999
+
1000
+ generator = torch.Generator(device=device).manual_seed(seed)
1001
+ negative_prompt = (
1002
+ "worst quality, low quality, bad anatomy, bad hands, text, error, missing fingers, "
1003
+ "extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"
1004
+ )
1005
+
1006
+ img1 = input_image_1.convert("RGB")
1007
+ img2 = input_image_2.convert("RGB") if input_image_2 is not None else None
1008
+
1009
+ # Normalize extra images (Gallery) to PIL RGB (handles tuples from Gallery)
1010
+ extra_imgs: list[Image.Image] = []
1011
+ if input_images_extra:
1012
+ for item in input_images_extra:
1013
+ pil = _to_pil_rgb(item)
1014
+ if pil is not None:
1015
+ extra_imgs.append(pil)
1016
+
1017
+ # Enforce existing 2-image LoRA behavior (image_1 + image_2 required)
1018
+ if lora_requires_two_images(lora_adapter) and img2 is None:
1019
+ raise gr.Error("This LoRA needs two images. Please upload Image 2 as well.")
1020
+
1021
+ # Label images as image_1, image_2, image_3...
1022
+ labeled = build_labeled_images(img1, img2, extra_imgs)
1023
+
1024
+ # Pass to pipeline in labeled order. Keep single-image call when only one is present.
1025
+ pipe_images = list(labeled.values())
1026
+ if len(pipe_images) == 1:
1027
+ pipe_images = pipe_images[0]
1028
+
1029
+ # Resolution derived from Image 1 (base/body/target)
1030
+ # Use target *area* (≈ megapixels) rather than long-edge sizing to reduce FOV drift.
1031
+ target_area = get_target_area_for_lora(img1, lora_adapter, float(target_megapixels))
1032
+ width, height = compute_canvas_dimensions_from_area(
1033
+ img1,
1034
+ target_area=target_area,
1035
+ multiple_of=int(resolution_multiple),
1036
+ )
1037
+
1038
+ # Decide which images participate in the VAE latent stream.
1039
+ # If enabled, extra references beyond (Img_1, Img_2) become conditioning-only.
1040
+ vae_image_indices = None
1041
+ if extras_condition_only:
1042
+ if isinstance(pipe_images, list) and len(pipe_images) > 2:
1043
+ vae_image_indices = [0, 1] if len(pipe_images) >= 2 else [0]
1044
+
1045
+ try:
1046
+ print(
1047
+ "[DEBUG][infer] submitting request | "
1048
+ f"lora_adapter={lora_adapter!r} seed={seed} prompt={prompt!r}"
1049
+ )
1050
+ print(f"[DEBUG][infer] canvas={width}x{height} (~{(width*height)/1_048_576:.3f} MP) vae_tiling={bool(vae_tiling)}")
1051
+
1052
+ # ✅ Apply UI toggle per-request (OFF by default)
1053
+ # Lattice multiple passed to pipeline too (anti-drift / valid size grid)
1054
+ res_mult = int(resolution_multiple) if resolution_multiple is not None else int(pipe.vae_scale_factor * 2)
1055
+
1056
+ # Optional: override VAE sizing for *extra* references (beyond Image 1 / Image 2)
1057
+ # Interpreted as megapixels; 0 disables override (uses canvas).
1058
+ try:
1059
+ mp_ref = float(vae_ref_megapixels)
1060
+ except Exception:
1061
+ mp_ref = 0.0
1062
+
1063
+ vae_ref_area = int(mp_ref * 1024 * 1024) if mp_ref and mp_ref > 0 else None
1064
+
1065
+ # Extras start index depends on whether Image 2 exists
1066
+ base_ref_count = 2 if img2 is not None else 1
1067
+
1068
+ _apply_vae_tiling(bool(vae_tiling))
1069
+
1070
+ result = pipe(
1071
+ image=pipe_images,
1072
+ prompt=prompt,
1073
+ negative_prompt=negative_prompt,
1074
+ height=height,
1075
+ width=width,
1076
+ num_inference_steps=steps,
1077
+ generator=generator,
1078
+ true_cfg_scale=guidance_scale,
1079
+ vae_image_indices=vae_image_indices,
1080
+ pad_to_canvas=bool(pad_to_canvas),
1081
+ resolution_multiple=res_mult,
1082
+ vae_ref_area=vae_ref_area,
1083
+ vae_ref_start_index=base_ref_count,
1084
+ decoder_vae=str(decoder_vae).lower(),
1085
+ keep_decoder_2x=bool(keep_decoder_2x),
1086
+ ).images[0]
1087
+ return result, seed, result
1088
+ finally:
1089
+ gc.collect()
1090
+ if torch.cuda.is_available():
1091
+ torch.cuda.empty_cache()
1092
+
1093
+
1094
+ @spaces.GPU
1095
+ def infer_example(input_image, prompt, lora_adapter):
1096
+ if input_image is None:
1097
+ return None, 0, None
1098
+ input_pil = input_image.convert("RGB")
1099
+ guidance_scale = 1.0
1100
+ steps = 4
1101
+ # Examples don't supply Image 2 or extra images; and example list doesn't include AnyPose/BFS.
1102
+ # Keep VAE tiling OFF in examples (matches default).
1103
+ result, seed, last = infer(
1104
+ input_pil,
1105
+ None,
1106
+ None,
1107
+ prompt,
1108
+ lora_adapter,
1109
+ 0,
1110
+ True,
1111
+ guidance_scale,
1112
+ steps,
1113
+ 1.0,
1114
+ True,
1115
+ True,
1116
+ False, # vae_tiling
1117
+ )
1118
+ return result, seed, last
1119
+
1120
+
1121
+ # ============================================================
1122
+ # UI
1123
+ # ============================================================
1124
+
1125
+ css = """
1126
+ #col-container {
1127
+ margin: 0 auto;
1128
+ max-width: 960px;
1129
+ }
1130
+ #main-title h1 {font-size: 2.1em !important;}
1131
+ """
1132
+
1133
+ aio_status_line = (
1134
+ f"**AIO transformer version:** `{AIO_VERSION}` "
1135
+ f"({AIO_VERSION_SOURCE}; env `AIO_VERSION`={_AIO_ENV_RAW!r})"
1136
+ )
1137
+
1138
+ with gr.Blocks() as demo:
1139
+ with gr.Column(elem_id="col-container"):
1140
+ gr.Markdown("# **Qwen-Image-Edit-2511-LoRAs-Fast**", elem_id="main-title")
1141
+ gr.Markdown(
1142
+ f"""This **experimental** space for [QIE-2511](https://huggingface.co/Qwen/Qwen-Image-Edit-2511) utilizes [extracted transformers](https://huggingface.co/Pr0f3ssi0n4ln00b/Phr00t-Qwen-Rapid-AIO) of [Phr00t’s Rapid AIO merge](https://huggingface.co/Phr00t/Qwen-Image-Edit-Rapid-AIO) and FA3-optimization with [LoRA](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image-Edit-2511) support and a couple of extra features:
1143
+
1144
+ - Optional conditioning-only routing for extra reference latents
1145
+ - Uncapped canvas resolution
1146
+ - Optional VAE tiling for high resolutions
1147
+ - Optional depth mapping for conditioning
1148
+ - Optional routing of output to input for further iterations
1149
+ - Optional alternative decoder [VAE](https://huggingface.co/spacepxl/Wan2.1-VAE-upscale2x/tree/main/diffusers/Wan2.1_VAE_upscale2x_imageonly_real_v1)
1150
+
1151
+ Current environment is running **{AIO_VERSION}** of the Rapid AIO. Duplicate the space and set the **AIO_VERSION** space variable to use a different version."""
1152
+ )
1153
+ gr.Markdown(aio_status_line)
1154
+
1155
+ with gr.Row(equal_height=True):
1156
+ with gr.Column():
1157
+ input_image_1 = gr.Image(label="Upload Image 1 (Base / Target)", type="pil", )
1158
+
1159
+ input_image_2 = gr.Image(label="Upload Reference (Image 2)", type="pil", height=290, visible=False)
1160
+
1161
+ with gr.Column(visible=False) as camera_container:
1162
+ gr.Markdown("### 🎮 3D Camera Control\n*Drag handles: 🟢 Azimuth, 🩷 Elevation, 🟠 Distance*")
1163
+ camera_3d = CameraControl3D(value={"azimuth": 0, "elevation": 0, "distance": 1.0}, elem_id="camera-3d-control")
1164
+ gr.Markdown("### 🎚️ Slider Controls")
1165
+ azimuth_slider = gr.Slider(label="Azimuth", minimum=0, maximum=315, step=45, value=0, info="0°=front, 90°=right, 180°=back, 270°=left")
1166
+ elevation_slider = gr.Slider(label="Elevation", minimum=-30, maximum=60, step=30, value=0, info="-30°=low angle, 0°=eye, 60°=high angle")
1167
+ distance_slider = gr.Slider(label="Distance", minimum=0.6, maximum=1.4, step=0.4, value=1.0, info="0.6=close, 1.0=medium, 1.4=wide")
1168
+
1169
+
1170
+ input_images_extra = gr.Gallery(
1171
+ label="Upload Additional Images (auto-indexed after Image 1/2)",
1172
+ type="pil",
1173
+ height=290,
1174
+ columns=4,
1175
+ rows=2,
1176
+ interactive=True,
1177
+ )
1178
+
1179
+ prompt = gr.Text(
1180
+ label="Edit Prompt",
1181
+ show_label=True,
1182
+ placeholder="e.g., transform into photo..",
1183
+ )
1184
+
1185
+ run_button = gr.Button("Edit Image", variant="primary")
1186
+
1187
+ with gr.Column():
1188
+ output_image = gr.Image(label="Output Image", interactive=False, format="png", height=353)
1189
+
1190
+ last_output = gr.State(value=None)
1191
+
1192
+ with gr.Row():
1193
+ btn_out_to_img1 = gr.Button("⬅️ Output → Image 1", variant="secondary")
1194
+ btn_out_to_img2 = gr.Button("⬅️ Output → Image 2", variant="secondary")
1195
+ btn_out_to_extra = gr.Button("➕ Output → Extra Ref", variant="secondary")
1196
+
1197
+ derived_preview = gr.Image(
1198
+ label="Derived Conditioning Preview",
1199
+ interactive=False,
1200
+ format="png",
1201
+ height=200,
1202
+ visible=False,
1203
+ )
1204
+
1205
+ with gr.Row():
1206
+ lora_choices = [NONE_LORA] + list(ADAPTER_SPECS.keys())
1207
+ lora_adapter = gr.Dropdown(
1208
+ label="Choose Editing Style",
1209
+ choices=lora_choices,
1210
+ value=NONE_LORA,
1211
+ )
1212
+
1213
+ with gr.Accordion("Advanced Settings", open=False, visible=True):
1214
+ with gr.Accordion("Derived Conditioning (Pose / Depth)", open=False):
1215
+ derived_type = gr.Dropdown(
1216
+ label="Derived Type (from Image 1)",
1217
+ choices=["None", "Depth (Depth Anything V2 Small)"],
1218
+ value="None",
1219
+ )
1220
+ derived_use_gpu = gr.Checkbox(label="Use GPU for derived model", value=False)
1221
+ add_derived_btn = gr.Button("➕ Add derived ref to Extras (conditioning-only recommended)")
1222
+
1223
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
1224
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
1225
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
1226
+ steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=4)
1227
+ target_megapixels = gr.Slider(
1228
+ label="Target Megapixels (canvas, 0 = match input area)",
1229
+ minimum=0.0,
1230
+ maximum=6.0,
1231
+ step=0.1,
1232
+ value=1.0,
1233
+ )
1234
+ resolution_multiple = gr.Dropdown(
1235
+ label="Resolution lattice multiple (anti-drift)",
1236
+ choices=[32, 56, 112],
1237
+ value=32,
1238
+ interactive=True,
1239
+ )
1240
+ vae_ref_megapixels = gr.Slider(
1241
+ label="Extra refs VAE megapixels override (0 = use canvas)",
1242
+ minimum=0.0,
1243
+ maximum=6.0,
1244
+ step=0.1,
1245
+ value=0.0,
1246
+ )
1247
+ decoder_vae = gr.Dropdown(
1248
+ label="Decoder VAE",
1249
+ choices=["qwen", "wan2x"],
1250
+ value="qwen",
1251
+ interactive=True,
1252
+ )
1253
+ keep_decoder_2x = gr.Checkbox(
1254
+ label="Keep 2× output (wan2x only)",
1255
+ value=False,
1256
+ )
1257
+ extras_condition_only = gr.Checkbox(
1258
+ label="Extra references are conditioning-only (exclude from VAE)",
1259
+ value=True,
1260
+ )
1261
+ pad_to_canvas = gr.Checkbox(
1262
+ label="Pad images to canvas aspect (avoid warping)",
1263
+ value=True,
1264
+ )
1265
+
1266
+ # ✅ NEW: VAE tiling toggle (OFF by default)
1267
+ vae_tiling = gr.Checkbox(
1268
+ label="VAE tiling (lower VRAM, slower)",
1269
+ value=False,
1270
+ )
1271
+
1272
+ # On LoRA selection: preset prompt + toggle Image 2
1273
+ lora_adapter.change(
1274
+ fn=on_lora_change_ui,
1275
+ inputs=[lora_adapter, prompt, extras_condition_only],
1276
+ outputs=[prompt, input_image_2, extras_condition_only, camera_container],
1277
+ )
1278
+
1279
+ # Examples removed automatically by setup_manager
1280
+
1281
+
1282
+ # --- 3D Camera Events ---
1283
+ def update_prompt_from_sliders(az, el, dist, curr_prompt):
1284
+ return update_prompt_with_camera(az, el, dist, curr_prompt)
1285
+
1286
+ def sync_3d_to_sliders(cv, curr_prompt):
1287
+ if cv and isinstance(cv, dict):
1288
+ az = cv.get('azimuth', 0)
1289
+ el = cv.get('elevation', 0)
1290
+ dist = cv.get('distance', 1.0)
1291
+ return az, el, dist, update_prompt_with_camera(az, el, dist, curr_prompt)
1292
+ return gr.update(), gr.update(), gr.update(), gr.update()
1293
+
1294
+ def sync_sliders_to_3d(az, el, dist):
1295
+ return {"azimuth": az, "elevation": el, "distance": dist}
1296
+
1297
+
1298
+ def update_3d_image(img):
1299
+ if img is None: return gr.update(imageUrl=None)
1300
+ import base64
1301
+ from io import BytesIO
1302
+ buf = BytesIO()
1303
+ img.save(buf, format="PNG")
1304
+ durl = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
1305
+ return gr.update(imageUrl=durl)
1306
+
1307
+ for slider in [azimuth_slider, elevation_slider, distance_slider]:
1308
+ slider.change(fn=update_prompt_from_sliders, inputs=[azimuth_slider, elevation_slider, distance_slider, prompt], outputs=[prompt])
1309
+ slider.release(fn=sync_sliders_to_3d, inputs=[azimuth_slider, elevation_slider, distance_slider], outputs=[camera_3d])
1310
+
1311
+ camera_3d.change(fn=sync_3d_to_sliders, inputs=[camera_3d, prompt], outputs=[azimuth_slider, elevation_slider, distance_slider, prompt])
1312
+
1313
+ input_image_1.upload(fn=update_3d_image, inputs=[input_image_1], outputs=[camera_3d])
1314
+ input_image_1.clear(fn=lambda: gr.update(imageUrl=None), outputs=[camera_3d])
1315
+
1316
+ run_button.click(
1317
+ fn=infer,
1318
+ inputs=[
1319
+ input_image_1,
1320
+ input_image_2,
1321
+ input_images_extra,
1322
+ prompt,
1323
+ lora_adapter,
1324
+ seed,
1325
+ randomize_seed,
1326
+ guidance_scale,
1327
+ steps,
1328
+ target_megapixels,
1329
+ extras_condition_only,
1330
+ pad_to_canvas,
1331
+ vae_tiling,
1332
+ resolution_multiple,
1333
+ vae_ref_megapixels,
1334
+ decoder_vae,
1335
+ keep_decoder_2x,
1336
+ ],
1337
+ outputs=[output_image, seed, last_output],
1338
+ )
1339
+
1340
+ # Output routing buttons
1341
+ btn_out_to_img1.click(fn=set_output_as_image1, inputs=[last_output], outputs=[input_image_1])
1342
+ btn_out_to_img2.click(fn=set_output_as_image2, inputs=[last_output], outputs=[input_image_2])
1343
+ btn_out_to_extra.click(fn=set_output_as_extra, inputs=[last_output, input_images_extra], outputs=[input_images_extra])
1344
+
1345
+ # Derived conditioning: append pose/depth map as extra ref (UI shows preview)
1346
+ add_derived_btn.click(
1347
+ fn=add_derived_ref,
1348
+ inputs=[input_image_1, input_images_extra, derived_type, derived_use_gpu],
1349
+ outputs=[input_images_extra, derived_preview],
1350
+ )
1351
+
1352
+ if __name__ == "__main__":
1353
+ head = '<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>'
1354
+ demo.queue(max_size=30).launch(head=head, server_name="0.0.0.0", share=True,
1355
+ css=css,
1356
+ theme=orange_red_theme,
1357
+ mcp_server=True,
1358
+ ssr_mode=False,
1359
+ show_error=True,
1360
+ )
1361
+
1362
+ # Manual Patch for missing prompts
1363
+ try:
1364
+ LORA_PRESET_PROMPTS.update({
1365
+ "Consistance": "improve consistency and quality of the generated image",
1366
+ "F2P": "transform the image into a high-quality photo with realistic details",
1367
+ "Multiple-Angles": "change the camera angle of the image",
1368
+ "Light-Restoration": "Remove shadows and relight the image using soft lighting",
1369
+ "Relight": "Relight the image with cinematic lighting",
1370
+ "Multi-Angle-Lighting": "Change the lighting direction and intensity",
1371
+ "Edit-Skin": "Enhance skin textures and natural details",
1372
+ "Next-Scene": "Generate the next scene based on the current image",
1373
+ "Flat-Log": "Desaturate and lower contrast for a flat log look",
1374
+ "Upscale-Image": "Enhance and sharpen the image details",
1375
+ "BFS-Best-FaceSwap": "head_swap : start with Picture 1 as the base image, keeping its lighting, environment, and background. remove the head from Picture 1 completely and replace it with the head from Picture 2, strictly preserving the hair, eye color, and nose structure, mouth, lips and front head of Picture 2. copy the eye direction, head rotation, and micro-expressions from Picture 1. high quality, sharp details, 4k",
1376
+ "BFS-Best-FaceSwap-merge": "head_swap : start with Picture 1 as the base image, keeping its lighting, environment, and background. remove the head from Picture 1 completely and replace it with the head from Picture 2, strictly preserving the hair, eye color, and nose structure, mouth, lips and front head of Picture 2. copy the eye direction, head rotation, and micro-expressions from Picture 1. high quality, sharp details, 4k",
1377
+ "Qwen-lora-nsfw": "Convert this picture to artistic style.",
1378
+ })
1379
+ except NameError:
1380
+ pass
camera_control_ui.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ # Azimuth mappings (8 positions)
4
+ AZIMUTH_MAP = {
5
+ 0: "front view",
6
+ 45: "front-right quarter view",
7
+ 90: "right side view",
8
+ 135: "back-right quarter view",
9
+ 180: "back view",
10
+ 225: "back-left quarter view",
11
+ 270: "left side view",
12
+ 315: "front-left quarter view"
13
+ }
14
+
15
+ # Elevation mappings (4 positions)
16
+ ELEVATION_MAP = {
17
+ -30: "low-angle shot",
18
+ 0: "eye-level shot",
19
+ 30: "elevated shot",
20
+ 60: "high-angle shot"
21
+ }
22
+
23
+ # Distance mappings (3 positions)
24
+ DISTANCE_MAP = {
25
+ 0.6: "close-up",
26
+ 1.0: "medium shot",
27
+ 1.8: "wide shot"
28
+ }
29
+
30
+
31
+ def snap_to_nearest(value, options):
32
+ """Snap a value to the nearest option in a list."""
33
+ return min(options, key=lambda x: abs(x - value))
34
+
35
+
36
+ def build_camera_prompt(azimuth: float, elevation: float, distance: float) -> str:
37
+ """
38
+ Build a camera prompt from azimuth, elevation, and distance values.
39
+
40
+ Args:
41
+ azimuth: Horizontal rotation in degrees (0-360)
42
+ elevation: Vertical angle in degrees (-30 to 60)
43
+ distance: Distance factor (0.6 to 1.8)
44
+
45
+ Returns:
46
+ Formatted prompt string for the LoRA
47
+ """
48
+ # Snap to nearest valid values
49
+ azimuth_snapped = snap_to_nearest(azimuth, list(AZIMUTH_MAP.keys()))
50
+ elevation_snapped = snap_to_nearest(elevation, list(ELEVATION_MAP.keys()))
51
+ distance_snapped = snap_to_nearest(distance, list(DISTANCE_MAP.keys()))
52
+
53
+ azimuth_name = AZIMUTH_MAP[azimuth_snapped]
54
+ elevation_name = ELEVATION_MAP[elevation_snapped]
55
+ distance_name = DISTANCE_MAP[distance_snapped]
56
+
57
+ return f"<sks> {azimuth_name} {elevation_name} {distance_name}"
58
+
59
+ def update_prompt_with_camera(azimuth: float, elevation: float, distance: float, current_prompt: str) -> str:
60
+ """
61
+ Updates the existing prompt by replacing or appending the camera trigger words.
62
+ """
63
+ import re
64
+ camera_str = build_camera_prompt(azimuth, elevation, distance)
65
+
66
+ if not current_prompt:
67
+ return camera_str
68
+
69
+ # Remove any existing <sks> ... shot tags
70
+ # The pattern matches <sks> followed by any characters until the word "shot"
71
+ clean_prompt = re.sub(r"<sks>.*?shot(?!.*shot)", "", current_prompt).strip()
72
+
73
+ # Clean up multiple spaces
74
+ clean_prompt = re.sub(r"\s+", " ", clean_prompt)
75
+
76
+ if clean_prompt:
77
+ return f"{clean_prompt} {camera_str}"
78
+ return camera_str
79
+
80
+
81
+
82
+ # --- 3D Camera Control Component ---
83
+ class CameraControl3D(gr.HTML):
84
+ """
85
+ A 3D camera control component using Three.js.
86
+ Outputs: { azimuth: number, elevation: number, distance: number }
87
+ Accepts imageUrl prop to display user's uploaded image on the plane.
88
+ """
89
+ def __init__(self, value=None, imageUrl=None, **kwargs):
90
+ if value is None:
91
+ value = {"azimuth": 0, "elevation": 0, "distance": 1.0}
92
+
93
+ html_template = """
94
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
95
+ <div id="camera-control-wrapper" style="width: 100%; height: 450px; position: relative; background: #1a1a1a; border-radius: 12px; overflow: hidden;">
96
+ <div id="prompt-overlay" style="position: absolute; bottom: 10px; left: 50%; transform: translateX(-50%); background: rgba(0,0,0,0.8); padding: 8px 16px; border-radius: 8px; font-family: monospace; font-size: 12px; color: #00ff88; white-space: nowrap; z-index: 10;"></div>
97
+ </div>
98
+ """
99
+
100
+ js_on_load = """
101
+ (() => {
102
+ const wrapper = element.querySelector('#camera-control-wrapper');
103
+ const promptOverlay = element.querySelector('#prompt-overlay');
104
+
105
+ // Wait for THREE to load
106
+ const initScene = () => {
107
+ if (typeof THREE === 'undefined') {
108
+ setTimeout(initScene, 100);
109
+ return;
110
+ }
111
+
112
+ // Scene setup
113
+ const scene = new THREE.Scene();
114
+ scene.background = new THREE.Color(0x1a1a1a);
115
+
116
+ const camera = new THREE.PerspectiveCamera(50, wrapper.clientWidth / wrapper.clientHeight, 0.1, 1000);
117
+ camera.position.set(4.5, 3, 4.5);
118
+ camera.lookAt(0, 0.75, 0);
119
+
120
+ const renderer = new THREE.WebGLRenderer({ antialias: true });
121
+ renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
122
+ renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
123
+ wrapper.insertBefore(renderer.domElement, promptOverlay);
124
+
125
+ // Lighting
126
+ scene.add(new THREE.AmbientLight(0xffffff, 0.6));
127
+ const dirLight = new THREE.DirectionalLight(0xffffff, 0.6);
128
+ dirLight.position.set(5, 10, 5);
129
+ scene.add(dirLight);
130
+
131
+ // Grid
132
+ scene.add(new THREE.GridHelper(8, 16, 0x333333, 0x222222));
133
+
134
+ // Constants - reduced distances for tighter framing
135
+ const CENTER = new THREE.Vector3(0, 0.75, 0);
136
+ const BASE_DISTANCE = 1.6;
137
+ const AZIMUTH_RADIUS = 2.4;
138
+ const ELEVATION_RADIUS = 1.8;
139
+
140
+ // State
141
+ let azimuthAngle = props.value?.azimuth || 0;
142
+ let elevationAngle = props.value?.elevation || 0;
143
+ let distanceFactor = props.value?.distance || 1.0;
144
+
145
+ // Mappings - reduced wide shot multiplier
146
+ const azimuthSteps = [0, 45, 90, 135, 180, 225, 270, 315];
147
+ const elevationSteps = [-30, 0, 30, 60];
148
+ const distanceSteps = [0.6, 1.0, 1.4];
149
+
150
+ const azimuthNames = {
151
+ 0: 'front view', 45: 'front-right quarter view', 90: 'right side view',
152
+ 135: 'back-right quarter view', 180: 'back view', 225: 'back-left quarter view',
153
+ 270: 'left side view', 315: 'front-left quarter view'
154
+ };
155
+ const elevationNames = { '-30': 'low-angle shot', '0': 'eye-level shot', '30': 'elevated shot', '60': 'high-angle shot' };
156
+ const distanceNames = { '0.6': 'close-up', '1': 'medium shot', '1.4': 'wide shot' };
157
+
158
+ function snapToNearest(value, steps) {
159
+ return steps.reduce((prev, curr) => Math.abs(curr - value) < Math.abs(prev - value) ? curr : prev);
160
+ }
161
+
162
+ // Create placeholder texture (smiley face)
163
+ function createPlaceholderTexture() {
164
+ const canvas = document.createElement('canvas');
165
+ canvas.width = 256;
166
+ canvas.height = 256;
167
+ const ctx = canvas.getContext('2d');
168
+ ctx.fillStyle = '#3a3a4a';
169
+ ctx.fillRect(0, 0, 256, 256);
170
+ ctx.fillStyle = '#ffcc99';
171
+ ctx.beginPath();
172
+ ctx.arc(128, 128, 80, 0, Math.PI * 2);
173
+ ctx.fill();
174
+ ctx.fillStyle = '#333';
175
+ ctx.beginPath();
176
+ ctx.arc(100, 110, 10, 0, Math.PI * 2);
177
+ ctx.arc(156, 110, 10, 0, Math.PI * 2);
178
+ ctx.fill();
179
+ ctx.strokeStyle = '#333';
180
+ ctx.lineWidth = 3;
181
+ ctx.beginPath();
182
+ ctx.arc(128, 130, 35, 0.2, Math.PI - 0.2);
183
+ ctx.stroke();
184
+ return new THREE.CanvasTexture(canvas);
185
+ }
186
+
187
+ // Target image plane
188
+ let currentTexture = createPlaceholderTexture();
189
+ const planeMaterial = new THREE.MeshBasicMaterial({ map: currentTexture, side: THREE.DoubleSide });
190
+ let targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
191
+ targetPlane.position.copy(CENTER);
192
+ scene.add(targetPlane);
193
+
194
+ // Function to update texture from image URL
195
+ function updateTextureFromUrl(url) {
196
+ if (!url) {
197
+ // Reset to placeholder
198
+ planeMaterial.map = createPlaceholderTexture();
199
+ planeMaterial.needsUpdate = true;
200
+ // Reset plane to square
201
+ scene.remove(targetPlane);
202
+ targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
203
+ targetPlane.position.copy(CENTER);
204
+ scene.add(targetPlane);
205
+ return;
206
+ }
207
+
208
+ const loader = new THREE.TextureLoader();
209
+ loader.crossOrigin = 'anonymous';
210
+ loader.load(url, (texture) => {
211
+ texture.minFilter = THREE.LinearFilter;
212
+ texture.magFilter = THREE.LinearFilter;
213
+ planeMaterial.map = texture;
214
+ planeMaterial.needsUpdate = true;
215
+
216
+ // Adjust plane aspect ratio to match image
217
+ const img = texture.image;
218
+ if (img && img.width && img.height) {
219
+ const aspect = img.width / img.height;
220
+ const maxSize = 1.5;
221
+ let planeWidth, planeHeight;
222
+ if (aspect > 1) {
223
+ planeWidth = maxSize;
224
+ planeHeight = maxSize / aspect;
225
+ } else {
226
+ planeHeight = maxSize;
227
+ planeWidth = maxSize * aspect;
228
+ }
229
+ scene.remove(targetPlane);
230
+ targetPlane = new THREE.Mesh(
231
+ new THREE.PlaneGeometry(planeWidth, planeHeight),
232
+ planeMaterial
233
+ );
234
+ targetPlane.position.copy(CENTER);
235
+ scene.add(targetPlane);
236
+ }
237
+ }, undefined, (err) => {
238
+ console.error('Failed to load texture:', err);
239
+ });
240
+ }
241
+
242
+ // Check for initial imageUrl
243
+ if (props.imageUrl) {
244
+ updateTextureFromUrl(props.imageUrl);
245
+ }
246
+
247
+ // Camera model
248
+ const cameraGroup = new THREE.Group();
249
+ const bodyMat = new THREE.MeshStandardMaterial({ color: 0x6699cc, metalness: 0.5, roughness: 0.3 });
250
+ const body = new THREE.Mesh(new THREE.BoxGeometry(0.3, 0.22, 0.38), bodyMat);
251
+ cameraGroup.add(body);
252
+ const lens = new THREE.Mesh(
253
+ new THREE.CylinderGeometry(0.09, 0.11, 0.18, 16),
254
+ new THREE.MeshStandardMaterial({ color: 0x6699cc, metalness: 0.5, roughness: 0.3 })
255
+ );
256
+ lens.rotation.x = Math.PI / 2;
257
+ lens.position.z = 0.26;
258
+ cameraGroup.add(lens);
259
+ scene.add(cameraGroup);
260
+
261
+ // GREEN: Azimuth ring
262
+ const azimuthRing = new THREE.Mesh(
263
+ new THREE.TorusGeometry(AZIMUTH_RADIUS, 0.04, 16, 64),
264
+ new THREE.MeshStandardMaterial({ color: 0x00ff88, emissive: 0x00ff88, emissiveIntensity: 0.3 })
265
+ );
266
+ azimuthRing.rotation.x = Math.PI / 2;
267
+ azimuthRing.position.y = 0.05;
268
+ scene.add(azimuthRing);
269
+
270
+ const azimuthHandle = new THREE.Mesh(
271
+ new THREE.SphereGeometry(0.18, 16, 16),
272
+ new THREE.MeshStandardMaterial({ color: 0x00ff88, emissive: 0x00ff88, emissiveIntensity: 0.5 })
273
+ );
274
+ azimuthHandle.userData.type = 'azimuth';
275
+ scene.add(azimuthHandle);
276
+
277
+ // PINK: Elevation arc
278
+ const arcPoints = [];
279
+ for (let i = 0; i <= 32; i++) {
280
+ const angle = THREE.MathUtils.degToRad(-30 + (90 * i / 32));
281
+ arcPoints.push(new THREE.Vector3(-0.8, ELEVATION_RADIUS * Math.sin(angle) + CENTER.y, ELEVATION_RADIUS * Math.cos(angle)));
282
+ }
283
+ const arcCurve = new THREE.CatmullRomCurve3(arcPoints);
284
+ const elevationArc = new THREE.Mesh(
285
+ new THREE.TubeGeometry(arcCurve, 32, 0.04, 8, false),
286
+ new THREE.MeshStandardMaterial({ color: 0xff69b4, emissive: 0xff69b4, emissiveIntensity: 0.3 })
287
+ );
288
+ scene.add(elevationArc);
289
+
290
+ const elevationHandle = new THREE.Mesh(
291
+ new THREE.SphereGeometry(0.18, 16, 16),
292
+ new THREE.MeshStandardMaterial({ color: 0xff69b4, emissive: 0xff69b4, emissiveIntensity: 0.5 })
293
+ );
294
+ elevationHandle.userData.type = 'elevation';
295
+ scene.add(elevationHandle);
296
+
297
+ // ORANGE: Distance line & handle
298
+ const distanceLineGeo = new THREE.BufferGeometry();
299
+ const distanceLine = new THREE.Line(distanceLineGeo, new THREE.LineBasicMaterial({ color: 0xffa500 }));
300
+ scene.add(distanceLine);
301
+
302
+ const distanceHandle = new THREE.Mesh(
303
+ new THREE.SphereGeometry(0.18, 16, 16),
304
+ new THREE.MeshStandardMaterial({ color: 0xffa500, emissive: 0xffa500, emissiveIntensity: 0.5 })
305
+ );
306
+ distanceHandle.userData.type = 'distance';
307
+ scene.add(distanceHandle);
308
+
309
+ function updatePositions() {
310
+ const distance = BASE_DISTANCE * distanceFactor;
311
+ const azRad = THREE.MathUtils.degToRad(azimuthAngle);
312
+ const elRad = THREE.MathUtils.degToRad(elevationAngle);
313
+
314
+ const camX = distance * Math.sin(azRad) * Math.cos(elRad);
315
+ const camY = distance * Math.sin(elRad) + CENTER.y;
316
+ const camZ = distance * Math.cos(azRad) * Math.cos(elRad);
317
+
318
+ cameraGroup.position.set(camX, camY, camZ);
319
+ cameraGroup.lookAt(CENTER);
320
+
321
+ azimuthHandle.position.set(AZIMUTH_RADIUS * Math.sin(azRad), 0.05, AZIMUTH_RADIUS * Math.cos(azRad));
322
+ elevationHandle.position.set(-0.8, ELEVATION_RADIUS * Math.sin(elRad) + CENTER.y, ELEVATION_RADIUS * Math.cos(elRad));
323
+
324
+ const orangeDist = distance - 0.5;
325
+ distanceHandle.position.set(
326
+ orangeDist * Math.sin(azRad) * Math.cos(elRad),
327
+ orangeDist * Math.sin(elRad) + CENTER.y,
328
+ orangeDist * Math.cos(azRad) * Math.cos(elRad)
329
+ );
330
+ distanceLineGeo.setFromPoints([cameraGroup.position.clone(), CENTER.clone()]);
331
+
332
+ // Update prompt
333
+ const azSnap = snapToNearest(azimuthAngle, azimuthSteps);
334
+ const elSnap = snapToNearest(elevationAngle, elevationSteps);
335
+ const distSnap = snapToNearest(distanceFactor, distanceSteps);
336
+ const distKey = distSnap === 1 ? '1' : distSnap.toFixed(1);
337
+ const prompt = '<sks> ' + azimuthNames[azSnap] + ' ' + elevationNames[String(elSnap)] + ' ' + distanceNames[distKey];
338
+ promptOverlay.textContent = prompt;
339
+ }
340
+
341
+ function updatePropsAndTrigger() {
342
+ const azSnap = snapToNearest(azimuthAngle, azimuthSteps);
343
+ const elSnap = snapToNearest(elevationAngle, elevationSteps);
344
+ const distSnap = snapToNearest(distanceFactor, distanceSteps);
345
+
346
+ props.value = { azimuth: azSnap, elevation: elSnap, distance: distSnap };
347
+ trigger('change', props.value);
348
+ }
349
+
350
+ // Raycasting
351
+ const raycaster = new THREE.Raycaster();
352
+ const mouse = new THREE.Vector2();
353
+ let isDragging = false;
354
+ let dragTarget = null;
355
+ let dragStartMouse = new THREE.Vector2();
356
+ let dragStartDistance = 1.0;
357
+ const intersection = new THREE.Vector3();
358
+
359
+ const canvas = renderer.domElement;
360
+
361
+ canvas.addEventListener('mousedown', (e) => {
362
+ const rect = canvas.getBoundingClientRect();
363
+ mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
364
+ mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
365
+
366
+ raycaster.setFromCamera(mouse, camera);
367
+ const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle, distanceHandle]);
368
+
369
+ if (intersects.length > 0) {
370
+ isDragging = true;
371
+ dragTarget = intersects[0].object;
372
+ dragTarget.material.emissiveIntensity = 1.0;
373
+ dragTarget.scale.setScalar(1.3);
374
+ dragStartMouse.copy(mouse);
375
+ dragStartDistance = distanceFactor;
376
+ canvas.style.cursor = 'grabbing';
377
+ }
378
+ });
379
+
380
+ canvas.addEventListener('mousemove', (e) => {
381
+ const rect = canvas.getBoundingClientRect();
382
+ mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
383
+ mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
384
+
385
+ if (isDragging && dragTarget) {
386
+ raycaster.setFromCamera(mouse, camera);
387
+
388
+ if (dragTarget.userData.type === 'azimuth') {
389
+ const plane = new THREE.Plane(new THREE.Vector3(0, 1, 0), -0.05);
390
+ if (raycaster.ray.intersectPlane(plane, intersection)) {
391
+ azimuthAngle = THREE.MathUtils.radToDeg(Math.atan2(intersection.x, intersection.z));
392
+ if (azimuthAngle < 0) azimuthAngle += 360;
393
+ }
394
+ } else if (dragTarget.userData.type === 'elevation') {
395
+ const plane = new THREE.Plane(new THREE.Vector3(1, 0, 0), -0.8);
396
+ if (raycaster.ray.intersectPlane(plane, intersection)) {
397
+ const relY = intersection.y - CENTER.y;
398
+ const relZ = intersection.z;
399
+ elevationAngle = THREE.MathUtils.clamp(THREE.MathUtils.radToDeg(Math.atan2(relY, relZ)), -30, 60);
400
+ }
401
+ } else if (dragTarget.userData.type === 'distance') {
402
+ const deltaY = mouse.y - dragStartMouse.y;
403
+ distanceFactor = THREE.MathUtils.clamp(dragStartDistance - deltaY * 1.5, 0.6, 1.4);
404
+ }
405
+ updatePositions();
406
+ } else {
407
+ raycaster.setFromCamera(mouse, camera);
408
+ const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle, distanceHandle]);
409
+ [azimuthHandle, elevationHandle, distanceHandle].forEach(h => {
410
+ h.material.emissiveIntensity = 0.5;
411
+ h.scale.setScalar(1);
412
+ });
413
+ if (intersects.length > 0) {
414
+ intersects[0].object.material.emissiveIntensity = 0.8;
415
+ intersects[0].object.scale.setScalar(1.1);
416
+ canvas.style.cursor = 'grab';
417
+ } else {
418
+ canvas.style.cursor = 'default';
419
+ }
420
+ }
421
+ });
422
+
423
+ const onMouseUp = () => {
424
+ if (dragTarget) {
425
+ dragTarget.material.emissiveIntensity = 0.5;
426
+ dragTarget.scale.setScalar(1);
427
+
428
+ // Snap and animate
429
+ const targetAz = snapToNearest(azimuthAngle, azimuthSteps);
430
+ const targetEl = snapToNearest(elevationAngle, elevationSteps);
431
+ const targetDist = snapToNearest(distanceFactor, distanceSteps);
432
+
433
+ const startAz = azimuthAngle, startEl = elevationAngle, startDist = distanceFactor;
434
+ const startTime = Date.now();
435
+
436
+ function animateSnap() {
437
+ const t = Math.min((Date.now() - startTime) / 200, 1);
438
+ const ease = 1 - Math.pow(1 - t, 3);
439
+
440
+ let azDiff = targetAz - startAz;
441
+ if (azDiff > 180) azDiff -= 360;
442
+ if (azDiff < -180) azDiff += 360;
443
+ azimuthAngle = startAz + azDiff * ease;
444
+ if (azimuthAngle < 0) azimuthAngle += 360;
445
+ if (azimuthAngle >= 360) azimuthAngle -= 360;
446
+
447
+ elevationAngle = startEl + (targetEl - startEl) * ease;
448
+ distanceFactor = startDist + (targetDist - startDist) * ease;
449
+
450
+ updatePositions();
451
+ if (t < 1) requestAnimationFrame(animateSnap);
452
+ else updatePropsAndTrigger();
453
+ }
454
+ animateSnap();
455
+ }
456
+ isDragging = false;
457
+ dragTarget = null;
458
+ canvas.style.cursor = 'default';
459
+ };
460
+
461
+ canvas.addEventListener('mouseup', onMouseUp);
462
+ canvas.addEventListener('mouseleave', onMouseUp);
463
+
464
+ // Touch support for mobile
465
+ canvas.addEventListener('touchstart', (e) => {
466
+ e.preventDefault();
467
+ const touch = e.touches[0];
468
+ const rect = canvas.getBoundingClientRect();
469
+ mouse.x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
470
+ mouse.y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
471
+
472
+ raycaster.setFromCamera(mouse, camera);
473
+ const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle, distanceHandle]);
474
+
475
+ if (intersects.length > 0) {
476
+ isDragging = true;
477
+ dragTarget = intersects[0].object;
478
+ dragTarget.material.emissiveIntensity = 1.0;
479
+ dragTarget.scale.setScalar(1.3);
480
+ dragStartMouse.copy(mouse);
481
+ dragStartDistance = distanceFactor;
482
+ }
483
+ }, { passive: false });
484
+
485
+ canvas.addEventListener('touchmove', (e) => {
486
+ e.preventDefault();
487
+ const touch = e.touches[0];
488
+ const rect = canvas.getBoundingClientRect();
489
+ mouse.x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
490
+ mouse.y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
491
+
492
+ if (isDragging && dragTarget) {
493
+ raycaster.setFromCamera(mouse, camera);
494
+
495
+ if (dragTarget.userData.type === 'azimuth') {
496
+ const plane = new THREE.Plane(new THREE.Vector3(0, 1, 0), -0.05);
497
+ if (raycaster.ray.intersectPlane(plane, intersection)) {
498
+ azimuthAngle = THREE.MathUtils.radToDeg(Math.atan2(intersection.x, intersection.z));
499
+ if (azimuthAngle < 0) azimuthAngle += 360;
500
+ }
501
+ } else if (dragTarget.userData.type === 'elevation') {
502
+ const plane = new THREE.Plane(new THREE.Vector3(1, 0, 0), -0.8);
503
+ if (raycaster.ray.intersectPlane(plane, intersection)) {
504
+ const relY = intersection.y - CENTER.y;
505
+ const relZ = intersection.z;
506
+ elevationAngle = THREE.MathUtils.clamp(THREE.MathUtils.radToDeg(Math.atan2(relY, relZ)), -30, 60);
507
+ }
508
+ } else if (dragTarget.userData.type === 'distance') {
509
+ const deltaY = mouse.y - dragStartMouse.y;
510
+ distanceFactor = THREE.MathUtils.clamp(dragStartDistance - deltaY * 1.5, 0.6, 1.4);
511
+ }
512
+ updatePositions();
513
+ }
514
+ }, { passive: false });
515
+
516
+ canvas.addEventListener('touchend', (e) => {
517
+ e.preventDefault();
518
+ onMouseUp();
519
+ }, { passive: false });
520
+
521
+ canvas.addEventListener('touchcancel', (e) => {
522
+ e.preventDefault();
523
+ onMouseUp();
524
+ }, { passive: false });
525
+
526
+ // Initial update
527
+ updatePositions();
528
+
529
+ // Render loop
530
+ function render() {
531
+ requestAnimationFrame(render);
532
+ renderer.render(scene, camera);
533
+ }
534
+ render();
535
+
536
+ // Handle resize
537
+ new ResizeObserver(() => {
538
+ camera.aspect = wrapper.clientWidth / wrapper.clientHeight;
539
+ camera.updateProjectionMatrix();
540
+ renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
541
+ }).observe(wrapper);
542
+
543
+ // Store update functions for external calls
544
+ wrapper._updateFromProps = (newVal) => {
545
+ if (newVal && typeof newVal === 'object') {
546
+ azimuthAngle = newVal.azimuth ?? azimuthAngle;
547
+ elevationAngle = newVal.elevation ?? elevationAngle;
548
+ distanceFactor = newVal.distance ?? distanceFactor;
549
+ updatePositions();
550
+ }
551
+ };
552
+
553
+ wrapper._updateTexture = updateTextureFromUrl;
554
+
555
+ // Watch for prop changes (imageUrl and value)
556
+ let lastImageUrl = props.imageUrl;
557
+ let lastValue = JSON.stringify(props.value);
558
+ setInterval(() => {
559
+ // Check imageUrl changes
560
+ if (props.imageUrl !== lastImageUrl) {
561
+ lastImageUrl = props.imageUrl;
562
+ updateTextureFromUrl(props.imageUrl);
563
+ }
564
+ // Check value changes (from sliders)
565
+ const currentValue = JSON.stringify(props.value);
566
+ if (currentValue !== lastValue) {
567
+ lastValue = currentValue;
568
+ if (props.value && typeof props.value === 'object') {
569
+ azimuthAngle = props.value.azimuth ?? azimuthAngle;
570
+ elevationAngle = props.value.elevation ?? elevationAngle;
571
+ distanceFactor = props.value.distance ?? distanceFactor;
572
+ updatePositions();
573
+ }
574
+ }
575
+ }, 100);
576
+ };
577
+
578
+ initScene();
579
+ })();
580
+ """
581
+
582
+ super().__init__(
583
+ value=value,
584
+ html_template=html_template,
585
+ js_on_load=js_on_load,
586
+ imageUrl=imageUrl,
587
+ **kwargs
588
+ )
589
+
camera_control_ui.pyi ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ # Azimuth mappings (8 positions)
4
+ AZIMUTH_MAP = {
5
+ 0: "front view",
6
+ 45: "front-right quarter view",
7
+ 90: "right side view",
8
+ 135: "back-right quarter view",
9
+ 180: "back view",
10
+ 225: "back-left quarter view",
11
+ 270: "left side view",
12
+ 315: "front-left quarter view"
13
+ }
14
+
15
+ # Elevation mappings (4 positions)
16
+ ELEVATION_MAP = {
17
+ -30: "low-angle shot",
18
+ 0: "eye-level shot",
19
+ 30: "elevated shot",
20
+ 60: "high-angle shot"
21
+ }
22
+
23
+ # Distance mappings (3 positions)
24
+ DISTANCE_MAP = {
25
+ 0.6: "close-up",
26
+ 1.0: "medium shot",
27
+ 1.8: "wide shot"
28
+ }
29
+
30
+
31
+ def snap_to_nearest(value, options):
32
+ """Snap a value to the nearest option in a list."""
33
+ return min(options, key=lambda x: abs(x - value))
34
+
35
+
36
+ def build_camera_prompt(azimuth: float, elevation: float, distance: float) -> str:
37
+ """
38
+ Build a camera prompt from azimuth, elevation, and distance values.
39
+
40
+ Args:
41
+ azimuth: Horizontal rotation in degrees (0-360)
42
+ elevation: Vertical angle in degrees (-30 to 60)
43
+ distance: Distance factor (0.6 to 1.8)
44
+
45
+ Returns:
46
+ Formatted prompt string for the LoRA
47
+ """
48
+ # Snap to nearest valid values
49
+ azimuth_snapped = snap_to_nearest(azimuth, list(AZIMUTH_MAP.keys()))
50
+ elevation_snapped = snap_to_nearest(elevation, list(ELEVATION_MAP.keys()))
51
+ distance_snapped = snap_to_nearest(distance, list(DISTANCE_MAP.keys()))
52
+
53
+ azimuth_name = AZIMUTH_MAP[azimuth_snapped]
54
+ elevation_name = ELEVATION_MAP[elevation_snapped]
55
+ distance_name = DISTANCE_MAP[distance_snapped]
56
+
57
+ return f"<sks> {azimuth_name} {elevation_name} {distance_name}"
58
+
59
+ from gradio.events import Dependency
60
+
61
+ # --- 3D Camera Control Component ---
62
+ class CameraControl3D(gr.HTML):
63
+ """
64
+ A 3D camera control component using Three.js.
65
+ Outputs: { azimuth: number, elevation: number, distance: number }
66
+ Accepts imageUrl prop to display user's uploaded image on the plane.
67
+ """
68
+ def __init__(self, value=None, imageUrl=None, **kwargs):
69
+ if value is None:
70
+ value = {"azimuth": 0, "elevation": 0, "distance": 1.0}
71
+
72
+ html_template = """
73
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
74
+ <div id="camera-control-wrapper" style="width: 100%; height: 450px; position: relative; background: #1a1a1a; border-radius: 12px; overflow: hidden;">
75
+ <div id="prompt-overlay" style="position: absolute; bottom: 10px; left: 50%; transform: translateX(-50%); background: rgba(0,0,0,0.8); padding: 8px 16px; border-radius: 8px; font-family: monospace; font-size: 12px; color: #00ff88; white-space: nowrap; z-index: 10;"></div>
76
+ </div>
77
+ """
78
+
79
+ js_on_load = """
80
+ (() => {
81
+ const wrapper = element.querySelector('#camera-control-wrapper');
82
+ const promptOverlay = element.querySelector('#prompt-overlay');
83
+
84
+ // Wait for THREE to load
85
+ const initScene = () => {
86
+ if (typeof THREE === 'undefined') {
87
+ setTimeout(initScene, 100);
88
+ return;
89
+ }
90
+
91
+ // Scene setup
92
+ const scene = new THREE.Scene();
93
+ scene.background = new THREE.Color(0x1a1a1a);
94
+
95
+ const camera = new THREE.PerspectiveCamera(50, wrapper.clientWidth / wrapper.clientHeight, 0.1, 1000);
96
+ camera.position.set(4.5, 3, 4.5);
97
+ camera.lookAt(0, 0.75, 0);
98
+
99
+ const renderer = new THREE.WebGLRenderer({ antialias: true });
100
+ renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
101
+ renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
102
+ wrapper.insertBefore(renderer.domElement, promptOverlay);
103
+
104
+ // Lighting
105
+ scene.add(new THREE.AmbientLight(0xffffff, 0.6));
106
+ const dirLight = new THREE.DirectionalLight(0xffffff, 0.6);
107
+ dirLight.position.set(5, 10, 5);
108
+ scene.add(dirLight);
109
+
110
+ // Grid
111
+ scene.add(new THREE.GridHelper(8, 16, 0x333333, 0x222222));
112
+
113
+ // Constants - reduced distances for tighter framing
114
+ const CENTER = new THREE.Vector3(0, 0.75, 0);
115
+ const BASE_DISTANCE = 1.6;
116
+ const AZIMUTH_RADIUS = 2.4;
117
+ const ELEVATION_RADIUS = 1.8;
118
+
119
+ // State
120
+ let azimuthAngle = props.value?.azimuth || 0;
121
+ let elevationAngle = props.value?.elevation || 0;
122
+ let distanceFactor = props.value?.distance || 1.0;
123
+
124
+ // Mappings - reduced wide shot multiplier
125
+ const azimuthSteps = [0, 45, 90, 135, 180, 225, 270, 315];
126
+ const elevationSteps = [-30, 0, 30, 60];
127
+ const distanceSteps = [0.6, 1.0, 1.4];
128
+
129
+ const azimuthNames = {
130
+ 0: 'front view', 45: 'front-right quarter view', 90: 'right side view',
131
+ 135: 'back-right quarter view', 180: 'back view', 225: 'back-left quarter view',
132
+ 270: 'left side view', 315: 'front-left quarter view'
133
+ };
134
+ const elevationNames = { '-30': 'low-angle shot', '0': 'eye-level shot', '30': 'elevated shot', '60': 'high-angle shot' };
135
+ const distanceNames = { '0.6': 'close-up', '1': 'medium shot', '1.4': 'wide shot' };
136
+
137
+ function snapToNearest(value, steps) {
138
+ return steps.reduce((prev, curr) => Math.abs(curr - value) < Math.abs(prev - value) ? curr : prev);
139
+ }
140
+
141
+ // Create placeholder texture (smiley face)
142
+ function createPlaceholderTexture() {
143
+ const canvas = document.createElement('canvas');
144
+ canvas.width = 256;
145
+ canvas.height = 256;
146
+ const ctx = canvas.getContext('2d');
147
+ ctx.fillStyle = '#3a3a4a';
148
+ ctx.fillRect(0, 0, 256, 256);
149
+ ctx.fillStyle = '#ffcc99';
150
+ ctx.beginPath();
151
+ ctx.arc(128, 128, 80, 0, Math.PI * 2);
152
+ ctx.fill();
153
+ ctx.fillStyle = '#333';
154
+ ctx.beginPath();
155
+ ctx.arc(100, 110, 10, 0, Math.PI * 2);
156
+ ctx.arc(156, 110, 10, 0, Math.PI * 2);
157
+ ctx.fill();
158
+ ctx.strokeStyle = '#333';
159
+ ctx.lineWidth = 3;
160
+ ctx.beginPath();
161
+ ctx.arc(128, 130, 35, 0.2, Math.PI - 0.2);
162
+ ctx.stroke();
163
+ return new THREE.CanvasTexture(canvas);
164
+ }
165
+
166
+ // Target image plane
167
+ let currentTexture = createPlaceholderTexture();
168
+ const planeMaterial = new THREE.MeshBasicMaterial({ map: currentTexture, side: THREE.DoubleSide });
169
+ let targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
170
+ targetPlane.position.copy(CENTER);
171
+ scene.add(targetPlane);
172
+
173
+ // Function to update texture from image URL
174
+ function updateTextureFromUrl(url) {
175
+ if (!url) {
176
+ // Reset to placeholder
177
+ planeMaterial.map = createPlaceholderTexture();
178
+ planeMaterial.needsUpdate = true;
179
+ // Reset plane to square
180
+ scene.remove(targetPlane);
181
+ targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
182
+ targetPlane.position.copy(CENTER);
183
+ scene.add(targetPlane);
184
+ return;
185
+ }
186
+
187
+ const loader = new THREE.TextureLoader();
188
+ loader.crossOrigin = 'anonymous';
189
+ loader.load(url, (texture) => {
190
+ texture.minFilter = THREE.LinearFilter;
191
+ texture.magFilter = THREE.LinearFilter;
192
+ planeMaterial.map = texture;
193
+ planeMaterial.needsUpdate = true;
194
+
195
+ // Adjust plane aspect ratio to match image
196
+ const img = texture.image;
197
+ if (img && img.width && img.height) {
198
+ const aspect = img.width / img.height;
199
+ const maxSize = 1.5;
200
+ let planeWidth, planeHeight;
201
+ if (aspect > 1) {
202
+ planeWidth = maxSize;
203
+ planeHeight = maxSize / aspect;
204
+ } else {
205
+ planeHeight = maxSize;
206
+ planeWidth = maxSize * aspect;
207
+ }
208
+ scene.remove(targetPlane);
209
+ targetPlane = new THREE.Mesh(
210
+ new THREE.PlaneGeometry(planeWidth, planeHeight),
211
+ planeMaterial
212
+ );
213
+ targetPlane.position.copy(CENTER);
214
+ scene.add(targetPlane);
215
+ }
216
+ }, undefined, (err) => {
217
+ console.error('Failed to load texture:', err);
218
+ });
219
+ }
220
+
221
+ // Check for initial imageUrl
222
+ if (props.imageUrl) {
223
+ updateTextureFromUrl(props.imageUrl);
224
+ }
225
+
226
+ // Camera model
227
+ const cameraGroup = new THREE.Group();
228
+ const bodyMat = new THREE.MeshStandardMaterial({ color: 0x6699cc, metalness: 0.5, roughness: 0.3 });
229
+ const body = new THREE.Mesh(new THREE.BoxGeometry(0.3, 0.22, 0.38), bodyMat);
230
+ cameraGroup.add(body);
231
+ const lens = new THREE.Mesh(
232
+ new THREE.CylinderGeometry(0.09, 0.11, 0.18, 16),
233
+ new THREE.MeshStandardMaterial({ color: 0x6699cc, metalness: 0.5, roughness: 0.3 })
234
+ );
235
+ lens.rotation.x = Math.PI / 2;
236
+ lens.position.z = 0.26;
237
+ cameraGroup.add(lens);
238
+ scene.add(cameraGroup);
239
+
240
+ // GREEN: Azimuth ring
241
+ const azimuthRing = new THREE.Mesh(
242
+ new THREE.TorusGeometry(AZIMUTH_RADIUS, 0.04, 16, 64),
243
+ new THREE.MeshStandardMaterial({ color: 0x00ff88, emissive: 0x00ff88, emissiveIntensity: 0.3 })
244
+ );
245
+ azimuthRing.rotation.x = Math.PI / 2;
246
+ azimuthRing.position.y = 0.05;
247
+ scene.add(azimuthRing);
248
+
249
+ const azimuthHandle = new THREE.Mesh(
250
+ new THREE.SphereGeometry(0.18, 16, 16),
251
+ new THREE.MeshStandardMaterial({ color: 0x00ff88, emissive: 0x00ff88, emissiveIntensity: 0.5 })
252
+ );
253
+ azimuthHandle.userData.type = 'azimuth';
254
+ scene.add(azimuthHandle);
255
+
256
+ // PINK: Elevation arc
257
+ const arcPoints = [];
258
+ for (let i = 0; i <= 32; i++) {
259
+ const angle = THREE.MathUtils.degToRad(-30 + (90 * i / 32));
260
+ arcPoints.push(new THREE.Vector3(-0.8, ELEVATION_RADIUS * Math.sin(angle) + CENTER.y, ELEVATION_RADIUS * Math.cos(angle)));
261
+ }
262
+ const arcCurve = new THREE.CatmullRomCurve3(arcPoints);
263
+ const elevationArc = new THREE.Mesh(
264
+ new THREE.TubeGeometry(arcCurve, 32, 0.04, 8, false),
265
+ new THREE.MeshStandardMaterial({ color: 0xff69b4, emissive: 0xff69b4, emissiveIntensity: 0.3 })
266
+ );
267
+ scene.add(elevationArc);
268
+
269
+ const elevationHandle = new THREE.Mesh(
270
+ new THREE.SphereGeometry(0.18, 16, 16),
271
+ new THREE.MeshStandardMaterial({ color: 0xff69b4, emissive: 0xff69b4, emissiveIntensity: 0.5 })
272
+ );
273
+ elevationHandle.userData.type = 'elevation';
274
+ scene.add(elevationHandle);
275
+
276
+ // ORANGE: Distance line & handle
277
+ const distanceLineGeo = new THREE.BufferGeometry();
278
+ const distanceLine = new THREE.Line(distanceLineGeo, new THREE.LineBasicMaterial({ color: 0xffa500 }));
279
+ scene.add(distanceLine);
280
+
281
+ const distanceHandle = new THREE.Mesh(
282
+ new THREE.SphereGeometry(0.18, 16, 16),
283
+ new THREE.MeshStandardMaterial({ color: 0xffa500, emissive: 0xffa500, emissiveIntensity: 0.5 })
284
+ );
285
+ distanceHandle.userData.type = 'distance';
286
+ scene.add(distanceHandle);
287
+
288
+ function updatePositions() {
289
+ const distance = BASE_DISTANCE * distanceFactor;
290
+ const azRad = THREE.MathUtils.degToRad(azimuthAngle);
291
+ const elRad = THREE.MathUtils.degToRad(elevationAngle);
292
+
293
+ const camX = distance * Math.sin(azRad) * Math.cos(elRad);
294
+ const camY = distance * Math.sin(elRad) + CENTER.y;
295
+ const camZ = distance * Math.cos(azRad) * Math.cos(elRad);
296
+
297
+ cameraGroup.position.set(camX, camY, camZ);
298
+ cameraGroup.lookAt(CENTER);
299
+
300
+ azimuthHandle.position.set(AZIMUTH_RADIUS * Math.sin(azRad), 0.05, AZIMUTH_RADIUS * Math.cos(azRad));
301
+ elevationHandle.position.set(-0.8, ELEVATION_RADIUS * Math.sin(elRad) + CENTER.y, ELEVATION_RADIUS * Math.cos(elRad));
302
+
303
+ const orangeDist = distance - 0.5;
304
+ distanceHandle.position.set(
305
+ orangeDist * Math.sin(azRad) * Math.cos(elRad),
306
+ orangeDist * Math.sin(elRad) + CENTER.y,
307
+ orangeDist * Math.cos(azRad) * Math.cos(elRad)
308
+ );
309
+ distanceLineGeo.setFromPoints([cameraGroup.position.clone(), CENTER.clone()]);
310
+
311
+ // Update prompt
312
+ const azSnap = snapToNearest(azimuthAngle, azimuthSteps);
313
+ const elSnap = snapToNearest(elevationAngle, elevationSteps);
314
+ const distSnap = snapToNearest(distanceFactor, distanceSteps);
315
+ const distKey = distSnap === 1 ? '1' : distSnap.toFixed(1);
316
+ const prompt = '<sks> ' + azimuthNames[azSnap] + ' ' + elevationNames[String(elSnap)] + ' ' + distanceNames[distKey];
317
+ promptOverlay.textContent = prompt;
318
+ }
319
+
320
+ function updatePropsAndTrigger() {
321
+ const azSnap = snapToNearest(azimuthAngle, azimuthSteps);
322
+ const elSnap = snapToNearest(elevationAngle, elevationSteps);
323
+ const distSnap = snapToNearest(distanceFactor, distanceSteps);
324
+
325
+ props.value = { azimuth: azSnap, elevation: elSnap, distance: distSnap };
326
+ trigger('change', props.value);
327
+ }
328
+
329
+ // Raycasting
330
+ const raycaster = new THREE.Raycaster();
331
+ const mouse = new THREE.Vector2();
332
+ let isDragging = false;
333
+ let dragTarget = null;
334
+ let dragStartMouse = new THREE.Vector2();
335
+ let dragStartDistance = 1.0;
336
+ const intersection = new THREE.Vector3();
337
+
338
+ const canvas = renderer.domElement;
339
+
340
+ canvas.addEventListener('mousedown', (e) => {
341
+ const rect = canvas.getBoundingClientRect();
342
+ mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
343
+ mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
344
+
345
+ raycaster.setFromCamera(mouse, camera);
346
+ const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle, distanceHandle]);
347
+
348
+ if (intersects.length > 0) {
349
+ isDragging = true;
350
+ dragTarget = intersects[0].object;
351
+ dragTarget.material.emissiveIntensity = 1.0;
352
+ dragTarget.scale.setScalar(1.3);
353
+ dragStartMouse.copy(mouse);
354
+ dragStartDistance = distanceFactor;
355
+ canvas.style.cursor = 'grabbing';
356
+ }
357
+ });
358
+
359
+ canvas.addEventListener('mousemove', (e) => {
360
+ const rect = canvas.getBoundingClientRect();
361
+ mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
362
+ mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
363
+
364
+ if (isDragging && dragTarget) {
365
+ raycaster.setFromCamera(mouse, camera);
366
+
367
+ if (dragTarget.userData.type === 'azimuth') {
368
+ const plane = new THREE.Plane(new THREE.Vector3(0, 1, 0), -0.05);
369
+ if (raycaster.ray.intersectPlane(plane, intersection)) {
370
+ azimuthAngle = THREE.MathUtils.radToDeg(Math.atan2(intersection.x, intersection.z));
371
+ if (azimuthAngle < 0) azimuthAngle += 360;
372
+ }
373
+ } else if (dragTarget.userData.type === 'elevation') {
374
+ const plane = new THREE.Plane(new THREE.Vector3(1, 0, 0), -0.8);
375
+ if (raycaster.ray.intersectPlane(plane, intersection)) {
376
+ const relY = intersection.y - CENTER.y;
377
+ const relZ = intersection.z;
378
+ elevationAngle = THREE.MathUtils.clamp(THREE.MathUtils.radToDeg(Math.atan2(relY, relZ)), -30, 60);
379
+ }
380
+ } else if (dragTarget.userData.type === 'distance') {
381
+ const deltaY = mouse.y - dragStartMouse.y;
382
+ distanceFactor = THREE.MathUtils.clamp(dragStartDistance - deltaY * 1.5, 0.6, 1.4);
383
+ }
384
+ updatePositions();
385
+ } else {
386
+ raycaster.setFromCamera(mouse, camera);
387
+ const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle, distanceHandle]);
388
+ [azimuthHandle, elevationHandle, distanceHandle].forEach(h => {
389
+ h.material.emissiveIntensity = 0.5;
390
+ h.scale.setScalar(1);
391
+ });
392
+ if (intersects.length > 0) {
393
+ intersects[0].object.material.emissiveIntensity = 0.8;
394
+ intersects[0].object.scale.setScalar(1.1);
395
+ canvas.style.cursor = 'grab';
396
+ } else {
397
+ canvas.style.cursor = 'default';
398
+ }
399
+ }
400
+ });
401
+
402
+ const onMouseUp = () => {
403
+ if (dragTarget) {
404
+ dragTarget.material.emissiveIntensity = 0.5;
405
+ dragTarget.scale.setScalar(1);
406
+
407
+ // Snap and animate
408
+ const targetAz = snapToNearest(azimuthAngle, azimuthSteps);
409
+ const targetEl = snapToNearest(elevationAngle, elevationSteps);
410
+ const targetDist = snapToNearest(distanceFactor, distanceSteps);
411
+
412
+ const startAz = azimuthAngle, startEl = elevationAngle, startDist = distanceFactor;
413
+ const startTime = Date.now();
414
+
415
+ function animateSnap() {
416
+ const t = Math.min((Date.now() - startTime) / 200, 1);
417
+ const ease = 1 - Math.pow(1 - t, 3);
418
+
419
+ let azDiff = targetAz - startAz;
420
+ if (azDiff > 180) azDiff -= 360;
421
+ if (azDiff < -180) azDiff += 360;
422
+ azimuthAngle = startAz + azDiff * ease;
423
+ if (azimuthAngle < 0) azimuthAngle += 360;
424
+ if (azimuthAngle >= 360) azimuthAngle -= 360;
425
+
426
+ elevationAngle = startEl + (targetEl - startEl) * ease;
427
+ distanceFactor = startDist + (targetDist - startDist) * ease;
428
+
429
+ updatePositions();
430
+ if (t < 1) requestAnimationFrame(animateSnap);
431
+ else updatePropsAndTrigger();
432
+ }
433
+ animateSnap();
434
+ }
435
+ isDragging = false;
436
+ dragTarget = null;
437
+ canvas.style.cursor = 'default';
438
+ };
439
+
440
+ canvas.addEventListener('mouseup', onMouseUp);
441
+ canvas.addEventListener('mouseleave', onMouseUp);
442
+
443
+ // Touch support for mobile
444
+ canvas.addEventListener('touchstart', (e) => {
445
+ e.preventDefault();
446
+ const touch = e.touches[0];
447
+ const rect = canvas.getBoundingClientRect();
448
+ mouse.x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
449
+ mouse.y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
450
+
451
+ raycaster.setFromCamera(mouse, camera);
452
+ const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle, distanceHandle]);
453
+
454
+ if (intersects.length > 0) {
455
+ isDragging = true;
456
+ dragTarget = intersects[0].object;
457
+ dragTarget.material.emissiveIntensity = 1.0;
458
+ dragTarget.scale.setScalar(1.3);
459
+ dragStartMouse.copy(mouse);
460
+ dragStartDistance = distanceFactor;
461
+ }
462
+ }, { passive: false });
463
+
464
+ canvas.addEventListener('touchmove', (e) => {
465
+ e.preventDefault();
466
+ const touch = e.touches[0];
467
+ const rect = canvas.getBoundingClientRect();
468
+ mouse.x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
469
+ mouse.y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
470
+
471
+ if (isDragging && dragTarget) {
472
+ raycaster.setFromCamera(mouse, camera);
473
+
474
+ if (dragTarget.userData.type === 'azimuth') {
475
+ const plane = new THREE.Plane(new THREE.Vector3(0, 1, 0), -0.05);
476
+ if (raycaster.ray.intersectPlane(plane, intersection)) {
477
+ azimuthAngle = THREE.MathUtils.radToDeg(Math.atan2(intersection.x, intersection.z));
478
+ if (azimuthAngle < 0) azimuthAngle += 360;
479
+ }
480
+ } else if (dragTarget.userData.type === 'elevation') {
481
+ const plane = new THREE.Plane(new THREE.Vector3(1, 0, 0), -0.8);
482
+ if (raycaster.ray.intersectPlane(plane, intersection)) {
483
+ const relY = intersection.y - CENTER.y;
484
+ const relZ = intersection.z;
485
+ elevationAngle = THREE.MathUtils.clamp(THREE.MathUtils.radToDeg(Math.atan2(relY, relZ)), -30, 60);
486
+ }
487
+ } else if (dragTarget.userData.type === 'distance') {
488
+ const deltaY = mouse.y - dragStartMouse.y;
489
+ distanceFactor = THREE.MathUtils.clamp(dragStartDistance - deltaY * 1.5, 0.6, 1.4);
490
+ }
491
+ updatePositions();
492
+ }
493
+ }, { passive: false });
494
+
495
+ canvas.addEventListener('touchend', (e) => {
496
+ e.preventDefault();
497
+ onMouseUp();
498
+ }, { passive: false });
499
+
500
+ canvas.addEventListener('touchcancel', (e) => {
501
+ e.preventDefault();
502
+ onMouseUp();
503
+ }, { passive: false });
504
+
505
+ // Initial update
506
+ updatePositions();
507
+
508
+ // Render loop
509
+ function render() {
510
+ requestAnimationFrame(render);
511
+ renderer.render(scene, camera);
512
+ }
513
+ render();
514
+
515
+ // Handle resize
516
+ new ResizeObserver(() => {
517
+ camera.aspect = wrapper.clientWidth / wrapper.clientHeight;
518
+ camera.updateProjectionMatrix();
519
+ renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
520
+ }).observe(wrapper);
521
+
522
+ // Store update functions for external calls
523
+ wrapper._updateFromProps = (newVal) => {
524
+ if (newVal && typeof newVal === 'object') {
525
+ azimuthAngle = newVal.azimuth ?? azimuthAngle;
526
+ elevationAngle = newVal.elevation ?? elevationAngle;
527
+ distanceFactor = newVal.distance ?? distanceFactor;
528
+ updatePositions();
529
+ }
530
+ };
531
+
532
+ wrapper._updateTexture = updateTextureFromUrl;
533
+
534
+ // Watch for prop changes (imageUrl and value)
535
+ let lastImageUrl = props.imageUrl;
536
+ let lastValue = JSON.stringify(props.value);
537
+ setInterval(() => {
538
+ // Check imageUrl changes
539
+ if (props.imageUrl !== lastImageUrl) {
540
+ lastImageUrl = props.imageUrl;
541
+ updateTextureFromUrl(props.imageUrl);
542
+ }
543
+ // Check value changes (from sliders)
544
+ const currentValue = JSON.stringify(props.value);
545
+ if (currentValue !== lastValue) {
546
+ lastValue = currentValue;
547
+ if (props.value && typeof props.value === 'object') {
548
+ azimuthAngle = props.value.azimuth ?? azimuthAngle;
549
+ elevationAngle = props.value.elevation ?? elevationAngle;
550
+ distanceFactor = props.value.distance ?? distanceFactor;
551
+ updatePositions();
552
+ }
553
+ }
554
+ }, 100);
555
+ };
556
+
557
+ initScene();
558
+ })();
559
+ """
560
+
561
+ super().__init__(
562
+ value=value,
563
+ html_template=html_template,
564
+ js_on_load=js_on_load,
565
+ imageUrl=imageUrl,
566
+ **kwargs
567
+ )
568
+ from typing import Callable, Literal, Sequence, Any, TYPE_CHECKING
569
+ from gradio.blocks import Block
570
+ if TYPE_CHECKING:
571
+ from gradio.components import Timer
572
+ from gradio.components.base import Component
examples/1.jpg ADDED
examples/10.jpeg ADDED
examples/11.jpg ADDED
examples/12.jpg ADDED
examples/13.jpg ADDED

Git LFS Details

  • SHA256: d54e023ee72ab14ca3180c3f0c1707234845cc4886adcbc7aa3039914ed4759e
  • Pointer size: 131 Bytes
  • Size of remote file: 197 kB
examples/14.jpg ADDED
examples/2.jpeg ADDED
examples/4.jpg ADDED
examples/5.jpg ADDED
examples/6.jpg ADDED
examples/7.jpg ADDED
examples/8.jpg ADDED
examples/9.jpg ADDED
examples/ELS.jpg ADDED
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pip>=23.0.0
qwenimage/__init__.py ADDED
File without changes
qwenimage/pipeline_qwenimage_edit_plus.py ADDED
@@ -0,0 +1,900 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import math
17
+ from typing import Any, Callable, Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from PIL import Image, ImageOps
23
+ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
24
+
25
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
26
+ from diffusers.loaders import QwenImageLoraLoaderMixin
27
+ from diffusers.models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
28
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
29
+ from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
30
+ from diffusers.utils.torch_utils import randn_tensor
31
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
32
+ from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput
33
+
34
+ if is_torch_xla_available():
35
+ import torch_xla.core.xla_model as xm
36
+
37
+ XLA_AVAILABLE = True
38
+ else:
39
+ XLA_AVAILABLE = False
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+ EXAMPLE_DOC_STRING = """
44
+ Examples:
45
+ ```py
46
+ >>> import torch
47
+ >>> from diffusers import QwenImageEditPlusPipeline
48
+ >>> from diffusers.utils import load_image
49
+
50
+ >>> pipe = QwenImageEditPlusPipeline.from_pretrained(
51
+ ... "Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16
52
+ ... ).to("cuda")
53
+
54
+ >>> image = load_image(
55
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
56
+ ... ).convert("RGB")
57
+
58
+ >>> prompt = "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors"
59
+
60
+ >>> out = pipe(image=image, prompt=prompt, num_inference_steps=50).images[0]
61
+ >>> out.save("qwenimage_edit_plus.png")
62
+ ```
63
+ """
64
+
65
+ CONDITION_IMAGE_SIZE = 384 * 384
66
+ VAE_IMAGE_SIZE = 1024 * 1024
67
+
68
+
69
+ def pad_to_aspect(img: Image.Image, target_w: int, target_h: int) -> Image.Image:
70
+ """Pad (letterbox) to target aspect ratio without warping."""
71
+ return ImageOps.pad(
72
+ img.convert("RGB"),
73
+ (int(target_w), int(target_h)),
74
+ method=Image.Resampling.LANCZOS,
75
+ color=(0, 0, 0),
76
+ centering=(0.5, 0.5),
77
+ )
78
+
79
+
80
+ def choose_condition_area(canvas_area: int, base_area: int = CONDITION_IMAGE_SIZE) -> int:
81
+ """Choose a conditioning target area derived from canvas area with sensible bounds."""
82
+ scaled = int(canvas_area * (base_area / (1024 * 1024)))
83
+ return int(min(base_area, max(256 * 256, scaled)))
84
+
85
+
86
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
87
+ def calculate_shift(
88
+ image_seq_len,
89
+ base_seq_len: int = 256,
90
+ max_seq_len: int = 4096,
91
+ base_shift: float = 0.5,
92
+ max_shift: float = 1.15,
93
+ ):
94
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
95
+ b = base_shift - m * base_seq_len
96
+ mu = image_seq_len * m + b
97
+ return mu
98
+
99
+
100
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
101
+ def retrieve_timesteps(
102
+ scheduler,
103
+ num_inference_steps: Optional[int] = None,
104
+ device: Optional[Union[str, torch.device]] = None,
105
+ timesteps: Optional[List[int]] = None,
106
+ sigmas: Optional[List[float]] = None,
107
+ **kwargs,
108
+ ):
109
+ if timesteps is not None and sigmas is not None:
110
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one.")
111
+
112
+ if timesteps is not None:
113
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
114
+ if not accepts_timesteps:
115
+ raise ValueError(
116
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timesteps."
117
+ )
118
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
119
+ timesteps = scheduler.timesteps
120
+ num_inference_steps = len(timesteps)
121
+
122
+ elif sigmas is not None:
123
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
124
+ if not accept_sigmas:
125
+ raise ValueError(
126
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas."
127
+ )
128
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
129
+ timesteps = scheduler.timesteps
130
+ num_inference_steps = len(timesteps)
131
+
132
+ else:
133
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
134
+ timesteps = scheduler.timesteps
135
+
136
+ return timesteps, num_inference_steps
137
+
138
+
139
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
140
+ def retrieve_latents(encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"):
141
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
142
+ return encoder_output.latent_dist.sample(generator)
143
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
144
+ return encoder_output.latent_dist.mode()
145
+ if hasattr(encoder_output, "latents"):
146
+ return encoder_output.latents
147
+ raise AttributeError("Could not access latents of provided encoder_output")
148
+
149
+
150
+ def calculate_dimensions(target_area: int, ratio: float, multiple: int = 32):
151
+ """
152
+ Area-based sizing while snapping to a chosen lattice multiple.
153
+ Used for canvas sizing AND conditioning sizing (anti-drift).
154
+ """
155
+ m = int(multiple) if multiple else 32
156
+ m = max(1, m)
157
+
158
+ width = math.sqrt(float(target_area) * float(ratio))
159
+ height = width / float(ratio)
160
+
161
+ width = round(width / m) * m
162
+ height = round(height / m) * m
163
+ return int(width), int(height)
164
+
165
+
166
+ # Optional: decoder VAE (Wan2x)
167
+ _ALT_VAE_WAN2X = None
168
+
169
+ # Track desired tiling state for the optional decoder VAE, so it stays consistent across lazy loads.
170
+ _ALT_VAE_WAN2X_TILING_ENABLED = False
171
+
172
+
173
+ def _set_vae_tiling(model: Any, enabled: bool) -> bool:
174
+ """
175
+ Best-effort tiling toggle for a VAE-like module.
176
+ Returns True if a tiling method existed and was called, False otherwise.
177
+ """
178
+ if model is None:
179
+ return False
180
+ try:
181
+ if enabled:
182
+ if hasattr(model, "enable_tiling"):
183
+ model.enable_tiling()
184
+ return True
185
+ if hasattr(model, "enable_vae_tiling"):
186
+ model.enable_vae_tiling()
187
+ return True
188
+ else:
189
+ if hasattr(model, "disable_tiling"):
190
+ model.disable_tiling()
191
+ return True
192
+ if hasattr(model, "disable_vae_tiling"):
193
+ model.disable_vae_tiling()
194
+ return True
195
+ except Exception as e:
196
+ # Don't hard-fail inference if tiling toggle fails for an alt decoder.
197
+ logger.warning(f"VAE tiling toggle failed on {type(model)}: {e}")
198
+ return False
199
+ return False
200
+
201
+
202
+
203
+ def _get_wan2x_vae(device: torch.device, dtype: torch.dtype):
204
+ """
205
+ Decoder-only finetune that outputs 2x resolution via pixel-shuffle.
206
+ Lazy-loaded so it doesn't impact startup unless used.
207
+ """
208
+ global _ALT_VAE_WAN2X, _ALT_VAE_WAN2X_TILING_ENABLED
209
+ if _ALT_VAE_WAN2X is None:
210
+ from diffusers import AutoencoderKLWan
211
+
212
+ _ALT_VAE_WAN2X = AutoencoderKLWan.from_pretrained(
213
+ "spacepxl/Wan2.1-VAE-upscale2x",
214
+ subfolder="diffusers/Wan2.1_VAE_upscale2x_imageonly_real_v1",
215
+ torch_dtype=dtype,
216
+ )
217
+ _ALT_VAE_WAN2X.eval()
218
+
219
+ # Apply last requested tiling immediately on first load (if supported).
220
+ _set_vae_tiling(_ALT_VAE_WAN2X, _ALT_VAE_WAN2X_TILING_ENABLED)
221
+
222
+ _ALT_VAE_WAN2X = _ALT_VAE_WAN2X.to(device=device, dtype=dtype)
223
+
224
+ # Re-apply after moving to device, just in case.
225
+ _set_vae_tiling(_ALT_VAE_WAN2X, _ALT_VAE_WAN2X_TILING_ENABLED)
226
+
227
+ return _ALT_VAE_WAN2X
228
+
229
+
230
+ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
231
+ r"""
232
+ The Qwen-Image-Edit pipeline for image editing.
233
+ """
234
+
235
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
236
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
237
+
238
+ def __init__(
239
+ self,
240
+ scheduler: FlowMatchEulerDiscreteScheduler,
241
+ vae: AutoencoderKLQwenImage,
242
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
243
+ tokenizer: Qwen2Tokenizer,
244
+ processor: Qwen2VLProcessor,
245
+ transformer: QwenImageTransformer2DModel,
246
+ ):
247
+ super().__init__()
248
+ self.register_modules(
249
+ vae=vae,
250
+ text_encoder=text_encoder,
251
+ tokenizer=tokenizer,
252
+ processor=processor,
253
+ transformer=transformer,
254
+ scheduler=scheduler,
255
+ )
256
+
257
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
258
+ self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
259
+
260
+ # QwenImage latents are turned into 2x2 patches and packed; multiply scale-factor by patch size
261
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
262
+ self.tokenizer_max_length = 1024
263
+
264
+
265
+ # Track tiling state (applies to both primary VAE and optional decoder VAE)
266
+ self._vae_tiling_enabled = False
267
+ self.prompt_template_encode = (
268
+ "<|im_start|>system\n"
269
+ "Describe the key features of the input image (color, shape, size, texture, objects, background), "
270
+ "then explain how the user's text instruction should alter or modify the image.\n"
271
+ "Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
272
+ "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
273
+ )
274
+ self.prompt_template_encode_start_idx = 64
275
+ self.default_sample_size = 128
276
+
277
+
278
+ # ------------------------------------------------------------
279
+ # VAE tiling control (applies to both primary VAE and optional decoder VAE)
280
+ # ------------------------------------------------------------
281
+ # Expose a stable API so app.py can call pipe.enable_vae_tiling()/disable_vae_tiling()
282
+ # regardless of which decoder VAE is selected at runtime.
283
+
284
+ def set_vae_tiling(self, enabled: bool) -> None:
285
+ global _ALT_VAE_WAN2X_TILING_ENABLED, _ALT_VAE_WAN2X
286
+
287
+ enabled = bool(enabled)
288
+ self._vae_tiling_enabled = enabled
289
+
290
+ # 1) Primary VAE (Qwen)
291
+ _set_vae_tiling(getattr(self, "vae", None), enabled)
292
+
293
+ # 2) Optional decoder VAE (Wan2x): store desired global state; apply now if already loaded.
294
+ _ALT_VAE_WAN2X_TILING_ENABLED = enabled
295
+ if _ALT_VAE_WAN2X is not None:
296
+ _set_vae_tiling(_ALT_VAE_WAN2X, enabled)
297
+
298
+ def enable_vae_tiling(self) -> None:
299
+ self.set_vae_tiling(True)
300
+
301
+ def disable_vae_tiling(self) -> None:
302
+ self.set_vae_tiling(False)
303
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
304
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
305
+ bool_mask = mask.bool()
306
+ valid_lengths = bool_mask.sum(dim=1)
307
+ selected = hidden_states[bool_mask]
308
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
309
+ return split_result
310
+
311
+ def _get_qwen_prompt_embeds(
312
+ self,
313
+ prompt: Union[str, List[str]] = None,
314
+ image: Optional[torch.Tensor] = None,
315
+ device: Optional[torch.device] = None,
316
+ dtype: Optional[torch.dtype] = None,
317
+ ):
318
+ device = device or self._execution_device
319
+ dtype = dtype or self.text_encoder.dtype
320
+
321
+ prompt = [prompt] if isinstance(prompt, str) else prompt
322
+ img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
323
+
324
+ if isinstance(image, list):
325
+ base_img_prompt = ""
326
+ for i, _ in enumerate(image):
327
+ base_img_prompt += img_prompt_template.format(i + 1)
328
+ elif image is not None:
329
+ base_img_prompt = img_prompt_template.format(1)
330
+ else:
331
+ base_img_prompt = ""
332
+
333
+ template = self.prompt_template_encode
334
+ drop_idx = self.prompt_template_encode_start_idx
335
+ txt = [template.format(base_img_prompt + e) for e in prompt]
336
+
337
+ model_inputs = self.processor(text=txt, images=image, padding=True, return_tensors="pt").to(device)
338
+
339
+ outputs = self.text_encoder(
340
+ input_ids=model_inputs.input_ids,
341
+ attention_mask=model_inputs.attention_mask,
342
+ pixel_values=model_inputs.pixel_values,
343
+ image_grid_thw=model_inputs.image_grid_thw,
344
+ output_hidden_states=True,
345
+ )
346
+
347
+ hidden_states = outputs.hidden_states[-1]
348
+ split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
349
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
350
+
351
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
352
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
353
+
354
+ prompt_embeds = torch.stack(
355
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
356
+ )
357
+ encoder_attention_mask = torch.stack(
358
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
359
+ )
360
+
361
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
362
+ return prompt_embeds, encoder_attention_mask
363
+
364
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
365
+ def encode_prompt(
366
+ self,
367
+ prompt: Union[str, List[str]],
368
+ image: Optional[torch.Tensor] = None,
369
+ device: Optional[torch.device] = None,
370
+ num_images_per_prompt: int = 1,
371
+ prompt_embeds: Optional[torch.Tensor] = None,
372
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
373
+ max_sequence_length: int = 1024,
374
+ ):
375
+ device = device or self._execution_device
376
+ prompt = [prompt] if isinstance(prompt, str) else prompt
377
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
378
+
379
+ if prompt_embeds is None:
380
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
381
+
382
+ _, seq_len, _ = prompt_embeds.shape
383
+
384
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
385
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
386
+
387
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
388
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
389
+
390
+ return prompt_embeds, prompt_embeds_mask
391
+
392
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
393
+ def check_inputs(
394
+ self,
395
+ prompt,
396
+ height,
397
+ width,
398
+ negative_prompt=None,
399
+ prompt_embeds=None,
400
+ negative_prompt_embeds=None,
401
+ prompt_embeds_mask=None,
402
+ negative_prompt_embeds_mask=None,
403
+ callback_on_step_end_tensor_inputs=None,
404
+ max_sequence_length=None,
405
+ ):
406
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
407
+ logger.warning(
408
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. "
409
+ "Dimensions will be resized accordingly."
410
+ )
411
+
412
+ if callback_on_step_end_tensor_inputs is not None and not all(
413
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
414
+ ):
415
+ raise ValueError(
416
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found "
417
+ f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
418
+ )
419
+
420
+ if prompt is not None and prompt_embeds is not None:
421
+ raise ValueError("Cannot forward both `prompt` and `prompt_embeds`.")
422
+ if prompt is None and prompt_embeds is None:
423
+ raise ValueError("Provide either `prompt` or `prompt_embeds`.")
424
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
425
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
426
+
427
+ if negative_prompt is not None and negative_prompt_embeds is not None:
428
+ raise ValueError("Cannot forward both `negative_prompt` and `negative_prompt_embeds`.")
429
+
430
+ if prompt_embeds is not None and prompt_embeds_mask is None:
431
+ raise ValueError("If `prompt_embeds` are provided, `prompt_embeds_mask` must also be passed.")
432
+
433
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
434
+ raise ValueError("If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` must also be passed.")
435
+
436
+ if max_sequence_length is not None and max_sequence_length > 1024:
437
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
438
+
439
+ @staticmethod
440
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
441
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
442
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
443
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
444
+ return latents
445
+
446
+ @staticmethod
447
+ def _unpack_latents(latents, height, width, vae_scale_factor):
448
+ batch_size, _, channels = latents.shape
449
+ height = 2 * (int(height) // (vae_scale_factor * 2))
450
+ width = 2 * (int(width) // (vae_scale_factor * 2))
451
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
452
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
453
+ latents = latents.reshape(batch_size, channels // 4, 1, height, width)
454
+ return latents
455
+
456
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
457
+ if isinstance(generator, list):
458
+ image_latents = [
459
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
460
+ for i in range(image.shape[0])
461
+ ]
462
+ image_latents = torch.cat(image_latents, dim=0)
463
+ else:
464
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
465
+
466
+ latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.latent_channels, 1, 1, 1).to(
467
+ image_latents.device, image_latents.dtype
468
+ )
469
+ latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.latent_channels, 1, 1, 1).to(
470
+ image_latents.device, image_latents.dtype
471
+ )
472
+ image_latents = (image_latents - latents_mean) / latents_std
473
+ return image_latents
474
+
475
+ def prepare_latents(
476
+ self,
477
+ images,
478
+ batch_size,
479
+ num_channels_latents,
480
+ height,
481
+ width,
482
+ dtype,
483
+ device,
484
+ generator,
485
+ latents=None,
486
+ ):
487
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
488
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
489
+ shape = (batch_size, 1, num_channels_latents, height, width)
490
+
491
+ image_latents = None
492
+ if images is not None:
493
+ if not isinstance(images, list):
494
+ images = [images]
495
+ all_image_latents = []
496
+
497
+ for image in images:
498
+ image = image.to(device=device, dtype=dtype)
499
+ if image.shape[1] != self.latent_channels:
500
+ image_latents = self._encode_vae_image(image=image, generator=generator)
501
+ else:
502
+ image_latents = image
503
+
504
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
505
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
506
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
507
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
508
+ raise ValueError(
509
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
510
+ )
511
+
512
+ image_latent_height, image_latent_width = image_latents.shape[3:]
513
+ image_latents = self._pack_latents(
514
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
515
+ )
516
+ all_image_latents.append(image_latents)
517
+
518
+ image_latents = torch.cat(all_image_latents, dim=1)
519
+
520
+ if isinstance(generator, list) and len(generator) != batch_size:
521
+ raise ValueError(
522
+ f"You passed a list of generators of length {len(generator)}, but requested an effective batch size of {batch_size}."
523
+ )
524
+
525
+ if latents is None:
526
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
527
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
528
+ else:
529
+ latents = latents.to(device=device, dtype=dtype)
530
+
531
+ return latents, image_latents
532
+
533
+ @property
534
+ def guidance_scale(self):
535
+ return self._guidance_scale
536
+
537
+ @property
538
+ def attention_kwargs(self):
539
+ return self._attention_kwargs
540
+
541
+ @property
542
+ def num_timesteps(self):
543
+ return self._num_timesteps
544
+
545
+ @property
546
+ def current_timestep(self):
547
+ return self._current_timestep
548
+
549
+ @property
550
+ def interrupt(self):
551
+ return self._interrupt
552
+
553
+ @torch.no_grad()
554
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
555
+ def __call__(
556
+ self,
557
+ image: Optional[PipelineImageInput] = None,
558
+ prompt: Union[str, List[str]] = None,
559
+ negative_prompt: Union[str, List[str]] = None,
560
+ true_cfg_scale: float = 4.0,
561
+ height: Optional[int] = None,
562
+ width: Optional[int] = None,
563
+ condition_area: Optional[int] = None,
564
+ vae_image_indices: Optional[List[int]] = None,
565
+ pad_to_canvas: bool = True,
566
+ # NEW: lattice + VAE ref override
567
+ resolution_multiple: Optional[int] = None,
568
+ vae_ref_area: Optional[int] = None,
569
+ vae_ref_start_index: int = 2,
570
+ # Optional: decoder swap
571
+ decoder_vae: str = "qwen", # "qwen" | "wan2x"
572
+ keep_decoder_2x: bool = False,
573
+ # standard args
574
+ num_inference_steps: int = 50,
575
+ sigmas: Optional[List[float]] = None,
576
+ guidance_scale: Optional[float] = None,
577
+ num_images_per_prompt: int = 1,
578
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
579
+ latents: Optional[torch.Tensor] = None,
580
+ prompt_embeds: Optional[torch.Tensor] = None,
581
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
582
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
583
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
584
+ output_type: Optional[str] = "pil",
585
+ return_dict: bool = True,
586
+ attention_kwargs: Optional[Dict[str, Any]] = None,
587
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
588
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
589
+ max_sequence_length: int = 512,
590
+ ):
591
+ """Run Qwen-Image-Edit inference.
592
+
593
+ Examples:
594
+ """
595
+ # ---- determine input size ----
596
+ if isinstance(image, list):
597
+ image_size = image[0].size
598
+ else:
599
+ image_size = image.size
600
+
601
+ # Lattice multiple used throughout (canvas sizing + condition sizing)
602
+ multiple_of = int(resolution_multiple) if resolution_multiple is not None else (self.vae_scale_factor * 2)
603
+ multiple_of = max(1, multiple_of)
604
+
605
+ calculated_width, calculated_height = calculate_dimensions(
606
+ 1024 * 1024, float(image_size[0]) / float(image_size[1]), multiple=multiple_of
607
+ )
608
+ height = height or calculated_height
609
+ width = width or calculated_width
610
+
611
+ width = (int(width) // multiple_of) * multiple_of
612
+ height = (int(height) // multiple_of) * multiple_of
613
+
614
+ # ---- validate ----
615
+ self.check_inputs(
616
+ prompt,
617
+ height,
618
+ width,
619
+ negative_prompt=negative_prompt,
620
+ prompt_embeds=prompt_embeds,
621
+ negative_prompt_embeds=negative_prompt_embeds,
622
+ prompt_embeds_mask=prompt_embeds_mask,
623
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
624
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
625
+ max_sequence_length=max_sequence_length,
626
+ )
627
+
628
+ self._guidance_scale = guidance_scale
629
+ self._attention_kwargs = attention_kwargs
630
+ self._current_timestep = None
631
+ self._interrupt = False
632
+
633
+ # ---- call params ----
634
+ if prompt is not None and isinstance(prompt, str):
635
+ batch_size = 1
636
+ elif prompt is not None and isinstance(prompt, list):
637
+ batch_size = len(prompt)
638
+ else:
639
+ batch_size = prompt_embeds.shape[0]
640
+
641
+ device = self._execution_device
642
+
643
+ # ---- preprocess ----
644
+ condition_images = None
645
+ vae_images = None
646
+ vae_image_sizes: List[tuple[int, int]] = []
647
+
648
+ # support pre-latent tensors (rare, but keep compatibility)
649
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
650
+ if not isinstance(image, list):
651
+ image = [image]
652
+
653
+ canvas_area = int(width) * int(height)
654
+ cond_area = int(condition_area) if condition_area is not None else choose_condition_area(canvas_area)
655
+
656
+ cond_w, cond_h = calculate_dimensions(cond_area, float(width) / float(height), multiple=multiple_of)
657
+
658
+ # Optional VAE ref override sizing (applied only to indices >= vae_ref_start_index)
659
+ ref_w = ref_h = None
660
+ if vae_ref_area is not None:
661
+ try:
662
+ ref_w, ref_h = calculate_dimensions(
663
+ int(vae_ref_area),
664
+ float(width) / float(height),
665
+ multiple=multiple_of,
666
+ )
667
+ except Exception:
668
+ ref_w = ref_h = None
669
+
670
+ condition_images = []
671
+ vae_images = []
672
+
673
+ if vae_image_indices is None:
674
+ vae_image_indices = list(range(len(image)))
675
+ vae_set = set(int(i) for i in vae_image_indices)
676
+
677
+ for idx, img in enumerate(image):
678
+ pil = img.convert("RGB") if isinstance(img, Image.Image) else img
679
+
680
+ if pad_to_canvas and isinstance(pil, Image.Image):
681
+ pil = pad_to_aspect(pil, int(width), int(height))
682
+
683
+ # conditioning stream (always)
684
+ condition_images.append(self.image_processor.resize(pil, cond_h, cond_w))
685
+
686
+ # VAE stream (selective)
687
+ if idx in vae_set:
688
+ if (ref_w is not None) and (ref_h is not None) and (int(idx) >= int(vae_ref_start_index)):
689
+ vw, vh = int(ref_w), int(ref_h)
690
+ else:
691
+ vw, vh = int(width), int(height)
692
+
693
+ vae_image_sizes.append((vw, vh))
694
+ vae_images.append(self.image_processor.preprocess(pil, int(vh), int(vw)).unsqueeze(2))
695
+
696
+ has_neg_prompt = negative_prompt is not None or (
697
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
698
+ )
699
+ if true_cfg_scale > 1 and not has_neg_prompt:
700
+ logger.warning(
701
+ f"true_cfg_scale={true_cfg_scale} but CFG disabled because no negative prompt was provided."
702
+ )
703
+ if true_cfg_scale <= 1 and has_neg_prompt:
704
+ logger.warning("negative_prompt provided but CFG disabled because true_cfg_scale <= 1")
705
+
706
+ do_true_cfg = (true_cfg_scale > 1) and has_neg_prompt
707
+
708
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
709
+ image=condition_images,
710
+ prompt=prompt,
711
+ prompt_embeds=prompt_embeds,
712
+ prompt_embeds_mask=prompt_embeds_mask,
713
+ device=device,
714
+ num_images_per_prompt=num_images_per_prompt,
715
+ max_sequence_length=max_sequence_length,
716
+ )
717
+
718
+ if do_true_cfg:
719
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
720
+ image=condition_images,
721
+ prompt=negative_prompt,
722
+ prompt_embeds=negative_prompt_embeds,
723
+ prompt_embeds_mask=negative_prompt_embeds_mask,
724
+ device=device,
725
+ num_images_per_prompt=num_images_per_prompt,
726
+ max_sequence_length=max_sequence_length,
727
+ )
728
+
729
+ # ---- prepare latents ----
730
+ num_channels_latents = self.transformer.config.in_channels // 4
731
+ latents, image_latents = self.prepare_latents(
732
+ vae_images,
733
+ batch_size * num_images_per_prompt,
734
+ num_channels_latents,
735
+ height,
736
+ width,
737
+ prompt_embeds.dtype,
738
+ device,
739
+ generator,
740
+ latents,
741
+ )
742
+
743
+ img_shapes = [
744
+ [
745
+ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
746
+ *[
747
+ (1, vae_h // self.vae_scale_factor // 2, vae_w // self.vae_scale_factor // 2)
748
+ for (vae_w, vae_h) in vae_image_sizes
749
+ ],
750
+ ]
751
+ ] * batch_size
752
+
753
+ else:
754
+ raise ValueError(
755
+ "This Space pipeline expects `image` as PIL/np inputs (not pre-latents) in this setup."
756
+ )
757
+
758
+ # ---- timesteps ----
759
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
760
+
761
+ image_seq_len = latents.shape[1]
762
+ mu = calculate_shift(
763
+ image_seq_len,
764
+ self.scheduler.config.get("base_image_seq_len", 256),
765
+ self.scheduler.config.get("max_image_seq_len", 4096),
766
+ self.scheduler.config.get("base_shift", 0.5),
767
+ self.scheduler.config.get("max_shift", 1.15),
768
+ )
769
+ timesteps, num_inference_steps = retrieve_timesteps(
770
+ self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu
771
+ )
772
+
773
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
774
+ self._num_timesteps = len(timesteps)
775
+
776
+ # guidance-distilled models need explicit guidance input
777
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
778
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
779
+ if self.transformer.config.guidance_embeds:
780
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0])
781
+ else:
782
+ if guidance_scale is not None:
783
+ logger.warning("guidance_scale passed but ignored since model is not guidance-distilled.")
784
+ guidance = None
785
+
786
+ if self.attention_kwargs is None:
787
+ self._attention_kwargs = {}
788
+
789
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
790
+ image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
791
+
792
+ do_true_cfg = (
793
+ (true_cfg_scale > 1)
794
+ and (negative_prompt_embeds is not None)
795
+ and (negative_prompt_embeds_mask is not None)
796
+ )
797
+ if do_true_cfg:
798
+ negative_txt_seq_lens = negative_prompt_embeds_mask.sum(dim=1).tolist()
799
+ uncond_image_rotary_emb = self.transformer.pos_embed(img_shapes, negative_txt_seq_lens, device=latents.device)
800
+ else:
801
+ uncond_image_rotary_emb = None
802
+
803
+ # ---- denoise ----
804
+ self.scheduler.set_begin_index(0)
805
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
806
+ for i, t in enumerate(timesteps):
807
+ if self.interrupt:
808
+ continue
809
+ self._current_timestep = t
810
+
811
+ latent_model_input = latents
812
+ if image_latents is not None:
813
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
814
+
815
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
816
+
817
+ with self.transformer.cache_context("cond"):
818
+ noise_pred = self.transformer(
819
+ hidden_states=latent_model_input,
820
+ timestep=timestep / 1000,
821
+ guidance=guidance,
822
+ encoder_hidden_states_mask=prompt_embeds_mask,
823
+ encoder_hidden_states=prompt_embeds,
824
+ image_rotary_emb=image_rotary_emb,
825
+ attention_kwargs=self.attention_kwargs,
826
+ return_dict=False,
827
+ )[0]
828
+ noise_pred = noise_pred[:, : latents.size(1)]
829
+
830
+ if do_true_cfg:
831
+ with self.transformer.cache_context("uncond"):
832
+ neg_noise_pred = self.transformer(
833
+ hidden_states=latent_model_input,
834
+ timestep=timestep / 1000,
835
+ guidance=guidance,
836
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
837
+ encoder_hidden_states=negative_prompt_embeds,
838
+ image_rotary_emb=uncond_image_rotary_emb,
839
+ attention_kwargs=self.attention_kwargs,
840
+ return_dict=False,
841
+ )[0]
842
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
843
+
844
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
845
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
846
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
847
+ noise_pred = comb_pred * (cond_norm / (noise_norm + 1e-8))
848
+
849
+ latents_dtype = latents.dtype
850
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
851
+ if latents.dtype != latents_dtype and torch.backends.mps.is_available():
852
+ latents = latents.to(latents_dtype)
853
+
854
+ if callback_on_step_end is not None:
855
+ callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
856
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
857
+ latents = callback_outputs.pop("latents", latents)
858
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
859
+
860
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
861
+ progress_bar.update()
862
+
863
+ if XLA_AVAILABLE:
864
+ xm.mark_step()
865
+
866
+ self._current_timestep = None
867
+
868
+ # ---- decode ----
869
+ if output_type == "latent":
870
+ image_out = latents
871
+ else:
872
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
873
+ latents = latents.to(self.vae.dtype)
874
+
875
+ latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(
876
+ latents.device, latents.dtype
877
+ )
878
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
879
+ latents.device, latents.dtype
880
+ )
881
+ latents = latents / latents_std + latents_mean
882
+
883
+ if decoder_vae == "wan2x":
884
+ alt_vae = _get_wan2x_vae(latents.device, self.vae.dtype)
885
+ decoder_out = alt_vae.decode(latents, return_dict=False)[0] # [B, 12, F, H, W]
886
+ img_2x = F.pixel_shuffle(decoder_out[:, :, 0], upscale_factor=2) # [B, 3, 2H, 2W]
887
+ if keep_decoder_2x:
888
+ decoded = img_2x
889
+ else:
890
+ decoded = F.interpolate(img_2x, size=(int(height), int(width)), mode="area")
891
+ else:
892
+ decoded = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
893
+
894
+ image_out = self.image_processor.postprocess(decoded, output_type=output_type)
895
+
896
+ self.maybe_free_model_hooks()
897
+
898
+ if not return_dict:
899
+ return (image_out,)
900
+ return QwenImagePipelineOutput(images=image_out)
qwenimage/qwen_fa3_processor.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Paired with a good language model. Thanks!
3
+ """
4
+
5
+ import torch
6
+ from typing import Optional, Tuple
7
+ from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen
8
+
9
+ try:
10
+ from kernels import get_kernel
11
+ _k = get_kernel("kernels-community/vllm-flash-attn3")
12
+ _flash_attn_func = _k.flash_attn_func
13
+ except Exception as e:
14
+ _flash_attn_func = None
15
+ _kernels_err = e
16
+
17
+
18
+ def _ensure_fa3_available():
19
+ if _flash_attn_func is None:
20
+ raise ImportError(
21
+ "FlashAttention-3 via Hugging Face `kernels` is required. "
22
+ "Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n"
23
+ f"{_kernels_err}"
24
+ )
25
+
26
+ @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
27
+ def flash_attn_func(
28
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False
29
+ ) -> torch.Tensor:
30
+ outputs, lse = _flash_attn_func(q, k, v, causal=causal)
31
+ return outputs
32
+
33
+ @flash_attn_func.register_fake
34
+ def _(q, k, v, **kwargs):
35
+ # two outputs:
36
+ # 1. output: (batch, seq_len, num_heads, head_dim)
37
+ # 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
38
+ meta_q = torch.empty_like(q).contiguous()
39
+ return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)
40
+
41
+
42
+ class QwenDoubleStreamAttnProcessorFA3:
43
+ """
44
+ FA3-based attention processor for Qwen double-stream architecture.
45
+ Computes joint attention over concatenated [text, image] streams using vLLM FlashAttention-3
46
+ accessed via Hugging Face `kernels`.
47
+
48
+ Notes / limitations:
49
+ - General attention masks are not supported here (FA3 path). `is_causal=False` and no arbitrary mask.
50
+ - Optional windowed attention / sink tokens / softcap can be plumbed through if you use those features.
51
+ - Expects an available `apply_rotary_emb_qwen` in scope (same as your non-FA3 processor).
52
+ """
53
+
54
+ _attention_backend = "fa3" # for parity with your other processors, not used internally
55
+
56
+ def __init__(self):
57
+ _ensure_fa3_available()
58
+
59
+ @torch.no_grad()
60
+ def __call__(
61
+ self,
62
+ attn, # Attention module with to_q/to_k/to_v/add_*_proj, norms, to_out, to_add_out, and .heads
63
+ hidden_states: torch.FloatTensor, # (B, S_img, D_model) image stream
64
+ encoder_hidden_states: torch.FloatTensor = None, # (B, S_txt, D_model) text stream
65
+ encoder_hidden_states_mask: torch.FloatTensor = None, # unused in FA3 path
66
+ attention_mask: Optional[torch.FloatTensor] = None, # unused in FA3 path
67
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # (img_freqs, txt_freqs)
68
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
69
+ if encoder_hidden_states is None:
70
+ raise ValueError("QwenDoubleStreamAttnProcessorFA3 requires encoder_hidden_states (text stream).")
71
+ if attention_mask is not None:
72
+ # FA3 kernel path here does not consume arbitrary masks; fail fast to avoid silent correctness issues.
73
+ raise NotImplementedError("attention_mask is not supported in this FA3 implementation.")
74
+
75
+ _ensure_fa3_available()
76
+
77
+ B, S_img, _ = hidden_states.shape
78
+ S_txt = encoder_hidden_states.shape[1]
79
+
80
+ # ---- QKV projections (image/sample stream) ----
81
+ img_q = attn.to_q(hidden_states) # (B, S_img, D)
82
+ img_k = attn.to_k(hidden_states)
83
+ img_v = attn.to_v(hidden_states)
84
+
85
+ # ---- QKV projections (text/context stream) ----
86
+ txt_q = attn.add_q_proj(encoder_hidden_states) # (B, S_txt, D)
87
+ txt_k = attn.add_k_proj(encoder_hidden_states)
88
+ txt_v = attn.add_v_proj(encoder_hidden_states)
89
+
90
+ # ---- Reshape to (B, S, H, D_h) ----
91
+ H = attn.heads
92
+ img_q = img_q.unflatten(-1, (H, -1))
93
+ img_k = img_k.unflatten(-1, (H, -1))
94
+ img_v = img_v.unflatten(-1, (H, -1))
95
+
96
+ txt_q = txt_q.unflatten(-1, (H, -1))
97
+ txt_k = txt_k.unflatten(-1, (H, -1))
98
+ txt_v = txt_v.unflatten(-1, (H, -1))
99
+
100
+ # ---- Q/K normalization (per your module contract) ----
101
+ if getattr(attn, "norm_q", None) is not None:
102
+ img_q = attn.norm_q(img_q)
103
+ if getattr(attn, "norm_k", None) is not None:
104
+ img_k = attn.norm_k(img_k)
105
+ if getattr(attn, "norm_added_q", None) is not None:
106
+ txt_q = attn.norm_added_q(txt_q)
107
+ if getattr(attn, "norm_added_k", None) is not None:
108
+ txt_k = attn.norm_added_k(txt_k)
109
+
110
+ # ---- RoPE (Qwen variant) ----
111
+ if image_rotary_emb is not None:
112
+ img_freqs, txt_freqs = image_rotary_emb
113
+ # expects tensors shaped (B, S, H, D_h)
114
+ img_q = apply_rotary_emb_qwen(img_q, img_freqs, use_real=False)
115
+ img_k = apply_rotary_emb_qwen(img_k, img_freqs, use_real=False)
116
+ txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs, use_real=False)
117
+ txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs, use_real=False)
118
+
119
+ # ---- Joint attention over [text, image] along sequence axis ----
120
+ # Shapes: (B, S_total, H, D_h)
121
+ q = torch.cat([txt_q, img_q], dim=1)
122
+ k = torch.cat([txt_k, img_k], dim=1)
123
+ v = torch.cat([txt_v, img_v], dim=1)
124
+
125
+ # FlashAttention-3 path expects (B, S, H, D_h) and returns (out, softmax_lse)
126
+ out = flash_attn_func(q, k, v, causal=False) # out: (B, S_total, H, D_h)
127
+
128
+ # ---- Back to (B, S, D_model) ----
129
+ out = out.flatten(2, 3).to(q.dtype)
130
+
131
+ # Split back to text / image segments
132
+ txt_attn_out = out[:, :S_txt, :]
133
+ img_attn_out = out[:, S_txt:, :]
134
+
135
+ # ---- Output projections ----
136
+ img_attn_out = attn.to_out[0](img_attn_out)
137
+ if len(attn.to_out) > 1:
138
+ img_attn_out = attn.to_out[1](img_attn_out) # dropout if present
139
+
140
+ txt_attn_out = attn.to_add_out(txt_attn_out)
141
+
142
+ return img_attn_out, txt_attn_out
qwenimage/transformer_qwenimage.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import math
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
26
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
27
+ from diffusers.models.attention import FeedForward, AttentionMixin
28
+ from diffusers.models.attention_dispatch import dispatch_attention_fn
29
+ from diffusers.models.attention_processor import Attention
30
+ from diffusers.models.cache_utils import CacheMixin
31
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
32
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
33
+ from diffusers.models.modeling_utils import ModelMixin
34
+ from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ def get_timestep_embedding(
41
+ timesteps: torch.Tensor,
42
+ embedding_dim: int,
43
+ flip_sin_to_cos: bool = False,
44
+ downscale_freq_shift: float = 1,
45
+ scale: float = 1,
46
+ max_period: int = 10000,
47
+ ) -> torch.Tensor:
48
+ """
49
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
50
+
51
+ Args
52
+ timesteps (torch.Tensor):
53
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
54
+ embedding_dim (int):
55
+ the dimension of the output.
56
+ flip_sin_to_cos (bool):
57
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
58
+ downscale_freq_shift (float):
59
+ Controls the delta between frequencies between dimensions
60
+ scale (float):
61
+ Scaling factor applied to the embeddings.
62
+ max_period (int):
63
+ Controls the maximum frequency of the embeddings
64
+ Returns
65
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
66
+ """
67
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
68
+
69
+ half_dim = embedding_dim // 2
70
+ exponent = -math.log(max_period) * torch.arange(
71
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
72
+ )
73
+ exponent = exponent / (half_dim - downscale_freq_shift)
74
+
75
+ emb = torch.exp(exponent).to(timesteps.dtype)
76
+ emb = timesteps[:, None].float() * emb[None, :]
77
+
78
+ # scale embeddings
79
+ emb = scale * emb
80
+
81
+ # concat sine and cosine embeddings
82
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
83
+
84
+ # flip sine and cosine embeddings
85
+ if flip_sin_to_cos:
86
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
87
+
88
+ # zero pad
89
+ if embedding_dim % 2 == 1:
90
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
91
+ return emb
92
+
93
+
94
+ def apply_rotary_emb_qwen(
95
+ x: torch.Tensor,
96
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
97
+ use_real: bool = True,
98
+ use_real_unbind_dim: int = -1,
99
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
100
+ """
101
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
102
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
103
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
104
+ tensors contain rotary embeddings and are returned as real tensors.
105
+
106
+ Args:
107
+ x (`torch.Tensor`):
108
+ Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
109
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
110
+
111
+ Returns:
112
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
113
+ """
114
+ if use_real:
115
+ cos, sin = freqs_cis # [S, D]
116
+ cos = cos[None, None]
117
+ sin = sin[None, None]
118
+ cos, sin = cos.to(x.device), sin.to(x.device)
119
+
120
+ if use_real_unbind_dim == -1:
121
+ # Used for flux, cogvideox, hunyuan-dit
122
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
123
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
124
+ elif use_real_unbind_dim == -2:
125
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
126
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
127
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
128
+ else:
129
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
130
+
131
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
132
+
133
+ return out
134
+ else:
135
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
136
+ freqs_cis = freqs_cis.unsqueeze(1)
137
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
138
+
139
+ return x_out.type_as(x)
140
+
141
+
142
+ class QwenTimestepProjEmbeddings(nn.Module):
143
+ def __init__(self, embedding_dim):
144
+ super().__init__()
145
+
146
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
147
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
148
+
149
+ def forward(self, timestep, hidden_states):
150
+ timesteps_proj = self.time_proj(timestep)
151
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
152
+
153
+ conditioning = timesteps_emb
154
+
155
+ return conditioning
156
+
157
+
158
+ class QwenEmbedRope(nn.Module):
159
+ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
160
+ super().__init__()
161
+ self.theta = theta
162
+ self.axes_dim = axes_dim
163
+ pos_index = torch.arange(4096)
164
+ neg_index = torch.arange(4096).flip(0) * -1 - 1
165
+ self.pos_freqs = torch.cat(
166
+ [
167
+ self.rope_params(pos_index, self.axes_dim[0], self.theta),
168
+ self.rope_params(pos_index, self.axes_dim[1], self.theta),
169
+ self.rope_params(pos_index, self.axes_dim[2], self.theta),
170
+ ],
171
+ dim=1,
172
+ )
173
+ self.neg_freqs = torch.cat(
174
+ [
175
+ self.rope_params(neg_index, self.axes_dim[0], self.theta),
176
+ self.rope_params(neg_index, self.axes_dim[1], self.theta),
177
+ self.rope_params(neg_index, self.axes_dim[2], self.theta),
178
+ ],
179
+ dim=1,
180
+ )
181
+ self.rope_cache = {}
182
+
183
+ # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
184
+ self.scale_rope = scale_rope
185
+
186
+ def rope_params(self, index, dim, theta=10000):
187
+ """
188
+ Args:
189
+ index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
190
+ """
191
+ assert dim % 2 == 0
192
+ freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
193
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
194
+ return freqs
195
+
196
+ def forward(self, video_fhw, txt_seq_lens, device):
197
+ """
198
+ Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
199
+ txt_length: [bs] a list of 1 integers representing the length of the text
200
+ """
201
+ if self.pos_freqs.device != device:
202
+ self.pos_freqs = self.pos_freqs.to(device)
203
+ self.neg_freqs = self.neg_freqs.to(device)
204
+
205
+ if isinstance(video_fhw, list):
206
+ video_fhw = video_fhw[0]
207
+ if not isinstance(video_fhw, list):
208
+ video_fhw = [video_fhw]
209
+
210
+ vid_freqs = []
211
+ max_vid_index = 0
212
+ for idx, fhw in enumerate(video_fhw):
213
+ frame, height, width = fhw
214
+ rope_key = f"{idx}_{height}_{width}"
215
+
216
+ if not torch.compiler.is_compiling():
217
+ if rope_key not in self.rope_cache:
218
+ self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
219
+ video_freq = self.rope_cache[rope_key]
220
+ else:
221
+ video_freq = self._compute_video_freqs(frame, height, width, idx)
222
+ video_freq = video_freq.to(device)
223
+ vid_freqs.append(video_freq)
224
+
225
+ if self.scale_rope:
226
+ max_vid_index = max(height // 2, width // 2, max_vid_index)
227
+ else:
228
+ max_vid_index = max(height, width, max_vid_index)
229
+
230
+ max_len = max(txt_seq_lens)
231
+ txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
232
+ vid_freqs = torch.cat(vid_freqs, dim=0)
233
+
234
+ return vid_freqs, txt_freqs
235
+
236
+ @functools.lru_cache(maxsize=None)
237
+ def _compute_video_freqs(self, frame, height, width, idx=0):
238
+ seq_lens = frame * height * width
239
+ freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
240
+ freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
241
+
242
+ freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
243
+ if self.scale_rope:
244
+ freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
245
+ freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
246
+ freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
247
+ freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
248
+ else:
249
+ freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
250
+ freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
251
+
252
+ freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
253
+ return freqs.clone().contiguous()
254
+
255
+
256
+ class QwenDoubleStreamAttnProcessor2_0:
257
+ """
258
+ Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
259
+ implements joint attention computation where text and image streams are processed together.
260
+ """
261
+
262
+ _attention_backend = None
263
+
264
+ def __init__(self):
265
+ if not hasattr(F, "scaled_dot_product_attention"):
266
+ raise ImportError(
267
+ "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
268
+ )
269
+
270
+ def __call__(
271
+ self,
272
+ attn: Attention,
273
+ hidden_states: torch.FloatTensor, # Image stream
274
+ encoder_hidden_states: torch.FloatTensor = None, # Text stream
275
+ encoder_hidden_states_mask: torch.FloatTensor = None,
276
+ attention_mask: Optional[torch.FloatTensor] = None,
277
+ image_rotary_emb: Optional[torch.Tensor] = None,
278
+ ) -> torch.FloatTensor:
279
+ if encoder_hidden_states is None:
280
+ raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
281
+
282
+ seq_txt = encoder_hidden_states.shape[1]
283
+
284
+ # Compute QKV for image stream (sample projections)
285
+ img_query = attn.to_q(hidden_states)
286
+ img_key = attn.to_k(hidden_states)
287
+ img_value = attn.to_v(hidden_states)
288
+
289
+ # Compute QKV for text stream (context projections)
290
+ txt_query = attn.add_q_proj(encoder_hidden_states)
291
+ txt_key = attn.add_k_proj(encoder_hidden_states)
292
+ txt_value = attn.add_v_proj(encoder_hidden_states)
293
+
294
+ # Reshape for multi-head attention
295
+ img_query = img_query.unflatten(-1, (attn.heads, -1))
296
+ img_key = img_key.unflatten(-1, (attn.heads, -1))
297
+ img_value = img_value.unflatten(-1, (attn.heads, -1))
298
+
299
+ txt_query = txt_query.unflatten(-1, (attn.heads, -1))
300
+ txt_key = txt_key.unflatten(-1, (attn.heads, -1))
301
+ txt_value = txt_value.unflatten(-1, (attn.heads, -1))
302
+
303
+ # Apply QK normalization
304
+ if attn.norm_q is not None:
305
+ img_query = attn.norm_q(img_query)
306
+ if attn.norm_k is not None:
307
+ img_key = attn.norm_k(img_key)
308
+ if attn.norm_added_q is not None:
309
+ txt_query = attn.norm_added_q(txt_query)
310
+ if attn.norm_added_k is not None:
311
+ txt_key = attn.norm_added_k(txt_key)
312
+
313
+ # Apply RoPE
314
+ if image_rotary_emb is not None:
315
+ img_freqs, txt_freqs = image_rotary_emb
316
+ img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
317
+ img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
318
+ txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
319
+ txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
320
+
321
+ # Concatenate for joint attention
322
+ # Order: [text, image]
323
+ joint_query = torch.cat([txt_query, img_query], dim=1)
324
+ joint_key = torch.cat([txt_key, img_key], dim=1)
325
+ joint_value = torch.cat([txt_value, img_value], dim=1)
326
+
327
+ # Compute joint attention
328
+ joint_hidden_states = dispatch_attention_fn(
329
+ joint_query,
330
+ joint_key,
331
+ joint_value,
332
+ attn_mask=attention_mask,
333
+ dropout_p=0.0,
334
+ is_causal=False,
335
+ backend=self._attention_backend,
336
+ )
337
+
338
+ # Reshape back
339
+ joint_hidden_states = joint_hidden_states.flatten(2, 3)
340
+ joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
341
+
342
+ # Split attention outputs back
343
+ txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
344
+ img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
345
+
346
+ # Apply output projections
347
+ img_attn_output = attn.to_out[0](img_attn_output)
348
+ if len(attn.to_out) > 1:
349
+ img_attn_output = attn.to_out[1](img_attn_output) # dropout
350
+
351
+ txt_attn_output = attn.to_add_out(txt_attn_output)
352
+
353
+ return img_attn_output, txt_attn_output
354
+
355
+
356
+ @maybe_allow_in_graph
357
+ class QwenImageTransformerBlock(nn.Module):
358
+ def __init__(
359
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
360
+ ):
361
+ super().__init__()
362
+
363
+ self.dim = dim
364
+ self.num_attention_heads = num_attention_heads
365
+ self.attention_head_dim = attention_head_dim
366
+
367
+ # Image processing modules
368
+ self.img_mod = nn.Sequential(
369
+ nn.SiLU(),
370
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
371
+ )
372
+ self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
373
+ self.attn = Attention(
374
+ query_dim=dim,
375
+ cross_attention_dim=None, # Enable cross attention for joint computation
376
+ added_kv_proj_dim=dim, # Enable added KV projections for text stream
377
+ dim_head=attention_head_dim,
378
+ heads=num_attention_heads,
379
+ out_dim=dim,
380
+ context_pre_only=False,
381
+ bias=True,
382
+ processor=QwenDoubleStreamAttnProcessor2_0(),
383
+ qk_norm=qk_norm,
384
+ eps=eps,
385
+ )
386
+ self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
387
+ self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
388
+
389
+ # Text processing modules
390
+ self.txt_mod = nn.Sequential(
391
+ nn.SiLU(),
392
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
393
+ )
394
+ self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
395
+ # Text doesn't need separate attention - it's handled by img_attn joint computation
396
+ self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
397
+ self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
398
+
399
+ def _modulate(self, x, mod_params):
400
+ """Apply modulation to input tensor"""
401
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
402
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
403
+
404
+ def forward(
405
+ self,
406
+ hidden_states: torch.Tensor,
407
+ encoder_hidden_states: torch.Tensor,
408
+ encoder_hidden_states_mask: torch.Tensor,
409
+ temb: torch.Tensor,
410
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
411
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
412
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
413
+ # Get modulation parameters for both streams
414
+ img_mod_params = self.img_mod(temb) # [B, 6*dim]
415
+ txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
416
+
417
+ # Split modulation parameters for norm1 and norm2
418
+ img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
419
+ txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
420
+
421
+ # Process image stream - norm1 + modulation
422
+ img_normed = self.img_norm1(hidden_states)
423
+ img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
424
+
425
+ # Process text stream - norm1 + modulation
426
+ txt_normed = self.txt_norm1(encoder_hidden_states)
427
+ txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
428
+
429
+ # Use QwenAttnProcessor2_0 for joint attention computation
430
+ # This directly implements the DoubleStreamLayerMegatron logic:
431
+ # 1. Computes QKV for both streams
432
+ # 2. Applies QK normalization and RoPE
433
+ # 3. Concatenates and runs joint attention
434
+ # 4. Splits results back to separate streams
435
+ joint_attention_kwargs = joint_attention_kwargs or {}
436
+ attn_output = self.attn(
437
+ hidden_states=img_modulated, # Image stream (will be processed as "sample")
438
+ encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
439
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
440
+ image_rotary_emb=image_rotary_emb,
441
+ **joint_attention_kwargs,
442
+ )
443
+
444
+ # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
445
+ img_attn_output, txt_attn_output = attn_output
446
+
447
+ # Apply attention gates and add residual (like in Megatron)
448
+ hidden_states = hidden_states + img_gate1 * img_attn_output
449
+ encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
450
+
451
+ # Process image stream - norm2 + MLP
452
+ img_normed2 = self.img_norm2(hidden_states)
453
+ img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
454
+ img_mlp_output = self.img_mlp(img_modulated2)
455
+ hidden_states = hidden_states + img_gate2 * img_mlp_output
456
+
457
+ # Process text stream - norm2 + MLP
458
+ txt_normed2 = self.txt_norm2(encoder_hidden_states)
459
+ txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
460
+ txt_mlp_output = self.txt_mlp(txt_modulated2)
461
+ encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
462
+
463
+ # Clip to prevent overflow for fp16
464
+ if encoder_hidden_states.dtype == torch.float16:
465
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
466
+ if hidden_states.dtype == torch.float16:
467
+ hidden_states = hidden_states.clip(-65504, 65504)
468
+
469
+ return encoder_hidden_states, hidden_states
470
+
471
+
472
+ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin):
473
+ """
474
+ The Transformer model introduced in Qwen.
475
+
476
+ Args:
477
+ patch_size (`int`, defaults to `2`):
478
+ Patch size to turn the input data into small patches.
479
+ in_channels (`int`, defaults to `64`):
480
+ The number of channels in the input.
481
+ out_channels (`int`, *optional*, defaults to `None`):
482
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
483
+ num_layers (`int`, defaults to `60`):
484
+ The number of layers of dual stream DiT blocks to use.
485
+ attention_head_dim (`int`, defaults to `128`):
486
+ The number of dimensions to use for each attention head.
487
+ num_attention_heads (`int`, defaults to `24`):
488
+ The number of attention heads to use.
489
+ joint_attention_dim (`int`, defaults to `3584`):
490
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
491
+ `encoder_hidden_states`).
492
+ guidance_embeds (`bool`, defaults to `False`):
493
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
494
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
495
+ The dimensions to use for the rotary positional embeddings.
496
+ """
497
+
498
+ _supports_gradient_checkpointing = True
499
+ _no_split_modules = ["QwenImageTransformerBlock"]
500
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
501
+ _repeated_blocks = ["QwenImageTransformerBlock"]
502
+
503
+ @register_to_config
504
+ def __init__(
505
+ self,
506
+ patch_size: int = 2,
507
+ in_channels: int = 64,
508
+ out_channels: Optional[int] = 16,
509
+ num_layers: int = 60,
510
+ attention_head_dim: int = 128,
511
+ num_attention_heads: int = 24,
512
+ joint_attention_dim: int = 3584,
513
+ guidance_embeds: bool = False, # TODO: this should probably be removed
514
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
515
+ ):
516
+ super().__init__()
517
+ self.out_channels = out_channels or in_channels
518
+ self.inner_dim = num_attention_heads * attention_head_dim
519
+
520
+ self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
521
+
522
+ self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
523
+
524
+ self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
525
+
526
+ self.img_in = nn.Linear(in_channels, self.inner_dim)
527
+ self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
528
+
529
+ self.transformer_blocks = nn.ModuleList(
530
+ [
531
+ QwenImageTransformerBlock(
532
+ dim=self.inner_dim,
533
+ num_attention_heads=num_attention_heads,
534
+ attention_head_dim=attention_head_dim,
535
+ )
536
+ for _ in range(num_layers)
537
+ ]
538
+ )
539
+
540
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
541
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
542
+
543
+ self.gradient_checkpointing = False
544
+
545
+ def forward(
546
+ self,
547
+ hidden_states: torch.Tensor,
548
+ encoder_hidden_states: torch.Tensor = None,
549
+ encoder_hidden_states_mask: torch.Tensor = None,
550
+ timestep: torch.LongTensor = None,
551
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
552
+ guidance: torch.Tensor = None, # TODO: this should probably be removed
553
+ attention_kwargs: Optional[Dict[str, Any]] = None,
554
+ return_dict: bool = True,
555
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
556
+ """
557
+ The [`QwenTransformer2DModel`] forward method.
558
+
559
+ Args:
560
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
561
+ Input `hidden_states`.
562
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
563
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
564
+ encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
565
+ Mask of the input conditions.
566
+ timestep ( `torch.LongTensor`):
567
+ Used to indicate denoising step.
568
+ attention_kwargs (`dict`, *optional*):
569
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
570
+ `self.processor` in
571
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
572
+ return_dict (`bool`, *optional*, defaults to `True`):
573
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
574
+ tuple.
575
+
576
+ Returns:
577
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
578
+ `tuple` where the first element is the sample tensor.
579
+ """
580
+ if attention_kwargs is not None:
581
+ attention_kwargs = attention_kwargs.copy()
582
+ lora_scale = attention_kwargs.pop("scale", 1.0)
583
+ else:
584
+ lora_scale = 1.0
585
+
586
+ if USE_PEFT_BACKEND:
587
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
588
+ scale_lora_layers(self, lora_scale)
589
+ else:
590
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
591
+ logger.warning(
592
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
593
+ )
594
+
595
+ hidden_states = self.img_in(hidden_states)
596
+
597
+ timestep = timestep.to(hidden_states.dtype)
598
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
599
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
600
+
601
+ if guidance is not None:
602
+ guidance = guidance.to(hidden_states.dtype) * 1000
603
+
604
+ temb = (
605
+ self.time_text_embed(timestep, hidden_states)
606
+ if guidance is None
607
+ else self.time_text_embed(timestep, guidance, hidden_states)
608
+ )
609
+
610
+ for index_block, block in enumerate(self.transformer_blocks):
611
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
612
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
613
+ block,
614
+ hidden_states,
615
+ encoder_hidden_states,
616
+ encoder_hidden_states_mask,
617
+ temb,
618
+ image_rotary_emb,
619
+ )
620
+
621
+ else:
622
+ encoder_hidden_states, hidden_states = block(
623
+ hidden_states=hidden_states,
624
+ encoder_hidden_states=encoder_hidden_states,
625
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
626
+ temb=temb,
627
+ image_rotary_emb=image_rotary_emb,
628
+ joint_attention_kwargs=attention_kwargs,
629
+ )
630
+
631
+ # Use only the image part (hidden_states) from the dual-stream blocks
632
+ hidden_states = self.norm_out(hidden_states, temb)
633
+ output = self.proj_out(hidden_states)
634
+
635
+ if USE_PEFT_BACKEND:
636
+ # remove `lora_scale` from each PEFT layer
637
+ unscale_lora_layers(self, lora_scale)
638
+
639
+ if not return_dict:
640
+ return (output,)
641
+
642
+ return Transformer2DModelOutput(sample=output)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers.git@v4.57.3
2
+ git+https://github.com/huggingface/accelerate.git
3
+ git+https://github.com/huggingface/diffusers.git
4
+ git+https://github.com/huggingface/peft.git
5
+ huggingface_hub
6
+ sentencepiece
7
+ torchvision
8
+ supervision
9
+ kernels
10
+ spaces
11
+ hf_xet
12
+ torch==2.9.1
13
+ numpy
14
+ av
setup_manager.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import sys
4
+
5
+ # Configuration
6
+ WORKSPACE_DIR = "/workspace"
7
+ VENV_DIR = os.path.join(WORKSPACE_DIR, "venv")
8
+ APPS_DIR = os.path.join(WORKSPACE_DIR, "apps")
9
+ REPO_DIR = os.path.join(WORKSPACE_DIR, "Qwen-Image-Edit")
10
+ HF_TOKEN = "YOUR_HF_TOKEN_HERE"
11
+
12
+ # Cache and Temp Directories (Strictly on persistent drive)
13
+ CACHE_BASE = os.path.join(WORKSPACE_DIR, "cache")
14
+ TMP_DIR = os.path.join(WORKSPACE_DIR, "tmp")
15
+ PIP_CACHE = os.path.join(CACHE_BASE, "pip")
16
+ HF_HOME = os.path.join(CACHE_BASE, "huggingface")
17
+
18
+ def ensure_dirs():
19
+ """Ensures all necessary persistent directories exist."""
20
+ dirs = [APPS_DIR, REPO_DIR, CACHE_BASE, TMP_DIR, PIP_CACHE, HF_HOME]
21
+ for d in dirs:
22
+ if not os.path.exists(d):
23
+ os.makedirs(d)
24
+ print(f"Created directory: {d}")
25
+
26
+ def run_command(command, cwd=None, env=None):
27
+ """Runs a shell command and prints output."""
28
+ print(f"Running: {command}")
29
+ current_env = os.environ.copy()
30
+
31
+ # Force use of persistent directories
32
+ current_env["TMPDIR"] = TMP_DIR
33
+ current_env["PIP_CACHE_DIR"] = PIP_CACHE
34
+ current_env["HF_HOME"] = HF_HOME
35
+
36
+ if env:
37
+ current_env.update(env)
38
+
39
+ process = subprocess.Popen(
40
+ command,
41
+ shell=True,
42
+ stdout=subprocess.PIPE,
43
+ stderr=subprocess.STDOUT,
44
+ text=True,
45
+ cwd=cwd,
46
+ env=current_env
47
+ )
48
+ for line in process.stdout:
49
+ print(line, end="")
50
+ process.wait()
51
+ if process.returncode != 0:
52
+ print(f"Command failed with return code {process.returncode}")
53
+ return process.returncode
54
+
55
+ def setup_venv():
56
+ """Sets up a persistent virtual environment in /workspace."""
57
+ if not os.path.exists(VENV_DIR):
58
+ print(f"Creating virtual environment in {VENV_DIR}...")
59
+ run_command(f"python3 -m venv {VENV_DIR}")
60
+ else:
61
+ print("Virtual environment already exists.")
62
+
63
+ def install_package(package_name):
64
+ """Installs a pip package into the persistent venv."""
65
+ pip_path = os.path.join(VENV_DIR, "bin", "pip")
66
+ run_command(f"{pip_path} install {package_name}")
67
+
68
+ def install_git_xet():
69
+ """Installs git-xet using the huggingface script."""
70
+ print("Installing git-xet...")
71
+ run_command("curl -LsSf https://huggingface.co/install-git-xet.sh | bash")
72
+ run_command("git xet install")
73
+
74
+ def install_hf_cli():
75
+ """Installs Hugging Face CLI."""
76
+ print("Installing Hugging Face CLI...")
77
+ run_command("curl -LsSf https://hf.co/cli/install.sh | bash")
78
+
79
+ def download_space():
80
+ """Downloads the Qwen Space using hf cli."""
81
+ if not os.path.exists(REPO_DIR):
82
+ os.makedirs(REPO_DIR)
83
+
84
+ print(f"Downloading Space to {REPO_DIR}...")
85
+ # Using full path to hf if it's in ~/.local/bin
86
+ hf_path = os.path.expanduser("~/.local/bin/hf")
87
+ if not os.path.exists(hf_path):
88
+ hf_path = "hf" # fallback to PATH
89
+
90
+ env = {"HF_TOKEN": HF_TOKEN}
91
+ run_command(f"{hf_path} download Pr0f3ssi0n4ln00b/Qwen-Image-Edit-Rapid-AIO-Loras-Experimental --repo-type=space --local-dir {REPO_DIR}", env=env)
92
+
93
+ def create_app_file(filename, content):
94
+ """Creates/Updates a file in the apps directory."""
95
+ if not os.path.exists(APPS_DIR):
96
+ os.makedirs(APPS_DIR)
97
+
98
+ filepath = os.path.join(APPS_DIR, filename)
99
+ with open(filepath, "w") as f:
100
+ f.write(content)
101
+ print(f"Created/Updated: {filepath}")
102
+
103
+ def patch_app():
104
+ """Patches app.py to optimize for VRAM and fix OOM issues."""
105
+ app_path = os.path.join(REPO_DIR, "app.py")
106
+ if not os.path.exists(app_path):
107
+ print(f"Warning: {app_path} not found, cannot patch.")
108
+ return
109
+
110
+ print("Patching app.py for memory optimization...")
111
+ with open(app_path, "r") as f:
112
+ content = f.read()
113
+
114
+ # 1. Update transformer loading to use device_map="auto" and low_cpu_mem_usage
115
+ content = content.replace(
116
+ 'device_map="cuda",',
117
+ 'device_map="auto",\n low_cpu_mem_usage=True,'
118
+ )
119
+
120
+ # 2. Remove redundant .to(device) which causes OOM
121
+ content = content.replace(').to(device)', ')')
122
+
123
+ # 3. Enable model CPU offload to save VRAM
124
+ if "p.enable_model_cpu_offload()" not in content:
125
+ content = content.replace(
126
+ 'return p',
127
+ 'p.enable_model_cpu_offload()\n return p'
128
+ )
129
+
130
+ # 4. Disable FA3 Processor (to avoid hangs/compilation issues)
131
+ content = content.replace(
132
+ 'pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())',
133
+ 'print("Skipping FA3 optimization for stability.")'
134
+ )
135
+
136
+ # 5. Fix launch parameters for visibility and accessibility
137
+ content = content.replace(
138
+ 'demo.queue(max_size=30).launch(',
139
+ 'demo.queue(max_size=30).launch(server_name="0.0.0.0", share=True, '
140
+ )
141
+
142
+ # 6. Ensure spaces.GPU is handled (if it blocks)
143
+ # Usually it's fine, but let's be safe and mock it if env isn't right
144
+ if 'import spaces' in content and 'class spaces:' not in content:
145
+ content = 'import sys\ntry:\n import spaces\nexcept ImportError:\n class spaces:\n @staticmethod\n def GPU(f): return f\nsys.modules["spaces"] = sys.modules.get("spaces", spaces)\n' + content
146
+
147
+ # 7. Add missing LORA_PRESET_PROMPTS (Robust append)
148
+ additional_prompts_map = {
149
+ "Consistance": "improve consistency and quality of the generated image",
150
+ "F2P": "transform the image into a high-quality photo with realistic details",
151
+ "Multiple-Angles": "change the camera angle of the image",
152
+ "Light-Restoration": "Remove shadows and relight the image using soft lighting",
153
+ "Relight": "Relight the image with cinematic lighting",
154
+ "Multi-Angle-Lighting": "Change the lighting direction and intensity",
155
+ "Edit-Skin": "Enhance skin textures and natural details",
156
+ "Next-Scene": "Generate the next scene based on the current image",
157
+ "Flat-Log": "Desaturate and lower contrast for a flat log look",
158
+ "Upscale-Image": "Enhance and sharpen the image details",
159
+ "BFS-Best-FaceSwap": "head_swap : start with Picture 1 as the base image, keeping its lighting, environment, and background. remove the head from Picture 1 completely and replace it with the head from Picture 2, strictly preserving the hair, eye color, and nose structure, mouth, lips and front head of Picture 2. copy the eye direction, head rotation, and micro-expressions from Picture 1. high quality, sharp details, 4k",
160
+ "BFS-Best-FaceSwap-merge": "head_swap : start with Picture 1 as the base image, keeping its lighting, environment, and background. remove the head from Picture 1 completely and replace it with the head from Picture 2, strictly preserving the hair, eye color, and nose structure, mouth, lips and front head of Picture 2. copy the eye direction, head rotation, and micro-expressions from Picture 1. high quality, sharp details, 4k",
161
+ "Qwen-lora-nsfw": "Convert this picture to artistic style.", # Default prompt
162
+ }
163
+
164
+ # 9. Add new LoRA to ADAPTER_SPECS
165
+ new_lora_config = """
166
+ "Qwen-lora-nsfw": {
167
+ "type": "single",
168
+ "repo": "wiikoo/Qwen-lora-nsfw",
169
+ "weights": "loras/qwen_image_edit_remove-clothing_v1.0.safetensors",
170
+ "adapter_name": "qwen-lora-nsfw",
171
+ "strength": 1.0,
172
+ },
173
+ """
174
+ if '"Qwen-lora-nsfw":' not in content:
175
+ content = content.replace(
176
+ 'ADAPTER_SPECS = {',
177
+ 'ADAPTER_SPECS = {' + new_lora_config
178
+ )
179
+
180
+ if "Manual Patch for missing prompts" not in content:
181
+ content += "\n\n# Manual Patch for missing prompts\ntry:\n LORA_PRESET_PROMPTS.update({\n"
182
+ for key, val in additional_prompts_map.items():
183
+ content += f' "{key}": "{val}",\n'
184
+ content += " })\nexcept NameError:\n pass\n"
185
+
186
+ # 8. Modify on_lora_change_ui to ALWAYS update the prompt if a style is picked
187
+ # (or at least be more aggressive)
188
+ new_ui_logic = """
189
+ def on_lora_change_ui(selected_lora, current_prompt, extras_condition_only):
190
+ # Always provide the preset if selected
191
+ prompt_val = current_prompt
192
+ if selected_lora != NONE_LORA:
193
+ preset = LORA_PRESET_PROMPTS.get(selected_lora, "")
194
+ if preset:
195
+ prompt_val = preset
196
+
197
+ prompt_update = gr.update(value=prompt_val)
198
+ """
199
+ # Find the old function and replace it
200
+ start_marker = "def on_lora_change_ui"
201
+ end_marker = "return prompt_update, img2_update, extras_update"
202
+
203
+ if start_marker in content and end_marker in content:
204
+ import re
205
+ content = re.sub(
206
+ r"def on_lora_change_ui\(.*?\):.*?return prompt_update, img2_update, extras_update",
207
+ new_ui_logic + "\n # Image2 visibility/label\n if lora_requires_two_images(selected_lora):\n img2_update = gr.update(visible=True, label=image2_label_for_lora(selected_lora))\n else:\n img2_update = gr.update(visible=False, value=None, label='Upload Reference (Image 2)')\n\n # Extra references routing default\n if selected_lora in ('BFS-Best-FaceSwap', 'BFS-Best-FaceSwap-merge', 'AnyPose'):\n extras_update = gr.update(value=True)\n else:\n extras_update = gr.update(value=extras_condition_only)\n\n return prompt_update, img2_update, extras_update",
208
+ content,
209
+ flags=re.DOTALL
210
+ )
211
+
212
+ with open(app_path, "w") as f:
213
+ f.write(content)
214
+
215
+ # --- NEW UI PATCHES ---
216
+ with open(app_path, "r") as f:
217
+ content = f.read()
218
+
219
+ # 10. Implement missing _append_to_gallery function
220
+ append_fn = """
221
+ def _append_to_gallery(existing_gallery, new_image):
222
+ if existing_gallery is None:
223
+ return [new_image]
224
+ if not isinstance(existing_gallery, list):
225
+ existing_gallery = [existing_gallery]
226
+ existing_gallery.append(new_image)
227
+ return existing_gallery
228
+ """
229
+ if "def _append_to_gallery" not in content:
230
+ content = content.replace(
231
+ '# UI helpers: output routing + derived conditioning',
232
+ '# UI helpers: output routing + derived conditioning\n' + append_fn
233
+ )
234
+
235
+ # 11. Remove height constraints from main image components
236
+ content = content.replace('height=290)', ')')
237
+ content = content.replace('height=350)', ')')
238
+
239
+ # 12. Strip out gr.Examples block to declutter UI
240
+ # We find the start of gr.Examples and the end of its call
241
+ if "gr.Examples(" in content:
242
+ import re
243
+ content = re.sub(
244
+ r"gr\.Examples\([\s\S]*?label=\"Examples\"[\s\S]*?\)",
245
+ "# Examples removed automatically by setup_manager",
246
+ content
247
+ )
248
+
249
+ with open(app_path, "w") as f:
250
+ f.write(content)
251
+ # --- END NEW UI PATCHES ---
252
+
253
+ # --- 3D CAMERA AND PROMPT CLEARING PATCHES ---
254
+ with open(app_path, "r") as f:
255
+ content = f.read()
256
+
257
+ # Import the custom 3D Camera control safely at the top
258
+ if "update_prompt_with_camera" not in content:
259
+ content = content.replace("import os", "import os\nfrom camera_control_ui import CameraControl3D, build_camera_prompt, update_prompt_with_camera")
260
+
261
+ # Add the 3D Camera LoRA to ADAPTER_SPECS
262
+ camera_lora_config = """
263
+ "3D-Camera": {
264
+ "type": "single",
265
+ "repo": "fal/Qwen-Image-Edit-2511-Multiple-Angles-LoRA",
266
+ "weights": "qwen-image-edit-2511-multiple-angles-lora.safetensors",
267
+ "adapter_name": "angles",
268
+ "strength": 1.0,
269
+ },
270
+ """
271
+ if '"3D-Camera":' not in content:
272
+ content = content.replace(
273
+ 'ADAPTER_SPECS = {',
274
+ 'ADAPTER_SPECS = {' + camera_lora_config
275
+ )
276
+
277
+ # Patch on_lora_change_ui to clear prompt if no preset exists and toggle 3D camera visibility
278
+ prompt_clear_logic = """
279
+ def on_lora_change_ui(selected_lora, current_prompt, extras_condition_only):
280
+ prompt_val = current_prompt
281
+ if selected_lora != NONE_LORA:
282
+ preset = LORA_PRESET_PROMPTS.get(selected_lora, "")
283
+ if preset:
284
+ prompt_val = preset
285
+ else:
286
+ prompt_val = "" # CLEAR THE PROMPT IF ACTIVE BUT NO PRESET
287
+
288
+ prompt_update = gr.update(value=prompt_val)
289
+ camera_update = gr.update(visible=(selected_lora == "3D-Camera"))
290
+
291
+ # Image2 visibility/label
292
+ if lora_requires_two_images(selected_lora):
293
+ img2_update = gr.update(visible=True, label=image2_label_for_lora(selected_lora))
294
+ else:
295
+ img2_update = gr.update(visible=False, value=None, label='Upload Reference (Image 2)')
296
+
297
+ # Extra references routing default
298
+ if selected_lora in ('BFS-Best-FaceSwap', 'BFS-Best-FaceSwap-merge', 'AnyPose'):
299
+ extras_update = gr.update(value=True)
300
+ else:
301
+ extras_update = gr.update(value=extras_condition_only)
302
+
303
+ return prompt_update, img2_update, extras_update, camera_update
304
+ """
305
+ old_on_lora = """
306
+ def on_lora_change_ui(selected_lora, current_prompt, extras_condition_only):
307
+ # Always provide the preset if selected
308
+ prompt_val = current_prompt
309
+ if selected_lora != NONE_LORA:
310
+ preset = LORA_PRESET_PROMPTS.get(selected_lora, "")
311
+ if preset:
312
+ prompt_val = preset
313
+
314
+ prompt_update = gr.update(value=prompt_val)
315
+
316
+ # Image2 visibility/label
317
+ if lora_requires_two_images(selected_lora):
318
+ img2_update = gr.update(visible=True, label=image2_label_for_lora(selected_lora))
319
+ else:
320
+ img2_update = gr.update(visible=False, value=None, label='Upload Reference (Image 2)')
321
+
322
+ # Extra references routing default
323
+ if selected_lora in ('BFS-Best-FaceSwap', 'BFS-Best-FaceSwap-merge', 'AnyPose'):
324
+ extras_update = gr.update(value=True)
325
+ else:
326
+ extras_update = gr.update(value=extras_condition_only)
327
+
328
+ return prompt_update, img2_update, extras_update
329
+ """
330
+ if "camera_update = gr.update(visible" not in content:
331
+ content = content.replace(old_on_lora.strip(), prompt_clear_logic.strip())
332
+
333
+ # We also need to update the caller
334
+ content = content.replace(
335
+ "outputs=[prompt, input_image_2, extras_condition_only],",
336
+ "outputs=[prompt, input_image_2, extras_condition_only, camera_container],"
337
+ )
338
+
339
+ # Inject the 3D Camera UI Block right below input_image_2 definition
340
+ camera_ui_block = """
341
+ input_image_2 = gr.Image(label="Upload Reference (Image 2)", type="pil", height=290, visible=False)
342
+
343
+ with gr.Column(visible=False) as camera_container:
344
+ gr.Markdown("### 🎮 3D Camera Control\\n*Drag handles: 🟢 Azimuth, 🩷 Elevation, 🟠 Distance*")
345
+ camera_3d = CameraControl3D(value={"azimuth": 0, "elevation": 0, "distance": 1.0}, elem_id="camera-3d-control")
346
+ gr.Markdown("### 🎚️ Slider Controls")
347
+ azimuth_slider = gr.Slider(label="Azimuth", minimum=0, maximum=315, step=45, value=0, info="0°=front, 90°=right, 180°=back, 270°=left")
348
+ elevation_slider = gr.Slider(label="Elevation", minimum=-30, maximum=60, step=30, value=0, info="-30°=low angle, 0°=eye, 60°=high angle")
349
+ distance_slider = gr.Slider(label="Distance", minimum=0.6, maximum=1.4, step=0.4, value=1.0, info="0.6=close, 1.0=medium, 1.4=wide")
350
+ """
351
+ if "camera_container:" not in content:
352
+ content = content.replace(
353
+ ' input_image_2 = gr.Image(label="Upload Reference (Image 2)", type="pil", height=290, visible=False)',
354
+ camera_ui_block.strip("\\n")
355
+ )
356
+
357
+ # Inject the Events. We place them right before "run_button.click("
358
+ camera_events = """
359
+ # --- 3D Camera Events ---
360
+ def update_prompt_from_sliders(az, el, dist, curr_prompt):
361
+ return update_prompt_with_camera(az, el, dist, curr_prompt)
362
+
363
+ def sync_3d_to_sliders(cv, curr_prompt):
364
+ if cv and isinstance(cv, dict):
365
+ az = cv.get('azimuth', 0)
366
+ el = cv.get('elevation', 0)
367
+ dist = cv.get('distance', 1.0)
368
+ return az, el, dist, update_prompt_with_camera(az, el, dist, curr_prompt)
369
+ return gr.update(), gr.update(), gr.update(), gr.update()
370
+
371
+ def sync_sliders_to_3d(az, el, dist):
372
+ return {"azimuth": az, "elevation": el, "distance": dist}
373
+
374
+
375
+ def update_3d_image(img):
376
+ if img is None: return gr.update(imageUrl=None)
377
+ import base64
378
+ from io import BytesIO
379
+ buf = BytesIO()
380
+ img.save(buf, format="PNG")
381
+ durl = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
382
+ return gr.update(imageUrl=durl)
383
+
384
+ for slider in [azimuth_slider, elevation_slider, distance_slider]:
385
+ slider.change(fn=update_prompt_from_sliders, inputs=[azimuth_slider, elevation_slider, distance_slider, prompt], outputs=[prompt])
386
+ slider.release(fn=sync_sliders_to_3d, inputs=[azimuth_slider, elevation_slider, distance_slider], outputs=[camera_3d])
387
+
388
+ camera_3d.change(fn=sync_3d_to_sliders, inputs=[camera_3d, prompt], outputs=[azimuth_slider, elevation_slider, distance_slider, prompt])
389
+
390
+ input_image_1.upload(fn=update_3d_image, inputs=[input_image_1], outputs=[camera_3d])
391
+ input_image_1.clear(fn=lambda: gr.update(imageUrl=None), outputs=[camera_3d])
392
+
393
+ run_button.click(
394
+ """
395
+ if "def sync_3d_to_sliders" not in content:
396
+ content = content.replace(" run_button.click(\n", camera_events)
397
+
398
+ # Clear any bad \\n literals if they exist
399
+ content = content.replace("\\n demo.queue", "\n demo.queue")
400
+
401
+ if "head=" not in content:
402
+ content = content.replace(
403
+ "demo.queue(max_size=30).launch(",
404
+ """head = '<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>'
405
+ demo.queue(max_size=30).launch(head=head, """
406
+ )
407
+
408
+ with open(app_path, "w") as f:
409
+ f.write(content)
410
+ # --- END 3D CAMERA PATCHES ---
411
+
412
+ print("Successfully patched app.py.")
413
+
414
+ def install_dependencies():
415
+ """Installs dependencies from requirements.txt into the persistent venv."""
416
+ pip_path = os.path.join(VENV_DIR, "bin", "pip")
417
+ requirements_path = os.path.join(REPO_DIR, "requirements.txt")
418
+
419
+ if os.path.exists(requirements_path):
420
+ print("Installing dependencies from requirements.txt...")
421
+ # Note: torch 2.9.1 might not exist on PyPI, checking if it needs --extra-index-url
422
+ # For L40S, we typically want the latest stable torch with CUDA 12.x
423
+ run_command(f"{pip_path} install -r {requirements_path}")
424
+ else:
425
+ print(f"No requirements.txt found in {REPO_DIR}")
426
+
427
+ def run_app():
428
+ """Starts the Gradio app."""
429
+ python_path = os.path.join(VENV_DIR, "bin", "python")
430
+ app_path = os.path.join(REPO_DIR, "app.py")
431
+
432
+ if os.path.exists(app_path):
433
+ print(f"Starting app: {app_path}")
434
+ # Gradio apps often need to be bound to 0.0.0.0 for external access
435
+ # We'll run it and see if it requires specific environment variables
436
+ env = {"PYTHONPATH": REPO_DIR}
437
+ run_command(f"{python_path} {app_path}", cwd=REPO_DIR, env=env)
438
+ else:
439
+ print(f"App file not found: {app_path}")
440
+
441
+ def main():
442
+ # Ensure workspace exists
443
+ if not os.path.exists(WORKSPACE_DIR):
444
+ print(f"Error: {WORKSPACE_DIR} not found. Ensure this is a RunPod with persistent storage.")
445
+ return
446
+
447
+ ensure_dirs()
448
+ setup_venv()
449
+ install_git_xet()
450
+ install_hf_cli()
451
+ download_space()
452
+ patch_app()
453
+ install_dependencies()
454
+
455
+ # We don't call run_app here by default to allow script updates
456
+ print("Setup tasks completed. Run with 'run' argument to start the app.")
457
+
458
+ if __name__ == "__main__":
459
+ if len(sys.argv) > 1 and sys.argv[1] == "run":
460
+ run_app()
461
+ else:
462
+ main()
start_app.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ export PYTHONPATH=/workspace/Qwen-Image-Edit
3
+ export TMPDIR=/workspace/tmp
4
+ export HF_HOME=/workspace/cache/huggingface
5
+ export PYTHONUNBUFFERED=1
6
+ cd /workspace/Qwen-Image-Edit
7
+ exec /workspace/venv/bin/python -u /workspace/Qwen-Image-Edit/app.py