prithivMLmods commited on
Commit
b47b250
·
verified ·
1 Parent(s): 32154ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +632 -411
app.py CHANGED
@@ -4,22 +4,15 @@ import random
4
  import torch
5
  import spaces
6
  from PIL import Image
7
-
8
- # --- Imports ---
9
  from diffusers import FlowMatchEulerDiscreteScheduler
10
- try:
11
- from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
12
- from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
13
- from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
14
- except ImportError:
15
- raise ImportError("Please ensure the 'qwenimage' package is installed.")
16
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
-
19
- # --- Configuration & Model Loading ---
20
  dtype = torch.bfloat16
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
-
23
  pipe = QwenImageEditPlusPipeline.from_pretrained(
24
  "Qwen/Qwen-Image-Edit-2511",
25
  transformer=QwenImageTransformer2DModel.from_pretrained(
@@ -29,7 +22,6 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
29
  ),
30
  torch_dtype=dtype
31
  ).to(device)
32
-
33
  try:
34
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
35
  print("Flash Attention 3 Processor set successfully.")
@@ -41,404 +33,567 @@ ADAPTER_SPECS = {
41
  "repo": "dx8152/Qwen-Edit-2509-Multi-Angle-Lighting",
42
  "weights": "多角度灯光-251116.safetensors",
43
  "adapter_name": "multi-angle-lighting"
44
- }
45
  }
46
-
47
- CURRENT_LOADED_ADAPTER = None
48
-
49
- # --- Logic: Mappings ---
50
- LIGHTING_AZIMUTH_MAP = {
51
- 0: "Light source from the Front",
52
- 45: "Light source from the Right Front",
53
- 90: "Light source from the Right",
54
- 135: "Light source from the Right Rear",
55
- 180: "Light source from the Rear",
56
- 225: "Light source from the Left Rear",
57
- 270: "Light source from the Left",
58
- 315: "Light source from the Left Front"
 
 
 
 
 
 
59
  }
60
 
61
- def snap_to_nearest_key(value, keys):
62
- return min(keys, key=lambda x: abs(x - value))
 
63
 
64
  def build_lighting_prompt(azimuth: float, elevation: float) -> str:
65
- if elevation >= 60: return "Light source from Above"
66
- if elevation <= -60: return "Light source from Below"
67
- keys = list(LIGHTING_AZIMUTH_MAP.keys())
68
- if azimuth > 337.5: azimuth = 0
69
- azimuth_snapped = snap_to_nearest_key(azimuth, keys)
70
- return LIGHTING_AZIMUTH_MAP[azimuth_snapped]
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- # --- Inference ---
73
  @spaces.GPU
74
- def infer_lighting_edit(image, azimuth, elevation, seed, randomize_seed, guidance_scale, num_inference_steps, height, width):
75
- global CURRENT_LOADED_ADAPTER
76
- spec = ADAPTER_SPECS["Multi-Angle-Lighting"]
77
- if CURRENT_LOADED_ADAPTER != spec["adapter_name"]:
78
- pipe.load_lora_weights(spec["repo"], weight_name=spec["weights"], adapter_name=spec["adapter_name"])
79
- pipe.set_adapters([spec["adapter_name"]], adapter_weights=[1.0])
80
- CURRENT_LOADED_ADAPTER = spec["adapter_name"]
 
 
 
 
 
 
 
 
 
81
 
82
- prompt = build_lighting_prompt(azimuth, elevation)
83
- print(f"💡 Generated Prompt: {prompt}")
84
-
85
- if image is None: raise gr.Error("Please upload an image.")
86
- if randomize_seed: seed = random.randint(0, MAX_SEED)
 
 
 
87
 
 
 
 
 
88
  generator = torch.Generator(device=device).manual_seed(seed)
89
- pil_image = image.convert("RGB")
90
-
 
91
  result = pipe(
92
- image=[pil_image], prompt=prompt, height=height, width=width,
93
- num_inference_steps=num_inference_steps, generator=generator,
94
- guidance_scale=guidance_scale, num_images_per_prompt=1
 
 
 
 
 
95
  ).images[0]
96
  return result, seed, prompt
97
 
98
  def update_dimensions_on_upload(image):
