prithivMLmods commited on
Commit
6a70d3e
·
verified ·
1 Parent(s): 13ca277

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +350 -541
app.py CHANGED
@@ -4,15 +4,26 @@ import random
4
  import torch
5
  import spaces
6
  from PIL import Image
 
 
7
  from diffusers import FlowMatchEulerDiscreteScheduler
8
- from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
9
- from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
10
- from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
 
 
 
 
11
 
12
  MAX_SEED = np.iinfo(np.int32).max
13
- # --- Model Loading ---
 
14
  dtype = torch.bfloat16
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
16
  pipe = QwenImageEditPlusPipeline.from_pretrained(
17
  "Qwen/Qwen-Image-Edit-2511",
18
  transformer=QwenImageTransformer2DModel.from_pretrained(
@@ -22,63 +33,53 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
22
  ),
23
  torch_dtype=dtype
24
  ).to(device)
 
 
25
  try:
26
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
27
  print("Flash Attention 3 Processor set successfully.")
28
  except Exception as e:
29
- print(f"Warning: Could not set FA3 processor: {e}")
30
 
 
31
  ADAPTER_SPECS = {
32
  "Multi-Angle-Lighting": {
33
  "repo": "dx8152/Qwen-Edit-2509-Multi-Angle-Lighting",
34
  "weights": "多角度灯光-251116.safetensors",
35
  "adapter_name": "multi-angle-lighting"
36
- },
37
- }
38
- loaded = False
39
-
40
- # --- Prompt Building ---
41
- # Azimuth mappings (8 positions)
42
- AZIMUTH_MAP = {
43
- 0: "Front",
44
- 45: "Right Front",
45
- 90: "Right",
46
- 135: "Right Rear",
47
- 180: "Rear",
48
- 225: "Left Rear",
49
- 270: "Left",
50
- 315: "Left Front"
51
  }
52
- # Elevation mappings (3 positions)
53
- ELEVATION_MAP = {
54
- -90: "Below",
55
- 0: "",
56
- 90: "Above"
 
 
 
 
 
 
 
 
 
 
 
57
  }
58
 
59
  def snap_to_nearest(value, options):
60
- """Snap a value to the nearest option in a list."""
61
  return min(options, key=lambda x: abs(x - value))
62
 
63
  def build_lighting_prompt(azimuth: float, elevation: float) -> str:
64
- """
65
- Build a lighting prompt from azimuth and elevation values.
66
-
67
- Args:
68
- azimuth: Horizontal rotation in degrees (0-360)
69
- elevation: Vertical angle in degrees (-90 to 90)
70
-
71
- Returns:
72
- Formatted prompt string for the LoRA
73
- """
74
- # Snap to nearest valid values
75
- azimuth_snapped = snap_to_nearest(azimuth, list(AZIMUTH_MAP.keys()))
76
- elevation_snapped = snap_to_nearest(elevation, list(ELEVATION_MAP.keys()))
77
-
78
- if elevation_snapped == 0:
79
- return f"Light source from the {AZIMUTH_MAP[azimuth_snapped]}"
80
- else:
81
- return f"Light source from {ELEVATION_MAP[elevation_snapped]}"
82
 
83
  @spaces.GPU
