prithivMLmods commited on
Commit
ba87cb1
·
verified ·
1 Parent(s): 81bbf62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +513 -498
app.py CHANGED
@@ -3,43 +3,42 @@ import numpy as np
3
  import random
4
  import torch
5
  import spaces
6
- import json
7
- import base64
8
- from io import BytesIO
9
  from PIL import Image
10
 
11
- # --- Imports based on your request ---
 
12
  try:
13
- from diffusers import FlowMatchEulerDiscreteScheduler
14
- # Attempting to import the custom pipeline/model classes
15
- # Assuming these files exist in your PYTHONPATH or local directory
16
  from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
17
  from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
18
  from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
19
- except ImportError as e:
20
- print(f"⚠️ Import Error: {e}")
21
- print("Falling back to standard classes for UI demonstration (Generate will fail without local files).")
22
- # Mocks for UI testing if files are missing
23
- from diffusers import DiffusionPipeline
24
- class QwenImageEditPlusPipeline(DiffusionPipeline):
25
- @classmethod
26
- def from_pretrained(cls, *args, **kwargs): return cls()
27
- def to(self, device): return self
28
- def load_lora_weights(self, *args, **kwargs): pass
29
- def set_adapters(self, *args, **kwargs): pass
30
- def get_active_adapters(self): return []
31
- def __call__(self, *args, **kwargs):
32
- class R: images=[Image.new("RGB", (512,512), "gray")]
33
- return R()
34
- class QwenImageTransformer2DModel:
35
- @classmethod
36
- def from_pretrained(cls, *args, **kwargs): return cls()
37
-
38
- # --- Configuration ---
39
  MAX_SEED = np.iinfo(np.int32).max
 
 