99
- if image is None: return 1024, 1024
100
- w, h = image.size
101
- if w > h: new_w, new_h = 1024, int(1024 * (h / w))
102
- else: new_h, new_w = 1024, int(1024 * (w / h))
103
- return (new_w // 8) * 8, (new_h // 8) * 8
104
-
105
- # --- 🌟 OPTIMIZED 3D CONTROLLER 🌟 ---
106
- class LightControl3D(gr.HTML):
 
 
 
 
 
 
 
 
 
 
107
  """
108
- High-Fidelity 3D Studio Light Controller
 
 
109
  """
110
  def __init__(self, value=None, imageUrl=None, **kwargs):
111
- if value is None: value = {"azimuth": 0, "elevation": 0}
 
112
 
113
  html_template = """
114
- <div id="light-control-wrapper" style="width: 100%; height: 500px; position: relative; background: #0f0f0f; border-radius: 12px; overflow: hidden; border: 1px solid #333; box-shadow: 0 4px 20px rgba(0,0,0,0.5);">
115
- <div id="prompt-badge" style="position: absolute; top: 20px; left: 50%; transform: translateX(-50%);
116
- background: rgba(10,10,10,0.85); border: 1px solid #FFD700; color: #FFD700;
117
- padding: 10px 30px; border-radius: 8px; font-family: 'Segoe UI', monospace; font-weight: 600; font-size: 14px;
118
- z-index: 10; pointer-events: none; backdrop-filter: blur(4px); box-shadow: 0 4px 12px rgba(0,0,0,0.5);
119
- transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);">
120
- Light Source: Front
121
- </div>
122
-
123
- <div style="position: absolute; bottom: 20px; left: 20px; color: #666; font-size: 11px; font-family: sans-serif; pointer-events: none; line-height: 1.5;">
124
- <span style="color: #FFD700">●</span> Light Position<br>
125
- <span style="color: #444">■</span> Shadow Caster
126
- </div>
127
  </div>
128
  """
129
 
130
  js_on_load = """
131
  (() => {
132
- const wrapper = element.querySelector('#light-control-wrapper');
133
- const badge = element.querySelector('#prompt-badge');
134
 
 
135
  const initScene = () => {
136
- if (typeof THREE === 'undefined') { setTimeout(initScene, 100); return; }
 
 
 
137
 
138
- // --- 1. Scene & Renderer (Enable Shadows) ---
139
  const scene = new THREE.Scene();
140
- scene.background = new THREE.Color(0x0f0f0f);
141
- scene.fog = new THREE.Fog(0x0f0f0f, 4, 12);
142
 
143
- const camera = new THREE.PerspectiveCamera(45, wrapper.clientWidth / wrapper.clientHeight, 0.1, 100);
144
- camera.position.set(5, 4, 5);
145
- camera.lookAt(0, 0.5, 0);
146
 
147
  const renderer = new THREE.WebGLRenderer({ antialias: true });
148
- renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
149
  renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
150
- renderer.shadowMap.enabled = true; // CRITICAL: Enable Shadows
151
- renderer.shadowMap.type = THREE.PCFSoftShadowMap;
152
- wrapper.appendChild(renderer.domElement);
 
 
 
 
153
 
154
- // --- 2. Environment (Studio Floor) ---
155
  const CENTER = new THREE.Vector3(0, 0.75, 0);
156
- const RADIUS = 2.8;
157
-
158
- // Floor (Shadow Catcher)
159
- const planeGeo = new THREE.PlaneGeometry(20, 20);
160
- const planeMat = new THREE.ShadowMaterial({ opacity: 0.5, color: 0x000000 });
161
- const floor = new THREE.Mesh(planeGeo, planeMat);
162
- floor.rotation.x = -Math.PI / 2;
163
- floor.position.y = 0;
164
- floor.receiveShadow = true;
165
- scene.add(floor);
166
-
167
- // Grid (Visual Reference)
168
- const grid = new THREE.GridHelper(10, 20, 0x333333, 0x111111);
169
- scene.add(grid);
170
-
171
- // --- 3. The Subject (Thick Slab) ---
172
- let subjectGroup = new THREE.Group();
173
- scene.add(subjectGroup);
174
-
175
- // Base Stand
176
- const standGeo = new THREE.CylinderGeometry(0.8, 0.8, 0.05, 32);
177
- const standMat = new THREE.MeshStandardMaterial({ color: 0x222222, roughness: 0.2, metalness: 0.8 });
178
- const stand = new THREE.Mesh(standGeo, standMat);
179
- stand.position.y = 0.025;
180
- stand.receiveShadow = true;
181
- subjectGroup.add(stand);
182
-
183
- // The Image Box
184
- let imageMesh;
185
- const imageMat = new THREE.MeshStandardMaterial({
186
- color: 0xaaaaaa, roughness: 0.4, side: THREE.DoubleSide
187
- });
188
 
189
- function createSubject(width=1.5, height=1.5) {
190
- if(imageMesh) {
191
- subjectGroup.remove(imageMesh);
192
- imageMesh.geometry.dispose();
193
- }
194
- // A box with slight thickness to cast better shadows
195
- imageMesh = new THREE.Mesh(new THREE.BoxGeometry(width, height, 0.05), imageMat);
196
- imageMesh.position.set(0, height/2 + 0.1, 0);
197
- imageMesh.castShadow = true; // CAST SHADOW
198
- imageMesh.receiveShadow = true;
199
- subjectGroup.add(imageMesh);
 
 
 
 
 
200
  }
201
- createSubject();
202
-
203
- // Texture Loader
204
- function updateTexture(url) {
205
- if (!url) { imageMat.map = null; imageMat.needsUpdate = true; return; }
206
- new THREE.TextureLoader().load(url, (tex) => {
207
- tex.encoding = THREE.sRGBEncoding;
208
- imageMat.map = tex;
209
- imageMat.color.setHex(0xffffff); // Reset tint
210
- imageMat.needsUpdate = true;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- const img = tex.image;
213
- if(img && img.width && img.height) {
 
214
  const aspect = img.width / img.height;
215
- const maxDim = 1.8;
216
- if (aspect > 1) createSubject(maxDim, maxDim/aspect);
217
- else createSubject(maxDim*aspect, maxDim);
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  }
 
 
219
  });
220
  }
221
- if (props.imageUrl) updateTexture(props.imageUrl);
222
-
223
- // --- 4. The Light System (Visual + Actual Light) ---
224
- const lightPivot = new THREE.Group();
225
- scene.add(lightPivot);
226
-
227
- // Actual Light Source (for Shadows)
228
- const dirLight = new THREE.DirectionalLight(0xffffff, 1.5);
229
- dirLight.castShadow = true;
230
- dirLight.shadow.mapSize.width = 1024;
231
- dirLight.shadow.mapSize.height = 1024;
232
- dirLight.shadow.camera.near = 0.5;
233
- dirLight.shadow.camera.far = 10;
234
- lightPivot.add(dirLight);
235
-
236
- // Visual Representation (The Gizmo)
237
- const gizmoGroup = new THREE.Group();
238
- lightPivot.add(gizmoGroup);
239
-
240
- // The "Sun" Mesh
241
- const sunGeo = new THREE.SphereGeometry(0.15, 32, 32);
242
- const sunMat = new THREE.MeshBasicMaterial({ color: 0xFFD700 });
243
- const sun = new THREE.Mesh(sunGeo, sunMat);
244
- gizmoGroup.add(sun);
245
-
246
- // The "Glow" Sprite
247
- const canvas = document.createElement('canvas');
248
- canvas.width = 128; canvas.height = 128;
249
- const context = canvas.getContext('2d');
250
- const gradient = context.createRadialGradient(64, 64, 0, 64, 64, 64);
251
- gradient.addColorStop(0, 'rgba(255, 215, 0, 1)');
252
- gradient.addColorStop(0.4, 'rgba(255, 215, 0, 0.2)');
253
- gradient.addColorStop(1, 'rgba(0, 0, 0, 0)');
254
- context.fillStyle = gradient;
255
- context.fillRect(0, 0, 128, 128);
256
- const glowTex = new THREE.CanvasTexture(canvas);
257
- const glowMat = new THREE.SpriteMaterial({ map: glowTex, transparent: true, blending: THREE.AdditiveBlending });
258
- const glow = new THREE.Sprite(glowMat);
259
- glow.scale.set(1.5, 1.5, 1);
260
- gizmoGroup.add(glow);
261
-
262
- // The Volumetric Cone (Beam)
263
- const beamGeo = new THREE.ConeGeometry(0.35, RADIUS, 32, 1, true);
264
- beamGeo.translate(0, -RADIUS/2, 0);
265
- beamGeo.rotateX(-Math.PI / 2);
266
- const beamMat = new THREE.MeshBasicMaterial({
267
- color: 0xFFD700, transparent: true, opacity: 0.08,
268
- side: THREE.DoubleSide, depthWrite: false, blending: THREE.AdditiveBlending
269
- });
270
- const beam = new THREE.Mesh(beamGeo, beamMat);
271
- beam.lookAt(new THREE.Vector3(0,0,0)); // Initialize pointing center
272
- gizmoGroup.add(beam);
273
-
274
- // Dome Visual Guide (Subtle)
275
- const domeGeo = new THREE.IcosahedronGeometry(RADIUS, 1);
276
- const domeWire = new THREE.WireframeGeometry(domeGeo);
277
- const domeLine = new THREE.LineSegments(domeWire, new THREE.LineBasicMaterial({ color: 0x333333, transparent: true, opacity: 0.1 }));
278
- scene.add(domeLine);
279
-
280
- // --- 5. State & Logic ---
281
- // We keep two sets of angles: Target (where mouse is) and Current (where light is, for animation)
282
- let targetAz = props.value?.azimuth || 0;
283
- let targetEl = props.value?.elevation || 0;
284
- let currentAz = targetAz;
285
- let currentEl = targetEl;
286
-
287
- const AZ_MAP = {
288
- 0: 'Front', 45: 'Right Front', 90: 'Right', 135: 'Right Rear',
289
- 180: 'Rear', 225: 'Left Rear', 270: 'Left', 315: 'Left Front'
290
- };
291
 
292
- function getPrompt(a, e) {
293
- if (e >= 60) return "Light source from Above";
294
- if (e <= -60) return "Light source from Below";
295
- let n = a % 360; if(n < 0) n += 360;
296
- const steps = [0,45,90,135,180,225,270,315];
297
- const snapped = steps.reduce((p, c) => Math.abs(c-n) < Math.abs(p-n) ? c : p);
298
- return `Light source from the ${AZ_MAP[snapped]}`;
299
  }
300
-
301
- // --- 6. Input & Raycasting ---
302
- const raycaster = new THREE.Raycaster();
303
- const mouse = new THREE.Vector2();
304
- let isDragging = false;
305
 
306
- // Invisible interaction sphere
307
- const dragMesh = new THREE.Mesh(
308
- new THREE.SphereGeometry(RADIUS, 32, 32),
309
- new THREE.MeshBasicMaterial({ visible: false, side: THREE.BackSide })
 
 
 
 
 
 
 
 
 
310
  );
311
- scene.add(dragMesh);
312
-
313
- function updateFromIntersect(point) {
314
- const rel = new THREE.Vector3().subVectors(point, CENTER);
315
- let az = Math.atan2(rel.x, rel.z) * (180 / Math.PI);
316
- if (az < 0) az += 360;
317
- const distXZ = Math.sqrt(rel.x*rel.x + rel.z*rel.z);
318
- let el = Math.atan2(rel.y, distXZ) * (180 / Math.PI);
319
- el = Math.max(-89, Math.min(89, el)); // Clamp
320
-
321
- targetAz = az;
322
- targetEl = el;
 
 
 
 
323
  }
324
-
325
- function onDown(e) {
326
- const rect = wrapper.getBoundingClientRect();
327
- const cx = e.clientX || (e.touches ? e.touches[0].clientX : 0);
328
- const cy = e.clientY || (e.touches ? e.touches[0].clientY : 0);
329
- mouse.x = ((cx - rect.left) / rect.width) * 2 - 1;
330
- mouse.y = -((cy - rect.top) / rect.height) * 2 + 1;
 
 
 
 
 
 
 
 
 
 
 
331
 
332
- raycaster.setFromCamera(mouse, camera);
333
- const intersects = raycaster.intersectObject(dragMesh);
334
- if(intersects.length > 0) {
335
- isDragging = true;
336
- wrapper.style.cursor = 'none'; // Immersion
337
- updateFromIntersect(intersects[0].point);
338
- }
339
- }
340
-
341
- function onMove(e) {
342
- if (!isDragging) {
343
- // Hover cursor logic
344
- const rect = wrapper.getBoundingClientRect();
345
- mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
346
- mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
347
- raycaster.setFromCamera(mouse, camera);
348
- const hits = raycaster.intersectObject(dragMesh);
349
- wrapper.style.cursor = (hits.length > 0) ? 'grab' : 'default';
350
- return;
351
- }
352
 
353
- const rect = wrapper.getBoundingClientRect();
354
- const cx = e.clientX || (e.touches ? e.touches[0].clientX : 0);
355
- const cy = e.clientY || (e.touches ? e.touches[0].clientY : 0);
356
- mouse.x = ((cx - rect.left) / rect.width) * 2 - 1;
357
- mouse.y = -((cy - rect.top) / rect.height) * 2 + 1;
358
 
359
- raycaster.setFromCamera(mouse, camera);
360
- const intersects = raycaster.intersectObject(dragMesh);
361
- if (intersects.length > 0) {
362
- updateFromIntersect(intersects[0].point);
363
- }
364
- }
365
-
366
- function onUp() {
367
- if(isDragging) {
368
- isDragging = false;
369
- wrapper.style.cursor = 'default';
370
- props.value = { azimuth: targetAz, elevation: targetEl };
371
- trigger('change', props.value);
372
  }
 
373
  }
374
-
375
- wrapper.addEventListener('mousedown', onDown);
376
- window.addEventListener('mousemove', onMove);
377
- window.addEventListener('mouseup', onUp);
378
- wrapper.addEventListener('touchstart', onDown, {passive: false});
379
- window.addEventListener('touchmove', onMove, {passive: false});
380
- window.addEventListener('touchend', onUp);
381
-
382
- // --- 7. Animation Loop (Visuals) ---
383
- function animate() {
384
- requestAnimationFrame(animate);
385
 
386
- // LERPING: Smoothly move current light pos to target
387
- // This creates the "swish" feeling
388
- const lerpSpeed = 0.15;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
- // Handle Azimuth wraparound (359 -> 0)
391
- let dAz = targetAz - currentAz;
392
- if (dAz > 180) dAz -= 360;
393
- if (dAz < -180) dAz += 360;
394
- currentAz += dAz * lerpSpeed;
395
 
396
- currentEl += (targetEl - currentEl) * lerpSpeed;
397
-
398
- // Calculate Cartesian
399
- const r_az = THREE.MathUtils.degToRad(currentAz);
400
- const r_el = THREE.MathUtils.degToRad(currentEl);
401
- const x = RADIUS * Math.sin(r_az) * Math.cos(r_el);
402
- const y = RADIUS * Math.sin(r_el) + CENTER.y;
403
- const z = RADIUS * Math.cos(r_az) * Math.cos(r_el);
404
-
405
- // Update Light Position
406
- gizmoGroup.position.set(x, y, z);
407
- gizmoGroup.lookAt(CENTER);
 
 
408
 
409
- dirLight.position.set(x, y, z); // Update actual shadow-casting light
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
- // Update Text
412
- badge.textContent = getPrompt(currentAz, currentEl);
413
 
414
- // Dynamic Coloring
415
- let colorHex = 0xFFD700; // Gold
416
- if (currentEl > 55 || currentEl < -55) colorHex = 0xFF4500; // Warning Orange
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
- const col = new THREE.Color(colorHex);
419
- sunMat.color.lerp(col, 0.1);
420
- beamMat.color.lerp(col, 0.1);
421
- glowMat.color.lerp(col, 0.1);
422
- badge.style.borderColor = '#' + col.getHexString();
423
- badge.style.color = '#' + col.getHexString();
424
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  renderer.render(scene, camera);
426
  }
427
- animate();
428
-
429
- // External Updates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  setInterval(() => {
431
- if (props.value && !isDragging) {
432
- // If controls updated externally (sliders), snap target
433
- if (Math.abs(props.value.azimuth - targetAz) > 0.5 || Math.abs(props.value.elevation - targetEl) > 0.5) {
434
- targetAz = props.value.azimuth;
435
- targetEl = props.value.elevation;
 
 
 
 
 
 
 
 
436
  }
437
  }
438
  }, 100);
439
-
440
- wrapper._updateTexture = updateTexture;
441
  };
 
442
  initScene();
443
  })();
444
  """
@@ -451,88 +606,154 @@ class LightControl3D(gr.HTML):
451
  **kwargs
452
  )