84
  def infer_lighting_edit(
@@ -92,29 +93,33 @@ def infer_lighting_edit(
92
  height: int = 1024,
93
  width: int = 1024,
94
  ):
95
- """
96
- Edit the lighting of an image using Qwen Image Edit 2511 with multi-angle lighting LoRA.
97
- """
98
- global loaded
99
- progress = gr.Progress(track_tqdm=True)
100
 
101
- if not loaded:
 
 
 
102
  pipe.load_lora_weights(
103
- ADAPTER_SPECS["Multi-Angle-Lighting"]["repo"],
104
- weight_name=ADAPTER_SPECS["Multi-Angle-Lighting"]["weights"],
105
- adapter_name=ADAPTER_SPECS["Multi-Angle-Lighting"]["adapter_name"]
106
  )
107
- pipe.set_adapters([ADAPTER_SPECS["Multi-Angle-Lighting"]["adapter_name"]], adapter_weights=[1.0])
108
- loaded = True
109
 
 
110
  prompt = build_lighting_prompt(azimuth, elevation)
111
- print(f"Generated Prompt: {prompt}")
 
112
  if randomize_seed:
113
  seed = random.randint(0, MAX_SEED)
114
  generator = torch.Generator(device=device).manual_seed(seed)
 
115
  if image is None:
116
  raise gr.Error("Please upload an image first.")
 
117
  pil_image = image.convert("RGB") if isinstance(image, Image.Image) else Image.open(image).convert("RGB")
 
118
  result = pipe(
119
  image=[pil_image],
120
  prompt=prompt,
@@ -125,10 +130,10 @@ def infer_lighting_edit(
125
  guidance_scale=guidance_scale,
126
  num_images_per_prompt=1,
127
  ).images[0]
 
128
  return result, seed, prompt
129
 
130
  def update_dimensions_on_upload(image):
131
- """Compute recommended dimensions preserving aspect ratio."""
132
  if image is None:
133
  return 1024, 1024
134
  original_width, original_height = image.size
@@ -145,455 +150,318 @@ def update_dimensions_on_upload(image):
145
  return new_width, new_height
146
 
147
  # --- 3D Lighting Control Component ---
148
- class LightingControl3D(gr.HTML):
149
  """
150
- A 3D lighting control component using Three.js.
151
- Outputs: { azimuth: number, elevation: number }
152
- Accepts imageUrl prop to display user's uploaded image on the plane.
153
  """
154
  def __init__(self, value=None, imageUrl=None, **kwargs):
155
  if value is None:
156
  value = {"azimuth": 0, "elevation": 0}
157
 
158
  html_template = """
159
- <div id="lighting-control-wrapper" style="width: 100%; height: 450px; position: relative; background: #1a1a1a; border-radius: 12px; overflow: hidden;">
160
- <div id="prompt-overlay" style="position: absolute; bottom: 10px; left: 50%; transform: translateX(-50%); background: rgba(0,0,0,0.8); padding: 8px 16px; border-radius: 8px; font-family: monospace; font-size: 12px; color: #00ff88; white-space: nowrap; z-index: 10;"></div>
 
 
 
 
 
161
  </div>
162
  """
163
 
164
  js_on_load = """
165
  (() => {
166
- const wrapper = element.querySelector('#lighting-control-wrapper');
167
  const promptOverlay = element.querySelector('#prompt-overlay');
168
 
169
- // Wait for THREE to load
170
  const initScene = () => {
171
- if (typeof THREE === 'undefined') {
172
  setTimeout(initScene, 100);
173
  return;
174
  }
175
 
176
- // Scene setup
177
  const scene = new THREE.Scene();
178
- scene.background = new THREE.Color(0x1a1a1a);
179
 
180
- const camera = new THREE.PerspectiveCamera(50, wrapper.clientWidth / wrapper.clientHeight, 0.1, 1000);
181
- camera.position.set(4.5, 3, 4.5);
182
- camera.lookAt(0, 0.75, 0);
183
 
184
- const renderer = new THREE.WebGLRenderer({ antialias: true });
185
- renderer.setSize(wrapper.clientWidth / wrapper.clientHeight);
186
  renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
 
 
187
  wrapper.insertBefore(renderer.domElement, promptOverlay);
188
 
189
- // Lighting
190
- scene.add(new THREE.AmbientLight(0xffffff, 0.2));
191
-
192
- // Grid
193
- scene.add(new THREE.GridHelper(8, 16, 0x333333, 0x222222));
194
-
195
- // Constants
 
 
196
  const CENTER = new THREE.Vector3(0, 0.75, 0);
197
- const BASE_DISTANCE = 2.5;
198
- const AZIMUTH_RADIUS = 2.4;
199
- const ELEVATION_RADIUS = 1.8;
200
-
201
- // State
202
- let azimuthAngle = props.value?.azimuth || 0;
203
- let elevationAngle = props.value?.elevation || 0;
204
-
205
- // Mappings
206
- const azimuthSteps = [0, 45, 90, 135, 180, 225, 270, 315];
207
- const elevationSteps = [-90, 0, 90];
208
- const azimuthNames = {
209
- 0: 'Front', 45: 'Right Front', 90: 'Right',
210
- 135: 'Right Rear', 180: 'Rear', 225: 'Left Rear',
211
- 270: 'Left', 315: 'Left Front'
212
- };
213
- const elevationNames = { '-90': 'Below', '0': '', '90': 'Above' };
214
 
215
- function snapToNearest(value, steps) {
216
- return steps.reduce((prev, curr) => Math.abs(curr - value) < Math.abs(prev - value) ? curr : prev);
217
- }
218
-
219
- // Create placeholder texture (smiley face)
220
- function createPlaceholderTexture() {
221
- const canvas = document.createElement('canvas');
222
- canvas.width = 256;
223
- canvas.height = 256;
224
- const ctx = canvas.getContext('2d');
225
- ctx.fillStyle = '#3a3a4a';
226
- ctx.fillRect(0, 0, 256, 256);
227
- ctx.fillStyle = '#ffcc99';
228
- ctx.beginPath();
229
- ctx.arc(128, 128, 80, 0, Math.PI * 2);
230
- ctx.fill();
231
- ctx.fillStyle = '#333';
232
- ctx.beginPath();
233
- ctx.arc(100, 110, 10, 0, Math.PI * 2);
234
- ctx.arc(156, 110, 10, 0, Math.PI * 2);
235
- ctx.fill();
236
- ctx.strokeStyle = '#333';
237
- ctx.lineWidth = 3;
238
- ctx.beginPath();
239
- ctx.arc(128, 130, 35, 0.2, Math.PI - 0.2);
240
- ctx.stroke();
241
- return new THREE.CanvasTexture(canvas);
242
- }
243
 
244
- // Target image plane
245
- let currentTexture = createPlaceholderTexture();
246
- const planeMaterial = new THREE.MeshStandardMaterial({ map: currentTexture, side: THREE.DoubleSide, roughness: 0.5, metalness: 0 });
247
- let targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
248
- targetPlane.position.copy(CENTER);
249
- scene.add(targetPlane);
250
-
251
- // Function to update texture from image URL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  function updateTextureFromUrl(url) {
253
  if (!url) {
254
- // Reset to placeholder
255
- planeMaterial.map = createPlaceholderTexture();
256
- planeMaterial.needsUpdate = true;
257
- // Reset plane to square
258
- scene.remove(targetPlane);
259
- targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
260
- targetPlane.position.copy(CENTER);
261
- scene.add(targetPlane);
262
  return;
263
  }
264
-
265
- const loader = new THREE.TextureLoader();
266
- loader.crossOrigin = 'anonymous';
267
- loader.load(url, (texture) => {
268
- texture.minFilter = THREE.LinearFilter;
269
- texture.magFilter = THREE.LinearFilter;
270
- planeMaterial.map = texture;
271
- planeMaterial.needsUpdate = true;
272
-
273
- // Adjust plane aspect ratio to match image
274
- const img = texture.image;
275
- if (img && img.width && img.height) {
276
  const aspect = img.width / img.height;
277
- const maxSize = 1.5;
278
- let planeWidth, planeHeight;
279
- if (aspect > 1) {
280
- planeWidth = maxSize;
281
- planeHeight = maxSize / aspect;
282
- } else {
283
- planeHeight = maxSize;
284
- planeWidth = maxSize * aspect;
285
- }
286
- scene.remove(targetPlane);
287
- targetPlane = new THREE.Mesh(
288
- new THREE.PlaneGeometry(planeWidth, planeHeight),
289
- planeMaterial
290
  );
291
- targetPlane.position.copy(CENTER);
292
- scene.add(targetPlane);
293
  }
294
- }, undefined, (err) => {
295
- console.error('Failed to load texture:', err);
296
  });
297
  }
298
-
299
- // Check for initial imageUrl
300
- if (props.imageUrl) {
301
- updateTextureFromUrl(props.imageUrl);
302
- }
303
-
304
- // Light model
305
  const lightGroup = new THREE.Group();
306
- const bulbMat = new THREE.MeshStandardMaterial({ color: 0xffff00, emissive: 0xffff00, emissiveIntensity: 1 });
307
- const bulb = new THREE.Mesh(new THREE.SphereGeometry(0.15, 16, 16), bulbMat);
308
- lightGroup.add(bulb);
309
- const pointLight = new THREE.PointLight(0xffffff, 5, 0);
310
- lightGroup.add(pointLight);
311
  scene.add(lightGroup);
312
 
313
- // GREEN: Azimuth ring
314
- const azimuthRing = new THREE.Mesh(
315
- new THREE.TorusGeometry(AZIMUTH_RADIUS, 0.04, 16, 64),
316
- new THREE.MeshStandardMaterial({ color: 0x00ff88, emissive: 0x00ff88, emissiveIntensity: 0.3 })
317
  );
318
- azimuthRing.rotation.x = Math.PI / 2;
319
- azimuthRing.position.y = 0.05;
320
- scene.add(azimuthRing);
321
 
322
- const azimuthHandle = new THREE.Mesh(
323
- new THREE.SphereGeometry(0.18, 16, 16),
324
- new THREE.MeshStandardMaterial({ color: 0x00ff88, emissive: 0x00ff88, emissiveIntensity: 0.5 })
 
325
  );
326
- azimuthHandle.userData.type = 'azimuth';
327
- scene.add(azimuthHandle);
328
-
329
- // PINK: Elevation arc
330
- const arcPoints = [];
331
- for (let i = 0; i <= 32; i++) {
332
- const angle = THREE.MathUtils.degToRad(-90 + (180 * i / 32));
333
- arcPoints.push(new THREE.Vector3(-0.8, ELEVATION_RADIUS * Math.sin(angle) + CENTER.y, ELEVATION_RADIUS * Math.cos(angle)));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  }
335
- const arcCurve = new THREE.CatmullRomCurve3(arcPoints);
336
- const elevationArc = new THREE.Mesh(
337
- new THREE.TubeGeometry(arcCurve, 32, 0.04, 8, false),
338
- new THREE.MeshStandardMaterial({ color: 0xff69b4, emissive: 0xff69b4, emissiveIntensity: 0.3 })
339
- );
340
- scene.add(elevationArc);
341
-
342
- const elevationHandle = new THREE.Mesh(
343
- new THREE.SphereGeometry(0.18, 16, 16),
344
- new THREE.MeshStandardMaterial({ color: 0xff69b4, emissive: 0xff69b4, emissiveIntensity: 0.5 })
345
- );
346
- elevationHandle.userData.type = 'elevation';
347
- scene.add(elevationHandle);
348
-
349
  function updatePositions() {
350
- const distance = BASE_DISTANCE;
351
  const azRad = THREE.MathUtils.degToRad(azimuthAngle);
352
  const elRad = THREE.MathUtils.degToRad(elevationAngle);
353
 
354
- const lightX = distance * Math.sin(azRad) * Math.cos(elRad);
355
- const lightY = distance * Math.sin(elRad) + CENTER.y;
356
- const lightZ = distance * Math.cos(azRad) * Math.cos(elRad);
357
 
358
- lightGroup.position.set(lightX, lightY, lightZ);
 
359
 
360
- azimuthHandle.position.set(AZIMUTH_RADIUS * Math.sin(azRad), 0.05, AZIMUTH_RADIUS * Math.cos(azRad));
361
- elevationHandle.position.set(-0.8, ELEVATION_RADIUS * Math.sin(elRad) + CENTER.y, ELEVATION_RADIUS * Math.cos(elRad));
362
-
363
- // Update prompt
364
- const azSnap = snapToNearest(azimuthAngle, azimuthSteps);
365
- const elSnap = snapToNearest(elevationAngle, elevationSteps);
366
- let prompt = 'Light source from';
367
- if (elSnap !== 0) {
368
- prompt += ' ' + elevationNames[String(elSnap)];
369
- } else {
370
- prompt += ' the ' + azimuthNames[azSnap];
371
- }
372
- promptOverlay.textContent = prompt;
373
- }
374
-
375
- function updatePropsAndTrigger() {
376
- const azSnap = snapToNearest(azimuthAngle, azimuthSteps);
377
- const elSnap = snapToNearest(elevationAngle, elevationSteps);
378
 
379
- props.value = { azimuth: azSnap, elevation: elSnap };
380
- trigger('change', props.value);
 
 
 
 
381
  }
 
 
 
382
 
383
- // Raycasting
384
  const raycaster = new THREE.Raycaster();
385
  const mouse = new THREE.Vector2();
386
- let isDragging = false;
387
- let dragTarget = null;
388
- let dragStartMouse = new THREE.Vector2();
389
- const intersection = new THREE.Vector3();
390
 
391
- const canvas = renderer.domElement;
 
 
 
 
 
 
392
 
393
- canvas.addEventListener('mousedown', (e) => {
394
- const rect = canvas.getBoundingClientRect();
395
- mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
396
- mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
 
 
 
 
397
 
398
  raycaster.setFromCamera(mouse, camera);
399
- const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle]);
400
 
401
- if (intersects.length > 0) {
402
- isDragging = true;
403
- dragTarget = intersects[0].object;
404
- dragTarget.material.emissiveIntensity = 1.0;
405
- dragTarget.scale.setScalar(1.3);
406
- dragStartMouse.copy(mouse);
407
- canvas.style.cursor = 'grabbing';
408
- }
409
- });
410
-
411
- canvas.addEventListener('mousemove', (e) => {
412
- const rect = canvas.getBoundingClientRect();
413
- mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
414
- mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
415
 
416
- if (isDragging && dragTarget) {
417
- raycaster.setFromCamera(mouse, camera);
418
-
419
- if (dragTarget.userData.type === 'azimuth') {
420
- const plane = new THREE.Plane(new THREE.Vector3(0, 1, 0), -0.05);
421
- if (raycaster.ray.intersectPlane(plane, intersection)) {
422
- azimuthAngle = THREE.MathUtils.radToDeg(Math.atan2(intersection.x, intersection.z));
423
- if (azimuthAngle < 0) azimuthAngle += 360;
424
- }
425
- } else if (dragTarget.userData.type === 'elevation') {
426
- const plane = new THREE.Plane(new THREE.Vector3(1, 0, 0), -0.8);
427
- if (raycaster.ray.intersectPlane(plane, intersection)) {
428
- const relY = intersection.y - CENTER.y;
429
- const relZ = intersection.z;
430
- elevationAngle = THREE.MathUtils.clamp(THREE.MathUtils.radToDeg(Math.atan2(relY, relZ)), -90, 90);
431
- }
432
- }
433
- updatePositions();
434
- } else {
435
- raycaster.setFromCamera(mouse, camera);
436
- const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle]);
437
- [azimuthHandle, elevationHandle].forEach(h => {
438
- h.material.emissiveIntensity = 0.5;
439
- h.scale.setScalar(1);
440
- });
441
- if (intersects.length > 0) {
442
- intersects[0].object.material.emissiveIntensity = 0.8;
443
- intersects[0].object.scale.setScalar(1.1);
444
- canvas.style.cursor = 'grab';
445
- } else {
446
- canvas.style.cursor = 'default';
447
- }
448
  }
449
- });
450
-
451
- const onMouseUp = () => {
452
- if (dragTarget) {
453
- dragTarget.material.emissiveIntensity = 0.5;
454
- dragTarget.scale.setScalar(1);
455
-
456
- // Snap and animate
457
- const targetAz = snapToNearest(azimuthAngle, azimuthSteps);
458
- const targetEl = snapToNearest(elevationAngle, elevationSteps);
459
-
460
- const startAz = azimuthAngle, startEl = elevationAngle;
461
- const startTime = Date.now();
462
-
463
- function animateSnap() {
464
- const t = Math.min((Date.now() - startTime) / 200, 1);
465
- const ease = 1 - Math.pow(1 - t, 3);
466
 
467
- let azDiff = targetAz - startAz;
468
- if (azDiff > 180) azDiff -= 360;
469
- if (azDiff < -180) azDiff += 360;
470
- azimuthAngle = startAz + azDiff * ease;
471
- if (azimuthAngle < 0) azimuthAngle += 360;
472
- if (azimuthAngle >= 360) azimuthAngle -= 360;
473
 
474
- elevationAngle = startEl + (targetEl - startEl) * ease;
 
 
475
 
 
 
476
  updatePositions();
477
- if (t < 1) requestAnimationFrame(animateSnap);
478
- else updatePropsAndTrigger();
479
  }
480
- animateSnap();
481
- }
482
- isDragging = false;
483
- dragTarget = null;
484
- canvas.style.cursor = 'default';
485
- };
486
-
487
- canvas.addEventListener('mouseup', onMouseUp);
488
- canvas.addEventListener('mouseleave', onMouseUp);
489
- // Touch support for mobile
490
- canvas.addEventListener('touchstart', (e) => {
491
- e.preventDefault();
492
- const touch = e.touches[0];
493
- const rect = canvas.getBoundingClientRect();
494
- mouse.x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
495
- mouse.y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
496
-
497
- raycaster.setFromCamera(mouse, camera);
498
- const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle]);
499
-
500
- if (intersects.length > 0) {
501
- isDragging = true;
502
- dragTarget = intersects[0].object;
503
- dragTarget.material.emissiveIntensity = 1.0;
504
- dragTarget.scale.setScalar(1.3);
505
- dragStartMouse.copy(mouse);
506
  }
507
- }, { passive: false });
508
-
509
- canvas.addEventListener('touchmove', (e) => {
510
- e.preventDefault();
511
- const touch = e.touches[0];
512
- const rect = canvas.getBoundingClientRect();
513
- mouse.x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
514
- mouse.y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
515
-
516
- if (isDragging && dragTarget) {
517
- raycaster.setFromCamera(mouse, camera);
518
 
519
- if (dragTarget.userData.type === 'azimuth') {
520
- const plane = new THREE.Plane(new THREE.Vector3(0, 1, 0), -0.05);
521
- if (raycaster.ray.intersectPlane(plane, intersection)) {
522
- azimuthAngle = THREE.MathUtils.radToDeg(Math.atan2(intersection.x, intersection.z));
523
- if (azimuthAngle < 0) azimuthAngle += 360;
524
- }
525
- } else if (dragTarget.userData.type === 'elevation') {
526
- const plane = new THREE.Plane(new THREE.Vector3(1, 0, 0), -0.8);
527
- if (raycaster.ray.intersectPlane(plane, intersection)) {
528
- const relY = intersection.y - CENTER.y;
529
- const relZ = intersection.z;
530
- elevationAngle = THREE.MathUtils.clamp(THREE.MathUtils.radToDeg(Math.atan2(relY, relZ)), -90, 90);
531
- }
532
- }
533
- updatePositions();
534
  }
535
- }, { passive: false });
536
-
537
- canvas.addEventListener('touchend', (e) => {
538
- e.preventDefault();
539
- onMouseUp();
540
- }, { passive: false });
541
-
542
- canvas.addEventListener('touchcancel', (e) => {
543
- e.preventDefault();
544
- onMouseUp();
545
- }, { passive: false });
546
-
547
- // Initial update
548
  updatePositions();
549
-
550
- // Render loop
551
- function render() {
552
- requestAnimationFrame(render);
553
  renderer.render(scene, camera);
554
  }
555
- render();
556
-
557
- // Handle resize
558
- new ResizeObserver(() => {
559
- camera.aspect = wrapper.clientWidth / wrapper.clientHeight;
560
- camera.updateProjectionMatrix();
561
- renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
562
- }).observe(wrapper);
563
-
564
- // Store update functions for external calls
565
- wrapper._updateFromProps = (newVal) => {
566
- if (newVal && typeof newVal === 'object') {
567
- azimuthAngle = newVal.azimuth ?? azimuthAngle;
568
- elevationAngle = newVal.elevation ?? elevationAngle;
569
- updatePositions();
570
- }
571
- };
572
-
573
- wrapper._updateTexture = updateTextureFromUrl;
574
-
575
- // Watch for prop changes (imageUrl and value)
576
- let lastImageUrl = props.imageUrl;
577
- let lastValue = JSON.stringify(props.value);
578
  setInterval(() => {
579
- // Check imageUrl changes
580
- if (props.imageUrl !== lastImageUrl) {
581
- lastImageUrl = props.imageUrl;
582
- updateTextureFromUrl(props.imageUrl);
583
- }
584
- // Check value changes (from sliders)
585
- const currentValue = JSON.stringify(props.value);
586
- if (currentValue !== lastValue) {
587
- lastValue = currentValue;
588
- if (props.value && typeof props.value === 'object') {
589
- azimuthAngle = props.value.azimuth ?? azimuthAngle;
590
- elevationAngle = props.value.elevation ?? elevationAngle;
591
  updatePositions();
592
  }
593
  }
594
  }, 100);
595
- };
596
-
 
597
  initScene();
598
  })();
599
  """
@@ -606,154 +474,95 @@ class LightingControl3D(gr.HTML):
606
  **kwargs
607
  )
608
 
609
- # --- UI ---
610
  css = '''
611
  #col-container { max-width: 1200px; margin: 0 auto; }
612
  .dark .progress-text { color: white !important; }
613
- #lighting-3d-control { min-height: 450px; }
614
- .slider-row { display: flex; gap: 10px; align-items: center; }
615
  '''
 
616
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
617
  gr.Markdown("""
618
- # 🎬 Qwen Image Edit 25113D Lighting Control
619
 
620
- Control lighting directions using the **3D viewport** or **sliders**.
621
- Using [dx8152/Qwen-Edit-2509-Multi-Angle-Lighting] for precise lighting control.
 
622
  """)
623
 
624
  with gr.Row():
625
- # Left column: Input image and controls
626
  with gr.Column(scale=1):
627
  image = gr.Image(label="Input Image", type="pil", height=300)
628
 
629
- gr.Markdown("### 🎮 3D Lighting Control")
630
- gr.Markdown("*Drag the colored handles: 🟢 Azimuth (Direction), 🩷 Elevation (Height)*")
631
 
632
- lighting_3d = LightingControl3D(
633
  value={"azimuth": 0, "elevation": 0},
634
- elem_id="lighting-3d-control"
635
- )
636
- run_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
637
-
638
- gr.Markdown("### 🎚️ Slider Controls")
639
-
640
- azimuth_slider = gr.Slider(
641
- label="Azimuth (Horizontal Rotation)",
642
- minimum=0,
643
- maximum=315,
644
- step=45,
645
- value=0,
646
- info="0°=front, 90°=right, 180°=rear, 270°=left"
647
  )
 
 
648
 
649
- elevation_slider = gr.Slider(
650
- label="Elevation (Vertical Angle)",
651
- minimum=-90,
652
- maximum=90,
653
- step=90,
654
- value=0,
655
- info="-90°=from below, 0°=horizontal, 90°=from above"
656
- )
657
 
658
  prompt_preview = gr.Textbox(
659
- label="Generated Prompt",
660
  value="Light source from the Front",
661
  interactive=False
662
  )
663
 
664
- # Right column: Output
665
  with gr.Column(scale=1):
666
- result = gr.Image(label="Output Image", height=500)
667
 
668
  with gr.Accordion("⚙️ Advanced Settings", open=False):
669
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
670
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
671
- guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
672
- num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=20, step=1, value=4)
673
  height = gr.Slider(label="Height", minimum=256, maximum=2048, step=8, value=1024)