40
  dtype = torch.bfloat16
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  ADAPTER_SPECS = {
44
  "Multi-Angle-Lighting": {
45
  "repo": "dx8152/Qwen-Edit-2509-Multi-Angle-Lighting",
@@ -48,147 +47,94 @@ ADAPTER_SPECS = {
48
  }
49
  }
50
 
51
- # --- Model Loading ---
52
- print("⏳ Initializing Model...")
53
- try:
54
- transformer_model = QwenImageTransformer2DModel.from_pretrained(
55
- "prithivMLmods/Qwen-Image-Edit-Rapid-AIO-V19",
56
- torch_dtype=dtype,
57
- device_map=device
58
- )
59
-
60
- pipe = QwenImageEditPlusPipeline.from_pretrained(
61
- "Qwen/Qwen-Image-Edit-2511",
62
- transformer=transformer_model,
63
- torch_dtype=dtype
64
- ).to(device)
65
-
66
- try:
67
- pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
68
- print("✅ Flash Attention 3 Processor set successfully.")
69
- except Exception as e:
70
- print(f"⚠️ Warning: Could not set FA3 processor: {e}")
71
-
72
- except Exception as e:
73
- print(f"❌ Critical Model Load Error: {e}")
74
- pipe = None
75
 
76
- # --- Prompt Logic ---
 
 
77
 
78
- def get_lighting_prompt(azimuth: float, elevation: float) -> str:
79
  """
80
- Maps Azimuth (0-360) and Elevation (-90 to 90) to the specific prompt list.
 
 
 
 
81
  """
82
- # 1. Handle Vertical Extremes
83
- if elevation >= 55:
 
84
  return "Light source from Above"
85
- if elevation <= -55:
86
  return "Light source from Below"
 
 
 
 
87
 
88
- # 2. Handle Horizontal (Azimuth)
89
- # Normalize azimuth to 0-360
90
- az = azimuth % 360
91
-
92
- # 8-way split (45 degrees per sector).
93
- # Center of "Front" (0) is at 0. Sector is 337.5 to 22.5.
94
-
95
- if az >= 337.5 or az < 22.5:
96
- return "Light source from the Front"
97
- elif 22.5 <= az < 67.5:
98
- return "Light source from the Right Front"
99
- elif 67.5 <= az < 112.5:
100
- return "Light source from the Right"
101
- elif 112.5 <= az < 157.5:
102
- return "Light source from the Right Rear"
103
- elif 157.5 <= az < 202.5:
104
- return "Light source from the Rear"
105
- elif 202.5 <= az < 247.5:
106
- return "Light source from the Left Rear"
107
- elif 247.5 <= az < 292.5:
108
- return "Light source from the Left"
109
- elif 292.5 <= az < 337.5:
110
- return "Light source from the Left Front"
111
-
112
- return "Light source from the Front" # Fallback
113
-
114
- # --- Helper Functions ---
115
-
116
- def load_adapter_if_needed():
117
- """Lazy loads the LoRA only when generation is requested."""
118
- if pipe is None: return
119
 
 
120
  spec = ADAPTER_SPECS["Multi-Angle-Lighting"]
121
- active = pipe.get_active_adapters()
122
-
123
- if spec["adapter_name"] not in active:
124
- print(f"📥 Lazy Loading LoRA: {spec['repo']}...")
125
  pipe.load_lora_weights(
126
  spec["repo"],
127
  weight_name=spec["weights"],
128
  adapter_name=spec["adapter_name"]
129
  )
130
  pipe.set_adapters([spec["adapter_name"]], adapter_weights=[1.0])
131
- print("✅ LoRA Loaded and Active.")
132
-
133
- def update_dimensions_on_upload(image):
134
- if image is None: return 1024, 1024
135
- w, h = image.size
136
- # Aspect ratio logic to keep near 1024
137
- if w > h:
138
- new_w, new_h = 1024, int(1024 * (h / w))
139
- else:
140
- new_h, new_w = 1024, int(1024 * (w / h))
141
- # Snap to multiples of 8
142
- return (new_w // 8) * 8, (new_h // 8) * 8
143
-
144
- def get_image_base64(image):
145
- if image is None: return None
146
- buffered = BytesIO()
147
- image.save(buffered, format="PNG")
148
- img_str = base64.b64encode(buffered.getvalue()).decode()
149
- return f"data:image/png;base64,{img_str}"
150
-
151
- # --- Inference ---
152
-
153
- @spaces.GPU
154
- def infer_lighting(
155
- image: Image.Image,
156
- azimuth: float,
157
- elevation: float,
158
- seed: int,
159
- randomize_seed: bool,
160
- guidance_scale: float,
161
- steps: int,
162
- height: int,
163
- width: int,
164
- ):
165
- if pipe is None:
166
- raise gr.Error("Model failed to initialize. Check console logs.")
167
 
168
- if image is None:
169
- raise gr.Error("Please upload an image first.")
170
-
171
- # 1. Ensure Adapter is Loaded
172
- load_adapter_if_needed()
173
-
174
- # 2. Build Prompt
175
- prompt = get_lighting_prompt(azimuth, elevation)
176
- print(f"💡 Generated Prompt: {prompt}")
177
 
178
- # 3. Setup Generator
179
  if randomize_seed:
180
  seed = random.randint(0, MAX_SEED)
181
  generator = torch.Generator(device=device).manual_seed(seed)
182
-
183
- pil_image = image.convert("RGB")
184
 
185
- # 4. Run Pipeline
 
 
 
 
186
  result = pipe(
187
  image=[pil_image],
188
- prompt=f"<sks> {prompt}", # Assuming <sks> triggers the LoRA + prompt
189
- height=height,
190
- width=width,
191
- num_inference_steps=steps,
192
  generator=generator,
193
  guidance_scale=guidance_scale,
194
  num_images_per_prompt=1,
@@ -196,388 +142,457 @@ def infer_lighting(
196
 
197
  return result, seed, prompt
198
 
199
- # --- 3D HTML Logic (Updated for Light Controller) ---
200
- THREE_JS_LOGIC = """
201
- <div id="light-control-wrapper" style="width: 100%; height: 500px; position: relative; background: radial-gradient(circle at center, #2a2a2a 0%, #111 100%); border-radius: 12px; overflow: hidden; border: 1px solid #333;">
202
- <div id="prompt-overlay" style="position: absolute; top: 15px; left: 50%; transform: translateX(-50%); background: rgba(0,0,0,0.7); padding: 8px 16px; border-radius: 20px; font-family: sans-serif; font-size: 14px; font-weight: bold; color: #ffeb3b; border: 1px solid #ffeb3b; white-space: nowrap; z-index: 10; pointer-events: none;">
203
- Initializing Light Studio...
204
- </div>
205
- <div style="position: absolute; bottom: 10px; right: 10px; color: #666; font-family: sans-serif; font-size: 10px; pointer-events: none;">
206
- Drag the Sun Right Click to Pan
207
- </div>
208
- </div>
209
- <script>
210
- (function() {
211
- const wrapper = document.getElementById('light-control-wrapper');
212
- const promptOverlay = document.getElementById('prompt-overlay');
213
-
214
- window.lightControl = { updateState: null, updateTexture: null };
215
-
216
- const initScene = () => {
217
- if (typeof THREE === 'undefined') { setTimeout(initScene, 100); return; }
218
-
219
- // 1. Scene Setup
220
- const scene = new THREE.Scene();
221
- // Fog to blend floor into background
222
- scene.fog = new THREE.FogExp2(0x111111, 0.05);
223
-
224
- const camera = new THREE.PerspectiveCamera(45, wrapper.clientWidth / wrapper.clientHeight, 0.1, 1000);
225
- camera.position.set(0, 3, 6.5);
226
- camera.lookAt(0, 0.5, 0);
227
-
228
- const renderer = new THREE.WebGLRenderer({ antialias: true, alpha: true });
229
- renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
230
- renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
231
- renderer.shadowMap.enabled = true;
232
- wrapper.appendChild(renderer.domElement);
233
-
234
- // 2. Objects
235
- const CENTER = new THREE.Vector3(0, 0.8, 0);
236
- const ORBIT_RADIUS = 2.5;
237
-
238
- // Grid / Floor
239
- const grid = new THREE.GridHelper(10, 20, 0x444444, 0x222222);
240
- scene.add(grid);
241
-
242
- // Subject Plane (The Image)
243
- function createPlaceholder() {
244
- const cvs = document.createElement('canvas'); cvs.width=256; cvs.height=256;
245
- const ctx = cvs.getContext('2d');
246
- ctx.fillStyle = '#333'; ctx.fillRect(0,0,256,256);
247
- ctx.fillStyle = '#555'; ctx.font = '30px Arial'; ctx.textAlign='center';
248
- ctx.fillText("SUBJECT", 128, 128);
249
- return new THREE.CanvasTexture(cvs);
250
- }
251
 
252
- let planeMat = new THREE.MeshStandardMaterial({
253
- map: createPlaceholder(),
254
- side: THREE.DoubleSide,
255
- roughness: 0.5, metalness: 0.1
256
- });
257
- let subjectPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.5, 1.5), planeMat);
258
- subjectPlane.position.copy(CENTER);
259
- subjectPlane.castShadow = true;
260
- subjectPlane.receiveShadow = true;
261
- scene.add(subjectPlane);
262
-
263
- // The "Sun" (Light Source Control)
264
- const lightGroup = new THREE.Group();
265
- scene.add(lightGroup);
266
-
267
- // Visual Sun
268
- const sunGeo = new THREE.SphereGeometry(0.2, 32, 32);
269
- const sunMat = new THREE.MeshBasicMaterial({ color: 0xffeb3b });
270
- const sunMesh = new THREE.Mesh(sunGeo, sunMat);
271
- lightGroup.add(sunMesh);
272
-
273
- // Glow effect (larger translucent sphere)
274
- const glowMesh = new THREE.Mesh(
275
- new THREE.SphereGeometry(0.35, 16, 16),
276
- new THREE.MeshBasicMaterial({ color: 0xffeb3b, transparent: true, opacity: 0.3 })
277
- );
278
- lightGroup.add(glowMesh);
279
-
280
- // Actual Light Source (for casting shadows in 3D view)
281
- const pointLight = new THREE.PointLight(0xffffff, 1.5, 100);
282
- lightGroup.add(pointLight);
283
-
284
- // Guide Rings (To visualize sphere of influence)
285
- const ringMat = new THREE.LineBasicMaterial({ color: 0x444444, transparent: true, opacity: 0.3 });
286
- const eqRing = new THREE.LineLoop(new THREE.CircleGeometry(ORBIT_RADIUS, 64), ringMat);
287
- eqRing.rotation.x = Math.PI/2; eqRing.position.y = CENTER.y;
288
- scene.add(eqRing);
289
 
290
- const merRing = new THREE.LineLoop(new THREE.CircleGeometry(ORBIT_RADIUS, 64), ringMat);
291
- merRing.rotation.y = Math.PI/2; merRing.position.y = CENTER.y;
292
- scene.add(merRing);
293
-
294
- // Connector line
295
- const connLine = new THREE.Line(new THREE.BufferGeometry(), new THREE.LineBasicMaterial({ color: 0xffeb3b, opacity: 0.5, transparent: true }));
296
- scene.add(connLine);
297
-
298
- // 3. Logic & State
299
- let azimuth = 0; // Degrees
300
- let elevation = 0; // Degrees
301
-
302
- const promptMap = [
303
- { max: 55, min: -999, label: "Front" }, // Placeholder logic
304
- ];
305
-
306
- function getLabel(az, el) {
307
- if (el >= 55) return "LIGHT: ABOVE";
308
- if (el <= -55) return "LIGHT: BELOW";
309
-
310
- const a = (az % 360 + 360) % 360;
311
- if (a >= 337.5 || a < 22.5) return "LIGHT: FRONT";
312
- if (a >= 22.5 && a < 67.5) return "LIGHT: RIGHT FRONT";
313
- if (a >= 67.5 && a < 112.5) return "LIGHT: RIGHT";
314
- if (a >= 112.5 && a < 157.5) return "LIGHT: RIGHT REAR";
315
- if (a >= 157.5 && a < 202.5) return "LIGHT: REAR";
316
- if (a >= 202.5 && a < 247.5) return "LIGHT: LEFT REAR";
317
- if (a >= 247.5 && a < 292.5) return "LIGHT: LEFT";
318
- if (a >= 292.5 && a < 337.5) return "LIGHT: LEFT FRONT";
319
- return "LIGHT";
320
- }
321
-
322
- function updatePositions() {
323
- // Convert Spherical to Cartesian
324
- const phi = THREE.MathUtils.degToRad(90 - elevation); // Elevation from Y axis
325
- const theta = THREE.MathUtils.degToRad(azimuth); // Azimuth around Y axis
326
-
327
- const x = ORBIT_RADIUS * Math.sin(phi) * Math.sin(theta);
328
- const y = ORBIT_RADIUS * Math.cos(phi) + CENTER.y;
329
- const z = ORBIT_RADIUS * Math.sin(phi) * Math.cos(theta);
330
-
331
- lightGroup.position.set(x, y, z);
332
- lightGroup.lookAt(CENTER);
333
-
334
- // Update Connector
335
- connLine.geometry.setFromPoints([CENTER, lightGroup.position]);
336
-
337
- // Update Text
338
- promptOverlay.innerText = getLabel(azimuth, elevation);
339
- promptOverlay.style.boxShadow = `0 0 15px rgba(255, 235, 59, ${Math.max(0.2, (elevation+90)/180)})`;
340
- }
341
-
342
- // 4. Interaction (Drag the Sun)
343
- const raycaster = new THREE.Raycaster();
344
- const mouse = new THREE.Vector2();
345
- let isDragging = false;
346
-
347
- // Invisible sphere for raycasting drag
348
- const dragSphere = new THREE.Mesh(
349
- new THREE.SphereGeometry(ORBIT_RADIUS, 32, 32),
350
- new THREE.MeshBasicMaterial({ visible: false })
351
- );
352
- dragSphere.position.copy(CENTER);
353
- scene.add(dragSphere);
354
-
355
- function handleInput(e) {
356
- if (!isDragging) return;
357
 
358
- const rect = wrapper.getBoundingClientRect();
359
- // Handle touch or mouse
360
- const clientX = e.touches ? e.touches[0].clientX : e.clientX;
361
- const clientY = e.touches ? e.touches[0].clientY : e.clientY;
362
-
363
- mouse.x = ((clientX - rect.left) / rect.width) * 2 - 1;
364
- mouse.y = -((clientY - rect.top) / rect.height) * 2 + 1;
365
-
366
- raycaster.setFromCamera(mouse, camera);
367
- const intersects = raycaster.intersectObject(dragSphere);
368
-
369
- if (intersects.length > 0) {
370
- const point = intersects[0].point;
371
- const rel = new THREE.Vector3().subVectors(point, CENTER);
372
 
373
- // Cartesian to Spherical
374
- const r = rel.length();
375
- // Elevation (Lat)
376
- const lat = 90 - THREE.MathUtils.radToDeg(Math.acos(rel.y / r));
377
- // Azimuth (Lon)
378
- const lon = THREE.MathUtils.radToDeg(Math.atan2(rel.x, rel.z));
379
-
380
- elevation = Math.max(-90, Math.min(90, lat));
381
- azimuth = (lon + 360) % 360; // Normalize
382
-
383
- updatePositions();
384
- }
385
- }
386
-
387
- // Event Listeners
388
- const getMouse = (e) => {
389
- const rect = wrapper.getBoundingClientRect();
390
- return {
391
- x: ((e.clientX - rect.left) / rect.width) * 2 - 1,
392
- y: -((e.clientY - rect.top) / rect.height) * 2 + 1
393
- };
394
- };
395
-
396
- wrapper.addEventListener('mousedown', (e) => {
397
- const m = getMouse(e);
398
- raycaster.setFromCamera(m, camera);
399
- const intersects = raycaster.intersectObject(sunMesh); // Hit the sun?
400
- if(intersects.length > 0 || raycaster.intersectObject(glowMesh).length > 0) {
401
- isDragging = true;
402
- wrapper.style.cursor = 'grabbing';
403
- // Disable OrbitControls usually attached to camera if we want pure sun control
404
- // But for now, we assume user hits the sun specifically
405
- }
406
- });
407
-
408
- window.addEventListener('mousemove', (e) => {
409
- if(isDragging) {
410
- handleInput(e);
411
- } else {
412
- // Hover state
413
- const m = getMouse(e);
414
- raycaster.setFromCamera(m, camera);
415
- if(raycaster.intersectObject(sunMesh).length > 0) {
416
- wrapper.style.cursor = 'grab';
417
- sunMesh.scale.setScalar(1.2);
418
- } else {
419
- wrapper.style.cursor = 'default';
420
- sunMesh.scale.setScalar(1.0);
421
  }
422
- }
423
- });
424
 
425
- window.addEventListener('mouseup', () => {
426
- if(isDragging) {
427
- isDragging = false;
428
- wrapper.style.cursor = 'default';
429
 
430
- // Snap for cleaner output?
431
- // Let's snap to nearest integer for clean sliders
432
- azimuth = Math.round(azimuth);
433
- elevation = Math.round(elevation);
434
-
435
- // Send to Python
436
- const bridge = document.querySelector("#bridge-output textarea");
437
- if (bridge) {
438
- bridge.value = JSON.stringify({ azimuth, elevation });
439
- bridge.dispatchEvent(new Event("input", { bubbles: true }));
440
  }
441
- }
442
- });
443
-
444
- // 5. Render Loop
445
- function animate() {
446
- requestAnimationFrame(animate);
447
- renderer.render(scene, camera);
448
- }
449
- animate();
450
- updatePositions(); // Initial draw
451
-
452
- // 6. External API
453
- window.lightControl.updateState = (data) => {
454
- if(typeof data === 'string') data = JSON.parse(data);
455
- if(data) {
456
- azimuth = data.azimuth !== undefined ? data.azimuth : azimuth;
457
- elevation = data.elevation !== undefined ? data.elevation : elevation;
458
- updatePositions();
459
- }
460
- };
 
 
 
 
461
 
462
- window.lightControl.updateTexture = (url) => {
463
- if (!url) {
464
- planeMat.map = createPlaceholder();
465
- planeMat.needsUpdate = true;
466
- return;
467
- }
468
- new THREE.TextureLoader().load(url, (tex) => {
469
- tex.colorSpace = THREE.SRGBColorSpace;
470
- planeMat.map = tex;
471
 
472
- const img = tex.image;
473
- const aspect = img.width / img.height;
474
- const scale = 1.5;
475
- if (aspect > 1) subjectPlane.scale.set(scale, scale / aspect, 1);
476
- else subjectPlane.scale.set(scale * aspect, scale, 1);
477
 
478
- planeMat.needsUpdate = true;
479
- });
480
- };
481
- };
482
-
483
- initScene();
484
- })();
485
- </script>
486
- """
487
-
488
- # --- UI Setup ---
489
- css = """
490
- #col-container { max-width: 1400px; margin: 0 auto; }
491
- #light-control-wrapper { box-shadow: 0 8px 32px rgba(0,0,0,0.8); border: 2px solid #222; }
492
- .gradio-container { background-color: #0b0f19 !important; }
493
- h1, p { color: #e2e8f0 !important; }
494
- .block.svelte-1t38q2d { background: #1a202c !important; border-color: #2d3748 !important; }
495
- """
496
-
497
- with gr.Blocks() as demo:
498
- gr.HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>')
499
-
500
- with gr.Column(elem_id="col-container"):
501
- gr.Markdown("# 💡 Qwen-Edit-2509 — Multi-Angle Lighting Studio")
502
- gr.Markdown("Control the lighting direction using the **3D Sun Controller** or sliders below. The system automatically selects the correct lighting prompt based on the sun's position.")
503
-
504
- with gr.Row():
505
- # Left: Controls
506
- with gr.Column(scale=4):
507
- # 3D Viewport
508
- gr.HTML(THREE_JS_LOGIC)
509
 
510
- # Hidden Bridges
511
- bridge_output = gr.Textbox(elem_id="bridge-output", visible=False) # JS -> Python
512
- bridge_input = gr.JSON(value={}, visible=False) # Python -> JS
513
-
514
- with gr.Row():
515
- azimuth_slider = gr.Slider(0, 360, label="Sun Azimuth (Rotation)", value=0)
516
- elevation_slider = gr.Slider(-90, 90, label="Sun Elevation (Height)", value=0)
 
 
 
 
 
517
 
518
- prompt_preview = gr.Textbox(label="Active Lighting Prompt", value="Light source from the Front", interactive=False)
519
 
520
- with gr.Row():
521
- image = gr.Image(label="Subject Image", type="pil", height=200)
522
- run_btn = gr.Button("✨ Render Light", variant="primary", scale=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
- # Right: Result & Settings
525
- with gr.Column(scale=3):
526
- result = gr.Image(label="Relit Result")
 
527
 
528
- with gr.Accordion("🛠️ Advanced Settings", open=False):
529
- seed = gr.Slider(0, MAX_SEED, value=42, label="Seed")
530
- randomize_seed = gr.Checkbox(True, label="Randomize Seed")
531
- guidance_scale = gr.Slider(1, 10, 5.0, step=0.1, label="Guidance Scale")
532
- steps = gr.Slider(1, 50, 20, step=1, label="Inference Steps")
533
- width = gr.Slider(256, 2048, 1024, step=8, label="Width")
534
- height = gr.Slider(256, 2048, 1024, step=8, label="Height")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
 
536
- # --- Event Wiring ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
 
538
- # 1. Update Prompt Display
539
- def update_prompt_ui(az, el):
540
- return get_lighting_prompt(az, el)
 
 
 
 
 
 
 
 
 
 
 
541
 
542
- # 2. Image Handling
543
- def handle_image_upload(img):
544
- w, h = update_dimensions_on_upload(img)
545
- b64 = get_image_base64(img)
546
- return w, h, b64
 
 
547
 
548
- image.upload(handle_image_upload, inputs=[image], outputs=[width, height, bridge_input]) \
549
- .then(None, [image], None, js="(img) => { if(img) window.lightControl.updateTexture(img); }")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
 
551
- # 3. Sliders -> Bridge -> 3D View
552
- def sync_sliders_to_bridge(az, el):
553
- return {"azimuth": az, "elevation": el}
 
 
 
554
 
555
- azimuth_slider.change(sync_sliders_to_bridge, [azimuth_slider, elevation_slider], bridge_input)
556
- elevation_slider.change(sync_sliders_to_bridge, [azimuth_slider, elevation_slider], bridge_input)
 
557
 
558
- # Trigger JS update when bridge input changes
559
- bridge_input.change(None, [bridge_input], None, js="(val) => window.lightControl.updateState(val)")
 
 
 
 
 
 
 
 
 
 
 
 
 
560
 
561
- # Update Prompt when sliders move
562
- azimuth_slider.change(update_prompt_ui, [azimuth_slider, elevation_slider], prompt_preview)
563
- elevation_slider.change(update_prompt_ui, [azimuth_slider, elevation_slider], prompt_preview)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
 
565
- # 4. 3D View (Bridge Output) -> Sliders
566
- def sync_bridge_to_sliders(data_str):
567
- try:
568
- data = json.loads(data_str)
569
- return data.get('azimuth', 0), data.get('elevation', 0)
570
- except:
571
- return 0, 0
572
 
573
- bridge_output.change(sync_bridge_to_sliders, bridge_output, [azimuth_slider, elevation_slider])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
- # 5. Generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
  run_btn.click(
577
- infer_lighting,
578
- inputs=[image, azimuth_slider, elevation_slider, seed, randomize_seed, guidance_scale, steps, height, width],
579
  outputs=[result, seed, prompt_preview]
580
  )
581
 
582
  if __name__ == "__main__":
583
- demo.launch(css=css, theme=gr.themes.Base())
 
 
3
  import random
4
  import torch
5
  import spaces
 
 
 
6
  from PIL import Image
7
 
8
+ # --- Updated Imports as requested ---
9
+ from diffusers import FlowMatchEulerDiscreteScheduler
10
  try:
 
 
 
11
  from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
12
  from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
13
  from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
14
+ except ImportError:
15
+ raise ImportError("Please ensure the 'qwenimage' package is installed or in your python path.")
16
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
+
19
+ # --- Model Configuration ---
20
  dtype = torch.bfloat16
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
+ # Load the pipeline with the custom transformer
24
+ pipe = QwenImageEditPlusPipeline.from_pretrained(
25
+ "Qwen/Qwen-Image-Edit-2511",
26
+ transformer=QwenImageTransformer2DModel.from_pretrained(
27
+ "prithivMLmods/Qwen-Image-Edit-Rapid-AIO-V19",
28
+ torch_dtype=dtype,
29
+ device_map='cuda'
30
+ ),
31
+ torch_dtype=dtype
32
+ ).to(device)
33
+
34
+ # Attempt to set Flash Attention 3
35
+ try:
36
+ pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
37
+ print("Flash Attention 3 Processor set successfully.")
38
+ except Exception as e:
39
+ print(f"Warning: Could not set FA3 processor: {e}")
40
+
41
+ # Adapter Specs
42
  ADAPTER_SPECS = {
43
  "Multi-Angle-Lighting": {
44
  "repo": "dx8152/Qwen-Edit-2509-Multi-Angle-Lighting",
 
47
  }
48
  }
49
 
50
+ # State to track loaded adapter
51
+ CURRENT_LOADED_ADAPTER = None
52
+
53
+ # --- Prompt Building & Logic ---
54
+
55
+ # Lighting mappings based on Azimuth (Horizontal)
56
+ # 0 degrees = Front
57
+ LIGHTING_AZIMUTH_MAP = {
58
+ 0: "Light source from the Front",
59
+ 45: "Light source from the Right Front",
60
+ 90: "Light source from the Right",
61
+ 135: "Light source from the Right Rear",
62
+ 180: "Light source from the Rear",
63
+ 225: "Light source from the Left Rear",
64
+ 270: "Light source from the Left",
65
+ 315: "Light source from the Left Front"
66
+ }
 
 
 
 
 
 
 
67
 
68
+ def snap_to_nearest(value, options):
69
+ """Snap a value to the nearest option in a list."""
70
+ return min(options, key=lambda x: abs(x - value))
71
 
72
+ def build_lighting_prompt(azimuth: float, elevation: float) -> str:
73
  """
74
+ Build a lighting prompt based on the specific 10 options provided.
75
+ Priority:
76
+ 1. If Elevation is very high (> 60) -> Above
77
+ 2. If Elevation is very low (< -60) -> Below
78
+ 3. Else -> Use Azimuth mapping
79
  """
80
+
81
+ # Check Vertical extremes first
82
+ if elevation >= 60:
83
  return "Light source from Above"
84
+ if elevation <= -60:
85
  return "Light source from Below"
86
+
87
+ # Snap Horizontal
88
+ azimuth_snapped = snap_to_nearest(azimuth, list(LIGHTING_AZIMUTH_MAP.keys()))
89
+ return LIGHTING_AZIMUTH_MAP[azimuth_snapped]
90
 
91
+ @spaces.GPU
92
+ def infer_lighting_edit(
93
+ image: Image.Image,
94
+ azimuth: float = 0.0,
95
+ elevation: float = 0.0,
96
+ seed: int = 0,
97
+ randomize_seed: bool = True,
98
+ guidance_scale: float = 1.0,
99
+ num_inference_steps: int = 4,
100
+ height: int = 1024,
101
+ width: int = 1024,
102
+ ):
103
+ global CURRENT_LOADED_ADAPTER
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ # --- Lazy Loading Logic ---
106
  spec = ADAPTER_SPECS["Multi-Angle-Lighting"]
107
+ if CURRENT_LOADED_ADAPTER != spec["adapter_name"]:
108
+ print(f"Lazy loading adapter: {spec['adapter_name']}...")
 
 
109
  pipe.load_lora_weights(
110
  spec["repo"],
111
  weight_name=spec["weights"],
112
  adapter_name=spec["adapter_name"]
113
  )
114
  pipe.set_adapters([spec["adapter_name"]], adapter_weights=[1.0])
115
+ CURRENT_LOADED_ADAPTER = spec["adapter_name"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ # --- Generation ---
118
+ progress = gr.Progress(track_tqdm=True)
119
+
120
+ prompt = build_lighting_prompt(azimuth, elevation)
121
+ print(f"Generated Lighting Prompt: {prompt}")
 
 
 
 
122
 
 
123
  if randomize_seed:
124
  seed = random.randint(0, MAX_SEED)
125
  generator = torch.Generator(device=device).manual_seed(seed)
 
 
126
 
127
+ if image is None:
128
+ raise gr.Error("Please upload an image first.")
129
+
130
+ pil_image = image.convert("RGB") if isinstance(image, Image.Image) else Image.open(image).convert("RGB")
131
+
132
  result = pipe(
133
  image=[pil_image],
134
+ prompt=prompt,
135
+ height=height if height != 0 else None,
136
+ width=width if width != 0 else None,
137
+ num_inference_steps=num_inference_steps,
138
  generator=generator,
139
  guidance_scale=guidance_scale,
140
  num_images_per_prompt=1,
 
142
 
143
  return result, seed, prompt
144
 
145
+ def update_dimensions_on_upload(image):
146
+ if image is None:
147
+ return 1024, 1024
148
+ original_width, original_height = image.size
149
+ if original_width > original_height:
150
+ new_width = 1024
151
+ aspect_ratio = original_height / original_width
152
+ new_height = int(new_width * aspect_ratio)
153
+ else:
154
+ new_height = 1024
155
+ aspect_ratio = original_width / original_height
156
+ new_width = int(new_height * aspect_ratio)
157
+ new_width = (new_width // 8) * 8
158
+ new_height = (new_height // 8) * 8
159
+ return new_width, new_height
160
+
161
+ # --- 3D Lighting Control Component ---
162
+ class LightControl3D(gr.HTML):
163
+ """
164
+ A 3D Lighting control component using Three.js.
165
+ Visualizes a Light Source (Sun/Bulb) relative to the subject.
166
+ """
167
+ def __init__(self, value=None, imageUrl=None, **kwargs):
168
+ if value is None:
169
+ value = {"azimuth": 0, "elevation": 0}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ html_template = """
172
+ <div id="light-control-wrapper" style="width: 100%; height: 450px; position: relative; background: #0b0b0b; border-radius: 12px; overflow: hidden; border: 1px solid #333;">
173
+ <div id="prompt-overlay" style="position: absolute; top: 10px; left: 50%; transform: translateX(-50%); background: rgba(0,0,0,0.7); padding: 8px 16px; border-radius: 20px; font-family: sans-serif; font-weight: bold; font-size: 14px; color: #FFD700; white-space: nowrap; z-index: 10; border: 1px solid #FFD700;">Light Source: Front</div>
174
+ </div>
175
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
+ js_on_load = """
178
+ (() => {
179
+ const wrapper = element.querySelector('#light-control-wrapper');
180
+ const promptOverlay = element.querySelector('#prompt-overlay');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ const initScene = () => {
183
+ if (typeof THREE === 'undefined') {
184
+ setTimeout(initScene, 100);
185
+ return;
186
+ }
 
 
 
 
 
 
 
 
 
187
 
188
+ // Scene setup
189
+ const scene = new THREE.Scene();
190
+ scene.background = new THREE.Color(0x0b0b0b);
191
+ scene.fog = new THREE.FogExp2(0x0b0b0b, 0.1);
192
+
193
+ const camera = new THREE.PerspectiveCamera(50, wrapper.clientWidth / wrapper.clientHeight, 0.1, 1000);
194
+ camera.position.set(3.5, 2.5, 3.5);
195
+ camera.lookAt(0, 0.5, 0);
196
+
197
+ const renderer = new THREE.WebGLRenderer({ antialias: true, alpha: true });
198
+ renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
199
+ renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
200
+ wrapper.insertBefore(renderer.domElement, promptOverlay);
201
+
202
+ // Ambient Light (dim)
203
+ scene.add(new THREE.AmbientLight(0xffffff, 0.2));
204
+
205
+ // The Grid
206
+ const grid = new THREE.GridHelper(6, 12, 0x444444, 0x222222);
207
+ scene.add(grid);
208
+
209
+ // Constants
210
+ const CENTER = new THREE.Vector3(0, 0.75, 0);
211
+ const ORBIT_RADIUS = 2.2;
212
+
213
+ // State
214
+ let azimuthAngle = props.value?.azimuth || 0;
215
+ let elevationAngle = props.value?.elevation || 0;
216
+
217
+ // --- Mappings for UI ---
218
+ const azimuthNames = {
219
+ 0: 'Front', 45: 'Right Front', 90: 'Right',
220
+ 135: 'Right Rear', 180: 'Rear', 225: 'Left Rear',
221
+ 270: 'Left', 315: 'Left Front'
222
+ };
223
+
224
+ function getPromptText(az, el) {
225
+ if (el >= 60) return "Light source from Above";
226
+ if (el <= -60) return "Light source from Below";
227
+
228
+ // Snap azimuth
229
+ const steps = [0, 45, 90, 135, 180, 225, 270, 315];
230
+ const snapped = steps.reduce((prev, curr) => Math.abs(curr - az) < Math.abs(prev - az) ? curr : prev);
231
+ return "Light source from the " + azimuthNames[snapped];
 
 
 
 
232
  }
 
 
233
 
234
+ // --- Objects ---
 
 
 
235
 
236
+ // 1. The Subject (Plane with image)
237
+ function createPlaceholderTexture() {
238
+ const canvas = document.createElement('canvas');
239
+ canvas.width = 256; canvas.height = 256;
240
+ const ctx = canvas.getContext('2d');
241
+ ctx.fillStyle = '#222'; ctx.fillRect(0,0,256,256);
242
+ ctx.fillStyle = '#444'; ctx.beginPath(); ctx.arc(128,128, 80, 0, Math.PI*2); ctx.fill();
243
+ ctx.strokeStyle = '#666'; ctx.lineWidth=4; ctx.beginPath(); ctx.moveTo(64,64); ctx.lineTo(192,192); ctx.moveTo(192,64); ctx.lineTo(64,192); ctx.stroke();
244
+ return new THREE.CanvasTexture(canvas);
 
245
  }
246
+
247
+ let currentTexture = createPlaceholderTexture();
248
+ const planeMaterial = new THREE.MeshBasicMaterial({ map: currentTexture, side: THREE.DoubleSide });
249
+ let targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
250
+ targetPlane.position.copy(CENTER);
251
+ scene.add(targetPlane);
252
+
253
+ function updateTextureFromUrl(url) {
254
+ if (!url) {
255
+ planeMaterial.map = createPlaceholderTexture();
256
+ planeMaterial.needsUpdate = true;
257
+ return;
258
+ }
259
+ new THREE.TextureLoader().load(url, (tex) => {
260
+ planeMaterial.map = tex;
261
+ planeMaterial.needsUpdate = true;
262
+ const img = tex.image;
263
+ if(img.width && img.height) {
264
+ const aspect = img.width / img.height;
265
+ targetPlane.scale.set(aspect > 1 ? 1 : aspect, aspect > 1 ? 1/aspect : 1, 1);
266
+ }
267
+ });
268
+ }
269
+ if (props.imageUrl) updateTextureFromUrl(props.imageUrl);
270
 
271
+ // 2. The Light Source Visualizer (Yellow Sphere)
272
+ const lightGroup = new THREE.Group();
273
+ scene.add(lightGroup);
 
 
 
 
 
 
274
 
275
+ // The "Sun"
276
+ const sunGeo = new THREE.SphereGeometry(0.25, 32, 32);
277
+ const sunMat = new THREE.MeshBasicMaterial({ color: 0xFFD700 });
278
+ const sunMesh = new THREE.Mesh(sunGeo, sunMat);
 
279
 
280
+ // Glow effect (simple semi-transparent sphere)
281
+ const glowGeo = new THREE.SphereGeometry(0.4, 32, 32);
282
+ const glowMat = new THREE.MeshBasicMaterial({ color: 0xFFD700, transparent: true, opacity: 0.3 });
283
+ const glowMesh = new THREE.Mesh(glowGeo, glowMat);
284
+ sunMesh.add(glowMesh);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
+ lightGroup.add(sunMesh);
287
+
288
+ // Directional Line (Ray of light)
289
+ const lineGeo = new THREE.BufferGeometry().setFromPoints([new THREE.Vector3(0,0,0), new THREE.Vector3(0,0,-1)]);
290
+ const lineMat = new THREE.LineDashedMaterial({ color: 0xFFD700, dashSize: 0.2, gapSize: 0.1 });
291
+ const lightRay = new THREE.Line(lineGeo, lineMat);
292
+ lightRay.computeLineDistances();
293
+ lightGroup.add(lightRay);
294
+
295
+ // Actual Light for the scene (to light up the 3D meshes)
296
+ const sceneLight = new THREE.DirectionalLight(0xffffff, 2.0);
297
+ lightGroup.add(sceneLight);
298
 
299
+ // 3. UI Handles (Rings)
300
 
301
+ // Azimuth Ring (Green-ish)
302
+ const azRing = new THREE.Mesh(
303
+ new THREE.TorusGeometry(ORBIT_RADIUS, 0.03, 16, 64),
304
+ new THREE.MeshBasicMaterial({ color: 0x444444, transparent: true, opacity: 0.5 })
305
+ );
306
+ azRing.rotation.x = Math.PI/2;
307
+ azRing.position.y = 0.1;
308
+ scene.add(azRing);
309
+
310
+ // --- Update Logic ---
311
+ function updatePositions() {
312
+ const azRad = THREE.MathUtils.degToRad(azimuthAngle);
313
+ const elRad = THREE.MathUtils.degToRad(elevationAngle);
314
+
315
+ // Spherical coordinates
316
+ const x = ORBIT_RADIUS * Math.sin(azRad) * Math.cos(elRad);
317
+ const y = ORBIT_RADIUS * Math.sin(elRad) + CENTER.y;
318
+ const z = ORBIT_RADIUS * Math.cos(azRad) * Math.cos(elRad);
319
+
320
+ lightGroup.position.set(x, y, z);
321
+ lightGroup.lookAt(CENTER);
322
+
323
+ // Scale ray to reach center
324
+ const dist = lightGroup.position.distanceTo(CENTER);
325
+ lightRay.scale.z = dist;
326
+
327
+ // Update UI text
328
+ promptOverlay.textContent = getPromptText(azimuthAngle, elevationAngle);
329
+
330
+ // Update Sun Color based on position (Warning colors for Above/Below)
331
+ if (elevationAngle >= 60 || elevationAngle <= -60) {
332
+ sunMat.color.setHex(0xFF4500); // Orange-Red for vertical extremes
333
+ promptOverlay.style.color = "#FF4500";
334
+ promptOverlay.style.borderColor = "#FF4500";
335
+ } else {
336
+ sunMat.color.setHex(0xFFD700); // Gold for horizontal
337
+ promptOverlay.style.color = "#FFD700";
338
+ promptOverlay.style.borderColor = "#FFD700";
339
+ }
340
+ }
341
 
342
+ // --- Interaction (Dragging) ---
343
+ const raycaster = new THREE.Raycaster();
344
+ const mouse = new THREE.Vector2();
345
+ let isDragging = false;
346
 
347
+ // Invisible sphere for raycasting drag
348
+ const dragSphere = new THREE.Mesh(
349
+ new THREE.SphereGeometry(ORBIT_RADIUS, 32, 32),
350
+ new THREE.MeshBasicMaterial({ visible: false, side: THREE.DoubleSide })
351
+ );
352
+ dragSphere.position.copy(CENTER);
353
+ scene.add(dragSphere);
354
+
355
+ function onDown(e) {
356
+ const rect = wrapper.getBoundingClientRect();
357
+ const x = e.clientX || e.touches[0].clientX;
358
+ const y = e.clientY || e.touches[0].clientY;
359
+ mouse.x = ((x - rect.left) / rect.width) * 2 - 1;
360
+ mouse.y = -((y - rect.top) / rect.height) * 2 + 1;
361
+
362
+ raycaster.setFromCamera(mouse, camera);
363
+ const intersects = raycaster.intersectObject(sunMesh); // Click strictly on sun
364
+ if(intersects.length > 0) {
365
+ isDragging = true;
366
+ wrapper.style.cursor = 'grabbing';
367
+ // Disable orbit controls if we added them, but here we are manual
368
+ }
369
+ }
370
 
371
+ function onMove(e) {
372
+ if(!isDragging) {
373
+ // Hover effect
374
+ const rect = wrapper.getBoundingClientRect();
375
+ mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
376
+ mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
377
+ raycaster.setFromCamera(mouse, camera);
378
+ const intersects = raycaster.intersectObject(sunMesh);
379
+ wrapper.style.cursor = intersects.length > 0 ? 'grab' : 'default';
380
+ return;
381
+ }
382
+
383
+ const rect = wrapper.getBoundingClientRect();
384
+ const x = e.clientX || (e.touches ? e.touches[0].clientX : 0);
385
+ const y = e.clientY || (e.touches ? e.touches[0].clientY : 0);
386
+ mouse.x = ((x - rect.left) / rect.width) * 2 - 1;
387
+ mouse.y = -((y - rect.top) / rect.height) * 2 + 1;
388
+
389
+ raycaster.setFromCamera(mouse, camera);
390
+ const intersects = raycaster.intersectObject(dragSphere);
391
+
392
+ if(intersects.length > 0) {
393
+ const point = intersects[0].point;
394
+ const rel = new THREE.Vector3().subVectors(point, CENTER);
395
+
396
+ // Calculate Azimuth
397
+ let az = Math.atan2(rel.x, rel.z) * (180/Math.PI);
398
+ if(az < 0) az += 360;
399
+
400
+ // Calculate Elevation
401
+ const dist = Math.sqrt(rel.x*rel.x + rel.z*rel.z);
402
+ let el = Math.atan2(rel.y, dist) * (180/Math.PI);
403
+
404
+ // Clamp Elevation slightly to avoid gimbal lock flip visual issues
405
+ el = Math.max(-89, Math.min(89, el));
406
+
407
+ azimuthAngle = az;
408
+ elevationAngle = el;
409
+ updatePositions();
410
+ }
411
+ }
412
 
413
+ function onUp() {
414
+ if(isDragging) {
415
+ isDragging = false;
416
+ wrapper.style.cursor = 'default';
417
+
418
+ // Snap logic for output
419
+ let finalAz = azimuthAngle;
420
+ let finalEl = elevationAngle;
421
+
422
+ // Simple snap visual helper could go here, but we just trigger values
423
+ props.value = { azimuth: finalAz, elevation: finalEl };
424
+ trigger('change', props.value);
425
+ }
426
+ }
427
 
428
+ wrapper.addEventListener('mousedown', onDown);
429
+ window.addEventListener('mousemove', onMove);
430
+ window.addEventListener('mouseup', onUp);
431
+
432
+ wrapper.addEventListener('touchstart', onDown, {passive: false});
433
+ window.addEventListener('touchmove', onMove, {passive: false});
434
+ window.addEventListener('touchend', onUp);
435
 
436
+ // Initial render
437
+ updatePositions();
438
+
439
+ function animate() {
440
+ requestAnimationFrame(animate);
441
+ renderer.render(scene, camera);
442
+ }
443
+ animate();
444
+
445
+ // Watchers
446
+ setInterval(() => {
447
+ if (props.imageUrl !== currentTexture.sourceFile) { // simple check
448
+ // handled by prop check mostly
449
+ }
450
+ // Sync from Sliders
451
+ if (props.value && (props.value.azimuth !== azimuthAngle || props.value.elevation !== elevationAngle)) {
452
+ if(!isDragging) {
453
+ azimuthAngle = props.value.azimuth;
454
+ elevationAngle = props.value.elevation;
455
+ updatePositions();
456
+ }
457
+ }
458
+ }, 100);
459
+
460
+ wrapper._updateTexture = updateTextureFromUrl;
461
+ }
462
+ initScene();
463
+ })();
464
+ """
465
+
466
+ super().__init__(
467
+ value=value,
468
+ html_template=html_template,
469
+ js_on_load=js_on_load,
470
+ imageUrl=imageUrl,
471
+ **kwargs
472
+ )
473
 
474
+ # --- UI Layout ---
475
+ css = '''
476
+ #col-container { max-width: 1200px; margin: 0 auto; }
477
+ .dark .progress-text { color: white !important; }
478
+ #light-3d-control { min-height: 450px; }
479
+ '''
480
 
481
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
482
+ gr.Markdown("""
483
+ # 💡 Qwen Edit 2509 — Multi-Angle Lighting Control
484
 
485
+ Control the **Light Source Direction** relative to your subject.
486
+ Drag the **Yellow Sun** in the 3D viewport or use the sliders to position the light.
487
+ """)
488
+
489
+ with gr.Row():
490
+ # Left column: Input and Controls
491
+ with gr.Column(scale=1):
492
+ image = gr.Image(label="Input Image", type="pil", height=300)
493
+
494
+ gr.Markdown("### 🎮 3D Light Source Controller")
495
+
496
+ light_3d = LightControl3D(
497
+ value={"azimuth": 0, "elevation": 0},
498
+ elem_id="light-3d-control"
499
+ )
500
 
501
+ run_btn = gr.Button("🚀 Generate Lighting", variant="primary", size="lg")
502
+
503
+ gr.Markdown("### 🎚️ Fine-Tune Position")
504
+
505
+ azimuth_slider = gr.Slider(
506
+ label="Light Azimuth (Horizontal)",
507
+ minimum=0, maximum=359, step=45, value=0,
508
+ info="0°=Front, 90°=Right, 180°=Rear, 270°=Left"
509
+ )
510
+
511
+ elevation_slider = gr.Slider(
512
+ label="Light Elevation (Vertical)",
513
+ minimum=-90, maximum=90, step=10, value=0,
514
+ info=">60° = Above, <-60° = Below"
515
+ )
516
+
517
+ prompt_preview = gr.Textbox(
518
+ label="Active Lighting Prompt",
519
+ value="Light source from the Front",
520
+ interactive=False
521
+ )
522
+
523
+ # Right column: Output
524
+ with gr.Column(scale=1):
525
+ result = gr.Image(label="Relighted Result", height=500)
526
+
527
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
528
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
529
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
530
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=5.0)
531
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=20, step=1, value=4)
532
+ height = gr.Slider(label="Height", minimum=256, maximum=2048, step=8, value=1024)
533
+ width = gr.Slider(label="Width", minimum=256, maximum=2048, step=8, value=1024)
534
 
535
+ # --- Event Wiring ---
 
 
 
 
 
 
536
 
537
+ def sync_3d_to_sliders(val):
538
+ """Syncs 3D controller movement to sliders and prompt"""
539
+ az = val.get('azimuth', 0)
540
+ el = val.get('elevation', 0)
541
+ prompt = build_lighting_prompt(az, el)
542
+ return az, el, prompt
543
+
544
+ def sync_sliders_to_3d(az, el):
545
+ """Syncs slider movement to 3D controller and prompt"""
546
+ prompt = build_lighting_prompt(az, el)
547
+ return {"azimuth": az, "elevation": el}, prompt
548
+
549
+ # 3D -> Sliders
550
+ light_3d.change(
551
+ fn=sync_3d_to_sliders,
552
+ inputs=[light_3d],
553
+ outputs=[azimuth_slider, elevation_slider, prompt_preview]
554
+ )
555
+
556
+ # Sliders -> 3D
557
+ azimuth_slider.change(
558
+ fn=sync_sliders_to_3d,
559
+ inputs=[azimuth_slider, elevation_slider],
560
+ outputs=[light_3d, prompt_preview]
561
+ )
562
+ elevation_slider.change(
563
+ fn=sync_sliders_to_3d,
564
+ inputs=[azimuth_slider, elevation_slider],
565
+ outputs=[light_3d, prompt_preview]
566
+ )
567
 
568
+ # Upload Image handling
569
+ def update_on_upload(img):
570
+ w, h = update_dimensions_on_upload(img)
571
+ # Convert to data URL for 3D texture
572
+ if img is None:
573
+ return w, h, gr.update(imageUrl=None)
574
+
575
+ import base64
576
+ from io import BytesIO
577
+ buffered = BytesIO()
578
+ img.save(buffered, format="PNG")
579
+ img_str = base64.b64encode(buffered.getvalue()).decode()
580
+ data_url = f"data:image/png;base64,{img_str}"
581
+ return w, h, gr.update(imageUrl=data_url)
582
+
583
+ image.upload(
584
+ fn=update_on_upload,
585
+ inputs=[image],
586
+ outputs=[width, height, light_3d]
587
+ )
588
+
589
+ # Generate
590
  run_btn.click(
591
+ fn=infer_lighting_edit,
592
+ inputs=[image, azimuth_slider, elevation_slider, seed, randomize_seed, guidance_scale, num_inference_steps, height, width],
593
  outputs=[result, seed, prompt_preview]
594
  )
595
 
596
  if __name__ == "__main__":
597
+ head = '<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>'
598
+ demo.launch(head=head, css=css)