453
 
454
- # --- UI Layout ---
455
- css = """
456
- #col-container { max-width: 1400px; margin: 0 auto; }
457
- #3d-container { border-radius: 12px; overflow: hidden; }
458
- .range-slider { accent-color: #FFD700 !important; }
459
- """
460
-
461
- with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="yellow")) as demo:
462
  gr.Markdown("""
463
- # 💡 Qwen Edit 2509Studio Light Controller
464
- **Interactive Relighting:** Drag the light in the 3D studio to cast realistic shadows and generate lighting prompts.
 
 
465
  """)
466
 
467
  with gr.Row():
468
- # --- Left Column: Controls ---
469
- with gr.Column(scale=5):
470
- gr.Markdown("### 📸 Input & Control")
471
- with gr.Row():
472
- image_input = gr.Image(label="Input", type="pil", height=280)
473
- result_output = gr.Image(label="Result", height=280)
474
 
475
- gr.Markdown("### 🎮 Studio Controller")
476
- light_controller = LightControl3D(
 
 
477
  value={"azimuth": 0, "elevation": 0},
478
- elem_id="3d-container"
479
  )
 
480
 
481
- run_btn = gr.Button(" Relight Image", variant="primary", size="lg")
482
 
483
- with gr.Accordion("🎚️ Fine-Tune Parameters", open=False):
484
- with gr.Row():
485
- az_slider = gr.Slider(0, 359, value=0, label="Azimuth", step=1)
486
- el_slider = gr.Slider(-90, 90, value=0, label="Elevation", step=1)
487
- with gr.Row():
488
- seed = gr.Slider(0, MAX_SEED, value=42, label="Seed", step=1)
489
- randomize_seed = gr.Checkbox(True, label="Randomize")
490
- with gr.Row():
491
- cfg = gr.Slider(1.0, 10.0, value=5.0, label="Guidance (CFG)")
492
- steps = gr.Slider(1, 20, value=4, step=1, label="Steps")
493
- prompt_display = gr.Textbox(label="Active Prompt", interactive=False)
494
-
495
- # --- Event Wiring ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
- # 3D -> Sliders
498
- def on_3d_change(val):
499
- az, el = val.get('azimuth', 0), val.get('elevation', 0)
500
- return az, el, build_lighting_prompt(az, el)
501
-
502
- light_controller.change(on_3d_change, inputs=[light_controller], outputs=[az_slider, el_slider, prompt_display])
503
 