674
  width = gr.Slider(label="Width", minimum=256, maximum=2048, step=8, value=1024)
 
 
 
 
 
 
675
 
676
- # --- Event Handlers ---
677
-
678
- def update_prompt_from_sliders(azimuth, elevation):
679
- """Update prompt preview when sliders change."""
680
- prompt = build_lighting_prompt(azimuth, elevation)
681
- return prompt
682
-
683
- def sync_3d_to_sliders(lighting_value):
684
- """Sync 3D control changes to sliders."""
685
- if lighting_value and isinstance(lighting_value, dict):
686
- az = lighting_value.get('azimuth', 0)
687
- el = lighting_value.get('elevation', 0)
688
- prompt = build_lighting_prompt(az, el)
689
- return az, el, prompt
690
- return gr.update(), gr.update(), gr.update()
691
-
692
- def sync_sliders_to_3d(azimuth, elevation):
693
- """Sync slider changes to 3D control."""
694
- return {"azimuth": azimuth, "elevation": elevation}
695
-
696
- def update_3d_image(image):
697
- """Update the 3D component with the uploaded image."""
698
- if image is None:
699
- return gr.update(imageUrl=None)
700
- # Convert PIL image to base64 data URL
701
  import base64
702
  from io import BytesIO
703
  buffered = BytesIO()
