prithivMLmods commited on
Commit
32154ee
·
verified ·
1 Parent(s): 8344210

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +265 -334
app.py CHANGED
@@ -12,7 +12,6 @@ try:
12
  from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
13
  from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
14
  except ImportError:
15
- # Fallback/Instruction if custom packages are missing
16
  raise ImportError("Please ensure the 'qwenimage' package is installed.")
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
@@ -21,7 +20,6 @@ MAX_SEED = np.iinfo(np.int32).max
21
  dtype = torch.bfloat16
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
- # 1. Load the Pipeline
25
  pipe = QwenImageEditPlusPipeline.from_pretrained(
26
  "Qwen/Qwen-Image-Edit-2511",
27
  transformer=QwenImageTransformer2DModel.from_pretrained(
@@ -32,14 +30,12 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
32
  torch_dtype=dtype
33
  ).to(device)
34
 
35
- # 2. Set Flash Attention 3 (if available)
36
  try:
37
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
38
  print("Flash Attention 3 Processor set successfully.")
39
  except Exception as e:
40
  print(f"Warning: Could not set FA3 processor: {e}")
41
 
42
- # 3. Adapter Specs (Lighting LoRA)
43
  ADAPTER_SPECS = {
44
  "Multi-Angle-Lighting": {
45
  "repo": "dx8152/Qwen-Edit-2509-Multi-Angle-Lighting",
@@ -48,13 +44,9 @@ ADAPTER_SPECS = {
48
  }
49
  }
50
 
51
- # Global state to track currently loaded adapter
52
  CURRENT_LOADED_ADAPTER = None
53
 
54
- # --- Logic: Mappings & Prompt Building ---
55
-
56
- # Lighting mappings for Azimuth (Horizontal)
57
- # 0 = Front, moving clockwise
58
  LIGHTING_AZIMUTH_MAP = {
59
  0: "Light source from the Front",
60
  45: "Light source from the Right Front",
@@ -67,135 +59,74 @@ LIGHTING_AZIMUTH_MAP = {
67
  }
68
 
69
  def snap_to_nearest_key(value, keys):
70
- """Finds the nearest key in a list of numbers."""
71
  return min(keys, key=lambda x: abs(x - value))
72
 
73
  def build_lighting_prompt(azimuth: float, elevation: float) -> str:
74
- """
75
- Constructs the specific text prompt required by the LoRA.
76
- Logic:
77
- 1. Prioritize Vertical Extremes (>60° or <-60°)
78
- 2. Fallback to Horizontal Azimuth mappings
79
- """
80
- # 1. Vertical Extremes
81
- if elevation >= 60:
82
- return "Light source from Above"
83
- if elevation <= -60:
84
- return "Light source from Below"
85
-
86
- # 2. Horizontal Snap
87
  keys = list(LIGHTING_AZIMUTH_MAP.keys())
88
- # Handle the 360 wrap-around for "Front" (0 vs 360)
89
- # If azimuth is > 337.5, it snaps to 0
90
- if azimuth > 337.5:
91
- azimuth = 0
92
-
93
  azimuth_snapped = snap_to_nearest_key(azimuth, keys)
94
  return LIGHTING_AZIMUTH_MAP[azimuth_snapped]
95
 
96
- # --- Inference Function ---
97
-
98
  @spaces.GPU
99
- def infer_lighting_edit(
100
- image: Image.Image,
101
- azimuth: float = 0.0,
102
- elevation: float = 0.0,
103
- seed: int = 0,
104
- randomize_seed: bool = True,
105
- guidance_scale: float = 5.0,
106
- num_inference_steps: int = 4,
107
- height: int = 1024,
108
- width: int = 1024,
109
- ):
110
  global CURRENT_LOADED_ADAPTER
111
-
112
- # 1. Lazy Load Adapter
113
  spec = ADAPTER_SPECS["Multi-Angle-Lighting"]
114
  if CURRENT_LOADED_ADAPTER != spec["adapter_name"]:
115
- print(f"⚙️ Lazy loading adapter: {spec['adapter_name']}...")
116
- pipe.load_lora_weights(
117
- spec["repo"],
118
- weight_name=spec["weights"],
119
- adapter_name=spec["adapter_name"]
120
- )
121
  pipe.set_adapters([spec["adapter_name"]], adapter_weights=[1.0])
122
  CURRENT_LOADED_ADAPTER = spec["adapter_name"]
123
 
124
- # 2. Build Prompt
125
  prompt = build_lighting_prompt(azimuth, elevation)
126
  print(f"💡 Generated Prompt: {prompt}")
127
 
128
- # 3. Prepare Inputs
129
- if image is None:
130
- raise gr.Error("Please upload an image first.")
131
-
132
- if randomize_seed:
133
- seed = random.randint(0, MAX_SEED)
134
  generator = torch.Generator(device=device).manual_seed(seed)
135
-
136
  pil_image = image.convert("RGB")
137
 
138
- # 4. Run Inference
139
  result = pipe(
140
- image=[pil_image],
141
- prompt=prompt,
142
- height=height,
143
- width=width,
144
- num_inference_steps=num_inference_steps,
145
- generator=generator,
146
- guidance_scale=guidance_scale,
147
- num_images_per_prompt=1,
148
  ).images[0]
149
-
150
  return result, seed, prompt
151
 
152
  def update_dimensions_on_upload(image):
153
- """Resizes image to nearest multiple of 8, max 1024, preserving aspect ratio."""
154
- if image is None:
155
- return 1024, 1024
156
  w, h = image.size
157
-
158
- # Constraint: Max dimension 1024
159
- if w > h:
160
- new_w = 1024
161
- new_h = int(new_w * (h / w))
162
- else:
163
- new_h = 1024
164
- new_w = int(new_h * (w / h))
165
-
166
- # Constraint: Multiple of 8
167
- new_w = (new_w // 8) * 8
168
- new_h = (new_h // 8) * 8
169
-
170
- return new_w, new_h
171
-
172
- # --- Enhanced 3D Component ---
173
 
 
174
  class LightControl3D(gr.HTML):
175
  """
176
- Advanced 3D Light Controller using Three.js.
177
- Features: Hemisphere guide, Beam visualization, Dynamic color feedback.
178
  """
179
  def __init__(self, value=None, imageUrl=None, **kwargs):
180
  if value is None: value = {"azimuth": 0, "elevation": 0}
181
 
182
- # HTML Container
183
  html_template = """
184
- <div id="light-control-wrapper" style="width: 100%; height: 500px; position: relative; background: radial-gradient(circle at center, #1a1a1a 0%, #000000 100%); border-radius: 12px; overflow: hidden; border: 1px solid #333; box-shadow: inset 0 0 20px #000;">
185
- <div id="prompt-badge" style="position: absolute; top: 15px; left: 50%; transform: translateX(-50%);
186
- background: rgba(0,0,0,0.8); border: 1px solid #FFD700; color: #FFD700;
187
- padding: 8px 24px; border-radius: 30px; font-family: monospace; font-weight: bold; font-size: 14px;
188
- z-index: 10; pointer-events: none; transition: all 0.2s ease;">
 
189
  Light Source: Front
190
  </div>
191
 
192
- <div style="position: absolute; bottom: 15px; right: 15px; color: #555; font-size: 10px; font-family: sans-serif; pointer-events: none;">
193
- Drag to rotate Scroll to zoom
 
194
  </div>
195
  </div>
196
  """
197
 
198
- # JavaScript Logic
199
  js_on_load = """
200
  (() => {
201
  const wrapper = element.querySelector('#light-control-wrapper');
@@ -204,114 +135,154 @@ class LightControl3D(gr.HTML):
204
  const initScene = () => {
205
  if (typeof THREE === 'undefined') { setTimeout(initScene, 100); return; }
206
 
207
- // --- 1. Scene & Camera ---
208
  const scene = new THREE.Scene();
209
- // No background color set here, letting CSS gradient show through
 
210
 
211
- const camera = new THREE.PerspectiveCamera(45, wrapper.clientWidth / wrapper.clientHeight, 0.1, 1000);
212
- camera.position.set(4, 3, 4); // Isometric-ish view
213
  camera.lookAt(0, 0.5, 0);
214
 
215
- const renderer = new THREE.WebGLRenderer({ antialias: true, alpha: true });
216
  renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
217
  renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
 
 
218
  wrapper.appendChild(renderer.domElement);
219
 
220
- // --- 2. Helpers (Grid & Dome) ---
221
  const CENTER = new THREE.Vector3(0, 0.75, 0);
222
- const RADIUS = 2.5;
 
 
 
 
 
 
 
 
 
223
 
224
- // Floor Grid
225
- const grid = new THREE.GridHelper(6, 12, 0x444444, 0x111111);
226
  scene.add(grid);
 
 
 
 
227
 
228
- // Hemisphere Guide (Wireframe Dome)
229
- const domeGeo = new THREE.SphereGeometry(RADIUS, 16, 8, 0, Math.PI * 2, 0, Math.PI * 0.5);
230
- const domeMat = new THREE.MeshBasicMaterial({ color: 0x333333, wireframe: true, transparent: true, opacity: 0.15 });
231
- const dome = new THREE.Mesh(domeGeo, domeMat);
232
- dome.position.y = CENTER.y - 0.75; // Ground the dome
233
- scene.add(dome);
234
-
235
- // Elevation Rings (Visual guides for 0, 45, 60 degrees)
236
- const ringMat = new THREE.MeshBasicMaterial({ color: 0x555555, transparent: true, opacity: 0.3, side: THREE.DoubleSide });
237
- const eqRing = new THREE.Mesh(new THREE.TorusGeometry(RADIUS, 0.01, 8, 64), ringMat);
238
- eqRing.rotation.x = Math.PI / 2;
239
- eqRing.position.y = CENTER.y;
240
- scene.add(eqRing);
241
-
242
- // --- 3. The Subject (Image Plane) ---
243
- let planeMesh;
244
- const planeMat = new THREE.MeshBasicMaterial({ color: 0x222222, side: THREE.DoubleSide });
245
 
246
- function createPlane(width=1.2, height=1.2) {
247
- if(planeMesh) scene.remove(planeMesh);
248
- planeMesh = new THREE.Mesh(new THREE.PlaneGeometry(width, height), planeMat);
249
- planeMesh.position.copy(CENTER);
250
- planeMesh.lookAt(camera.position); // Billboarding slightly? No, fixed upright.
251
- planeMesh.rotation.set(0,0,0); // Reset rotation
252
- scene.add(planeMesh);
 
 
 
 
253
  }
254
- createPlane();
255
 
256
  // Texture Loader
257
  function updateTexture(url) {
258
- if (!url) {
259
- planeMat.map = null;
260
- planeMat.needsUpdate = true;
261
- return;
262
- }
263
  new THREE.TextureLoader().load(url, (tex) => {
264
- planeMat.map = tex;
265
- planeMat.needsUpdate = true;
266
- // Adjust Aspect Ratio
 
 
267
  const img = tex.image;
268
  if(img && img.width && img.height) {
269
  const aspect = img.width / img.height;
270
- const size = 1.4; // Max dimension
271
- if (aspect > 1) createPlane(size, size/aspect);
272
- else createPlane(size*aspect, size);
273
  }
274
  });
275
  }
276
  if (props.imageUrl) updateTexture(props.imageUrl);
277
 
278
- // --- 4. The Light Gizmo (Interactive) ---
279
- const lightGroup = new THREE.Group();
280
- scene.add(lightGroup);
281
-
282
- // The Orb
283
- const orb = new THREE.Mesh(
284
- new THREE.SphereGeometry(0.2, 32, 32),
285
- new THREE.MeshBasicMaterial({ color: 0xFFD700 })
286
- );
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
- // The Glow
289
- const glow = new THREE.Mesh(
290
- new THREE.SphereGeometry(0.35, 32, 32),
291
- new THREE.MeshBasicMaterial({ color: 0xFFD700, transparent: true, opacity: 0.4 })
292
- );
293
- orb.add(glow);
294
- lightGroup.add(orb);
295
-
296
- // The Beam (Cone)
297
- const beamGeo = new THREE.ConeGeometry(0.4, RADIUS, 32, 1, true);
298
- beamGeo.translate(0, -RADIUS/2, 0); // Pivot at base
299
- beamGeo.rotateX(-Math.PI / 2); // Point along Z
 
 
 
 
 
 
 
 
300
  const beamMat = new THREE.MeshBasicMaterial({
301
- color: 0xFFD700,
302
- transparent: true,
303
- opacity: 0.15,
304
- side: THREE.DoubleSide,
305
- depthWrite: false,
306
- blending: THREE.AdditiveBlending
307
  });
308
  const beam = new THREE.Mesh(beamGeo, beamMat);
309
- beam.lookAt(CENTER); // This will need dynamic updating
310
- lightGroup.add(beam);
 
 
 
 
 
 
311
 
312
  // --- 5. State & Logic ---
313
- let az = props.value?.azimuth || 0;
314
- let el = props.value?.elevation || 0;
 
 
 
315
 
316
  const AZ_MAP = {
317
  0: 'Front', 45: 'Right Front', 90: 'Right', 135: 'Right Rear',
@@ -321,114 +292,74 @@ class LightControl3D(gr.HTML):
321
  function getPrompt(a, e) {
322
  if (e >= 60) return "Light source from Above";
323
  if (e <= -60) return "Light source from Below";
324
- // Snap
325
  const steps = [0,45,90,135,180,225,270,315];
326
- // Handle wrapped 360
327
- let normalized = a % 360;
328
- if(normalized < 0) normalized += 360;
329
- const snapped = steps.reduce((p, c) => Math.abs(c-normalized) < Math.abs(p-normalized) ? c : p);
330
  return `Light source from the ${AZ_MAP[snapped]}`;
331
  }
332
 
333
- function updateGizmo() {
334
- const r_az = THREE.MathUtils.degToRad(az);
335
- const r_el = THREE.MathUtils.degToRad(el);
336
-
337
- // Orbit Calculation
338
- const x = RADIUS * Math.sin(r_az) * Math.cos(r_el);
339
- const y = RADIUS * Math.sin(r_el) + CENTER.y;
340
- const z = RADIUS * Math.cos(r_az) * Math.cos(r_el);
341
-
342
- lightGroup.position.set(x, y, z);
343
- lightGroup.lookAt(CENTER); // Points the Beam at center
344
-
345
- // UI Updates
346
- const text = getPrompt(az, el);
347
- badge.innerText = text;
348
-
349
- // Color Logic (Warning for Above/Below)
350
- let mainColor = 0xFFD700; // Gold
351
- if (el >= 60 || el <= -60) mainColor = 0xFF4500; // OrangeRed
352
-
353
- orb.material.color.setHex(mainColor);
354
- glow.material.color.setHex(mainColor);
355
- beam.material.color.setHex(mainColor);
356
- badge.style.borderColor = '#' + new THREE.Color(mainColor).getHexString();
357
- badge.style.color = '#' + new THREE.Color(mainColor).getHexString();
358
- }
359
-
360
- // --- 6. Interaction (Drag) ---
361
  const raycaster = new THREE.Raycaster();
362
  const mouse = new THREE.Vector2();
363
  let isDragging = false;
364
 
365
- // Invisible Drag Sphere (Larger hit area)
366
- const dragSphere = new THREE.Mesh(
367
- new THREE.SphereGeometry(RADIUS, 32, 16),
368
- new THREE.MeshBasicMaterial({ visible: false, side: THREE.DoubleSide })
369
  );
370
- dragSphere.position.copy(CENTER);
371
- scene.add(dragSphere);
372
-
373
- function getMouse(e) {
374
- const rect = wrapper.getBoundingClientRect();
375
- const clientX = e.clientX || (e.touches ? e.touches[0].clientX : 0);
376
- const clientY = e.clientY || (e.touches ? e.touches[0].clientY : 0);
377
- return {
378
- x: ((clientX - rect.left) / rect.width) * 2 - 1,
379
- y: -((clientY - rect.top) / rect.height) * 2 + 1
380
- };
 
381
  }
382
 
383
  function onDown(e) {
384
- const m = getMouse(e);
385
- raycaster.setFromCamera(m, camera);
386
- // Check if clicked near the light orb
387
- const intersects = raycaster.intersectObject(dragSphere);
 
 
 
 
388
  if(intersects.length > 0) {
389
- // Check distance to current light pos to prevent jumping if clicked far away
390
- if (intersects[0].point.distanceTo(lightGroup.position) < 1.0) {
391
- isDragging = true;
392
- wrapper.style.cursor = 'none'; // Hide cursor while dragging for immersion
393
- }
394
  }
395
  }
396
 
397
  function onMove(e) {
398
  if (!isDragging) {
399
- // Hover state
400
- const m = getMouse(e);
401
- raycaster.setFromCamera(m, camera);
402
- const hits = raycaster.intersectObject(dragSphere);
403
- if (hits.length > 0 && hits[0].point.distanceTo(lightGroup.position) < 0.8) {
404
- wrapper.style.cursor = 'pointer';
405
- } else {
406
- wrapper.style.cursor = 'default';
407
- }
408
  return;
409
  }
410
-
411
- const m = getMouse(e);
412
- raycaster.setFromCamera(m, camera);
413
- const intersects = raycaster.intersectObject(dragSphere);
414
 
 
 
 
 
 
 
 
 
415
  if (intersects.length > 0) {
416
- const p = intersects[0].point;
417
- const rel = new THREE.Vector3().subVectors(p, CENTER);
418
-
419
- // Convert Cartesian to Spherical (Azimuth/Elevation)
420
- let newAz = Math.atan2(rel.x, rel.z) * (180 / Math.PI);
421
- if (newAz < 0) newAz += 360;
422
-
423
- const distXZ = Math.sqrt(rel.x*rel.x + rel.z*rel.z);
424
- let newEl = Math.atan2(rel.y, distXZ) * (180 / Math.PI);
425
-
426
- // Limits
427
- newEl = Math.max(-89, Math.min(89, newEl));
428
-
429
- az = newAz;
430
- el = newEl;
431
- updateGizmo();
432
  }
433
  }
434
 
@@ -436,13 +367,11 @@ class LightControl3D(gr.HTML):
436
  if(isDragging) {
437
  isDragging = false;
438
  wrapper.style.cursor = 'default';
439
- // Propagate value back to Gradio
440
- props.value = { azimuth: az, elevation: el };
441
  trigger('change', props.value);
442
  }
443
  }
444
 
445
- // Event Listeners
446
  wrapper.addEventListener('mousedown', onDown);
447
  window.addEventListener('mousemove', onMove);
448
  window.addEventListener('mouseup', onUp);
@@ -450,34 +379,64 @@ class LightControl3D(gr.HTML):
450
  window.addEventListener('touchmove', onMove, {passive: false});
451
  window.addEventListener('touchend', onUp);
452
 
453
- // --- 7. Loop & Watchers ---
454
- updateGizmo(); // Init
455
-
456
  function animate() {
457
  requestAnimationFrame(animate);
458
- // Subtle idle animation for the glow
459
- glow.scale.setScalar(1 + Math.sin(Date.now() * 0.003) * 0.1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  renderer.render(scene, camera);
461
  }
462
  animate();
463
 
464
- // Watch for changes from Python/Sliders
465
  setInterval(() => {
466
- // Texture change
467
- if (props.imageUrl && (!planeMat.map || props.imageUrl !== planeMat.map.image.src)) {
468
- // handled by dedicated updater usually, but fail-safe
469
- }
470
- // Value change
471
  if (props.value && !isDragging) {
472
- if (Math.abs(props.value.azimuth - az) > 0.1 || Math.abs(props.value.elevation - el) > 0.1) {
473
- az = props.value.azimuth;
474
- el = props.value.elevation;
475
- updateGizmo();
476
  }
477
  }
478
  }, 100);
479
 
480
- // Expose updater
481
  wrapper._updateTexture = updateTexture;
482
  };
483
  initScene();
@@ -493,106 +452,79 @@ class LightControl3D(gr.HTML):
493
  )
494
 
495
  # --- UI Layout ---
496
-
497
  css = """
498
- #col-container { max-width: 1200px; margin: 0 auto; }
499
- #3d-container { border: 1px solid #333; border-radius: 12px; overflow: hidden; }
500
  .range-slider { accent-color: #FFD700 !important; }
501
  """
502
 
503
  with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="yellow")) as demo:
504
  gr.Markdown("""
505
- # 💡 Qwen Edit 2509 — 3D Lighting Studio
506
-
507
- **Interactive Relighting:** Drag the ☀️ Sun in the 3D Viewport to change the lighting direction.
508
  """)
509
 
510
  with gr.Row():
511
  # --- Left Column: Controls ---
512
  with gr.Column(scale=5):
513
- # Input
514
- image_input = gr.Image(label="Input Image", type="pil", height=320)
 
 
515
 
516
- gr.Markdown("### 🎮 3D Controller")
517
  light_controller = LightControl3D(
518
  value={"azimuth": 0, "elevation": 0},
519
  elem_id="3d-container"
520
  )
521
 
522
- # Action
523
- run_btn = gr.Button("✨ Generate Lighting", variant="primary", size="lg")
524
 
525
- # Fine Tuning
526
- with gr.Accordion("🎚️ Fine-Tune & Advanced", open=False):
527
  with gr.Row():
528
  az_slider = gr.Slider(0, 359, value=0, label="Azimuth", step=1)
529
  el_slider = gr.Slider(-90, 90, value=0, label="Elevation", step=1)
530
-
531
  with gr.Row():
532
  seed = gr.Slider(0, MAX_SEED, value=42, label="Seed", step=1)
533
  randomize_seed = gr.Checkbox(True, label="Randomize")
534
-
535
  with gr.Row():
536
  cfg = gr.Slider(1.0, 10.0, value=5.0, label="Guidance (CFG)")
537
  steps = gr.Slider(1, 20, value=4, step=1, label="Steps")
538
-
539
- prompt_display = gr.Textbox(label="Actual Prompt sent to Model", interactive=False)
540
-
541
- # --- Right Column: Output ---
542
- with gr.Column(scale=4):
543
- result_output = gr.Image(label="Result", height=600)
544
 
545
- # --- wiring ---
546
 
547
- # 1. Sync 3D -> Sliders & Text
548
  def on_3d_change(val):
549
- az = val.get('azimuth', 0)
550
- el = val.get('elevation', 0)
551
- prompt = build_lighting_prompt(az, el)
552
- return az, el, prompt
553
-
554
- light_controller.change(
555
- on_3d_change,
556
- inputs=[light_controller],
557
- outputs=[az_slider, el_slider, prompt_display]
558
- )
559
 
560
- # 2. Sync Sliders -> 3D & Text
561
  def on_slider_change(az, el):
562
- prompt = build_lighting_prompt(az, el)
563
- return {"azimuth": az, "elevation": el}, prompt
564
 
565
  az_slider.change(on_slider_change, inputs=[az_slider, el_slider], outputs=[light_controller, prompt_display])
566
  el_slider.change(on_slider_change, inputs=[az_slider, el_slider], outputs=[light_controller, prompt_display])
567
 
568
- # 3. Handle Image Upload (Resize + Update 3D Texture)
569
  def on_upload(img):
570
  w, h = update_dimensions_on_upload(img)
571
- if img is None:
572
- return w, h, gr.update(imageUrl=None)
573
-
574
- # Convert to Base64 for Three.js
575
  import base64
576
  from io import BytesIO
577
  buffered = BytesIO()
578
  img.save(buffered, format="PNG")
579
  img_str = base64.b64encode(buffered.getvalue()).decode()
580
- data_url = f"data:image/png;base64,{img_str}"
581
- return w, h, gr.update(imageUrl=data_url)
582
-
583
- image_input.upload(
584
- on_upload,
585
- inputs=[image_input],
586
- outputs=[gr.State(), gr.State(), light_controller] # We store W/H in state mostly, or just pass to infer
587
- ).then(
588
- # Pass W/H to hidden sliders or just recalc in infer for simplicity
589
- None, None, None
590
- )
591
 
592
- # 4. Generate
 
 
593
  def run_inference_wrapper(img, az, el, seed, rand, cfg, steps):
594
- w, h = update_dimensions_on_upload(img) # Recalc dims here for safety
595
- res, used_seed, p = infer_lighting_edit(img, az, el, seed, rand, cfg, steps, h, w)
596
  return res
597
 
598
  run_btn.click(
@@ -602,6 +534,5 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="yellow")) as demo:
602
  )
603
 
604
  if __name__ == "__main__":
605
- # CDN Load Three.js
606
  head_js = '<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>'
607
  demo.launch(head=head_js)
 
12
  from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
13
  from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
14
  except ImportError:
 
15
  raise ImportError("Please ensure the 'qwenimage' package is installed.")
16
 
17
  MAX_SEED = np.iinfo(np.int32).max
 
20
  dtype = torch.bfloat16
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
 
23
  pipe = QwenImageEditPlusPipeline.from_pretrained(
24
  "Qwen/Qwen-Image-Edit-2511",
25
  transformer=QwenImageTransformer2DModel.from_pretrained(
 
30
  torch_dtype=dtype
31
  ).to(device)
32
 
 
33
  try:
34
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
35
  print("Flash Attention 3 Processor set successfully.")
36
  except Exception as e:
37
  print(f"Warning: Could not set FA3 processor: {e}")
38
 
 
39
  ADAPTER_SPECS = {
40
  "Multi-Angle-Lighting": {
41
  "repo": "dx8152/Qwen-Edit-2509-Multi-Angle-Lighting",
 
44
  }
45
  }
46
 
 
47
  CURRENT_LOADED_ADAPTER = None
48
 
49
+ # --- Logic: Mappings ---
 
 
 
50
  LIGHTING_AZIMUTH_MAP = {
51
  0: "Light source from the Front",
52
  45: "Light source from the Right Front",
 
59
  }
60
 
61
  def snap_to_nearest_key(value, keys):
 
62
  return min(keys, key=lambda x: abs(x - value))
63
 
64
  def build_lighting_prompt(azimuth: float, elevation: float) -> str:
65
+ if elevation >= 60: return "Light source from Above"
66
+ if elevation <= -60: return "Light source from Below"
 
 
 
 
 
 
 
 
 
 
 
67
  keys = list(LIGHTING_AZIMUTH_MAP.keys())
68
+ if azimuth > 337.5: azimuth = 0
 
 
 
 
69
  azimuth_snapped = snap_to_nearest_key(azimuth, keys)
70
  return LIGHTING_AZIMUTH_MAP[azimuth_snapped]
71
 
72
+ # --- Inference ---
 
73
  @spaces.GPU
74
+ def infer_lighting_edit(image, azimuth, elevation, seed, randomize_seed, guidance_scale, num_inference_steps, height, width):
 
 
 
 
 
 
 
 
 
 
75
  global CURRENT_LOADED_ADAPTER
 
 
76
  spec = ADAPTER_SPECS["Multi-Angle-Lighting"]
77
  if CURRENT_LOADED_ADAPTER != spec["adapter_name"]:
78
+ pipe.load_lora_weights(spec["repo"], weight_name=spec["weights"], adapter_name=spec["adapter_name"])
 
 
 
 
 
79
  pipe.set_adapters([spec["adapter_name"]], adapter_weights=[1.0])
80
  CURRENT_LOADED_ADAPTER = spec["adapter_name"]
81
 
 
82
  prompt = build_lighting_prompt(azimuth, elevation)
83
  print(f"💡 Generated Prompt: {prompt}")
84
 
85
+ if image is None: raise gr.Error("Please upload an image.")
86
+ if randomize_seed: seed = random.randint(0, MAX_SEED)
87
+
 
 
 
88
  generator = torch.Generator(device=device).manual_seed(seed)
 
89
  pil_image = image.convert("RGB")
90
 
 
91
  result = pipe(
92
+ image=[pil_image], prompt=prompt, height=height, width=width,
93
+ num_inference_steps=num_inference_steps, generator=generator,
94
+ guidance_scale=guidance_scale, num_images_per_prompt=1
 
 
 
 
 
95
  ).images[0]
 
96
  return result, seed, prompt
97
 
98
  def update_dimensions_on_upload(image):
99
+ if image is None: return 1024, 1024
 
 
100
  w, h = image.size
101
+ if w > h: new_w, new_h = 1024, int(1024 * (h / w))
102
+ else: new_h, new_w = 1024, int(1024 * (w / h))
103
+ return (new_w // 8) * 8, (new_h // 8) * 8
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ # --- 🌟 OPTIMIZED 3D CONTROLLER 🌟 ---
106
  class LightControl3D(gr.HTML):
107
  """
108
+ High-Fidelity 3D Studio Light Controller
 
109
  """
110
  def __init__(self, value=None, imageUrl=None, **kwargs):
111
  if value is None: value = {"azimuth": 0, "elevation": 0}
112
 
 
113
  html_template = """
114
+ <div id="light-control-wrapper" style="width: 100%; height: 500px; position: relative; background: #0f0f0f; border-radius: 12px; overflow: hidden; border: 1px solid #333; box-shadow: 0 4px 20px rgba(0,0,0,0.5);">
115
+ <div id="prompt-badge" style="position: absolute; top: 20px; left: 50%; transform: translateX(-50%);
116
+ background: rgba(10,10,10,0.85); border: 1px solid #FFD700; color: #FFD700;
117
+ padding: 10px 30px; border-radius: 8px; font-family: 'Segoe UI', monospace; font-weight: 600; font-size: 14px;
118
+ z-index: 10; pointer-events: none; backdrop-filter: blur(4px); box-shadow: 0 4px 12px rgba(0,0,0,0.5);
119
+ transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);">
120
  Light Source: Front
121
  </div>
122
 
123
+ <div style="position: absolute; bottom: 20px; left: 20px; color: #666; font-size: 11px; font-family: sans-serif; pointer-events: none; line-height: 1.5;">
124
+ <span style="color: #FFD700">●</span> Light Position<br>
125
+ <span style="color: #444">■</span> Shadow Caster
126
  </div>
127
  </div>
128
  """
129
 
 
130
  js_on_load = """
131
  (() => {
132
  const wrapper = element.querySelector('#light-control-wrapper');
 
135
  const initScene = () => {
136
  if (typeof THREE === 'undefined') { setTimeout(initScene, 100); return; }
137
 
138
+ // --- 1. Scene & Renderer (Enable Shadows) ---
139
  const scene = new THREE.Scene();
140
+ scene.background = new THREE.Color(0x0f0f0f);
141
+ scene.fog = new THREE.Fog(0x0f0f0f, 4, 12);
142
 
143
+ const camera = new THREE.PerspectiveCamera(45, wrapper.clientWidth / wrapper.clientHeight, 0.1, 100);
144
+ camera.position.set(5, 4, 5);
145
  camera.lookAt(0, 0.5, 0);
146
 
147
+ const renderer = new THREE.WebGLRenderer({ antialias: true });
148
  renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
149
  renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
150
+ renderer.shadowMap.enabled = true; // CRITICAL: Enable Shadows
151
+ renderer.shadowMap.type = THREE.PCFSoftShadowMap;
152
  wrapper.appendChild(renderer.domElement);
153
 
154
+ // --- 2. Environment (Studio Floor) ---
155
  const CENTER = new THREE.Vector3(0, 0.75, 0);
156
+ const RADIUS = 2.8;
157
+
158
+ // Floor (Shadow Catcher)
159
+ const planeGeo = new THREE.PlaneGeometry(20, 20);
160
+ const planeMat = new THREE.ShadowMaterial({ opacity: 0.5, color: 0x000000 });
161
+ const floor = new THREE.Mesh(planeGeo, planeMat);
162
+ floor.rotation.x = -Math.PI / 2;
163
+ floor.position.y = 0;
164
+ floor.receiveShadow = true;
165
+ scene.add(floor);
166
 
167
+ // Grid (Visual Reference)
168
+ const grid = new THREE.GridHelper(10, 20, 0x333333, 0x111111);
169
  scene.add(grid);
170
+
171
+ // --- 3. The Subject (Thick Slab) ---
172
+ let subjectGroup = new THREE.Group();
173
+ scene.add(subjectGroup);
174
 
175
+ // Base Stand
176
+ const standGeo = new THREE.CylinderGeometry(0.8, 0.8, 0.05, 32);
177
+ const standMat = new THREE.MeshStandardMaterial({ color: 0x222222, roughness: 0.2, metalness: 0.8 });
178
+ const stand = new THREE.Mesh(standGeo, standMat);
179
+ stand.position.y = 0.025;
180
+ stand.receiveShadow = true;
181
+ subjectGroup.add(stand);
182
+
183
+ // The Image Box
184
+ let imageMesh;
185
+ const imageMat = new THREE.MeshStandardMaterial({
186
+ color: 0xaaaaaa, roughness: 0.4, side: THREE.DoubleSide
187
+ });
 
 
 
 
188
 
189
+ function createSubject(width=1.5, height=1.5) {
190
+ if(imageMesh) {
191
+ subjectGroup.remove(imageMesh);
192
+ imageMesh.geometry.dispose();
193
+ }
194
+ // A box with slight thickness to cast better shadows
195
+ imageMesh = new THREE.Mesh(new THREE.BoxGeometry(width, height, 0.05), imageMat);
196
+ imageMesh.position.set(0, height/2 + 0.1, 0);
197
+ imageMesh.castShadow = true; // CAST SHADOW
198
+ imageMesh.receiveShadow = true;
199
+ subjectGroup.add(imageMesh);
200
  }
201
+ createSubject();
202
 
203
  // Texture Loader
204
  function updateTexture(url) {
205
+ if (!url) { imageMat.map = null; imageMat.needsUpdate = true; return; }
 
 
 
 
206
  new THREE.TextureLoader().load(url, (tex) => {
207
+ tex.encoding = THREE.sRGBEncoding;
208
+ imageMat.map = tex;
209
+ imageMat.color.setHex(0xffffff); // Reset tint
210
+ imageMat.needsUpdate = true;
211
+
212
  const img = tex.image;
213
  if(img && img.width && img.height) {
214
  const aspect = img.width / img.height;
215
+ const maxDim = 1.8;
216
+ if (aspect > 1) createSubject(maxDim, maxDim/aspect);
217
+ else createSubject(maxDim*aspect, maxDim);
218
  }
219
  });
220
  }
221
  if (props.imageUrl) updateTexture(props.imageUrl);
222
 
223
+ // --- 4. The Light System (Visual + Actual Light) ---
224
+ const lightPivot = new THREE.Group();
225
+ scene.add(lightPivot);
226
+
227
+ // Actual Light Source (for Shadows)
228
+ const dirLight = new THREE.DirectionalLight(0xffffff, 1.5);
229
+ dirLight.castShadow = true;
230
+ dirLight.shadow.mapSize.width = 1024;
231
+ dirLight.shadow.mapSize.height = 1024;
232
+ dirLight.shadow.camera.near = 0.5;
233
+ dirLight.shadow.camera.far = 10;
234
+ lightPivot.add(dirLight);
235
+
236
+ // Visual Representation (The Gizmo)
237
+ const gizmoGroup = new THREE.Group();
238
+ lightPivot.add(gizmoGroup);
239
+
240
+ // The "Sun" Mesh
241
+ const sunGeo = new THREE.SphereGeometry(0.15, 32, 32);
242
+ const sunMat = new THREE.MeshBasicMaterial({ color: 0xFFD700 });
243
+ const sun = new THREE.Mesh(sunGeo, sunMat);
244
+ gizmoGroup.add(sun);
245
 
246
+ // The "Glow" Sprite
247
+ const canvas = document.createElement('canvas');
248
+ canvas.width = 128; canvas.height = 128;
249
+ const context = canvas.getContext('2d');
250
+ const gradient = context.createRadialGradient(64, 64, 0, 64, 64, 64);
251
+ gradient.addColorStop(0, 'rgba(255, 215, 0, 1)');
252
+ gradient.addColorStop(0.4, 'rgba(255, 215, 0, 0.2)');
253
+ gradient.addColorStop(1, 'rgba(0, 0, 0, 0)');
254
+ context.fillStyle = gradient;
255
+ context.fillRect(0, 0, 128, 128);
256
+ const glowTex = new THREE.CanvasTexture(canvas);
257
+ const glowMat = new THREE.SpriteMaterial({ map: glowTex, transparent: true, blending: THREE.AdditiveBlending });
258
+ const glow = new THREE.Sprite(glowMat);
259
+ glow.scale.set(1.5, 1.5, 1);
260
+ gizmoGroup.add(glow);
261
+
262
+ // The Volumetric Cone (Beam)
263
+ const beamGeo = new THREE.ConeGeometry(0.35, RADIUS, 32, 1, true);
264
+ beamGeo.translate(0, -RADIUS/2, 0);
265
+ beamGeo.rotateX(-Math.PI / 2);
266
  const beamMat = new THREE.MeshBasicMaterial({
267
+ color: 0xFFD700, transparent: true, opacity: 0.08,
268
+ side: THREE.DoubleSide, depthWrite: false, blending: THREE.AdditiveBlending
 
 
 
 
269
  });
270
  const beam = new THREE.Mesh(beamGeo, beamMat);
271
+ beam.lookAt(new THREE.Vector3(0,0,0)); // Initialize pointing center
272
+ gizmoGroup.add(beam);
273
+
274
+ // Dome Visual Guide (Subtle)
275
+ const domeGeo = new THREE.IcosahedronGeometry(RADIUS, 1);
276
+ const domeWire = new THREE.WireframeGeometry(domeGeo);
277
+ const domeLine = new THREE.LineSegments(domeWire, new THREE.LineBasicMaterial({ color: 0x333333, transparent: true, opacity: 0.1 }));
278
+ scene.add(domeLine);
279
 
280
  // --- 5. State & Logic ---
281
+ // We keep two sets of angles: Target (where mouse is) and Current (where light is, for animation)
282
+ let targetAz = props.value?.azimuth || 0;
283
+ let targetEl = props.value?.elevation || 0;
284
+ let currentAz = targetAz;
285
+ let currentEl = targetEl;
286
 
287
  const AZ_MAP = {
288
  0: 'Front', 45: 'Right Front', 90: 'Right', 135: 'Right Rear',
 
292
  function getPrompt(a, e) {
293
  if (e >= 60) return "Light source from Above";
294
  if (e <= -60) return "Light source from Below";
295
+ let n = a % 360; if(n < 0) n += 360;
296
  const steps = [0,45,90,135,180,225,270,315];
297
+ const snapped = steps.reduce((p, c) => Math.abs(c-n) < Math.abs(p-n) ? c : p);
 
 
 
298
  return `Light source from the ${AZ_MAP[snapped]}`;
299
  }
300
 
301
+ // --- 6. Input & Raycasting ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  const raycaster = new THREE.Raycaster();
303
  const mouse = new THREE.Vector2();
304
  let isDragging = false;
305
 
306
+ // Invisible interaction sphere
307
+ const dragMesh = new THREE.Mesh(
308
+ new THREE.SphereGeometry(RADIUS, 32, 32),
309
+ new THREE.MeshBasicMaterial({ visible: false, side: THREE.BackSide })
310
  );
311
+ scene.add(dragMesh);
312
+
313
+ function updateFromIntersect(point) {
314
+ const rel = new THREE.Vector3().subVectors(point, CENTER);
315
+ let az = Math.atan2(rel.x, rel.z) * (180 / Math.PI);
316
+ if (az < 0) az += 360;
317
+ const distXZ = Math.sqrt(rel.x*rel.x + rel.z*rel.z);
318
+ let el = Math.atan2(rel.y, distXZ) * (180 / Math.PI);
319
+ el = Math.max(-89, Math.min(89, el)); // Clamp
320
+
321
+ targetAz = az;
322
+ targetEl = el;
323
  }
324
 
325
  function onDown(e) {
326
+ const rect = wrapper.getBoundingClientRect();
327
+ const cx = e.clientX || (e.touches ? e.touches[0].clientX : 0);
328
+ const cy = e.clientY || (e.touches ? e.touches[0].clientY : 0);
329
+ mouse.x = ((cx - rect.left) / rect.width) * 2 - 1;
330
+ mouse.y = -((cy - rect.top) / rect.height) * 2 + 1;
331
+
332
+ raycaster.setFromCamera(mouse, camera);
333
+ const intersects = raycaster.intersectObject(dragMesh);
334
  if(intersects.length > 0) {
335
+ isDragging = true;
336
+ wrapper.style.cursor = 'none'; // Immersion
337
+ updateFromIntersect(intersects[0].point);
 
 
338
  }
339
  }
340
 
341
  function onMove(e) {
342
  if (!isDragging) {
343
+ // Hover cursor logic
344
+ const rect = wrapper.getBoundingClientRect();
345
+ mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
346
+ mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
347
+ raycaster.setFromCamera(mouse, camera);
348
+ const hits = raycaster.intersectObject(dragMesh);
349
+ wrapper.style.cursor = (hits.length > 0) ? 'grab' : 'default';
 
 
350
  return;
351
  }
 
 
 
 
352
 
353
+ const rect = wrapper.getBoundingClientRect();
354
+ const cx = e.clientX || (e.touches ? e.touches[0].clientX : 0);
355
+ const cy = e.clientY || (e.touches ? e.touches[0].clientY : 0);
356
+ mouse.x = ((cx - rect.left) / rect.width) * 2 - 1;
357
+ mouse.y = -((cy - rect.top) / rect.height) * 2 + 1;
358
+
359
+ raycaster.setFromCamera(mouse, camera);
360
+ const intersects = raycaster.intersectObject(dragMesh);
361
  if (intersects.length > 0) {
362
+ updateFromIntersect(intersects[0].point);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  }
364
  }
365
 
 
367
  if(isDragging) {
368
  isDragging = false;
369
  wrapper.style.cursor = 'default';
370
+ props.value = { azimuth: targetAz, elevation: targetEl };
 
371
  trigger('change', props.value);
372
  }
373
  }
374
 
 
375
  wrapper.addEventListener('mousedown', onDown);
376
  window.addEventListener('mousemove', onMove);
377
  window.addEventListener('mouseup', onUp);
 
379
  window.addEventListener('touchmove', onMove, {passive: false});
380
  window.addEventListener('touchend', onUp);
381
 
382
+ // --- 7. Animation Loop (Visuals) ---
 
 
383
  function animate() {
384
  requestAnimationFrame(animate);
385
+
386
+ // LERPING: Smoothly move current light pos to target
387
+ // This creates the "swish" feeling
388
+ const lerpSpeed = 0.15;
389
+
390
+ // Handle Azimuth wraparound (359 -> 0)
391
+ let dAz = targetAz - currentAz;
392
+ if (dAz > 180) dAz -= 360;
393
+ if (dAz < -180) dAz += 360;
394
+ currentAz += dAz * lerpSpeed;
395
+
396
+ currentEl += (targetEl - currentEl) * lerpSpeed;
397
+
398
+ // Calculate Cartesian
399
+ const r_az = THREE.MathUtils.degToRad(currentAz);
400
+ const r_el = THREE.MathUtils.degToRad(currentEl);
401
+ const x = RADIUS * Math.sin(r_az) * Math.cos(r_el);
402
+ const y = RADIUS * Math.sin(r_el) + CENTER.y;
403
+ const z = RADIUS * Math.cos(r_az) * Math.cos(r_el);
404
+
405
+ // Update Light Position
406
+ gizmoGroup.position.set(x, y, z);
407
+ gizmoGroup.lookAt(CENTER);
408
+
409
+ dirLight.position.set(x, y, z); // Update actual shadow-casting light
410
+
411
+ // Update Text
412
+ badge.textContent = getPrompt(currentAz, currentEl);
413
+
414
+ // Dynamic Coloring
415
+ let colorHex = 0xFFD700; // Gold
416
+ if (currentEl > 55 || currentEl < -55) colorHex = 0xFF4500; // Warning Orange
417
+
418
+ const col = new THREE.Color(colorHex);
419
+ sunMat.color.lerp(col, 0.1);
420
+ beamMat.color.lerp(col, 0.1);
421
+ glowMat.color.lerp(col, 0.1);
422
+ badge.style.borderColor = '#' + col.getHexString();
423
+ badge.style.color = '#' + col.getHexString();
424
+
425
  renderer.render(scene, camera);
426
  }
427
  animate();
428
 
429
+ // External Updates
430
  setInterval(() => {
 
 
 
 
 
431
  if (props.value && !isDragging) {
432
+ // If controls updated externally (sliders), snap target
433
+ if (Math.abs(props.value.azimuth - targetAz) > 0.5 || Math.abs(props.value.elevation - targetEl) > 0.5) {
434
+ targetAz = props.value.azimuth;
435
+ targetEl = props.value.elevation;
436
  }
437
  }
438
  }, 100);
439
 
 
440
  wrapper._updateTexture = updateTexture;
441
  };
442
  initScene();
 
452
  )
453
 
454
  # --- UI Layout ---
 
455
  css = """
456
+ #col-container { max-width: 1400px; margin: 0 auto; }
457
+ #3d-container { border-radius: 12px; overflow: hidden; }
458
  .range-slider { accent-color: #FFD700 !important; }
459
  """
460
 
461
  with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="yellow")) as demo:
462
  gr.Markdown("""
463
+ # 💡 Qwen Edit 2509 — Studio Light Controller
464
+ **Interactive Relighting:** Drag the light in the 3D studio to cast realistic shadows and generate lighting prompts.
 
465
  """)
466
 
467
  with gr.Row():
468
  # --- Left Column: Controls ---
469
  with gr.Column(scale=5):
470
+ gr.Markdown("### 📸 Input & Control")
471
+ with gr.Row():
472
+ image_input = gr.Image(label="Input", type="pil", height=280)
473
+ result_output = gr.Image(label="Result", height=280)
474
 
475
+ gr.Markdown("### 🎮 Studio Controller")
476
  light_controller = LightControl3D(
477
  value={"azimuth": 0, "elevation": 0},
478
  elem_id="3d-container"
479
  )
480
 
481
+ run_btn = gr.Button("✨ Relight Image", variant="primary", size="lg")
 
482
 
483
+ with gr.Accordion("🎚️ Fine-Tune Parameters", open=False):
 
484
  with gr.Row():
485
  az_slider = gr.Slider(0, 359, value=0, label="Azimuth", step=1)
486
  el_slider = gr.Slider(-90, 90, value=0, label="Elevation", step=1)
 
487
  with gr.Row():
488
  seed = gr.Slider(0, MAX_SEED, value=42, label="Seed", step=1)
489
  randomize_seed = gr.Checkbox(True, label="Randomize")
 
490
  with gr.Row():
491
  cfg = gr.Slider(1.0, 10.0, value=5.0, label="Guidance (CFG)")
492
  steps = gr.Slider(1, 20, value=4, step=1, label="Steps")
493
+ prompt_display = gr.Textbox(label="Active Prompt", interactive=False)
 
 
 
 
 
494
 
495
+ # --- Event Wiring ---
496
 
497
+ # 3D -> Sliders
498
  def on_3d_change(val):
499
+ az, el = val.get('azimuth', 0), val.get('elevation', 0)
500
+ return az, el, build_lighting_prompt(az, el)
501
+
502
+ light_controller.change(on_3d_change, inputs=[light_controller], outputs=[az_slider, el_slider, prompt_display])
 
 
 
 
 
 
503
 
504
+ # Sliders -> 3D
505
  def on_slider_change(az, el):
506
+ return {"azimuth": az, "elevation": el}, build_lighting_prompt(az, el)
 
507
 
508
  az_slider.change(on_slider_change, inputs=[az_slider, el_slider], outputs=[light_controller, prompt_display])
509
  el_slider.change(on_slider_change, inputs=[az_slider, el_slider], outputs=[light_controller, prompt_display])
510
 
511
+ # Upload
512
  def on_upload(img):
513
  w, h = update_dimensions_on_upload(img)
514
+ if img is None: return w, h, gr.update(imageUrl=None)
 
 
 
515
  import base64
516
  from io import BytesIO
517
  buffered = BytesIO()
518
  img.save(buffered, format="PNG")
519
  img_str = base64.b64encode(buffered.getvalue()).decode()
520
+ return w, h, gr.update(imageUrl=f"data:image/png;base64,{img_str}")
 
 
 
 
 
 
 
 
 
 
521
 
522
+ image_input.upload(on_upload, inputs=[image_input], outputs=[gr.State(), gr.State(), light_controller])
523
+
524
+ # Run
525
  def run_inference_wrapper(img, az, el, seed, rand, cfg, steps):
526
+ w, h = update_dimensions_on_upload(img)
527
+ res, _, _ = infer_lighting_edit(img, az, el, seed, rand, cfg, steps, h, w)
528
  return res
529
 
530
  run_btn.click(
 
534
  )
535
 
536
  if __name__ == "__main__":
 
537
  head_js = '<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>'
538
  demo.launch(head=head_js)