504
- # Sliders -> 3D
505
- def on_slider_change(az, el):
506
- return {"azimuth": az, "elevation": el}, build_lighting_prompt(az, el)
507
-
508
- az_slider.change(on_slider_change, inputs=[az_slider, el_slider], outputs=[light_controller, prompt_display])
509
- el_slider.change(on_slider_change, inputs=[az_slider, el_slider], outputs=[light_controller, prompt_display])
510
-
511
- # Upload
512
- def on_upload(img):
513
- w, h = update_dimensions_on_upload(img)
514
- if img is None: return w, h, gr.update(imageUrl=None)
 
 
 
 
 
 
 
 
 
 
 
 
515
  import base64
516
  from io import BytesIO
517
  buffered = BytesIO()
518
- img.save(buffered, format="PNG")
519
  img_str = base64.b64encode(buffered.getvalue()).decode()
520
- return w, h, gr.update(imageUrl=f"data:image/png;base64,{img_str}")
521
-
522
- image_input.upload(on_upload, inputs=[image_input], outputs=[gr.State(), gr.State(), light_controller])
523
-
524
- # Run
525
- def run_inference_wrapper(img, az, el, seed, rand, cfg, steps):
526
- w, h = update_dimensions_on_upload(img)
527
- res, _, _ = infer_lighting_edit(img, az, el, seed, rand, cfg, steps, h, w)
528
- return res
529
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  run_btn.click(
531
- run_inference_wrapper,
532
- inputs=[image_input, az_slider, el_slider, seed, randomize_seed, cfg, steps],
533
- outputs=[result_output]
534
  )
535
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  if __name__ == "__main__":
537
- head_js = '<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>'
538
- demo.launch(head=head_js)
 
 
4
  import torch
5
  import spaces
6
  from PIL import Image
 
 
7
  from diffusers import FlowMatchEulerDiscreteScheduler
8
+ from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
9
+ from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
10
+ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
 
 
 
11
 
12
  MAX_SEED = np.iinfo(np.int32).max
13
+ # --- Model Loading ---
 
14
  dtype = torch.bfloat16
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
16
  pipe = QwenImageEditPlusPipeline.from_pretrained(
17
  "Qwen/Qwen-Image-Edit-2511",
18
  transformer=QwenImageTransformer2DModel.from_pretrained(
 
22
  ),
23
  torch_dtype=dtype
24
  ).to(device)
 
25
  try:
26
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
27
  print("Flash Attention 3 Processor set successfully.")
 
33
  "repo": "dx8152/Qwen-Edit-2509-Multi-Angle-Lighting",
34
  "weights": "多角度灯光-251116.safetensors",
35
  "adapter_name": "multi-angle-lighting"
36
+ },
37
  }
38
+ loaded = False
39
+
40
+ # --- Prompt Building ---
41
+ # Azimuth mappings (8 positions)
42
+ AZIMUTH_MAP = {
43
+ 0: "Front",
44
+ 45: "Right Front",
45
+ 90: "Right",
46
+ 135: "Right Rear",
47
+ 180: "Rear",
48
+ 225: "Left Rear",
49
+ 270: "Left",
50
+ 315: "Left Front"
51
+ }
52
+ # Elevation mappings (3 positions)
53
+ ELEVATION_MAP = {
54
+ -90: "Below",
55
+ 0: "",
56
+ 90: "Above"
57
  }
58
 
59
+ def snap_to_nearest(value, options):
60
+ """Snap a value to the nearest option in a list."""
61
+ return min(options, key=lambda x: abs(x - value))
62
 
63
  def build_lighting_prompt(azimuth: float, elevation: float) -> str:
64
+ """
65
+ Build a lighting prompt from azimuth and elevation values.
66
+
67
+ Args:
68
+ azimuth: Horizontal rotation in degrees (0-360)
69
+ elevation: Vertical angle in degrees (-90 to 90)
70
+
71
+ Returns:
72
+ Formatted prompt string for the LoRA
73
+ """
74
+ # Snap to nearest valid values
75
+ azimuth_snapped = snap_to_nearest(azimuth, list(AZIMUTH_MAP.keys()))
76
+ elevation_snapped = snap_to_nearest(elevation, list(ELEVATION_MAP.keys()))
77
+
78
+ if elevation_snapped == 0:
79
+ return f"Light source from the {AZIMUTH_MAP[azimuth_snapped]}"
80
+ else:
81
+ return f"Light source from {ELEVATION_MAP[elevation_snapped]}"
82
 
 
83
  @spaces.GPU
84
+ def infer_lighting_edit(
85
+ image: Image.Image,
86
+ azimuth: float = 0.0,
87
+ elevation: float = 0.0,
88
+ seed: int = 0,
89
+ randomize_seed: bool = True,
90
+ guidance_scale: float = 1.0,
91
+ num_inference_steps: int = 4,
92
+ height: int = 1024,
93
+ width: int = 1024,
94
+ ):
95
+ """
96
+ Edit the lighting of an image using Qwen Image Edit 2511 with multi-angle lighting LoRA.
97
+ """
98
+ global loaded
99
+ progress = gr.Progress(track_tqdm=True)
100
 