704
- image.save(buffered, format="PNG")
705
  img_str = base64.b64encode(buffered.getvalue()).decode()
706
- data_url = f"data:image/png;base64,{img_str}"
707
- return gr.update(imageUrl=data_url)
708
-
709
- # Slider -> Prompt preview
710
- for slider in [azimuth_slider, elevation_slider]:
711
- slider.change(
712
- fn=update_prompt_from_sliders,
713
- inputs=[azimuth_slider, elevation_slider],
714
- outputs=[prompt_preview]
715
- )
716
-
717
- # 3D control -> Sliders + Prompt
718
- lighting_3d.change(
719
- fn=sync_3d_to_sliders,
720
- inputs=[lighting_3d],
721
- outputs=[azimuth_slider, elevation_slider, prompt_preview]
722
- )
723
-
724
- # Sliders -> 3D control
725
- for slider in [azimuth_slider, elevation_slider]:
726
- slider.release(
727
- fn=sync_sliders_to_3d,
728
- inputs=[azimuth_slider, elevation_slider],
729
- outputs=[lighting_3d]
730
- )
731
 
732
- # Generate button
733
  run_btn.click(
734
- fn=infer_lighting_edit,
735
- inputs=[image, azimuth_slider, elevation_slider, seed, randomize_seed, guidance_scale, num_inference_steps, height, width],
736
- outputs=[result, seed, prompt_preview]
737
- )
738
-
739
- # Image upload -> update dimensions AND update 3D preview
740
- image.upload(
741
- fn=update_dimensions_on_upload,
742
- inputs=[image],
743
- outputs=[width, height]
744
- ).then(
745
- fn=update_3d_image,
746
- inputs=[image],
747
- outputs=[lighting_3d]
748
  )
