prithivMLmods commited on
Commit
3d07170
·
verified ·
1 Parent(s): a81b7e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +304 -354
app.py CHANGED
@@ -7,60 +7,94 @@ import json
7
  import base64
8
  from io import BytesIO
9
  from PIL import Image
 
10
 
11
- # NOTE: Ensure QwenImageEditPlusPipeline is available in your environment.
12
- # If using a local file, uncomment the local import. If using a custom Diffusers build, keep as is.
13
  try:
14
- from diffusers import FlowMatchEulerDiscreteScheduler, QwenImageEditPlusPipeline
 
 
 
 
15
  except ImportError:
16
- # Fallback/Placeholder if specific pipeline isn't installed, purely to allow UI testing
17
- print("Warning: QwenImageEditPlusPipeline not found. UI will load, but generation will fail.")
18
- class QwenImageEditPlusPipeline:
19
- @classmethod
20
- def from_pretrained(cls, *args, **kwargs):
21
- return cls()
22
- def to(self, device): return self
23
- def load_lora_weights(self, *args, **kwargs): pass
24
- def set_adapters(self, *args, **kwargs): pass
25
- def __call__(self, *args, **kwargs):
26
- class Result: images = [Image.new("RGB", (512, 512), color="gray")]
27
- return Result()
28
 
 
29
  MAX_SEED = np.iinfo(np.int32).max
30
-
31
- # --- Model Loading ---
32
  dtype = torch.bfloat16
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
 
35
- try:
36
- pipe = QwenImageEditPlusPipeline.from_pretrained(
37
- "Qwen/Qwen-Image-Edit-2511",
38
- torch_dtype=dtype
39
- ).to(device)
40
-
41
- # Load the lightning LoRA for fast inference
42
- pipe.load_lora_weights(
43
- "lightx2v/Qwen-Image-Edit-2511-Lightning",
44
- weight_name="Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors",
45
- adapter_name="lightning"
46
- )
47
 
48
- # Load the Lighting LoRA
49
- pipe.load_lora_weights(
50
- "dx8152/Qwen-Edit-2509-Multi-Angle-Lighting",
51
- weight_name="qwen-edit-2509-multi-angle-lighting.safetensors",
52
- adapter_name="lighting"
53
- )
54
 
55
- pipe.set_adapters(["lightning", "lighting"], adapter_weights=[1.0, 1.0])
56
- except Exception as e:
57
- print(f"Model loading failed (ignorable if just testing UI): {e}")
58
- pipe = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # --- Prompt Building ---
61
 