101
+ if not loaded:
102
+ pipe.load_lora_weights(
103
+ ADAPTER_SPECS["Multi-Angle-Lighting"]["repo"],
104
+ weight_name=ADAPTER_SPECS["Multi-Angle-Lighting"]["weights"],
105
+ adapter_name=ADAPTER_SPECS["Multi-Angle-Lighting"]["adapter_name"]
106
+ )
107
+ pipe.set_adapters([ADAPTER_SPECS["Multi-Angle-Lighting"]["adapter_name"]], adapter_weights=[1.0])
108
+ loaded = True
109
 
110
+ prompt = build_lighting_prompt(azimuth, elevation)
111
+ print(f"Generated Prompt: {prompt}")
112
+ if randomize_seed:
113
+ seed = random.randint(0, MAX_SEED)
114
  generator = torch.Generator(device=device).manual_seed(seed)
115
+ if image is None:
116
+ raise gr.Error("Please upload an image first.")
117
+ pil_image = image.convert("RGB") if isinstance(image, Image.Image) else Image.open(image).convert("RGB")
118
  result = pipe(
119
+ image=[pil_image],
120
+ prompt=prompt,
121
+ height=height if height != 0 else None,
122
+ width=width if width != 0 else None,
123
+ num_inference_steps=num_inference_steps,
124
+ generator=generator,
125
+ guidance_scale=guidance_scale,
126
+ num_images_per_prompt=1,
127
  ).images[0]
128
  return result, seed, prompt
129
 
130
  def update_dimensions_on_upload(image):