749
-
750
- # Also handle image clear
751
- image.clear(
752
- fn=lambda: gr.update(imageUrl=None),
753
- outputs=[lighting_3d]
754
- )
755
-
756
  if __name__ == "__main__":
757
- head = '<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>'
758
- css = '.fillable{max-width: 1200px !important}'
759
- demo.launch(head=head, css=css)
 
 
 
 
4
  import torch
5
  import spaces
6
  from PIL import Image
7
+
8
+ # --- Updated Imports ---
9
  from diffusers import FlowMatchEulerDiscreteScheduler
10
+ try:
11
+ from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
12
+ from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
13
+ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
14
+ except ImportError:
15
+ # Fallback for standard environments or if package missing (user should install)
16
+ print("Warning: 'qwenimage' package not found. Ensure it is installed.")
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
+
20
+ # --- Model Configuration ---
21
  dtype = torch.bfloat16
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
+
24
+ print(f"Loading Model on {device}...")
25
+
26
+ # Load the pipeline with the custom transformer
27
  pipe = QwenImageEditPlusPipeline.from_pretrained(
28
  "Qwen/Qwen-Image-Edit-2511",
29
  transformer=QwenImageTransformer2DModel.from_pretrained(
 
33
  ),
34
  torch_dtype=dtype
35
  ).to(device)
36
+
37
+ # Attempt to set Flash Attention 3
38
  try:
39
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
40
  print("Flash Attention 3 Processor set successfully.")
41
  except Exception as e:
42
+ print(f"Standard Attention enabled (FA3 skipped: {e})")
43
 
