ysharma HF Staff commited on
Commit
090805a
·
verified ·
1 Parent(s): f0d6a54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -264
app.py CHANGED
@@ -1,395 +1,350 @@
1
- """
2
- Camera Control App with Working Arrow Interface
3
- Complete version with Qwen model integration
4
- """
5
-
6
  import gradio as gr
7
- import torch
8
  import numpy as np
9
  import random
10
- from PIL import Image
11
- import spaces
12
- from diffusers import DiffusionPipeline
13
  import base64
14
  from io import BytesIO
 
 
 
 
15
 
16
- # Model configuration
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
- dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
19
 
20
- MAX_SEED = np.iinfo(np.int32).max
21
  pipe = None
22
 
23
  def load_model():
24
- """Load the Qwen diffusion model with camera control LoRAs."""
25
  global pipe
26
  if pipe is None:
27
- pipe = DiffusionPipeline.from_pretrained(
28
- "multimodalart/qwen-image-multiple-angles-3d-camera",
29
- torch_dtype=dtype,
30
  ).to(device)
31
- pipe.load_lora_weights("multimodalart/qwen-image-multiple-angles-3d-camera", weight_name="lightning.safetensors")
32
- pipe.fuse_lora(lora_scale=1.0)
33
- pipe.load_lora_weights("multimodalart/qwen-image-multiple-angles-3d-camera", weight_name="multi_angles.safetensors", adapter_name="multi_angles")
34
- pipe.set_adapters(["default", "multi_angles"], adapter_weights=[1.0, 1.0])
 
 
 
 
 
 
 
 
 
 
 
 
35
  return pipe
36
 
37
- # Camera parameter mappings
38
- azimuth_mapping = {
39
- 0: "front view",
40
  45: "front-right quarter view",
41
  90: "right side view",
42
- 135: "back-right quarter view",
43
- 180: "back view",
44
  225: "back-left quarter view",
45
  270: "left side view",
46
  315: "front-left quarter view"
47
  }
48
 
49
- elevation_mapping = {
50
- -30: "low-angle shot",
51
  0: "eye-level shot",
52
- 30: "elevated shot",
53
  60: "high-angle shot"
54
  }
55
 
56
- distance_mapping = {
57
- 0.6: "close-up",
58
- 1.0: "medium shot",
59
  1.8: "wide shot"
60
  }
61
 
62
  def snap_to_nearest(value, steps):
63
- """Snap a value to the nearest step in a list."""
64
  return min(steps, key=lambda x: abs(x - value))
65
 
66
  def build_camera_prompt(azimuth, elevation, distance):
67
- """Build camera prompt from numerical parameters."""
68
  azimuth_steps = [0, 45, 90, 135, 180, 225, 270, 315]
69
  elevation_steps = [-30, 0, 30, 60]
70
  distance_steps = [0.6, 1.0, 1.8]
71
 
72
- azimuth_snapped = snap_to_nearest(azimuth, azimuth_steps)
73
- elevation_snapped = snap_to_nearest(elevation, elevation_steps)
74
- distance_snapped = snap_to_nearest(distance, distance_steps)
75
 
76
- azimuth_name = azimuth_mapping[azimuth_snapped]
77
- elevation_name = elevation_mapping[elevation_snapped]
78
- distance_name = distance_mapping[distance_snapped]
79
 
80
- return f"<sks> {azimuth_name} {elevation_name} {distance_name}"
81
-
82
- @spaces.GPU(duration=5)
83
- def infer_camera_edit(
84
- image: Image.Image,
85
- azimuth: float = 0.0,
86
- elevation: float = 0.0,
87
- distance: float = 1.0,
88
- seed: int = 0,
89
- randomize_seed: bool = True,
90
- guidance_scale: float = 1.0,
91
- num_inference_steps: int = 4,
92
- height: int = 1024,
93
- width: int = 1024,
94
- ):
95
- """Generate new camera view using Qwen model."""
96
- prompt = build_camera_prompt(azimuth, elevation, distance)
97
- print(f"Generated Prompt: {prompt}")
98
 
 
 
99
  if randomize_seed:
100
  seed = random.randint(0, MAX_SEED)
 
101
  generator = torch.Generator(device=device).manual_seed(seed)
102
-
103
- if image is None:
104
- raise gr.Error("Please upload an image first.")
105
-
106
- pil_image = image.convert("RGB") if isinstance(image, Image.Image) else Image.open(image).convert("RGB")
107
-
108
- # Load model only when needed
109
- current_pipe = load_model()
110
-
111
- result = current_pipe(
112
- image=[pil_image],
113
  prompt=prompt,
114
- height=height if height != 0 else None,
115
- width=width if width != 0 else None,
116
- num_inference_steps=num_inference_steps,
117
- generator=generator,
118
  guidance_scale=guidance_scale,
119
- num_images_per_prompt=1,
 
120
  ).images[0]
121
-
122
  return result, seed, prompt
123
 
124
  def create_camera_control_app():