131
+ """Compute recommended dimensions preserving aspect ratio."""
132
+ if image is None:
133
+ return 1024, 1024
134
+ original_width, original_height = image.size
135
+ if original_width > original_height:
136
+ new_width = 1024
137
+ aspect_ratio = original_height / original_width
138
+ new_height = int(new_width * aspect_ratio)
139
+ else:
140
+ new_height = 1024
141
+ aspect_ratio = original_width / original_height
142
+ new_width = int(new_height * aspect_ratio)
143
+ new_width = (new_width // 8) * 8
144
+ new_height = (new_height // 8) * 8
145
+ return new_width, new_height
146
+
147
+ # --- 3D Lighting Control Component ---
148
+ class LightingControl3D(gr.HTML):
149
  """
150
+ A 3D lighting control component using Three.js.
151
+ Outputs: { azimuth: number, elevation: number }
152
+ Accepts imageUrl prop to display user's uploaded image on the plane.
153
  """
154
  def __init__(self, value=None, imageUrl=None, **kwargs):
155
+ if value is None:
156
+ value = {"azimuth": 0, "elevation": 0}
157
 
158
  html_template = """
159
+ <div id="lighting-control-wrapper" style="width: 100%; height: 450px; position: relative; background: #1a1a1a; border-radius: 12px; overflow: hidden;">
160
+ <div id="prompt-overlay" style="position: absolute; bottom: 10px; left: 50%; transform: translateX(-50%); background: rgba(0,0,0,0.8); padding: 8px 16px; border-radius: 8px; font-family: monospace; font-size: 12px; color: #00ff88; white-space: nowrap; z-index: 10;"></div>
 
 
 
 
 
 
 
 
 
 
 
161
  </div>
162
  """
163
 
164
  js_on_load = """
165
  (() => {
166
+ const wrapper = element.querySelector('#lighting-control-wrapper');
167
+ const promptOverlay = element.querySelector('#prompt-overlay');
168
 
169
+ // Wait for THREE to load
170
  const initScene = () => {
171
+ if (typeof THREE === 'undefined') {
172
+ setTimeout(initScene, 100);
173
+ return;
174
+ }
175
 
176
+ // Scene setup
177
  const scene = new THREE.Scene();
178
+ scene.background = new THREE.Color(0x1a1a1a);
 
179
 
180
+ const camera = new THREE.PerspectiveCamera(50, wrapper.clientWidth / wrapper.clientHeight, 0.1, 1000);
181
+ camera.position.set(4.5, 3, 4.5);
182
+ camera.lookAt(0, 0.75, 0);
183
 
184
  const renderer = new THREE.WebGLRenderer({ antialias: true });
185
+ renderer.setSize(wrapper.clientWidth / wrapper.clientHeight);
186
  renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
187
+ wrapper.insertBefore(renderer.domElement, promptOverlay);
188
+
189
+ // Lighting
190
+ scene.add(new THREE.AmbientLight(0xffffff, 0.2));
191
+
192
+ // Grid
193
+ scene.add(new THREE.GridHelper(8, 16, 0x333333, 0x222222));
194
 
195
+ // Constants
196
  const CENTER = new THREE.Vector3(0, 0.75, 0);
197
+ const BASE_DISTANCE = 2.5;
198
+ const AZIMUTH_RADIUS = 2.4;
199
+ const ELEVATION_RADIUS = 1.8;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ // State
202
+ let azimuthAngle = props.value?.azimuth || 0;
203
+ let elevationAngle = props.value?.elevation || 0;
204
+
205
+ // Mappings
206
+ const azimuthSteps = [0, 45, 90, 135, 180, 225, 270, 315];
207
+ const elevationSteps = [-90, 0, 90];
208
+ const azimuthNames = {
209
+ 0: 'Front', 45: 'Right Front', 90: 'Right',
210
+ 135: 'Right Rear', 180: 'Rear', 225: 'Left Rear',
211
+ 270: 'Left', 315: 'Left Front'
212
+ };
213
+ const elevationNames = { '-90': 'Below', '0': '', '90': 'Above' };
214
+
215
+ function snapToNearest(value, steps) {
216
+ return steps.reduce((prev, curr) => Math.abs(curr - value) < Math.abs(prev - value) ? curr : prev);
217
  }
218
+
219
+ // Create placeholder texture (smiley face)
220
+ function createPlaceholderTexture() {
221
+ const canvas = document.createElement('canvas');
222
+ canvas.width = 256;
223
+ canvas.height = 256;
224
+ const ctx = canvas.getContext('2d');
225
+ ctx.fillStyle = '#3a3a4a';
226
+ ctx.fillRect(0, 0, 256, 256);
227
+ ctx.fillStyle = '#ffcc99';
228
+ ctx.beginPath();
229
+ ctx.arc(128, 128, 80, 0, Math.PI * 2);
230
+ ctx.fill();
231
+ ctx.fillStyle = '#333';
232
+ ctx.beginPath();
233
+ ctx.arc(100, 110, 10, 0, Math.PI * 2);
234
+ ctx.arc(156, 110, 10, 0, Math.PI * 2);
235
+ ctx.fill();
236
+ ctx.strokeStyle = '#333';
237
+ ctx.lineWidth = 3;
238
+ ctx.beginPath();
239
+ ctx.arc(128, 130, 35, 0.2, Math.PI - 0.2);
240
+ ctx.stroke();
241
+ return new THREE.CanvasTexture(canvas);
242
+ }
243
+
244
+ // Target image plane
245
+ let currentTexture = createPlaceholderTexture();
246
+ const planeMaterial = new THREE.MeshStandardMaterial({ map: currentTexture, side: THREE.DoubleSide, roughness: 0.5, metalness: 0 });
247
+ let targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
248
+ targetPlane.position.copy(CENTER);
249
+ scene.add(targetPlane);
250
+
251
+ // Function to update texture from image URL
252
+ function updateTextureFromUrl(url) {
253
+ if (!url) {
254
+ // Reset to placeholder
255
+ planeMaterial.map = createPlaceholderTexture();
256
+ planeMaterial.needsUpdate = true;
257
+ // Reset plane to square
258
+ scene.remove(targetPlane);
259
+ targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
260
+ targetPlane.position.copy(CENTER);
261
+ scene.add(targetPlane);
262
+ return;
263
+ }
264
+
265
+ const loader = new THREE.TextureLoader();
266
+ loader.crossOrigin = 'anonymous';
267
+ loader.load(url, (texture) => {
268
+ texture.minFilter = THREE.LinearFilter;
269
+ texture.magFilter = THREE.LinearFilter;
270
+ planeMaterial.map = texture;
271
+ planeMaterial.needsUpdate = true;
272
 
273
+ // Adjust plane aspect ratio to match image
274
+ const img = texture.image;
275
+ if (img && img.width && img.height) {
276
  const aspect = img.width / img.height;
277
+ const maxSize = 1.5;
278
+ let planeWidth, planeHeight;
279
+ if (aspect > 1) {
280
+ planeWidth = maxSize;
281
+ planeHeight = maxSize / aspect;
282
+ } else {
283
+ planeHeight = maxSize;
284
+ planeWidth = maxSize * aspect;
285
+ }
286
+ scene.remove(targetPlane);
287
+ targetPlane = new THREE.Mesh(
288
+ new THREE.PlaneGeometry(planeWidth, planeHeight),
289
+ planeMaterial
290
+ );
291
+ targetPlane.position.copy(CENTER);
292
+ scene.add(targetPlane);
293
  }
294
+ }, undefined, (err) => {
295
+ console.error('Failed to load texture:', err);
296
  });
297
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
+ // Check for initial imageUrl
300
+ if (props.imageUrl) {
301
+ updateTextureFromUrl(props.imageUrl);
 
 
 
 
302
  }
 
 
 
 
 
303
 
304
+ // Light model
305
+ const lightGroup = new THREE.Group();
306
+ const bulbMat = new THREE.MeshStandardMaterial({ color: 0xffff00, emissive: 0xffff00, emissiveIntensity: 1 });
307
+ const bulb = new THREE.Mesh(new THREE.SphereGeometry(0.15, 16, 16), bulbMat);
308
+ lightGroup.add(bulb);
309
+ const pointLight = new THREE.PointLight(0xffffff, 5, 0);
310
+ lightGroup.add(pointLight);
311
+ scene.add(lightGroup);
312
+
313
+ // GREEN: Azimuth ring
314
+ const azimuthRing = new THREE.Mesh(
315
+ new THREE.TorusGeometry(AZIMUTH_RADIUS, 0.04, 16, 64),
316
+ new THREE.MeshStandardMaterial({ color: 0x00ff88, emissive: 0x00ff88, emissiveIntensity: 0.3 })
317
  );
318
+ azimuthRing.rotation.x = Math.PI / 2;
319
+ azimuthRing.position.y = 0.05;
320
+ scene.add(azimuthRing);
321
+
322
+ const azimuthHandle = new THREE.Mesh(
323
+ new THREE.SphereGeometry(0.18, 16, 16),
324
+ new THREE.MeshStandardMaterial({ color: 0x00ff88, emissive: 0x00ff88, emissiveIntensity: 0.5 })
325
+ );
326
+ azimuthHandle.userData.type = 'azimuth';
327
+ scene.add(azimuthHandle);
328
+
329
+ // PINK: Elevation arc
330
+ const arcPoints = [];
331
+ for (let i = 0; i <= 32; i++) {
332
+ const angle = THREE.MathUtils.degToRad(-90 + (180 * i / 32));
333
+ arcPoints.push(new THREE.Vector3(-0.8, ELEVATION_RADIUS * Math.sin(angle) + CENTER.y, ELEVATION_RADIUS * Math.cos(angle)));
334
  }
335
+ const arcCurve = new THREE.CatmullRomCurve3(arcPoints);
336
+ const elevationArc = new THREE.Mesh(
337
+ new THREE.TubeGeometry(arcCurve, 32, 0.04, 8, false),
338
+ new THREE.MeshStandardMaterial({ color: 0xff69b4, emissive: 0xff69b4, emissiveIntensity: 0.3 })
339
+ );
340
+ scene.add(elevationArc);
341
+
342
+ const elevationHandle = new THREE.Mesh(
343
+ new THREE.SphereGeometry(0.18, 16, 16),
344
+ new THREE.MeshStandardMaterial({ color: 0xff69b4, emissive: 0xff69b4, emissiveIntensity: 0.5 })
345
+ );
346
+ elevationHandle.userData.type = 'elevation';
347
+ scene.add(elevationHandle);
348
+
349
+ function updatePositions() {
350
+ const distance = BASE_DISTANCE;
351
+ const azRad = THREE.MathUtils.degToRad(azimuthAngle);
352
+ const elRad = THREE.MathUtils.degToRad(elevationAngle);
353
 
354
+ const lightX = distance * Math.sin(azRad) * Math.cos(elRad);
355
+ const lightY = distance * Math.sin(elRad) + CENTER.y;
356
+ const lightZ = distance * Math.cos(azRad) * Math.cos(elRad);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
+ lightGroup.position.set(lightX, lightY, lightZ);
 
 
 
 
359
 
360
+ azimuthHandle.position.set(AZIMUTH_RADIUS * Math.sin(azRad), 0.05, AZIMUTH_RADIUS * Math.cos(azRad));
361
+ elevationHandle.position.set(-0.8, ELEVATION_RADIUS * Math.sin(elRad) + CENTER.y, ELEVATION_RADIUS * Math.cos(elRad));
362
+
363
+ // Update prompt
364
+ const azSnap = snapToNearest(azimuthAngle, azimuthSteps);
365
+ const elSnap = snapToNearest(elevationAngle, elevationSteps);
366
+ let prompt = 'Light source from';
367
+ if (elSnap !== 0) {
368
+ prompt += ' ' + elevationNames[String(elSnap)];
369
+ } else {
370
+ prompt += ' the ' + azimuthNames[azSnap];
 
 
371
  }
372
+ promptOverlay.textContent = prompt;
373
  }
374
+
375
+ function updatePropsAndTrigger() {
376
+ const azSnap = snapToNearest(azimuthAngle, azimuthSteps);
377
+ const elSnap = snapToNearest(elevationAngle, elevationSteps);
 
 
 
 
 
 
 
378
 
379
+ props.value = { azimuth: azSnap, elevation: elSnap };
380
+ trigger('change', props.value);
381
+ }
382
+
383
+ // Raycasting
384
+ const raycaster = new THREE.Raycaster();
385
+ const mouse = new THREE.Vector2();
386
+ let isDragging = false;
387
+ let dragTarget = null;
388
+ let dragStartMouse = new THREE.Vector2();
389
+ const intersection = new THREE.Vector3();
390
+
391
+ const canvas = renderer.domElement;
392
+
393
+ canvas.addEventListener('mousedown', (e) => {
394
+ const rect = canvas.getBoundingClientRect();
395
+ mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
396
+ mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
397
 
398
+ raycaster.setFromCamera(mouse, camera);
399
+ const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle]);
 
 
 
400
 
401
+ if (intersects.length > 0) {
402
+ isDragging = true;
403
+ dragTarget = intersects[0].object;
404
+ dragTarget.material.emissiveIntensity = 1.0;
405
+ dragTarget.scale.setScalar(1.3);
406
+ dragStartMouse.copy(mouse);
407
+ canvas.style.cursor = 'grabbing';
408
+ }
409
+ });
410
+
411
+ canvas.addEventListener('mousemove', (e) => {
412
+ const rect = canvas.getBoundingClientRect();
413
+ mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
414
+ mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
415
 
416
+ if (isDragging && dragTarget) {
417
+ raycaster.setFromCamera(mouse, camera);
418
+
419
+ if (dragTarget.userData.type === 'azimuth') {
420
+ const plane = new THREE.Plane(new THREE.Vector3(0, 1, 0), -0.05);
421
+ if (raycaster.ray.intersectPlane(plane, intersection)) {
422
+ azimuthAngle = THREE.MathUtils.radToDeg(Math.atan2(intersection.x, intersection.z));
423
+ if (azimuthAngle < 0) azimuthAngle += 360;
424
+ }
425
+ } else if (dragTarget.userData.type === 'elevation') {
426
+ const plane = new THREE.Plane(new THREE.Vector3(1, 0, 0), -0.8);
427
+ if (raycaster.ray.intersectPlane(plane, intersection)) {
428
+ const relY = intersection.y - CENTER.y;
429
+ const relZ = intersection.z;
430
+ elevationAngle = THREE.MathUtils.clamp(THREE.MathUtils.radToDeg(Math.atan2(relY, relZ)), -90, 90);
431
+ }
432
+ }
433
+ updatePositions();
434
+ } else {
435
+ raycaster.setFromCamera(mouse, camera);
436
+ const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle]);
437
+ [azimuthHandle, elevationHandle].forEach(h => {
438
+ h.material.emissiveIntensity = 0.5;
439
+ h.scale.setScalar(1);
440
+ });
441
+ if (intersects.length > 0) {
442
+ intersects[0].object.material.emissiveIntensity = 0.8;
443
+ intersects[0].object.scale.setScalar(1.1);
444
+ canvas.style.cursor = 'grab';
445
+ } else {
446
+ canvas.style.cursor = 'default';
447
+ }
448
+ }
449
+ });
450
+
451
+ const onMouseUp = () => {
452
+ if (dragTarget) {
453
+ dragTarget.material.emissiveIntensity = 0.5;
454
+ dragTarget.scale.setScalar(1);
455
+
456
+ // Snap and animate
457
+ const targetAz = snapToNearest(azimuthAngle, azimuthSteps);
458
+ const targetEl = snapToNearest(elevationAngle, elevationSteps);
459
+
460
+ const startAz = azimuthAngle, startEl = elevationAngle;
461
+ const startTime = Date.now();
462
+
463
+ function animateSnap() {
464
+ const t = Math.min((Date.now() - startTime) / 200, 1);
465
+ const ease = 1 - Math.pow(1 - t, 3);
466
+
467
+ let azDiff = targetAz - startAz;
468
+ if (azDiff > 180) azDiff -= 360;
469
+ if (azDiff < -180) azDiff += 360;
470
+ azimuthAngle = startAz + azDiff * ease;
471
+ if (azimuthAngle < 0) azimuthAngle += 360;
472
+ if (azimuthAngle >= 360) azimuthAngle -= 360;
473
+
474
+ elevationAngle = startEl + (targetEl - startEl) * ease;
475
+
476
+ updatePositions();
477
+ if (t < 1) requestAnimationFrame(animateSnap);
478
+ else updatePropsAndTrigger();
479
+ }
480
+ animateSnap();
481
+ }
482
+ isDragging = false;
483
+ dragTarget = null;
484
+ canvas.style.cursor = 'default';
485
+ };
486
+
487
+ canvas.addEventListener('mouseup', onMouseUp);
488
+ canvas.addEventListener('mouseleave', onMouseUp);
489
+ // Touch support for mobile
490
+ canvas.addEventListener('touchstart', (e) => {
491
+ e.preventDefault();
492
+ const touch = e.touches[0];
493
+ const rect = canvas.getBoundingClientRect();
494
+ mouse.x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
495
+ mouse.y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
496
 
497
+ raycaster.setFromCamera(mouse, camera);
498
+ const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle]);
499
 
500
+ if (intersects.length > 0) {
501
+ isDragging = true;
502
+ dragTarget = intersects[0].object;
503
+ dragTarget.material.emissiveIntensity = 1.0;
504
+ dragTarget.scale.setScalar(1.3);
505
+ dragStartMouse.copy(mouse);
506
+ }
507
+ }, { passive: false });
508
+
509
+ canvas.addEventListener('touchmove', (e) => {
510
+ e.preventDefault();
511
+ const touch = e.touches[0];
512
+ const rect = canvas.getBoundingClientRect();
513
+ mouse.x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
514
+ mouse.y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
515
 
516
+ if (isDragging && dragTarget) {
517
+ raycaster.setFromCamera(mouse, camera);
518
+
519
+ if (dragTarget.userData.type === 'azimuth') {
520
+ const plane = new THREE.Plane(new THREE.Vector3(0, 1, 0), -0.05);
521
+ if (raycaster.ray.intersectPlane(plane, intersection)) {
522
+ azimuthAngle = THREE.MathUtils.radToDeg(Math.atan2(intersection.x, intersection.z));
523
+ if (azimuthAngle < 0) azimuthAngle += 360;
524
+ }
525
+ } else if (dragTarget.userData.type === 'elevation') {
526
+ const plane = new THREE.Plane(new THREE.Vector3(1, 0, 0), -0.8);
527
+ if (raycaster.ray.intersectPlane(plane, intersection)) {
528
+ const relY = intersection.y - CENTER.y;
529
+ const relZ = intersection.z;
530
+ elevationAngle = THREE.MathUtils.clamp(THREE.MathUtils.radToDeg(Math.atan2(relY, relZ)), -90, 90);
531
+ }
532
+ }
533
+ updatePositions();
534
+ }
535
+ }, { passive: false });
536
+
537
+ canvas.addEventListener('touchend', (e) => {
538
+ e.preventDefault();
539
+ onMouseUp();
540
+ }, { passive: false });
541
+
542
+ canvas.addEventListener('touchcancel', (e) => {
543
+ e.preventDefault();
544
+ onMouseUp();
545
+ }, { passive: false });
546
+
547
+ // Initial update
548
+ updatePositions();
549
+
550
+ // Render loop
551
+ function render() {
552
+ requestAnimationFrame(render);
553
  renderer.render(scene, camera);
554
  }
555
+ render();
556
+
557
+ // Handle resize
558
+ new ResizeObserver(() => {
559
+ camera.aspect = wrapper.clientWidth / wrapper.clientHeight;
560
+ camera.updateProjectionMatrix();
561
+ renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
562
+ }).observe(wrapper);
563
+
564
+ // Store update functions for external calls
565
+ wrapper._updateFromProps = (newVal) => {
566
+ if (newVal && typeof newVal === 'object') {
567
+ azimuthAngle = newVal.azimuth ?? azimuthAngle;
568
+ elevationAngle = newVal.elevation ?? elevationAngle;
569
+ updatePositions();
570
+ }
571
+ };
572
+
573
+ wrapper._updateTexture = updateTextureFromUrl;
574
+
575
+ // Watch for prop changes (imageUrl and value)
576
+ let lastImageUrl = props.imageUrl;
577
+ let lastValue = JSON.stringify(props.value);
578
  setInterval(() => {
579
+ // Check imageUrl changes
580
+ if (props.imageUrl !== lastImageUrl) {
581
+ lastImageUrl = props.imageUrl;
582
+ updateTextureFromUrl(props.imageUrl);
583
+ }
584
+ // Check value changes (from sliders)
585
+ const currentValue = JSON.stringify(props.value);
586
+ if (currentValue !== lastValue) {
587
+ lastValue = currentValue;
588
+ if (props.value && typeof props.value === 'object') {
589
+ azimuthAngle = props.value.azimuth ?? azimuthAngle;
590
+ elevationAngle = props.value.elevation ?? elevationAngle;
591
+ updatePositions();
592
  }
593
  }
594
  }, 100);
 
 
595
  };
596
+
597
  initScene();
598
  })();
599
  """
 
606
  **kwargs
607
  )
608
 
609
+ # --- UI ---
610
+ css = '''
611
+ #col-container { max-width: 1200px; margin: 0 auto; }
612
+ .dark .progress-text { color: white !important; }
613
+ #lighting-3d-control { min-height: 450px; }
614
+ .slider-row { display: flex; gap: 10px; align-items: center; }
615
+ '''
616
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
617
  gr.Markdown("""
618
+ # 🎬 Qwen Image Edit 25113D Lighting Control
619
+
620
+ Control lighting directions using the **3D viewport** or **sliders**.
621
+ Using [dx8152/Qwen-Edit-2509-Multi-Angle-Lighting] for precise lighting control.
622
  """)
623
 
624
  with gr.Row():
625
+ # Left column: Input image and controls
626
+ with gr.Column(scale=1):
627
+ image = gr.Image(label="Input Image", type="pil", height=300)
 
 
 
628
 
629
+ gr.Markdown("### 🎮 3D Lighting Control")
630
+ gr.Markdown("*Drag the colored handles: 🟢 Azimuth (Direction), 🩷 Elevation (Height)*")
631
+
632
+ lighting_3d = LightingControl3D(
633
  value={"azimuth": 0, "elevation": 0},
634
+ elem_id="lighting-3d-control"
635
  )
636
+ run_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
637
 
638
+ gr.Markdown("### 🎚️ Slider Controls")
639
 
640
+ azimuth_slider = gr.Slider(
641
+ label="Azimuth (Horizontal Rotation)",
642
+ minimum=0,
643
+ maximum=315,
644
+ step=45,
645
+ value=0,
646
+ info="0°=front, 90°=right, 180°=rear, 270°=left"
647
+ )
648
+
649
+ elevation_slider = gr.Slider(
650
+ label="Elevation (Vertical Angle)",
651
+ minimum=-90,
652
+ maximum=90,
653
+ step=90,
654
+ value=0,
655
+ info="-90°=from below, 0°=horizontal, 90°=from above"
656
+ )
657
+
658
+ prompt_preview = gr.Textbox(
659
+ label="Generated Prompt",
660
+ value="Light source from the Front",
661
+ interactive=False
662
+ )
663
+
664
+ # Right column: Output
665
+ with gr.Column(scale=1):
666
+ result = gr.Image(label="Output Image", height=500)
667
+
668
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
669
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
670
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
671
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
672
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=20, step=1, value=4)
673
+ height = gr.Slider(label="Height", minimum=256, maximum=2048, step=8, value=1024)
674
+ width = gr.Slider(label="Width", minimum=256, maximum=2048, step=8, value=1024)
675
 