44
+ # Adapter Specs
45
  ADAPTER_SPECS = {
46
  "Multi-Angle-Lighting": {
47
  "repo": "dx8152/Qwen-Edit-2509-Multi-Angle-Lighting",
48
  "weights": "多角度灯光-251116.safetensors",
49
  "adapter_name": "multi-angle-lighting"
50
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  }
52
+
53
+ # State to track loaded adapter
54
+ CURRENT_LOADED_ADAPTER = None
55
+
56
+ # --- Prompt Building & Logic ---
57
+
58
+ # 0 degrees = Front
59
+ LIGHTING_AZIMUTH_MAP = {
60
+ 0: "Light source from the Front",
61
+ 45: "Light source from the Right Front",
62
+ 90: "Light source from the Right",
63
+ 135: "Light source from the Right Rear",
64
+ 180: "Light source from the Rear",
65
+ 225: "Light source from the Left Rear",
66
+ 270: "Light source from the Left",
67
+ 315: "Light source from the Left Front"
68
  }
69
 
70
  def snap_to_nearest(value, options):
 
71
  return min(options, key=lambda x: abs(x - value))
72
 
73
  def build_lighting_prompt(azimuth: float, elevation: float) -> str:
74
+ # Check Vertical extremes
75
+ if elevation >= 60:
76
+ return "Light source from Above"
77
+ if elevation <= -60:
78
+ return "Light source from Below"
79
+
80
+ # Snap Horizontal
81
+ azimuth_snapped = snap_to_nearest(azimuth, list(LIGHTING_AZIMUTH_MAP.keys()))
82
+ return LIGHTING_AZIMUTH_MAP[azimuth_snapped]
 
 
 
 
 
 
 
 
 
83
 
84
  @spaces.GPU
85
  def infer_lighting_edit(
 
93
  height: int = 1024,
94
  width: int = 1024,
95
  ):
96
+ global CURRENT_LOADED_ADAPTER
 
 
 
 
97
 
98
+ # --- Lazy Loading Logic ---
99
+ spec = ADAPTER_SPECS["Multi-Angle-Lighting"]
100
+ if CURRENT_LOADED_ADAPTER != spec["adapter_name"]:
101
+ print(f"Lazy loading adapter: {spec['adapter_name']}...")
102
  pipe.load_lora_weights(
103
+ spec["repo"],
104
+ weight_name=spec["weights"],
105
+ adapter_name=spec["adapter_name"]
106
  )
107
+ pipe.set_adapters([spec["adapter_name"]], adapter_weights=[1.0])
108
+ CURRENT_LOADED_ADAPTER = spec["adapter_name"]
109
 
110
+ # --- Generation ---
111
  prompt = build_lighting_prompt(azimuth, elevation)
112
+ print(f"Generated Lighting Prompt: {prompt}")
113
+
114
  if randomize_seed:
115
  seed = random.randint(0, MAX_SEED)
116
  generator = torch.Generator(device=device).manual_seed(seed)
117
+
118
  if image is None:
119
  raise gr.Error("Please upload an image first.")
120
+
121
  pil_image = image.convert("RGB") if isinstance(image, Image.Image) else Image.open(image).convert("RGB")
122
+
123
  result = pipe(
124
  image=[pil_image],
125
  prompt=prompt,
 
130
  guidance_scale=guidance_scale,
131
  num_images_per_prompt=1,
132
  ).images[0]
133
+
134
  return result, seed, prompt
135
 
136
  def update_dimensions_on_upload(image):
 
137
  if image is None:
138
  return 1024, 1024
139
  original_width, original_height = image.size
 
150
  return new_width, new_height
151
 
152
  # --- 3D Lighting Control Component ---
153
+ class LightControl3D(gr.HTML):
154
  """
155
+ Advanced 3D Lighting control component using Three.js + OrbitControls.
 
 
156
  """
157
  def __init__(self, value=None, imageUrl=None, **kwargs):
158
  if value is None:
159
  value = {"azimuth": 0, "elevation": 0}
160
 
161
  html_template = """
162
+ <div id="light-control-wrapper" style="width: 100%; height: 500px; position: relative; background: radial-gradient(circle at center, #1a1a1a 0%, #000000 100%); border-radius: 12px; overflow: hidden; border: 1px solid #333; box-shadow: inset 0 0 20px #000;">
163
+ <div id="prompt-overlay" style="position: absolute; top: 15px; left: 50%; transform: translateX(-50%); background: rgba(0, 0, 0, 0.85); padding: 8px 20px; border-radius: 30px; font-family: 'Segoe UI', sans-serif; font-weight: 600; font-size: 14px; color: #FFD700; white-space: nowrap; z-index: 10; border: 1px solid #FFD700; box-shadow: 0 0 10px rgba(255, 215, 0, 0.3);">
164
+ Light Source: Front
165
+ </div>
166
+ <div style="position: absolute; bottom: 15px; left: 15px; color: #666; font-family: monospace; font-size: 11px; z-index: 10;">
167
+ LEFT CLICK: Move Light • RIGHT CLICK: Rotate View
168
+ </div>
169
  </div>
170
  """
171
 
172
  js_on_load = """
173
  (() => {
174
+ const wrapper = element.querySelector('#light-control-wrapper');
175
  const promptOverlay = element.querySelector('#prompt-overlay');
176
 
 
177
  const initScene = () => {
178
+ if (typeof THREE === 'undefined' || typeof THREE.OrbitControls === 'undefined') {
179
  setTimeout(initScene, 100);
180
  return;
181
  }
182
 
183
+ // 1. Setup Scene
184
  const scene = new THREE.Scene();
185
+ // No background color, letting CSS gradient show through
186
 
187
+ const camera = new THREE.PerspectiveCamera(45, wrapper.clientWidth / wrapper.clientHeight, 0.1, 1000);
188
+ camera.position.set(4, 3, 4);
 
189
 
190
+ const renderer = new THREE.WebGLRenderer({ antialias: true, alpha: true });
191
+ renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
192
  renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
193
+ renderer.shadowMap.enabled = true; // Enable Shadows
194
+ renderer.shadowMap.type = THREE.PCFSoftShadowMap;
195
  wrapper.insertBefore(renderer.domElement, promptOverlay);
196
 
197
+ // 2. Controls (Orbit for Camera)
198
+ const controls = new THREE.OrbitControls(camera, renderer.domElement);
199
+ controls.enableDamping = true;
200
+ controls.dampingFactor = 0.05;
201
+ controls.maxDistance = 10;
202
+ controls.minDistance = 2;
203
+ controls.target.set(0, 0.75, 0);
204
+
205
+ // 3. Environment
206
  const CENTER = new THREE.Vector3(0, 0.75, 0);
207
+ const ORBIT_RADIUS = 2.2;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
+ // Floor (Shadow Receiver)
210
+ const planeGeo = new THREE.PlaneGeometry(10, 10);
211
+ const planeMat = new THREE.ShadowMaterial({ opacity: 0.5, color: 0x000000 });
212
+ const floor = new THREE.Mesh(planeGeo, planeMat);
213
+ floor.rotation.x = -Math.PI / 2;
214
+ floor.receiveShadow = true;
215
+ scene.add(floor);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
+ // Grid
218
+ const grid = new THREE.GridHelper(10, 20, 0x444444, 0x111111);
219
+ scene.add(grid);
220
+
221
+ // Dome Helper (Wireframe)
222
+ const domeGeo = new THREE.SphereGeometry(ORBIT_RADIUS, 16, 12, 0, Math.PI*2, 0, Math.PI/2);
223
+ const domeMat = new THREE.MeshBasicMaterial({ color: 0x333333, wireframe: true, transparent: true, opacity: 0.15 });
224
+ const dome = new THREE.Mesh(domeGeo, domeMat);
225
+ dome.position.copy(CENTER);
226
+ dome.position.y -= 0.75; // Adjust to base
227
+ scene.add(dome);
228
+
229
+ // 4. The Subject (Image Box)
230
+ let currentTexture = null;
231
+ const boxGeo = new THREE.BoxGeometry(1.2, 1.2, 0.1); // Thin box instead of plane
232
+ const boxMat = new THREE.MeshStandardMaterial({ color: 0x222222, roughness: 0.7, metalness: 0.1 });
233
+ let subjectMesh = new THREE.Mesh(boxGeo, boxMat);
234
+ subjectMesh.position.copy(CENTER);
235
+ subjectMesh.castShadow = true;
236
+ subjectMesh.receiveShadow = true;
237
+ scene.add(subjectMesh);
238
+
239
+ // Hidden Reference Sphere (To cast better shadows if image is flat)
240
+ const refSphere = new THREE.Mesh(
241
+ new THREE.SphereGeometry(0.5, 32, 32),
242
+ new THREE.MeshBasicMaterial({ visible: false }) // Invisible but casts shadow? No, must be visible.
243
+ );
244
+ // Actually, let's keep the box, it casts decent shadows.
245
+
246
  function updateTextureFromUrl(url) {
247
  if (!url) {
248
+ boxMat.map = null;
249
+ boxMat.needsUpdate = true;
 
 
 
 
 
 
250
  return;
251
  }
252
+ new THREE.TextureLoader().load(url, (tex) => {
253
+ boxMat.map = tex;
254
+ boxMat.color.setHex(0xffffff); // Reset color if texture present
255
+ boxMat.needsUpdate = true;
256
+ const img = tex.image;
257
+ if(img.width && img.height) {
 
 
 
 
 
 
258
  const aspect = img.width / img.height;
259
+ const scale = aspect > 1 ? 1 : aspect;
260
+ // Adjust box dimensions
261
+ const newGeo = new THREE.BoxGeometry(
262
+ aspect > 1 ? 1.5 : 1.5 * aspect,
263
+ aspect > 1 ? 1.5 / aspect : 1.5,
264
+ 0.05
 
 
 
 
 
 
 
265
  );
266
+ subjectMesh.geometry.dispose();
267
+ subjectMesh.geometry = newGeo;
268
  }
 
 
269
  });
270
  }
271
+ if (props.imageUrl) updateTextureFromUrl(props.imageUrl);
272
+
273
+ // 5. Light Source (The Sun) & The Actual Light
 
 
 
 
274
  const lightGroup = new THREE.Group();
 
 
 
 
 
275
  scene.add(lightGroup);
276
 
277
+ // The Visual Sun
278
+ const sunMesh = new THREE.Mesh(
279
+ new THREE.SphereGeometry(0.2, 32, 32),
280
+ new THREE.MeshBasicMaterial({ color: 0xFFD700 })
281
  );
282
+ lightGroup.add(sunMesh);
 
 
283
 
284
+ // Glow
285
+ const glowMesh = new THREE.Mesh(
286
+ new THREE.SphereGeometry(0.35, 32, 32),
287
+ new THREE.MeshBasicMaterial({ color: 0xFFD700, transparent: true, opacity: 0.3 })
288
  );
289
+ sunMesh.add(glowMesh);
290
+
291
+ // The Actual Light Caster
292
+ const spotLight = new THREE.SpotLight(0xffffff, 2);
293
+ spotLight.angle = Math.PI / 4;
294
+ spotLight.penumbra = 0.5;
295
+ spotLight.decay = 2;
296
+ spotLight.distance = 10;
297
+ spotLight.castShadow = true;
298
+ spotLight.shadow.mapSize.width = 1024;
299
+ spotLight.shadow.mapSize.height = 1024;
300
+ lightGroup.add(spotLight);
301
+
302
+ // Ambient fill
303
+ scene.add(new THREE.AmbientLight(0xffffff, 0.15));
304
+
305
+ // Target Line
306
+ const lineGeo = new THREE.BufferGeometry().setFromPoints([new THREE.Vector3(0,0,0), new THREE.Vector3(0,0,-1)]);
307
+ const lineMat = new THREE.LineBasicMaterial({ color: 0xFFD700, transparent: true, opacity: 0.5 });
308
+ const line = new THREE.Line(lineGeo, lineMat);
309
+ lightGroup.add(line);
310
+
311
+ // State
312
+ let azimuthAngle = props.value?.azimuth || 0;
313
+ let elevationAngle = props.value?.elevation || 0;
314
+
315
+ const azimuthNames = {
316
+ 0: 'Front', 45: 'Right Front', 90: 'Right',
317
+ 135: 'Right Rear', 180: 'Rear', 225: 'Left Rear',
318
+ 270: 'Left', 315: 'Left Front'
319
+ };
320
+
321
+ function getPromptText(az, el) {
322
+ if (el >= 60) return "Light source from Above";
323
+ if (el <= -60) return "Light source from Below";
324
+ const steps = [0, 45, 90, 135, 180, 225, 270, 315];
325
+ const snapped = steps.reduce((prev, curr) => Math.abs(curr - az) < Math.abs(prev - az) ? curr : prev);
326
+ return "Light source from the " + azimuthNames[snapped];
327
  }
328
+
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  function updatePositions() {
 
330
  const azRad = THREE.MathUtils.degToRad(azimuthAngle);
331
  const elRad = THREE.MathUtils.degToRad(elevationAngle);
332
 
333
+ const x = ORBIT_RADIUS * Math.sin(azRad) * Math.cos(elRad);
334
+ const y = ORBIT_RADIUS * Math.sin(elRad) + CENTER.y;
335
+ const z = ORBIT_RADIUS * Math.cos(azRad) * Math.cos(elRad);
336
 
337
+ lightGroup.position.set(x, y, z);
338
+ lightGroup.lookAt(CENTER);
339
 
340
+ // Line length
341
+ const dist = lightGroup.position.distanceTo(CENTER);
342
+ line.scale.z = dist;
343
+
344
+ // Update Prompt UI
345
+ promptOverlay.textContent = getPromptText(azimuthAngle, elevationAngle);
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
+ // Color updates for warning zones
348
+ const isExtreme = elevationAngle >= 60 || elevationAngle <= -60;
349
+ const color = isExtreme ? 0xFF4500 : 0xFFD700;
350
+ sunMesh.material.color.setHex(color);
351
+ promptOverlay.style.color = isExtreme ? '#FF4500' : '#FFD700';
352
+ promptOverlay.style.borderColor = isExtreme ? '#FF4500' : '#FFD700';
353
  }
354
+
355
+ // --- Interaction Logic ---
356
+ // We need to distinguish between Orbiting (Background drag) and Light Moving (Sun drag)
357
 
 
358
  const raycaster = new THREE.Raycaster();
359
  const mouse = new THREE.Vector2();
 
 
 
 
360
 
361
+ // Drag Helper Sphere (Invisible large sphere for raycasting light position)
362
+ const dragSphere = new THREE.Mesh(
363
+ new THREE.SphereGeometry(ORBIT_RADIUS, 32, 32),
364
+ new THREE.MeshBasicMaterial({ visible: false, side: THREE.DoubleSide })
365
+ );
366
+ dragSphere.position.copy(CENTER);
367
+ scene.add(dragSphere);
368
 
369
+ let isDraggingLight = false;
370
+
371
+ function onDown(e) {
372
+ const rect = wrapper.getBoundingClientRect();
373
+ const x = e.clientX || e.touches[0].clientX;
374
+ const y = e.clientY || e.touches[0].clientY;
375
+ mouse.x = ((x - rect.left) / rect.width) * 2 - 1;
376
+ mouse.y = -((y - rect.top) / rect.height) * 2 + 1;
377
 
378
  raycaster.setFromCamera(mouse, camera);
 
379
 
380
+ // Check if clicking the Sun
381
+ const sunIntersects = raycaster.intersectObject(sunMesh);
382
+ const glowIntersects = raycaster.intersectObject(glowMesh);
 
 
 
 
 
 
 
 
 
 
 
383
 
384
+ if (sunIntersects.length > 0 || glowIntersects.length > 0) {
385
+ isDraggingLight = true;
386
+ controls.enabled = false; // Disable Orbit
387
+ wrapper.style.cursor = 'grabbing';
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  }
389
+ }
390
+
391
+ function onMove(e) {
392
+ const rect = wrapper.getBoundingClientRect();
393
+ const x = e.clientX || (e.touches ? e.touches[0].clientX : 0);
394
+ const y = e.clientY || (e.touches ? e.touches[0].clientY : 0);
395
+ mouse.x = ((x - rect.left) / rect.width) * 2 - 1;
396
+ mouse.y = -((y - rect.top) / rect.height) * 2 + 1;
397
+
398
+ raycaster.setFromCamera(mouse, camera);
399
+
400
+ if (isDraggingLight) {
401
+ const intersects = raycaster.intersectObject(dragSphere);
402
+ if(intersects.length > 0) {
403
+ const point = intersects[0].point;
404
+ const rel = new THREE.Vector3().subVectors(point, CENTER);
 
405
 
406
+ let az = Math.atan2(rel.x, rel.z) * (180/Math.PI);
407
+ if(az < 0) az += 360;
 
 
 
 
408
 
409
+ const dist = Math.sqrt(rel.x*rel.x + rel.z*rel.z);
410
+ let el = Math.atan2(rel.y, dist) * (180/Math.PI);
411
+ el = Math.max(-89, Math.min(89, el)); // Clamp
412
 
413
+ azimuthAngle = az;
414
+ elevationAngle = el;
415
  updatePositions();
 
 
416
  }
417
+ } else {
418
+ // Hover state
419
+ const sunIntersects = raycaster.intersectObject(sunMesh);
420
+ if (sunIntersects.length > 0) wrapper.style.cursor = 'grab';
421
+ else wrapper.style.cursor = 'default';
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  }
423
+ }
424
+
425
+ function onUp() {
426
+ if (isDraggingLight) {
427
+ isDraggingLight = false;
428
+ controls.enabled = true; // Re-enable Orbit
429
+ wrapper.style.cursor = 'default';
 
 
 
 
430
 
431
+ props.value = { azimuth: azimuthAngle, elevation: elevationAngle };
432
+ trigger('change', props.value);
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  }
434
+ }
435
+
436
+ wrapper.addEventListener('mousedown', onDown);
437
+ window.addEventListener('mousemove', onMove);
438
+ window.addEventListener('mouseup', onUp);
439
+ wrapper.addEventListener('touchstart', onDown, {passive: false});
440
+ window.addEventListener('touchmove', onMove, {passive: false});
441
+ window.addEventListener('touchend', onUp);
442
+
 
 
 
 
443
  updatePositions();
444
+
445
+ function animate() {
446
+ requestAnimationFrame(animate);
447
+ controls.update();
448
  renderer.render(scene, camera);
449
  }
450
+ animate();
451
+
452
+ // Prop watchers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  setInterval(() => {
454
+ if (props.value && (props.value.azimuth !== azimuthAngle || props.value.elevation !== elevationAngle)) {
455
+ if(!isDraggingLight) {
456
+ azimuthAngle = props.value.azimuth;
457
+ elevationAngle = props.value.elevation;
 
 
 
 
 
 
 
 
458
  updatePositions();
459
  }
460
  }
461
  }, 100);
462
+
463
+ wrapper._updateTexture = updateTextureFromUrl;
464
+ }
465
  initScene();
466
  })();
467
  """
 
474
  **kwargs
475
  )