62
- # Horizontal mappings (Azimuth)
63
- AZIMUTH_MAP = {
 
64
  0: "Light source from the Front",
65
  45: "Light source from the Right Front",
66
  90: "Light source from the Right",
@@ -71,27 +105,44 @@ AZIMUTH_MAP = {
71
  315: "Light source from the Left Front"
72
  }
73
 
74
- def snap_to_nearest(value, options):
75
- return min(options, key=lambda x: abs(x - value))
76
-
77
- def build_lighting_prompt(azimuth: float, elevation: float) -> str:
78
  """
79
- Constructs the prompt based on Azimuth (horizontal) and Elevation (vertical).
80
- Priority: If elevation is extreme (Above/Below), that takes precedence.
81
- Otherwise, use horizontal direction.
82
  """
83
- # Vertical Thresholds
84
- if elevation >= 45:
85
- return "<sks> Light source from Above"
86
- if elevation <= -45:
87
- return "<sks> Light source from Below"
88
-
89
- # Horizontal Logic
90
- # Normalize azimuth to 0-360
91
- azimuth = azimuth % 360
92
- az_snap = snap_to_nearest(azimuth, list(AZIMUTH_MAP.keys()))
 
 
93
 
94
- return f"<sks> {AZIMUTH_MAP[az_snap]}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  @spaces.GPU
97
  def infer_lighting_edit(
@@ -100,26 +151,29 @@ def infer_lighting_edit(
100
  elevation: float = 0.0,
101
  seed: int = 0,
102
  randomize_seed: bool = True,
103
- guidance_scale: float = 1.0,
104
- num_inference_steps: int = 4,
105
  height: int = 1024,
106
  width: int = 1024,
107
  ):
108
- if pipe is None:
109
- raise gr.Error("Model not initialized.")
110
-
111
- prompt = build_lighting_prompt(azimuth, elevation)
112
- print(f"Generated Prompt: {prompt}")
 
 
 
 
113
 
 
114
  if randomize_seed:
115
  seed = random.randint(0, MAX_SEED)
116
  generator = torch.Generator(device=device).manual_seed(seed)
117
-
118
- if image is None:
119
- raise gr.Error("Please upload an image first.")
120
-
121
  pil_image = image.convert("RGB")
122
 
 
123
  result = pipe(
124
  image=[pil_image],
125
  prompt=prompt,
@@ -133,10 +187,11 @@ def infer_lighting_edit(
133
 
134
  return result, seed, prompt
135
 
 
 
136
  def update_dimensions_on_upload(image):
137
  if image is None: return 1024, 1024
138
  w, h = image.size
139
- # Resize logic to keep aspect ratio but snap to multiples of 8 within reasonable bounds
140
  if w > h:
141
  new_w, new_h = 1024, int(1024 * (h / w))
142
  else:
@@ -150,402 +205,297 @@ def get_image_base64(image):
150
  img_str = base64.b64encode(buffered.getvalue()).decode()
151
  return f"data:image/png;base64,{img_str}"
152
 
153
- # --- 3D Lighting Control HTML Logic ---
154
  THREE_JS_LOGIC = """
155
- <div id="light-control-wrapper" style="width: 100%; height: 450px; position: relative; background: #111; border-radius: 12px; overflow: hidden;">
156
- <div id="prompt-overlay" style="position: absolute; bottom: 10px; left: 50%; transform: translateX(-50%); background: rgba(0,0,0,0.8); padding: 8px 16px; border-radius: 8px; font-family: monospace; font-size: 14px; color: #ffcc00; white-space: nowrap; z-index: 10; border: 1px solid #ffcc00;">Initializing...</div>
157
- <div style="position: absolute; top: 10px; left: 10px; color: #666; font-family: sans-serif; font-size: 11px;">Drag the Yellow Orb to move light</div>
158
  </div>
159
  <script>
160
  (function() {
161
  const wrapper = document.getElementById('light-control-wrapper');
162
  const promptOverlay = document.getElementById('prompt-overlay');
163
 
164
- // Global Access for Python Bridge
165
- window.light3D = {
166
- updateState: null,
167
- updateTexture: null
168
- };
169
 
170
  const initScene = () => {
171
- if (typeof THREE === 'undefined') {
172
- setTimeout(initScene, 100);
173
- return;
174
- }
175
 
176
- // --- Setup ---
177
  const scene = new THREE.Scene();
178
- scene.background = new THREE.Color(0x111111);
179
 
180
- // Static Camera looking at the scene
181
- const camera = new THREE.PerspectiveCamera(45, wrapper.clientWidth / wrapper.clientHeight, 0.1, 1000);
182
- camera.position.set(0, 1.5, 5); // Slightly elevated front view
183
  camera.lookAt(0, 0, 0);
184
 
185
  const renderer = new THREE.WebGLRenderer({ antialias: true });
186
  renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
187
- renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
188
- renderer.shadowMap.enabled = true; // Enable shadows for visual feedback
189
- renderer.shadowMap.type = THREE.PCFSoftShadowMap;
190
  wrapper.appendChild(renderer.domElement);
191
 
192
- // --- Helpers ---
193
- scene.add(new THREE.GridHelper(8, 16, 0x333333, 0x222222));
194
-
195
- // --- Objects ---
196
- const CENTER = new THREE.Vector3(0, 0, 0);
197
- const ORBIT_RADIUS = 2.5;
198
 
199
- // 1. The Subject (Central Image Plane + Sphere for shading ref)
200
- const group = new THREE.Group();
201
- scene.add(group);
202
-
203
- // Placeholder Texture
204
  function createPlaceholderTexture() {
205
- const canvas = document.createElement('canvas');
206
- canvas.width = 256; canvas.height = 256;
207
- const ctx = canvas.getContext('2d');
208
- ctx.fillStyle = '#222'; ctx.fillRect(0, 0, 256, 256);
209
- ctx.fillStyle = '#444';
210
- ctx.font = '30px Arial'; ctx.textAlign = 'center'; ctx.textBaseline = 'middle';
211
- ctx.fillText("Upload Image", 128, 128);
212
- return new THREE.CanvasTexture(canvas);
213
  }
214
 
215
- let planeMaterial = new THREE.MeshStandardMaterial({
216
  map: createPlaceholderTexture(),
217
  side: THREE.DoubleSide,
218
- roughness: 0.8,
219
  metalness: 0.1
220
  });
 
 
 
221
 
222
- let targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.5, 1.5), planeMaterial);
223
- targetPlane.castShadow = true;
224
- targetPlane.receiveShadow = true;
225
- group.add(targetPlane);
226
-
227
- // Reference Sphere (Hidden behind plane usually, or useful for seeing pure shading)
228
- const refSphere = new THREE.Mesh(
229
- new THREE.SphereGeometry(0.5, 32, 32),
230
- new THREE.MeshStandardMaterial({ color: 0xffffff, roughness: 1.0 })
231
- );
232
- refSphere.position.z = -0.5;
233
- refSphere.castShadow = true;
234
- group.add(refSphere);
235
-
236
- // 2. The Light Source (The "Sun")
237
- const lightGroup = new THREE.Group();
238
- scene.add(lightGroup);
239
-
240
- // Actual Light
241
- const dirLight = new THREE.DirectionalLight(0xffffff, 2.0);
242
- dirLight.castShadow = true;
243
- dirLight.shadow.mapSize.width = 1024;
244
- dirLight.shadow.mapSize.height = 1024;
245
- lightGroup.add(dirLight);
246
-
247
- // Visual Representation (Yellow Orb)
248
- const lightMesh = new THREE.Mesh(
249
- new THREE.SphereGeometry(0.2, 16, 16),
250
- new THREE.MeshBasicMaterial({ color: 0xffcc00 })
251
- );
252
- // Add glow
253
- const glow = new THREE.Mesh(
254
- new THREE.SphereGeometry(0.3, 16, 16),
255
- new THREE.MeshBasicMaterial({ color: 0xffcc00, transparent: true, opacity: 0.3 })
256
- );
257
- lightMesh.add(glow);
258
- lightMesh.userData.type = 'lightSource';
259
- lightGroup.add(lightMesh);
260
-
261
- // Ambient light to fill shadows slightly
262
  scene.add(new THREE.AmbientLight(0xffffff, 0.2));
263
 
264
- // --- State ---
265
- let azimuthAngle = 0; // 0 = Front, 90 = Right, 180 = Back
266
- let elevationAngle = 0; // 90 = Top, -90 = Bottom
267
-
268
- // --- Prompt Mapping Logic (JS Side for preview) ---
269
- const azMap = {
270
- 0: "Front", 45: "Right Front", 90: "Right", 135: "Right Rear",
271
- 180: "Rear", 225: "Left Rear", 270: "Left", 315: "Left Front"
272
- };
273
- const azSteps = [0, 45, 90, 135, 180, 225, 270, 315];
274
 
275
- function snapToNearest(value, steps) {
276
- let norm = value % 360;
277
- if (norm < 0) norm += 360;
278
- return steps.reduce((prev, curr) => Math.abs(curr - norm) < Math.abs(prev - norm) ? curr : prev);
279
- }
280
 
281
- function updatePositions() {
282
- // Convert Azimuth/Elevation to spherical coordinates
283
- // In ThreeJS: Y is Up. 0 Azimuth should be +Z (Front)
284
- // But usually Front is +Z. Let's calculate standard spherical.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
- const rAz = THREE.MathUtils.degToRad(azimuthAngle);
287
- const rEl = THREE.MathUtils.degToRad(elevationAngle);
288
-
289
- // Calculate position on sphere
290
- // x = r * sin(az) * cos(el)
291
  // y = r * sin(el)
292
- // z = r * cos(az) * cos(el)
 
293
 
294
- const x = ORBIT_RADIUS * Math.sin(rAz) * Math.cos(rEl);
295
- const y = ORBIT_RADIUS * Math.sin(rEl);
296
- const z = ORBIT_RADIUS * Math.cos(rAz) * Math.cos(rEl);
297
 
298
  lightGroup.position.set(x, y, z);
299
- lightGroup.lookAt(CENTER); // Light points to center
300
-
301
- // Update UI Text
302
  let text = "";
303
- if (elevationAngle >= 45) text = "Light source from Above";
304
- else if (elevationAngle <= -45) text = "Light source from Below";
305
  else {
306
- const snap = snapToNearest(azimuthAngle, azSteps);
307
- text = "Light source from the " + azMap[snap];
 
 
 
308
  }
309
- promptOverlay.innerText = text;
310
  }
311
 
312
  // --- Interaction ---
313
- const raycaster = new THREE.Raycaster();
314
- const mouse = new THREE.Vector2();
315
  let isDragging = false;
316
-
317
  const canvas = renderer.domElement;
318
 
319
- function getMouse(e) {
320
- const rect = canvas.getBoundingClientRect();
321
- return {
322
- x: ((e.clientX - rect.left) / rect.width) * 2 - 1,
323
- y: -((e.clientY - rect.top) / rect.height) * 2 + 1
324
- };
325
- }
326
-
327
- canvas.addEventListener('mousedown', (e) => {
328
- const m = getMouse(e);
329
- mouse.set(m.x, m.y);
330
- raycaster.setFromCamera(mouse, camera);
331
-
332
- // Allow clicking anywhere to move light, or specifically the orb
333
- // To make it easy, let's just project mouse to a virtual sphere
334
- isDragging = true;
335
- handleDrag(e);
336
- });
337
-
338
- function handleDrag(e) {
339
- if (!isDragging) return;
340
-
341
- const m = getMouse(e);
342
- mouse.set(m.x, m.y);
343
-
344
- // Logic: Raycast to a virtual sphere at center
345
- // Or simpler: Map mouse X to Azimuth, Mouse Y to Elevation
346
- // Let's use Mouse movement to delta
347
-
348
- // Robust approach: Project mouse onto a virtual sphere
349
- // But simpler UI: Mouse X = Rotation, Mouse Y = Elevation
350
- // This feels like "OrbitControls" but for the light
351
- }
352
-
353
- // Let's use a simpler drag logic: standard delta movement
354
- let previousMouse = { x: 0, y: 0 };
355
-
356
- canvas.addEventListener('mousedown', (e) => {
357
- isDragging = true;
358
- previousMouse = { x: e.clientX, y: e.clientY };
359
-
360
- // Check if clicked on orb (visual feedback)
361
- const m = getMouse(e);
362
- mouse.set(m.x, m.y);
363
- raycaster.setFromCamera(mouse, camera);
364
- const intersects = raycaster.intersectObject(lightMesh);
365
- if(intersects.length > 0) {
366
- lightMesh.scale.setScalar(1.2);
367
- }
368
- canvas.style.cursor = 'grabbing';
369
- });
370
-
371
- window.addEventListener('mousemove', (e) => {
372
- if (isDragging) {
373
- const deltaX = e.clientX - previousMouse.x;
374
- const deltaY = e.clientY - previousMouse.y;
375
- previousMouse = { x: e.clientX, y: e.clientY };
376
-
377
- // Adjust sensitivity
378
- azimuthAngle -= deltaX * 0.5;
379
- elevationAngle += deltaY * 0.5;
380
-
381
- // Clamp Elevation
382
- elevationAngle = Math.max(-89, Math.min(89, elevationAngle));
383
-
384
- updatePositions();
385
- } else {
386
- // Hover effect
387
- const m = getMouse(e);
388
- mouse.set(m.x, m.y);
389
- raycaster.setFromCamera(mouse, camera);
390
- const intersects = raycaster.intersectObject(lightMesh);
391
- canvas.style.cursor = intersects.length > 0 ? 'grab' : 'default';
392
- }
393
- });
394
-
395
- window.addEventListener('mouseup', () => {
396
- if (isDragging) {
397
- isDragging = false;
398
- lightMesh.scale.setScalar(1.0);
399
  canvas.style.cursor = 'default';
 
 
 
 
 
 
400
 
401
- // Snap for the bridge output, but keep visual smooth?
402
- // No, let's output exact values, python snaps them.
 
 
403
 
404
  // Send to Python
405
- // Normalize Azimuth for output
406
- let outAz = azimuthAngle % 360;
407
- if(outAz < 0) outAz += 360;
408
-
409
- const data = { azimuth: outAz, elevation: elevationAngle };
410
-
411
  const bridge = document.querySelector("#bridge-output textarea");
412
- if (bridge) {
413
- bridge.value = JSON.stringify(data);
414
  bridge.dispatchEvent(new Event("input", { bubbles: true }));
415
  }
416
  }
417
  });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
- // --- Render Loop ---
420
  function animate() {
421
  requestAnimationFrame(animate);
 
 
 
422
  renderer.render(scene, camera);
423
  }
424
  animate();
425
- updatePositions();
426
 
427
- // --- Exposed Methods for Python ---
428
  window.light3D.updateState = (data) => {
429
- if (!data) return;
430
- if (typeof data === 'string') data = JSON.parse(data);
431
- azimuthAngle = data.azimuth !== undefined ? data.azimuth : azimuthAngle;
432
- elevationAngle = data.elevation !== undefined ? data.elevation : elevationAngle;
433
- updatePositions();
 
434
  };
435
 
436
  window.light3D.updateTexture = (url) => {
437
- if (!url) {
438
- planeMaterial.map = createPlaceholderTexture();
439
- planeMaterial.needsUpdate = true;
440
- return;
441
- }
442
  new THREE.TextureLoader().load(url, (tex) => {
443
  tex.colorSpace = THREE.SRGBColorSpace;
444
- tex.minFilter = THREE.LinearFilter;
445
- planeMaterial.map = tex;
446
 
447
  const img = tex.image;
448
  const aspect = img.width / img.height;
449
- const scale = 1.5;
450
  if (aspect > 1) targetPlane.scale.set(scale, scale / aspect, 1);
451
  else targetPlane.scale.set(scale * aspect, scale, 1);
452
 
453
- planeMaterial.needsUpdate = true;
454
  });
455
  };
456
  };
457
-
458
  initScene();
459
  })();
460
  </script>
461
  """
462
 
463
- # --- UI Setup ---
464
  css = """
465
  #col-container { max-width: 1200px; margin: 0 auto; }
466
- #light-control-wrapper { box-shadow: 0 4px 12px rgba(255, 204, 0, 0.2); border: 1px solid #333; }
467
  .gradio-container { overflow: visible !important; }
 
468
  """
469
 
470
  with gr.Blocks() as demo:
471
  gr.HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>')
472
 
473
- gr.Markdown("# 💡 Qwen Edit 2509 Multi-Angle Lighting Control")
474
- gr.Markdown("Control the **direction of the light source** using the 3D visualizer or sliders.")
475
-
476
  with gr.Row(elem_id="col-container"):
477
- # Left: Controls
478
  with gr.Column(scale=1):
479
- image = gr.Image(label="Input Image", type="pil", height=250)
480
 
481
- # The 3D Viewport
482
  gr.HTML(THREE_JS_LOGIC)
483
 
484
- # Hidden Bridges
485
- bridge_output = gr.Textbox(elem_id="bridge-output", visible=False, label="Bridge Output")
486
- bridge_input = gr.JSON(value={}, visible=False, label="Bridge Input")
487
 
488
  with gr.Group():
489
- gr.Markdown("### Light Position")
490
- azimuth_slider = gr.Slider(0, 360, step=45, label="Horizontal Direction (Azimuth)", value=0, info="0=Front, 90=Right, 180=Rear, 270=Left")
491
- elevation_slider = gr.Slider(-90, 90, step=15, label="Vertical Angle (Elevation)", value=0, info="+90=Above, -90=Below")
 
492
 
493
- run_btn = gr.Button("🚀 Relight Image", variant="primary", size="lg")
494
-
495
- prompt_preview = gr.Textbox(label="Generated Prompt", interactive=False, value="<sks> Light source from the Front")
496
 
497
- # Right: Result
498
  with gr.Column(scale=1):
499
- result = gr.Image(label="Output Image")
500
 
501
- with gr.Accordion("Advanced", open=False):
502
- seed = gr.Slider(0, MAX_SEED, value=0, label="Seed")
503
  randomize_seed = gr.Checkbox(True, label="Randomize Seed")
504
- guidance_scale = gr.Slider(1, 10, 1.0, step=0.1, label="Guidance")
505
- steps = gr.Slider(1, 50, 4, step=1, label="Steps")
506
  width = gr.Slider(256, 2048, 1024, step=8, label="Width")
507
  height = gr.Slider(256, 2048, 1024, step=8, label="Height")
508
 
509
- # --- Event Wiring ---
510
-
511
- # 1. Helper to sync Textbox (Prompt)
512
- def update_prompt(az, el):
513
- return build_lighting_prompt(az, el)
514
-
515
- # 2. Image Upload
516
- def handle_image_upload(img):
517
  w, h = update_dimensions_on_upload(img)
518
- b64 = get_image_base64(img)
519
- return w, h, b64
520
-
521
- image.upload(handle_image_upload, inputs=[image], outputs=[width, height, bridge_input]) \
522
- .then(None, [image], None, js="(img) => { if(img) window.light3D.updateTexture(img); }")
523
 
524
- # 3. Sliders -> Update Bridge -> Update 3D
525
- def sync_sliders_to_bridge(az, el):
526
- return {"azimuth": az, "elevation": el}
527
 
 
 
 
528
  for s in [azimuth_slider, elevation_slider]:
529
- s.change(sync_sliders_to_bridge, [azimuth_slider, elevation_slider], bridge_input) \
530
  .then(update_prompt, [azimuth_slider, elevation_slider], prompt_preview)
 
 
531
 
532
- # Trigger JS update when bridge_input changes
533
- bridge_input.change(None, [bridge_input], None, js="(val) => window.light3D.updateState(val)")
534
-
535
- # 4. 3D Interaction (Bridge Output) -> Update Sliders
536
- def sync_bridge_to_sliders(data_str):
537
  try:
538
- data = json.loads(data_str)
539
- return data.get('azimuth', 0), data.get('elevation', 0)
540
- except:
541
- return 0, 0
542
-
543
- bridge_output.change(sync_bridge_to_sliders, bridge_output, [azimuth_slider, elevation_slider])
544
 
545
- # 5. Generation
546
  run_btn.click(
547
  infer_lighting_edit,
548
- inputs=[image, azimuth_slider, elevation_slider, seed, randomize_seed, guidance_scale, steps, height, width],
549
  outputs=[result, seed, prompt_preview]
550
  )
551
 
 
7
  import base64
8
  from io import BytesIO
9
  from PIL import Image
10
+ import os
11
 
12
+ # --- Imports (Custom Structure) ---
 
13
  try:
14
+ from diffusers import FlowMatchEulerDiscreteScheduler
15
+ # Assuming these modules exist in your environment as requested
16
+ from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
17
+ from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
18
+ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
19
  except ImportError:
20
+ print("⚠️ Custom modules (qwenimage) not found. Using standard Diffusers classes for UI testing.")
21
+ from diffusers import Qwen2VLForConditionalGeneration as QwenImageEditPlusPipeline # Fallback
22
+ from diffusers import Transformer2DModel as QwenImageTransformer2DModel # Fallback
23
+ # Dummy class for processor if missing
24
+ class QwenDoubleStreamAttnProcessorFA3: pass
 
 
 
 
 
 
 
25
 
26
+ # --- Configuration ---
27
  MAX_SEED = np.iinfo(np.int32).max
 
 
28
  dtype = torch.bfloat16
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
 
31
+ ADAPTER_SPECS = {
32
+ "Multi-Angle-Lighting": {
33
+ "repo": "dx8152/Qwen-Edit-2509-Multi-Angle-Lighting",
34
+ "weights": "多角度灯光-251116.safetensors",
35
+ "adapter_name": "multi-angle-lighting"
36
+ }
37
+ }
 
 
 
 
 
38
 
39
+ # --- Model Loading ---
40
+ # Global variable for the pipe
41
+ pipe = None
 
 
 
42
 
43
+ def initialize_model():
44
+ global pipe
45
+ print(" Initializing Base Model...")
46
+
47
+ try:
48
+ # Load Transformer first
49
+ transformer = QwenImageTransformer2DModel.from_pretrained(
50
+ "prithivMLmods/Qwen-Image-Edit-Rapid-AIO-V19",
51
+ torch_dtype=dtype,
52
+ device_map='cuda'
53
+ )
54
+
55
+ # Load Pipeline with injected transformer
56
+ pipe = QwenImageEditPlusPipeline.from_pretrained(
57
+ "Qwen/Qwen-Image-Edit-2511",
58
+ transformer=transformer,
59
+ torch_dtype=dtype
60
+ ).to(device)
61
+
62
+ # Set Processor
63
+ try:
64
+ pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
65
+ print("✅ Flash Attention 3 Processor set successfully.")
66
+ except Exception as e:
67
+ print(f"⚠️ Warning: Could not set FA3 processor: {e}")
68
+
69
+ except Exception as e:
70
+ print(f"❌ Model Loading Failed: {e}")
71
+ # Create dummy pipe for UI testing if GPU/Weights missing
72
+ class DummyPipe:
73
+ def __call__(self, *args, **kwargs):
74
+ class Res: images = [Image.new("RGB", (1024, 1024), "gray")]
75
+ return Res()
76
+ def load_lora_weights(self, *args, **kwargs): pass
77
+ def set_adapters(self, *args, **kwargs): pass
78
+ def get_active_adapters(self): return []
79
+ pipe = DummyPipe()
80
+
81
+ # Initialize base model immediately (or lazy load this too if preferred, but usually base model loads on startup)
82
+ if torch.cuda.is_available():
83
+ initialize_model()
84
+ else:
85
+ print("⚠️ CUDA not available. Skipping model load for UI rendering.")
86
+ class DummyPipe:
87
+ def __call__(self, *args, **kwargs): return type('obj', (object,), {'images': [Image.new("RGB", (512, 512), "black")]})
88
+ def load_lora_weights(self, *args, **kwargs): pass
89
+ def set_adapters(self, *args, **kwargs): pass
90
+ def get_active_adapters(self): return []
91
+ pipe = DummyPipe()
92
 
93
+ # --- Prompt Logic ---
94
 
95
+ # Mappings based on the requested list
96
+ # Azimuth: 0=Front, 90=Right, 180=Rear, 270=Left
97
+ HORIZONTAL_MAP = {
98
  0: "Light source from the Front",
99
  45: "Light source from the Right Front",
100
  90: "Light source from the Right",
 
105
  315: "Light source from the Left Front"
106
  }
107
 
108
+ def get_lighting_prompt(azimuth: float, elevation: float) -> str:
 
 
 
109
  """
110
+ Determines the prompt based on azimuth (0-360) and elevation (-90 to 90).
111
+ Prioritizes Above/Below if elevation is significant.
 
112
  """
113
+ # 1. Check Vertical Extremes (Elevation)
114
+ # If elevation is > 45 degrees, treat as "Above"
115
+ if elevation > 45:
116
+ return "Light source from Above"
117
+ # If elevation is < -45 degrees, treat as "Below"
118
+ if elevation < -45:
119
+ return "Light source from Below"
120
+
121
+ # 2. Check Horizontal (Azimuth)
122
+ # Snap to nearest 45 degree increment
123
+ az_options = list(HORIZONTAL_MAP.keys())
124
+ closest_az = min(az_options, key=lambda x: abs(x - azimuth))
125
 
126
+ return HORIZONTAL_MAP[closest_az]
127
+
128
+ # --- Inference ---
129
+
130
+ def load_lora_lazy():
131
+ """Checks if LoRA is loaded, if not, loads it."""
132
+ spec = ADAPTER_SPECS["Multi-Angle-Lighting"]
133
+ try:
134
+ active = pipe.get_active_adapters()
135
+ if spec["adapter_name"] not in active:
136
+ print(f"♻️ Lazy Loading LoRA: {spec['repo']}...")
137
+ pipe.load_lora_weights(
138
+ spec["repo"],
139
+ weight_name=spec["weights"],
140
+ adapter_name=spec["adapter_name"]
141
+ )
142
+ pipe.set_adapters([spec["adapter_name"]], adapter_weights=[1.0])
143
+ print("✅ LoRA Loaded.")
144
+ except Exception as e:
145
+ print(f"⚠️ LoRA Load Error: {e}")
146
 
147
  @spaces.GPU
148
  def infer_lighting_edit(
 
151
  elevation: float = 0.0,
152
  seed: int = 0,
153
  randomize_seed: bool = True,
154
+ guidance_scale: float = 3.5, # Slightly higher for lighting typically
155
+ num_inference_steps: int = 20,
156
  height: int = 1024,
157
  width: int = 1024,
158
  ):
159
+ if image is None:
160
+ raise gr.Error("Please upload an image first.")
161
+
162
+ # 1. Lazy Load LoRA
163
+ load_lora_lazy()
164
+
165
+ # 2. Build Prompt
166
+ prompt = get_lighting_prompt(azimuth, elevation)
167
+ print(f"💡 Generated Prompt: {prompt}")
168
 
169
+ # 3. Prepare Inputs
170
  if randomize_seed:
171
  seed = random.randint(0, MAX_SEED)
172
  generator = torch.Generator(device=device).manual_seed(seed)
173
+
 
 
 
174
  pil_image = image.convert("RGB")
175
 
176
+ # 4. Generate
177
  result = pipe(
178
  image=[pil_image],
179
  prompt=prompt,
 
187
 
188
  return result, seed, prompt
189
 
190
+ # --- Helpers ---
191
+
192
  def update_dimensions_on_upload(image):
193
  if image is None: return 1024, 1024
194
  w, h = image.size
 
195
  if w > h:
196
  new_w, new_h = 1024, int(1024 * (h / w))
197
  else:
 
205
  img_str = base64.b64encode(buffered.getvalue()).decode()
206
  return f"data:image/png;base64,{img_str}"
207
 
208
+ # --- 3D Lighting Controller (Three.js) ---
209
  THREE_JS_LOGIC = """
210
+ <div id="light-control-wrapper" style="width: 100%; height: 450px; position: relative; background: #0f0f0f; border-radius: 12px; overflow: hidden; box-shadow: inset 0 0 20px #000;">
211
+ <div id="prompt-overlay" style="position: absolute; bottom: 10px; left: 50%; transform: translateX(-50%); background: rgba(0,0,0,0.8); padding: 8px 16px; border-radius: 8px; font-family: sans-serif; font-size: 14px; color: #fbff00; white-space: nowrap; z-index: 10; border: 1px solid #fbff00;">Initializing...</div>
 
212
  </div>
213
  <script>
214
  (function() {
215
  const wrapper = document.getElementById('light-control-wrapper');
216
  const promptOverlay = document.getElementById('prompt-overlay');
217
 
218
+ window.light3D = { updateState: null, updateTexture: null };
 
 
 
 
219
 
220
  const initScene = () => {
221
+ if (typeof THREE === 'undefined') { setTimeout(initScene, 100); return; }
 
 
 
222
 
223
+ // Setup
224
  const scene = new THREE.Scene();
225
+ scene.background = new THREE.Color(0x0f0f0f);
226
 
227
+ const camera = new THREE.PerspectiveCamera(50, wrapper.clientWidth / wrapper.clientHeight, 0.1, 1000);
228
+ camera.position.set(4.0, 2.5, 4.0);
 
229
  camera.lookAt(0, 0, 0);
230
 
231
  const renderer = new THREE.WebGLRenderer({ antialias: true });
232
  renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
 
 
 
233
  wrapper.appendChild(renderer.domElement);
234
 
235
+ // --- Environment ---
236
+ const grid = new THREE.GridHelper(6, 12, 0x333333, 0x111111);
237
+ scene.add(grid);
 
 
 
238
 
239
+ // --- Image Plane (The Subject) ---
 
 
 
 
240
  function createPlaceholderTexture() {
241
+ const cvs = document.createElement('canvas');
242
+ cvs.width = 256; cvs.height = 256;
243
+ const ctx = cvs.getContext('2d');
244
+ ctx.fillStyle = '#222'; ctx.fillRect(0,0,256,256);
245
+ ctx.strokeStyle = '#444'; ctx.lineWidth=5; ctx.strokeRect(20,20,216,216);
246
+ ctx.font = '30px Arial'; ctx.fillStyle='#555'; ctx.textAlign='center';
247
+ ctx.fillText('SUBJECT', 128, 138);
248
+ return new THREE.CanvasTexture(cvs);
249
  }
250
 
251
+ const planeMat = new THREE.MeshStandardMaterial({
252
  map: createPlaceholderTexture(),
253
  side: THREE.DoubleSide,
254
+ roughness: 0.5,
255
  metalness: 0.1
256
  });
257
+ const targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.5, 1.5), planeMat);
258
+ targetPlane.position.y = 0.75;
259
+ scene.add(targetPlane);
260
 
261
+ // Base Ambient Light (Dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  scene.add(new THREE.AmbientLight(0xffffff, 0.2));
263
 
264
+ // --- The Controlled Light Source ---
265
+ const RADIUS = 2.2;
266
+ let azimuth = 0; // 0 - 360
267
+ let elevation = 0; // -90 - 90
 
 
 
 
 
 
268
 
269
+ const lightGroup = new THREE.Group();
270
+ scene.add(lightGroup);
 
 
 
271
 
272
+ // The Physical Bulb Representation
273
+ const bulbGeo = new THREE.SphereGeometry(0.15, 32, 32);
274
+ const bulbMat = new THREE.MeshBasicMaterial({ color: 0xffaa00 });
275
+ const bulb = new THREE.Mesh(bulbGeo, bulbMat);
276
+ lightGroup.add(bulb);
277
+
278
+ // The Glow Halo
279
+ const glowGeo = new THREE.SphereGeometry(0.25, 32, 32);
280
+ const glowMat = new THREE.MeshBasicMaterial({ color: 0xffaa00, transparent: true, opacity: 0.3 });
281
+ const glow = new THREE.Mesh(glowGeo, glowMat);
282
+ lightGroup.add(glow);
283
+
284
+ // The Actual Light Caster
285
+ const spotLight = new THREE.PointLight(0xffaa00, 2, 10);
286
+ lightGroup.add(spotLight);
287
+
288
+ // Visual Guides (Orbit Rings)
289
+ const azRing = new THREE.Mesh(new THREE.TorusGeometry(RADIUS, 0.02, 16, 64), new THREE.MeshBasicMaterial({ color: 0x444444, transparent:true, opacity:0.5 }));
290
+ azRing.rotation.x = Math.PI/2;
291
+ azRing.position.y = 0.75; // Center of image
292
+ scene.add(azRing);
293
+
294
+ const elRing = new THREE.Mesh(new THREE.TorusGeometry(RADIUS, 0.02, 16, 64), new THREE.MeshBasicMaterial({ color: 0x444444, transparent:true, opacity:0.5 }));
295
+ elRing.rotation.y = Math.PI/2;
296
+ elRing.position.y = 0.75;
297
+ scene.add(elRing);
298
+
299
+ // --- Logic ---
300
+ function updatePosition() {
301
+ const rAz = THREE.MathUtils.degToRad(azimuth);
302
+ const rEl = THREE.MathUtils.degToRad(elevation);
303
 
304
+ // Standard Spherical: Y is up.
305
+ // x = r * cos(el) * sin(az)
 
 
 
306
  // y = r * sin(el)
307
+ // z = r * cos(el) * cos(az)
308
+ // Offset Y by center height (0.75)
309
 
310
+ const x = RADIUS * Math.cos(rEl) * Math.sin(rAz);
311
+ const y = (RADIUS * Math.sin(rEl)) + 0.75;
312
+ const z = RADIUS * Math.cos(rEl) * Math.cos(rAz);
313
 
314
  lightGroup.position.set(x, y, z);
315
+
316
+ // Text Update
 
317
  let text = "";
318
+ if (elevation > 45) text = "Above";
319
+ else if (elevation < -45) text = "Below";
320
  else {
321
+ // Snap Azimuth
322
+ const snaps = {0:'Front', 45:'Right Front', 90:'Right', 135:'Right Rear', 180:'Rear', 225:'Left Rear', 270:'Left', 315:'Left Front'};
323
+ const snapKeys = Object.keys(snaps).map(Number);
324
+ const closest = snapKeys.reduce((p, c) => Math.abs(c - azimuth) < Math.abs(p - azimuth) ? c : p);
325
+ text = snaps[closest];
326
  }
327
+ promptOverlay.innerText = `Light Source: ${text}`;
328
  }
329
 
330
  // --- Interaction ---
 
 
331
  let isDragging = false;
 
332
  const canvas = renderer.domElement;
333
 
334
+ // Simple Orbit Control Logic for the Light
335
+ canvas.addEventListener('mousedown', () => { isDragging = true; canvas.style.cursor = 'grabbing'; });
336
+ window.addEventListener('mouseup', () => {
337
+ if(isDragging) {
338
+ isDragging = false;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  canvas.style.cursor = 'default';
340
+ // Snap on Release
341
+ const snapAz = Math.round(azimuth / 45) * 45;
342
+ let snapEl = 0;
343
+ if(elevation > 45) snapEl = 60; // Bias to high
344
+ else if (elevation < -45) snapEl = -60;
345
+ else snapEl = 0;
346
 
347
+ // Animate snap (simplified)
348
+ azimuth = (snapAz % 360 + 360) % 360;
349
+ elevation = snapEl;
350
+ updatePosition();
351
 
352
  // Send to Python
 
 
 
 
 
 
353
  const bridge = document.querySelector("#bridge-output textarea");
354
+ if(bridge) {
355
+ bridge.value = JSON.stringify({ azimuth: azimuth, elevation: elevation });
356
  bridge.dispatchEvent(new Event("input", { bubbles: true }));
357
  }
358
  }
359
  });
360
+
361
+ canvas.addEventListener('mousemove', (e) => {
362
+ if(isDragging) {
363
+ const rect = canvas.getBoundingClientRect();
364
+ const deltaX = e.movementX;
365
+ const deltaY = e.movementY;
366
+
367
+ azimuth -= deltaX * 0.5;
368
+ if(azimuth < 0) azimuth += 360;
369
+ if(azimuth >= 360) azimuth -= 360;
370
+
371
+ elevation += deltaY * 0.5;
372
+ elevation = Math.max(-80, Math.min(80, elevation));
373
+
374
+ updatePosition();
375
+ }
376
+ });
377
 
378
+ // Loop
379
  function animate() {
380
  requestAnimationFrame(animate);
381
+ // Pulse the bulb
382
+ const time = Date.now() * 0.002;
383
+ glow.scale.setScalar(1.0 + Math.sin(time)*0.1);
384
  renderer.render(scene, camera);
385
  }
386
  animate();
387
+ updatePosition();
388
 
389
+ // --- External Interface ---
390
  window.light3D.updateState = (data) => {
391
+ if(typeof data === 'string') data = JSON.parse(data);
392
+ if(data) {
393
+ azimuth = data.azimuth ?? azimuth;
394
+ elevation = data.elevation ?? elevation;
395
+ updatePosition();
396
+ }
397
  };
398
 
399
  window.light3D.updateTexture = (url) => {
400
+ if(!url) return;
 
 
 
 
401
  new THREE.TextureLoader().load(url, (tex) => {
402
  tex.colorSpace = THREE.SRGBColorSpace;
403
+ planeMat.map = tex;
 
404
 
405
  const img = tex.image;
406
  const aspect = img.width / img.height;
407
+ const scale = 1.5;
408
  if (aspect > 1) targetPlane.scale.set(scale, scale / aspect, 1);
409
  else targetPlane.scale.set(scale * aspect, scale, 1);
410
 
411
+ planeMat.needsUpdate = true;
412
  });
413
  };
414
  };
 
415
  initScene();
416
  })();
417
  </script>
418
  """
419
 
420
+ # --- UI Layout ---
421
  css = """
422
  #col-container { max-width: 1200px; margin: 0 auto; }
 
423
  .gradio-container { overflow: visible !important; }
424
+ #light-control-wrapper { cursor: grab; }
425
  """
426
 
427
  with gr.Blocks() as demo:
428
  gr.HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>')
429
 
430
+ gr.Markdown("# 🎮 3D Light Camera Control (Qwen-Edit-2509-Multi-Angle-Lighting)")
431
+ gr.Markdown("Control the **Light Source** direction by dragging the 3D scene or using sliders.")
432
+
433
  with gr.Row(elem_id="col-container"):
 
434
  with gr.Column(scale=1):
435
+ image = gr.Image(label="Input Image", type="pil", height=300)
436
 
437
+ # 3D Viewport
438
  gr.HTML(THREE_JS_LOGIC)
439
 
440
+ # Communication Bridges
441
+ bridge_output = gr.Textbox(elem_id="bridge-output", visible=False)
442
+ bridge_input = gr.JSON(value={}, visible=False)
443
 
444
  with gr.Group():
445
+ azimuth_slider = gr.Slider(0, 315, step=45, label="Light Azimuth (Horizontal)", value=0,
446
+ info="0=Front, 90=Right, 180=Rear, 270=Left")
447
+ elevation_slider = gr.Slider(-60, 60, step=30, label="Light Elevation (Vertical)", value=0,
448
+ info=">45 = Above, < -45 = Below")
449
 
450
+ run_btn = gr.Button("🚀 Generate Lighting", variant="primary", size="lg")
451
+ prompt_preview = gr.Textbox(label="Active Prompt", interactive=False)
 
452
 
 
453
  with gr.Column(scale=1):
454
+ result = gr.Image(label="Output Result")
455
 
456
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
457
+ seed = gr.Slider(0, MAX_SEED, value=0, step=1, label="Seed")
458
  randomize_seed = gr.Checkbox(True, label="Randomize Seed")
459
+ guidance = gr.Slider(1, 10, 3.5, step=0.1, label="Guidance Scale")
460
+ steps = gr.Slider(1, 50, 20, step=1, label="Inference Steps")
461
  width = gr.Slider(256, 2048, 1024, step=8, label="Width")
462
  height = gr.Slider(256, 2048, 1024, step=8, label="Height")
463
 
464
+ # --- Events ---
465
+
466
+ # 1. Prompt Preview Helper
467
+ def update_prompt(az, el): return get_lighting_prompt(az, el)
468
+
469
+ # 2. Image Upload -> Size & 3D Texture
470
+ def handle_upload(img):
 
471
  w, h = update_dimensions_on_upload(img)
472
+ return w, h, get_image_base64(img)
 
 
 
 
473
 
474
+ image.upload(handle_upload, inputs=image, outputs=[width, height, bridge_input]) \
475
+ .then(None, [image], None, js="(img) => { if(img) window.light3D.updateTexture(img); }")
 
476
 
477
+ # 3. Sliders -> Bridge (Input) -> 3D View
478
+ def slider_to_bridge(az, el): return {"azimuth": az, "elevation": el}
479
+
480
  for s in [azimuth_slider, elevation_slider]:
481
+ s.change(slider_to_bridge, [azimuth_slider, elevation_slider], bridge_input) \
482
  .then(update_prompt, [azimuth_slider, elevation_slider], prompt_preview)
483
+
484
+ bridge_input.change(None, [bridge_input], None, js="(v) => window.light3D.updateState(v)")
485
 
486
+ # 4. 3D View (Bridge Output) -> Sliders
487
+ def bridge_to_slider(data_str):
 
 
 
488
  try:
489
+ d = json.loads(data_str)
490
+ return d.get('azimuth', 0), d.get('elevation', 0)
491
+ except: return 0, 0
492
+
493
+ bridge_output.change(bridge_to_slider, bridge_output, [azimuth_slider, elevation_slider])
 
494
 
495
+ # 5. Generate
496
  run_btn.click(
497
  infer_lighting_edit,
498
+ inputs=[image, azimuth_slider, elevation_slider, seed, randomize_seed, guidance, steps, height, width],
499
  outputs=[result, seed, prompt_preview]
500
  )
501