676
+ # --- Event Handlers ---
 
 
 
 
 
677
 
678
+ def update_prompt_from_sliders(azimuth, elevation):
679
+ """Update prompt preview when sliders change."""
680
+ prompt = build_lighting_prompt(azimuth, elevation)
681
+ return prompt
682
+
683
+ def sync_3d_to_sliders(lighting_value):
684
+ """Sync 3D control changes to sliders."""
685
+ if lighting_value and isinstance(lighting_value, dict):
686
+ az = lighting_value.get('azimuth', 0)
687
+ el = lighting_value.get('elevation', 0)
688
+ prompt = build_lighting_prompt(az, el)
689
+ return az, el, prompt
690
+ return gr.update(), gr.update(), gr.update()
691
+
692
+ def sync_sliders_to_3d(azimuth, elevation):
693
+ """Sync slider changes to 3D control."""
694
+ return {"azimuth": azimuth, "elevation": elevation}
695
+
696
+ def update_3d_image(image):
697
+ """Update the 3D component with the uploaded image."""
698
+ if image is None:
699
+ return gr.update(imageUrl=None)
700
+ # Convert PIL image to base64 data URL
701
  import base64
702
  from io import BytesIO
703
  buffered = BytesIO()
704
+ image.save(buffered, format="PNG")
705
  img_str = base64.b64encode(buffered.getvalue()).decode()