476
 
477
+ # --- UI Layout ---
478
  css = '''
479
  #col-container { max-width: 1200px; margin: 0 auto; }
480
  .dark .progress-text { color: white !important; }
 
 
481
  '''
482
+
483
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
484
  gr.Markdown("""
485
+ # 💡 Qwen Edit 2509Studio Lighting Control
486
 
487
+ **Interactive 3D Studio:** * **Left Drag** the 🟡 **Sun** to change light direction.
488
+ * **Right Drag** background to rotate the camera view.
489
+ * The shadow on the floor previews the light direction.
490
  """)
491
 
492
  with gr.Row():
 
493
  with gr.Column(scale=1):
494
  image = gr.Image(label="Input Image", type="pil", height=300)
495
 
496
+ gr.Markdown("### 🎮 3D Studio Controller")
 
497
 
498
+ light_3d = LightControl3D(
499
  value={"azimuth": 0, "elevation": 0},
500
+ elem_id="light-3d-control"
 
 
 
 
 
 
 
 
 
 
 
 
501
  )
502
+
503
+ run_btn = gr.Button("✨ Render Lighting", variant="primary", size="lg")
504
 
505
+ with gr.Accordion("🎚️ Manual Sliders", open=False):
506
+ azimuth_slider = gr.Slider(
507
+ label="Azimuth", minimum=0, maximum=359, step=45, value=0
508
+ )
509
+ elevation_slider = gr.Slider(
510
+ label="Elevation", minimum=-90, maximum=90, step=10, value=0
511
+ )
 