125
- """Create the complete working camera control app."""
126
 
127
- with gr.Blocks(title="Camera Control with Directional Arrows", theme=gr.themes.Soft()) as demo:
128
  gr.Markdown("# 📸 Camera Control with Directional Arrows")
129
- gr.Markdown("Upload an image and use the directional arrows to control camera angles")
130
 
131
  with gr.Row():
132
- # Left column: Input image and settings
133
  with gr.Column(scale=1):
134
  image = gr.Image(label="Upload Image", type="pil", height=400)
135
 
136
- # Camera parameter inputs (visible for debugging, can be hidden later)
137
  js_azimuth = gr.Textbox("0", visible=True, elem_id="js-azimuth", label="Azimuth")
138
  js_elevation = gr.Textbox("0", visible=True, elem_id="js-elevation", label="Elevation")
139
  js_distance = gr.Textbox("1.0", visible=True, elem_id="js-distance", label="Distance")
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  prompt_display = gr.Textbox(
142
- label="Current Camera Prompt",
143
  value="<sks> front view eye-level shot medium shot",
144
  interactive=False
145
  )
146
 
147
- # Advanced settings
148
- with gr.Accordion("⚙️ Advanced Settings", open=False):
149
- seed = gr.Slider(
150
- label="Seed",
151
- minimum=0,
152
- maximum=MAX_SEED,
153
- step=1,
154
- value=0,
155
- )
156
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
157
-
158
- with gr.Row():
159
- guidance_scale = gr.Slider(
160
- label="Guidance scale",
161
- minimum=0.1,
162
- maximum=2.0,
163
- step=0.1,
164
- value=1.0,
165
- )
166
- num_inference_steps = gr.Slider(
167
- label="Number of inference steps",
168
- minimum=1,
169
- maximum=8,
170
- step=1,
171
- value=4,
172
- )
173
-
174
- with gr.Row():
175
- height = gr.Slider(
176
- label="Height",
177
- minimum=256,
178
- maximum=1024,
179
- step=32,
180
- value=1024,
181
- )
182
- width = gr.Slider(
183
- label="Width",
184
- minimum=256,
185
- maximum=1024,
186
- step=32,
187
- value=1024,
188
- )
189
-
190
- # Right column: Interactive image view
191
  with gr.Column(scale=1):
192
  gr.Markdown("### 🎯 Interactive Image View")
193
- gr.Markdown("*Upload an image, then hover to see camera controls and click arrows to generate new views*")
194
 
195
  # Interactive HTML component using working pattern
196
  result_display = gr.HTML(
197
- value=\"\"\"
198
- <div style=\"width: 100%; height: 500px; background: #f8f8f8; border: 2px solid #e0e0e0; border-radius: 12px;
199
- position: relative; display: flex; align-items: center; justify-content: center;\">
200
- <div style=\"text-align: center; color: #999;\">
201
- <div style=\"font-size: 48px; margin-bottom: 10px;\">📸</div>
202
  <p>Upload an image on the left to begin</p>
203
  <p>Then hover here to see camera controls</p>
204
  </div>
205
  </div>
206
- \"\"\",
207
  elem_id="result-display"
208
  )
209
 
210
- # ===== FUNCTIONS INSIDE BLOCKS CONTEXT =====
211
-
212
- def update_dimensions_on_upload(input_image):
213
- \"\"\"Compute recommended dimensions preserving aspect ratio.\"\"\"
214
- if input_image is None:
215
- return 1024, 1024
216
-
217
- original_width, original_height = input_image.size
218
- aspect_ratio = original_width / original_height
219
-
220
- if aspect_ratio > 1:
221
- # Landscape
222
- new_width = 1024
223
- new_height = round(1024 / aspect_ratio / 32) * 32
224
- else:
225
- # Portrait or square
226
- new_height = 1024
227
- new_width = round(1024 * aspect_ratio / 32) * 32
228
-
229
- # Ensure minimum size
230
- new_width = max(256, min(1024, new_width))
231
- new_height = max(256, min(1024, new_height))
232
-
233
- return new_width, new_height
234
 
 
235
  def show_uploaded_image_with_arrows(uploaded_image):
236
- \"\"\"Show uploaded image with working arrow controls.\"\"\"
237
  if uploaded_image is None:
238
- return gr.update(value=\"\"\"
239
- <div style=\"width: 100%; height: 500px; background: #f8f8f8; border: 2px solid #e0e0e0; border-radius: 12px;
240
- position: relative; display: flex; align-items: center; justify-content: center;\">
241
- <div style=\"text-align: center; color: #999;\">
242
- <div style=\"font-size: 48px; margin-bottom: 10px;\">📸</div>
243
  <p>Upload an image on the left to begin</p>
244
  <p>Then hover here to see camera controls</p>
245
  </div>
246
  </div>
247
- \"\"\")
248
 
249
  # Convert to data URL
250
  buffered = BytesIO()
251
- uploaded_image.save(buffered, format=\"PNG\")
252
  img_str = base64.b64encode(buffered.getvalue()).decode()
253
- data_url = f\"data:image/png;base64,{img_str}\"
254
 
255
- return gr.update(value=f\"\"\"
256
- <div style=\"width: 100%; height: 500px; background: #f8f8f8; border: 2px solid #e0e0e0; border-radius: 12px;
257
- position: relative; display: flex; align-items: center; justify-content: center;\"
258
- onmouseenter=\"this.querySelector('#arrow-controls').style.opacity='1'\"
259
- onmouseleave=\"this.querySelector('#arrow-controls').style.opacity='0'\">
260
 
261
  <!-- Image -->
262
- <img src=\"{data_url}\" style=\"max-width: 100%; max-height: 100%; object-fit: contain;\">
263
 
264
  <!-- Arrow controls -->
265
- <div id=\"arrow-controls\" style=\"position: absolute; inset: 0; opacity: 0; transition: opacity 0.3s ease; z-index: 10;\">
266
 
267
- <!-- Left Arrow (Azimuth -45°) -->
268
- <button onclick=\"
269
  var azInput = document.getElementById('js-azimuth').querySelector('input');
270
  var newAz = (parseInt(azInput.value) - 45 + 360) % 360;
271
  azInput.value = newAz;
272
  azInput.dispatchEvent(new Event('input', {{bubbles: true}}));
273
  document.getElementById('status-az').textContent = newAz;
274
- \"
275
- style=\"position: absolute; left: 20px; top: 50%; transform: translateY(-50%);
276
- width: 60px; height: 60px; background: rgba(0,255,136,0.95); border: none;
277
  border-radius: 50%; color: white; font-size: 24px; cursor: pointer;
278
- box-shadow: 0 6px 20px rgba(0,0,0,0.4); transition: transform 0.2s;\"
279
- onmouseover=\"this.style.transform += ' scale(1.1)'\"
280
- onmouseout=\"this.style.transform = this.style.transform.replace(' scale(1.1)', '')\"
281
- title=\"Rotate Left (Azimuth -45°)\">
282
 
283
  </button>
284
 
285
- <!-- Right Arrow (Azimuth +45°) -->
286
- <button onclick=\"
287
  var azInput = document.getElementById('js-azimuth').querySelector('input');
288
  var newAz = (parseInt(azInput.value) + 45) % 360;
289
  azInput.value = newAz;
290
  azInput.dispatchEvent(new Event('input', {{bubbles: true}}));
291
  document.getElementById('status-az').textContent = newAz;
292
- \"
293
- style=\"position: absolute; right: 20px; top: 50%; transform: translateY(-50%);
294
- width: 60px; height: 60px; background: rgba(0,255,136,0.95); border: none;
295
  border-radius: 50%; color: white; font-size: 24px; cursor: pointer;
296
- box-shadow: 0 6px 20px rgba(0,0,0,0.4); transition: transform 0.2s;\"
297
- onmouseover=\"this.style.transform += ' scale(1.1)'\"
298
- onmouseout=\"this.style.transform = this.style.transform.replace(' scale(1.1)', '')\"
299
- title=\"Rotate Right (Azimuth +45°)\">
300
 
301
  </button>
302
 
303
- <!-- Up Arrow (Elevation +30°) -->
304
- <button onclick=\"
305
  var elInput = document.getElementById('js-elevation').querySelector('input');
306
  var newEl = Math.min(60, parseInt(elInput.value) + 30);
307
  elInput.value = newEl;
308
  elInput.dispatchEvent(new Event('input', {{bubbles: true}}));
309
  document.getElementById('status-el').textContent = newEl;
310
- \"
311
- style=\"position: absolute; top: 20px; left: 50%; transform: translateX(-50%);
312
- width: 60px; height: 60px; background: rgba(255,105,180,0.95); border: none;
313
  border-radius: 50%; color: white; font-size: 24px; cursor: pointer;
314
- box-shadow: 0 6px 20px rgba(0,0,0,0.4); transition: transform 0.2s;\"
315
- onmouseover=\"this.style.transform += ' scale(1.1)'\"
316
- onmouseout=\"this.style.transform = this.style.transform.replace(' scale(1.1)', '')\"
317
- title=\"Look Up (Elevation +30°)\">
318
 
319
  </button>
320
 
321
- <!-- Down Arrow (Elevation -30°) -->
322
- <button onclick=\"
323
  var elInput = document.getElementById('js-elevation').querySelector('input');
324
  var newEl = Math.max(-30, parseInt(elInput.value) - 30);
325
  elInput.value = newEl;
326
  elInput.dispatchEvent(new Event('input', {{bubbles: true}}));
327
  document.getElementById('status-el').textContent = newEl;
328
- \"
329
- style=\"position: absolute; bottom: 80px; left: 50%; transform: translateX(-50%);
330
- width: 60px; height: 60px; background: rgba(255,105,180,0.95); border: none;
331
  border-radius: 50%; color: white; font-size: 24px; cursor: pointer;
332
- box-shadow: 0 6px 20px rgba(0,0,0,0.4); transition: transform 0.2s;\"
333
- onmouseover=\"this.style.transform += ' scale(1.1)'\"
334
- onmouseout=\"this.style.transform = this.style.transform.replace(' scale(1.1)', '')\"
335
- title=\"Look Down (Elevation -30°)\">
336
 
337
  </button>
338
 
339
  <!-- Zoom Controls -->
340
- <div style=\"position: absolute; bottom: 20px; left: 50%; transform: translateX(-50%);
341
- display: flex; gap: 15px;\">
342
 
343
- <!-- Zoom Out -->
344
- <button onclick=\"
345
  var distInput = document.getElementById('js-distance').querySelector('input');
346
  var newDist = Math.min(1.8, parseFloat(distInput.value) + 0.4);
347
  distInput.value = newDist.toFixed(1);
348
  distInput.dispatchEvent(new Event('input', {{bubbles: true}}));
349
  document.getElementById('status-dist').textContent = newDist.toFixed(1);
350
- \"
351
- style=\"width: 55px; height: 55px; background: rgba(255,165,0,0.95); border: none;
352
- border-radius: 50%; color: white; font-size: 28px; cursor: pointer;
353
- box-shadow: 0 6px 20px rgba(0,0,0,0.4); transition: transform 0.2s;\"
354
- onmouseover=\"this.style.transform = 'scale(1.1)'\"
355
- onmouseout=\"this.style.transform = ''\"
356
- title=\"Zoom Out (Distance +0.4)\">
357
 
358
  </button>
359
 
360
- <!-- Zoom In -->
361
- <button onclick=\"
362
  var distInput = document.getElementById('js-distance').querySelector('input');
363
  var newDist = Math.max(0.6, parseFloat(distInput.value) - 0.4);
364
  distInput.value = newDist.toFixed(1);
365
  distInput.dispatchEvent(new Event('input', {{bubbles: true}}));
366
  document.getElementById('status-dist').textContent = newDist.toFixed(1);
367
- \"
368
- style=\"width: 55px; height: 55px; background: rgba(255,165,0,0.95); border: none;
369
  border-radius: 50%; color: white; font-size: 24px; cursor: pointer;
370
- box-shadow: 0 6px 20px rgba(0,0,0,0.4); transition: transform 0.2s;\"
371
- onmouseover=\"this.style.transform = 'scale(1.1)'\"
372
- onmouseout=\"this.style.transform = ''\"
373
- title=\"Zoom In (Distance -0.4)\">
374
  +
375
  </button>
376
  </div>
377
 
378
  <!-- Status Display -->
379
- <div style=\"position: absolute; top: 15px; right: 15px; background: rgba(0,0,0,0.9);
380
- color: white; padding: 12px 16px; border-radius: 10px; font-family: monospace;
381
- font-size: 14px; box-shadow: 0 6px 20px rgba(0,0,0,0.4); min-width: 200px;\">
382
- <div style=\"margin-bottom: 4px;\">Az: <span id=\"status-az\">0</span>° | El: <span id=\"status-el\">0</span>° | Dist: <span id=\"status-dist\">1.0</span></div>
383
- <div id=\"status-prompt\" style=\"color: #00ff88; font-size: 12px; line-height: 1.3;\">
384
- <sks> front view eye-level shot medium shot
385
- </div>
386
  </div>
387
  </div>
388
  </div>
389
- \"\"\")
390
 
391
  def handle_parameter_change(az, el, dist, input_image, seed_val, randomize_seed_val, guidance_val, steps_val, h_val, w_val):
392
- \"\"\"Handle camera parameter changes and generate new view.\"\"\"
393
  try:
394
  azimuth = float(az)
395
  elevation = float(el)
@@ -403,62 +358,48 @@ def create_camera_control_app():
403
  generated_image, final_seed, final_prompt = infer_camera_edit(
404
  image=input_image,
405
  azimuth=azimuth,
406
- elevation=elevation,
407
  distance=distance,
408
  seed=seed_val,
409
  randomize_seed=randomize_seed_val,
410
  guidance_scale=guidance_val,
411
  num_inference_steps=steps_val,
412
- height=h_val,
413
- width=w_val
414
  )
415
 
416
- # Show generated image with arrows
417
- html_result = show_uploaded_image_with_arrows(generated_image)
418
- return html_result.value, final_seed, final_prompt
419
 
420
- return gr.update(), seed_val, prompt
421
 
422
  except Exception as e:
423
- print(f\"Generation error: {e}\")
424
- import traceback
425
- traceback.print_exc()
426
- raise gr.Error(f\"Generation failed: {str(e)}\")
427
 
428
- # ===== EVENT HANDLERS INSIDE BLOCKS CONTEXT =====
429
-
430
- # Auto-update dimensions when image is uploaded
431
  image.upload(
432
  fn=update_dimensions_on_upload,
433
  inputs=[image],
434
  outputs=[width, height]
435
  )
436
 
437
- # Show uploaded image immediately
438
  image.upload(
439
  fn=show_uploaded_image_with_arrows,
440
  inputs=[image],
441
  outputs=[result_display]
442
  )
443
 
444
- # Auto-generation handler triggered by input changes
445
- def auto_generate_on_change(js_az, js_el, js_dist, input_image, seed_val, randomize_seed_val, guidance_val, steps_val, h_val, w_val):
446
- \"\"\"Auto-generate when camera parameters change from arrow clicks.\"\"\"
447
- if input_image is None:
448
- return gr.update(), seed_val, \"<sks> front view eye-level shot medium shot\"
449
-
450
- return handle_parameter_change(js_az, js_el, js_dist, input_image, seed_val, randomize_seed_val, guidance_val, steps_val, h_val, w_val)
451
-
452
- # Set up auto-generation on parameter changes
453
- for input_component in [js_azimuth, js_elevation, js_distance]:
454
- input_component.change(
455
- fn=auto_generate_on_change,
456
  inputs=[js_azimuth, js_elevation, js_distance, image, seed, randomize_seed, guidance_scale, num_inference_steps, height, width],
457
- outputs=[result_display, seed, prompt_display]
458
  )
459
 
460
  return demo
461
 
462
- if __name__ == \"__main__\":
463
  demo = create_camera_control_app()
464
  demo.launch()
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import numpy as np
3
  import random
4
+ import torch
 
 
5
  import base64
6
  from io import BytesIO
7
+ from PIL import Image
8
+ from diffusers import FlowMatchEulerDiscreteScheduler, QwenImageEditPlusPipeline
9
+
10
+ MAX_SEED = np.iinfo(np.int32).max
11
 
12
+ # --- Model Loading ---
13
+ dtype = torch.bfloat16
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
15
 
16
+ # Initialize pipe as None - will be loaded when needed
17
  pipe = None
18
 
19
  def load_model():
20
+ """Load the model only when needed to avoid initialization errors."""
21
  global pipe
22
  if pipe is None:
23
+ pipe = QwenImageEditPlusPipeline.from_pretrained(
24
+ "Qwen/Qwen-Image-Edit-2511",
25
+ torch_dtype=dtype
26
  ).to(device)
27
+
28
+ # Load the lightning LoRA for fast inference
29
+ pipe.load_lora_weights(
30
+ "lightx2v/Qwen-Image-Edit-2511-Lightning",
31
+ weight_name="Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors",
32
+ adapter_name="lightning"
33
+ )
34
+
35
+ # Load the multi-angles LoRA
36
+ pipe.load_lora_weights(
37
+ "fal/Qwen-Image-Edit-2511-Multiple-Angles-LoRA",
38
+ weight_name="qwen-image-edit-2511-multiple-angles-lora.safetensors",
39
+ adapter_name="angles"
40
+ )
41
+
42
+ pipe.set_adapters(["lightning", "angles"], adapter_weights=[1.0, 1.0])
43
  return pipe
44
 
45
+ # --- Camera Parameter Mappings ---
46
+ AZIMUTH_MAP = {
47
+ 0: "front view",
48
  45: "front-right quarter view",
49
  90: "right side view",
50
+ 135: "back-right quarter view",
51
+ 180: "back view",
52
  225: "back-left quarter view",
53
  270: "left side view",
54
  315: "front-left quarter view"
55
  }
56
 
57
+ ELEVATION_MAP = {
58
+ -30: "low-angle shot",
59
  0: "eye-level shot",
60
+ 30: "elevated shot",
61
  60: "high-angle shot"
62
  }
63
 
64
+ DISTANCE_MAP = {
65
+ 0.6: "close-up",
66
+ 1.0: "medium shot",
67
  1.8: "wide shot"
68
  }
69
 
70
  def snap_to_nearest(value, steps):
71
+ """Snap value to nearest step."""
72
  return min(steps, key=lambda x: abs(x - value))
73
 
74
  def build_camera_prompt(azimuth, elevation, distance):
75
+ """Build camera prompt from parameters."""
76
  azimuth_steps = [0, 45, 90, 135, 180, 225, 270, 315]
77
  elevation_steps = [-30, 0, 30, 60]
78
  distance_steps = [0.6, 1.0, 1.8]
79
 
80
+ az_snap = snap_to_nearest(azimuth, azimuth_steps)
81
+ el_snap = snap_to_nearest(elevation, elevation_steps)
82
+ dist_snap = snap_to_nearest(distance, distance_steps)
83
 
84
+ az_name = AZIMUTH_MAP[az_snap]
85
+ el_name = ELEVATION_MAP[el_snap]
86
+ dist_name = DISTANCE_MAP[dist_snap]
87
 
88
+ return f"<sks> {az_name} {el_name} {dist_name}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ def infer_camera_edit(image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, num_inference_steps, height, width):
91
+ """Generate new camera view using the Qwen model."""
92
  if randomize_seed:
93
  seed = random.randint(0, MAX_SEED)
94
+
95
  generator = torch.Generator(device=device).manual_seed(seed)
96
+
97
+ # Build the camera prompt
98
+ prompt = build_camera_prompt(azimuth, elevation, distance)
99
+
100
+ # Load model if not already loaded
101
+ model = load_model()
102
+
103
+ # Generate the new view
104
+ result = model(
105
+ image=image,
 
106
  prompt=prompt,
107
+ height=height,
108
+ width=width,
 
 
109
  guidance_scale=guidance_scale,
110
+ num_inference_steps=num_inference_steps,
111
+ generator=generator
112
  ).images[0]
113
+
114
  return result, seed, prompt
115
 
116
  def create_camera_control_app():
117
+ """Create the complete camera control app."""
118
 
119
+ with gr.Blocks(title="Camera Control with Directional Arrows") as demo:
120
  gr.Markdown("# 📸 Camera Control with Directional Arrows")
121
+ gr.Markdown("Upload an image and use arrows to control camera angles for 3D view generation")
122
 
123
  with gr.Row():
124
+ # Left column: Image upload and controls
125
  with gr.Column(scale=1):
126
  image = gr.Image(label="Upload Image", type="pil", height=400)
127
 
128
+ # Camera parameter inputs (visible for debugging)
129
  js_azimuth = gr.Textbox("0", visible=True, elem_id="js-azimuth", label="Azimuth")
130
  js_elevation = gr.Textbox("0", visible=True, elem_id="js-elevation", label="Elevation")
131
  js_distance = gr.Textbox("1.0", visible=True, elem_id="js-distance", label="Distance")
132
 
133
+ # Generation settings
134
+ with gr.Accordion("⚙️ Generation Settings", open=False):
135
+ seed = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=42, label="Seed")
136
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
137
+ guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.1, value=7.5, label="Guidance scale")
138
+ num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=4, label="Number of inference steps")
139
+
140
+ def update_dimensions_on_upload(input_image):
141
+ if input_image is None:
142
+ return 1024, 1024
143
+
144
+ original_width, original_height = input_image.size
145
+ aspect_ratio = original_width / original_height
146
+
147
+ if aspect_ratio > 1:
148
+ # Landscape
149
+ new_width = 1024
150
+ new_height = round(1024 / aspect_ratio / 32) * 32
151
+ else:
152
+ # Portrait or square
153
+ new_height = 1024
154
+ new_width = round(1024 * aspect_ratio / 32) * 32
155
+
156
+ # Ensure minimum size
157
+ new_width = max(256, min(1024, new_width))
158
+ new_height = max(256, min(1024, new_height))
159
+
160
+ return new_width, new_height
161
+
162
+ height = gr.Slider(minimum=256, maximum=1024, step=32, value=1024, label="Height")
163
+ width = gr.Slider(minimum=256, maximum=1024, step=32, value=1024, label="Width")
164
+
165
  prompt_display = gr.Textbox(
166
+ label="Current Camera Prompt",
167
  value="<sks> front view eye-level shot medium shot",
168
  interactive=False
169
  )
170
 
171
+ # Right column: Interactive image view
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  with gr.Column(scale=1):
173
  gr.Markdown("### 🎯 Interactive Image View")
174
+ gr.Markdown("*Upload an image, then hover to see controls and click arrows to generate new views*")
175
 
176
  # Interactive HTML component using working pattern
177
  result_display = gr.HTML(
178
+ value="""
179
+ <div style="width: 100%; height: 500px; background: #f8f8f8; border: 2px solid #e0e0e0; border-radius: 12px;
180
+ position: relative; display: flex; align-items: center; justify-content: center;">
181
+ <div style="text-align: center; color: #999;">
182
+ <div style="font-size: 48px; margin-bottom: 10px;">📸</div>
183
  <p>Upload an image on the left to begin</p>
184
  <p>Then hover here to see camera controls</p>
185
  </div>
186
  </div>
187
+ """,
188
  elem_id="result-display"
189
  )
190
 
191
+ # Debug output
192
+ debug_output = gr.Textbox(label="Debug Output", visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
+ # Functions for handling interactions (inside Blocks context)
195
  def show_uploaded_image_with_arrows(uploaded_image):
196
+ """Show uploaded image with working arrow controls."""
197
  if uploaded_image is None:
198
+ return gr.update(value="""
199
+ <div style="width: 100%; height: 500px; background: #f8f8f8; border: 2px solid #e0e0e0; border-radius: 12px;
200
+ position: relative; display: flex; align-items: center; justify-content: center;">
201
+ <div style="text-align: center; color: #999;">
202
+ <div style="font-size: 48px; margin-bottom: 10px;">📸</div>
203
  <p>Upload an image on the left to begin</p>
204
  <p>Then hover here to see camera controls</p>
205
  </div>
206
  </div>
207
+ """)
208
 
209
  # Convert to data URL
210
  buffered = BytesIO()
211
+ uploaded_image.save(buffered, format="PNG")
212
  img_str = base64.b64encode(buffered.getvalue()).decode()
213
+ data_url = f"data:image/png;base64,{img_str}"
214
 
215
+ return gr.update(value=f"""
216
+ <div style="width: 100%; height: 500px; background: #f8f8f8; border: 2px solid #e0e0e0; border-radius: 12px;
217
+ position: relative; display: flex; align-items: center; justify-content: center;"
218
+ onmouseenter="this.querySelector('#arrow-controls').style.opacity='1'"
219
+ onmouseleave="this.querySelector('#arrow-controls').style.opacity='0'">
220
 
221
  <!-- Image -->
222
+ <img src="{data_url}" style="max-width: 100%; max-height: 100%; object-fit: contain;">
223
 
224
  <!-- Arrow controls -->
225
+ <div id="arrow-controls" style="position: absolute; inset: 0; opacity: 0; transition: opacity 0.3s ease; z-index: 10;">
226
 
227
+ <!-- Left Arrow -->
228
+ <button onclick="
229
  var azInput = document.getElementById('js-azimuth').querySelector('input');
230
  var newAz = (parseInt(azInput.value) - 45 + 360) % 360;
231
  azInput.value = newAz;
232
  azInput.dispatchEvent(new Event('input', {{bubbles: true}}));
233
  document.getElementById('status-az').textContent = newAz;
234
+ "
235
+ style="position: absolute; left: 20px; top: 50%; transform: translateY(-50%);
236
+ width: 60px; height: 60px; background: rgba(0,255,136,0.9); border: none;
237
  border-radius: 50%; color: white; font-size: 24px; cursor: pointer;
238
+ box-shadow: 0 4px 12px rgba(0,0,0,0.3); transition: transform 0.2s;"
239
+ onmouseover="this.style.transform += ' scale(1.1)'"
240
+ onmouseout="this.style.transform = this.style.transform.replace(' scale(1.1)', '')"
241
+ title="Rotate Left">
242
 
243
  </button>
244
 
245
+ <!-- Right Arrow -->
246
+ <button onclick="
247
  var azInput = document.getElementById('js-azimuth').querySelector('input');
248
  var newAz = (parseInt(azInput.value) + 45) % 360;
249
  azInput.value = newAz;
250
  azInput.dispatchEvent(new Event('input', {{bubbles: true}}));
251
  document.getElementById('status-az').textContent = newAz;
252
+ "
253
+ style="position: absolute; right: 20px; top: 50%; transform: translateY(-50%);
254
+ width: 60px; height: 60px; background: rgba(0,255,136,0.9); border: none;
255
  border-radius: 50%; color: white; font-size: 24px; cursor: pointer;
256
+ box-shadow: 0 4px 12px rgba(0,0,0,0.3); transition: transform 0.2s;"
257
+ onmouseover="this.style.transform += ' scale(1.1)'"
258
+ onmouseout="this.style.transform = this.style.transform.replace(' scale(1.1)', '')"
259
+ title="Rotate Right">
260
 
261
  </button>
262
 
263
+ <!-- Up Arrow -->
264
+ <button onclick="
265
  var elInput = document.getElementById('js-elevation').querySelector('input');
266
  var newEl = Math.min(60, parseInt(elInput.value) + 30);
267
  elInput.value = newEl;
268
  elInput.dispatchEvent(new Event('input', {{bubbles: true}}));
269
  document.getElementById('status-el').textContent = newEl;
270
+ "
271
+ style="position: absolute; top: 20px; left: 50%; transform: translateX(-50%);
272
+ width: 60px; height: 60px; background: rgba(255,105,180,0.9); border: none;
273
  border-radius: 50%; color: white; font-size: 24px; cursor: pointer;
274
+ box-shadow: 0 4px 12px rgba(0,0,0,0.3); transition: transform 0.2s;"
275
+ onmouseover="this.style.transform += ' scale(1.1)'"
276
+ onmouseout="this.style.transform = this.style.transform.replace(' scale(1.1)', '')"
277
+ title="Look Up">
278
 
279
  </button>
280
 
281
+ <!-- Down Arrow -->
282
+ <button onclick="
283
  var elInput = document.getElementById('js-elevation').querySelector('input');
284
  var newEl = Math.max(-30, parseInt(elInput.value) - 30);
285
  elInput.value = newEl;
286
  elInput.dispatchEvent(new Event('input', {{bubbles: true}}));
287
  document.getElementById('status-el').textContent = newEl;
288
+ "
289
+ style="position: absolute; bottom: 80px; left: 50%; transform: translateX(-50%);
290
+ width: 60px; height: 60px; background: rgba(255,105,180,0.9); border: none;
291
  border-radius: 50%; color: white; font-size: 24px; cursor: pointer;
292
+ box-shadow: 0 4px 12px rgba(0,0,0,0.3); transition: transform 0.2s;"
293
+ onmouseover="this.style.transform += ' scale(1.1)'"
294
+ onmouseout="this.style.transform = this.style.transform.replace(' scale(1.1)', '')"
295
+ title="Look Down">
296
 
297
  </button>
298
 
299
  <!-- Zoom Controls -->
300
+ <div style="position: absolute; bottom: 20px; left: 50%; transform: translateX(-50%);
301
+ display: flex; gap: 15px;">
302
 
303
+ <button onclick="
 
304
  var distInput = document.getElementById('js-distance').querySelector('input');
305
  var newDist = Math.min(1.8, parseFloat(distInput.value) + 0.4);
306
  distInput.value = newDist.toFixed(1);
307
  distInput.dispatchEvent(new Event('input', {{bubbles: true}}));
308
  document.getElementById('status-dist').textContent = newDist.toFixed(1);
309
+ "
310
+ style="width: 55px; height: 55px; background: rgba(255,165,0,0.9); border: none;
311
+ border-radius: 50%; color: white; font-size: 24px; cursor: pointer;
312
+ box-shadow: 0 4px 12px rgba(0,0,0,0.3); transition: transform 0.2s;"
313
+ onmouseover="this.style.transform = 'scale(1.1)'"
314
+ onmouseout="this.style.transform = ''"
315
+ title="Zoom Out">
316
 
317
  </button>
318
 
319
+ <button onclick="
 
320
  var distInput = document.getElementById('js-distance').querySelector('input');
321
  var newDist = Math.max(0.6, parseFloat(distInput.value) - 0.4);
322
  distInput.value = newDist.toFixed(1);
323
  distInput.dispatchEvent(new Event('input', {{bubbles: true}}));
324
  document.getElementById('status-dist').textContent = newDist.toFixed(1);
325
+ "
326
+ style="width: 55px; height: 55px; background: rgba(255,165,0,0.9); border: none;
327
  border-radius: 50%; color: white; font-size: 24px; cursor: pointer;
328
+ box-shadow: 0 4px 12px rgba(0,0,0,0.3); transition: transform 0.2s;"
329
+ onmouseover="this.style.transform = 'scale(1.1)'"
330
+ onmouseout="this.style.transform = ''"
331
+ title="Zoom In">
332
  +
333
  </button>
334
  </div>
335
 
336
  <!-- Status Display -->
337
+ <div style="position: absolute; top: 15px; right: 15px; background: rgba(0,0,0,0.85);
338
+ color: white; padding: 10px 14px; border-radius: 8px; font-family: monospace;
339
+ font-size: 13px; box-shadow: 0 4px 12px rgba(0,0,0,0.4);">
340
+ <div>Az: <span id="status-az">0</span>° | El: <span id="status-el">0</span>° | Dist: <span id="status-dist">1.0</span></div>
 
 
 
341
  </div>
342
  </div>
343
  </div>
344
+ """)
345
 
346
  def handle_parameter_change(az, el, dist, input_image, seed_val, randomize_seed_val, guidance_val, steps_val, h_val, w_val):
347
+ """Handle camera parameter changes and generate new view."""
348
  try:
349
  azimuth = float(az)
350
  elevation = float(el)
 
358
  generated_image, final_seed, final_prompt = infer_camera_edit(
359
  image=input_image,
360
  azimuth=azimuth,
361
+ elevation=elevation,
362
  distance=distance,
363
  seed=seed_val,
364
  randomize_seed=randomize_seed_val,
365
  guidance_scale=guidance_val,
366
  num_inference_steps=steps_val,
367
+ height=int(h_val),
368
+ width=int(w_val)
369
  )
370
 
371
+ # Update the HTML display with the generated image
372
+ return show_uploaded_image_with_arrows(generated_image).value, prompt, f"Generated view: Az={azimuth}°, El={elevation}°, Dist={distance}, Seed={final_seed}"
 
373
 
374
+ return gr.update(), prompt, f"Parameters updated: Az={azimuth}°, El={elevation}°, Dist={distance}"
375
 
376
  except Exception as e:
377
+ return gr.update(), f"Error: {str(e)}", f"Error processing parameters: {str(e)}"
 
 
 
378
 
379
+ # Update dimensions when image is uploaded
 
 
380
  image.upload(
381
  fn=update_dimensions_on_upload,
382
  inputs=[image],
383
  outputs=[width, height]
384
  )
385
 
386
+ # Image upload handler
387
  image.upload(
388
  fn=show_uploaded_image_with_arrows,
389
  inputs=[image],
390
  outputs=[result_display]
391
  )
392
 
393
+ # Parameter change handlers (triggered by arrow clicks)
394
+ for param_input in [js_azimuth, js_elevation, js_distance]:
395
+ param_input.change(
396
+ fn=handle_parameter_change,
 
 
 
 
 
 
 
 
397
  inputs=[js_azimuth, js_elevation, js_distance, image, seed, randomize_seed, guidance_scale, num_inference_steps, height, width],
398
+ outputs=[result_display, prompt_display, debug_output]
399
  )
400
 
401
  return demo
402
 
403
+ if __name__ == "__main__":
404
  demo = create_camera_control_app()
405
  demo.launch()