706
+ data_url = f"data:image/png;base64,{img_str}"
707
+ return gr.update(imageUrl=data_url)
708
+
709
+ # Slider -> Prompt preview
710
+ for slider in [azimuth_slider, elevation_slider]:
711
+ slider.change(
712
+ fn=update_prompt_from_sliders,
713
+ inputs=[azimuth_slider, elevation_slider],
714
+ outputs=[prompt_preview]
715
+ )
716
+
717
+ # 3D control -> Sliders + Prompt
718
+ lighting_3d.change(
719
+ fn=sync_3d_to_sliders,
720
+ inputs=[lighting_3d],
721
+ outputs=[azimuth_slider, elevation_slider, prompt_preview]
722
+ )
723
+
724
+ # Sliders -> 3D control
725
+ for slider in [azimuth_slider, elevation_slider]:
726
+ slider.release(
727
+ fn=sync_sliders_to_3d,
728
+ inputs=[azimuth_slider, elevation_slider],
729
+ outputs=[lighting_3d]
730
+ )
731
+
732
+ # Generate button
733
  run_btn.click(
734
+ fn=infer_lighting_edit,
735
+ inputs=[image, azimuth_slider, elevation_slider, seed, randomize_seed, guidance_scale, num_inference_steps, height, width],
736
+ outputs=[result, seed, prompt_preview]
737
  )
738
+
739
+ # Image upload -> update dimensions AND update 3D preview
740
+ image.upload(
741
+ fn=update_dimensions_on_upload,
742
+ inputs=[image],
743
+ outputs=[width, height]
744
+ ).then(
745
+ fn=update_3d_image,
746
+ inputs=[image],
747
+ outputs=[lighting_3d]
748
+ )
749
+
750
+ # Also handle image clear
751
+ image.clear(
752
+ fn=lambda: gr.update(imageUrl=None),
753
+ outputs=[lighting_3d]
754
+ )
755
+
756
  if __name__ == "__main__":
757
+ head = '<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>'
758
+ css = '.fillable{max-width: 1200px !important}'
759
+ demo.launch(head=head, css=css)