512
 
513
  prompt_preview = gr.Textbox(
514
+ label="Active Prompt",
515
  value="Light source from the Front",
516
  interactive=False
517
  )
518
 
 
519
  with gr.Column(scale=1):
520
+ result = gr.Image(label="Relighted Result", height=600)
521
 
522
  with gr.Accordion("⚙️ Advanced Settings", open=False):
523
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
524
+ randomize_seed = gr.Checkbox(label="Randomize", value=True)
525
+ guidance_scale = gr.Slider(label="Guidance", minimum=1.0, maximum=10.0, step=0.1, value=5.0)
526
+ num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=20, step=1, value=4)
527
  height = gr.Slider(label="Height", minimum=256, maximum=2048, step=8, value=1024)
528
  width = gr.Slider(label="Width", minimum=256, maximum=2048, step=8, value=1024)
529
+
530
+ # --- Wiring ---
531
+ def sync_3d_to_sliders(val):
532
+ az = val.get('azimuth', 0)
533
+ el = val.get('elevation', 0)
534
+ return az, el, build_lighting_prompt(az, el)
535
 
536
+ def sync_sliders_to_3d(az, el):
537
+ return {"azimuth": az, "elevation": el}, build_lighting_prompt(az, el)
538
+
539
+ light_3d.change(sync_3d_to_sliders, [light_3d], [azimuth_slider, elevation_slider, prompt_preview])
540
+ azimuth_slider.change(sync_sliders_to_3d, [azimuth_slider, elevation_slider], [light_3d, prompt_preview])
541
+ elevation_slider.change(sync_sliders_to_3d, [azimuth_slider, elevation_slider], [light_3d, prompt_preview])
542
+
543
+ def update_on_upload(img):
544
+ w, h = update_dimensions_on_upload(img)
545
+ if img is None:
546
+ return w, h, gr.update(imageUrl=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  import base64
548
  from io import BytesIO
549
  buffered = BytesIO()
550
+ img.save(buffered, format="PNG")
551
  img_str = base64.b64encode(buffered.getvalue()).decode()
552
+ return w, h, gr.update(imageUrl=f"data:image/png;base64,{img_str}")
553
+
554
+ image.upload(update_on_upload, [image], [width, height, light_3d])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
555
 
 
556
  run_btn.click(
557
+ infer_lighting_edit,
558
+ [image, azimuth_slider, elevation_slider, seed, randomize_seed, guidance_scale, num_inference_steps, height, width],
559
+ [result, seed, prompt_preview]
 
 
 
 
 
 
 
 
 
 
 
560
  )
561
+
 
 
 
 
 
 
562
  if __name__ == "__main__":
563
+ # Load Three.js and OrbitControls from CDN
564
+ head = '''
565
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
566
+ <script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/controls/OrbitControls.js"></script>
567
+ '''
568
+ demo.launch(head=head)