gagndeep commited on
Commit
e719a67
·
1 Parent(s): 7e390f4
Files changed (2) hide show
  1. app.py +18 -37
  2. model_utils.py +20 -240
app.py CHANGED
@@ -1,8 +1,8 @@
1
  """
2
  SHARP Gradio Demo
3
- - Standard Native Layout (Clean Two-Column)
4
- - Logic: Matches original Apple implementation (Robust Enum & Resolution handling)
5
- - System: ZeroGPU compatible
6
  """
7
 
8
  from __future__ import annotations
@@ -17,6 +17,7 @@ import gradio as gr
17
  try:
18
  import spaces
19
  except ImportError:
 
20
  class spaces:
21
  @staticmethod
22
  def GPU(func):
@@ -83,10 +84,7 @@ def get_example_files() -> list[list[str]]:
83
  examples.append([str(img)])
84
  return examples
85
 
86
- # -----------------------------------------------------------------------------
87
- # Main Inference Logic
88
- # -----------------------------------------------------------------------------
89
-
90
  @spaces.GPU(duration=120)
91
  def run_sharp(
92
  image_path: str | None,
@@ -104,26 +102,20 @@ def run_sharp(
104
  if not image_path:
105
  raise gr.Error("Please upload an image first.")
106
 
107
- # 1. Logic: Robust Enum Conversion
108
- # The model likely expects the Enum object, not the string.
109
- try:
110
- # Try exact match (e.g. "swipe" -> TrajectoryType.swipe)
111
- traj_enum = TrajectoryType[trajectory_type]
112
- except KeyError:
113
- try:
114
- # Try upper case (e.g. "swipe" -> TrajectoryType.SWIPE)
115
- traj_enum = TrajectoryType[trajectory_type.upper()]
116
- except KeyError:
117
- # Fallback: pass the string itself
118
- traj_enum = trajectory_type
119
-
120
- # 2. Logic: Handle Resolution
121
  out_long_side_val = None if int(output_long_side) <= 0 else int(output_long_side)
 
 
 
 
 
 
 
122
 
123
  try:
124
  progress(0.1, desc="Initializing SHARP model on GPU...")
125
 
126
- # 3. Call Backend
127
  video_path, ply_path = predict_and_maybe_render_gpu(
128
  image_path,
129
  trajectory_type=traj_enum,
@@ -133,14 +125,12 @@ def run_sharp(
133
  render_video=bool(render_video),
134
  )
135
 
136
- # 4. Prepare Outputs
137
  status_msg = f"### ✅ Success\nGenerated: `{ply_path.name}`"
138
 
139
  video_result = str(video_path) if video_path else None
140
  if video_path:
141
  status_msg += f"\nVideo: `{video_path.name}`"
142
- else:
143
- status_msg += "\n(Video rendering disabled or failed)"
144
 
145
  # Explicitly update the Download Button
146
  download_btn_update = gr.DownloadButton(
@@ -172,7 +162,6 @@ def build_demo() -> gr.Blocks:
172
 
173
  with gr.Blocks(theme=theme, head=SEO_HEAD, title="SHARP 3D Generator") as demo:
174
 
175
- # --- Header ---
176
  with gr.Row():
177
  with gr.Column(scale=1):
178
  gr.Markdown("# SHARP: Single-Image 3D Generator\nConvert any static image into a 3D Gaussian Splat scene instantly.")
@@ -189,7 +178,7 @@ def build_demo() -> gr.Blocks:
189
  interactive=True
190
  )
191
 
192
- # Configs (Updated with Full Options from Original File)
193
  with gr.Group():
194
  with gr.Row():
195
  trajectory = gr.Dropdown(
@@ -200,15 +189,7 @@ def build_demo() -> gr.Blocks:
200
  )
201
  output_res = gr.Dropdown(
202
  label="Output Resolution",
203
- # Full list from the original logic
204
- choices=[
205
- ("Match input", 0),
206
- ("512", 512),
207
- ("768", 768),
208
- ("1024", 1024),
209
- ("1280", 1280),
210
- ("1536", 1536),
211
- ],
212
  value=0,
213
  scale=1
214
  )
@@ -242,7 +223,7 @@ def build_demo() -> gr.Blocks:
242
 
243
  with gr.Group():
244
  status_md = gr.Markdown("Ready to generate.")
245
- # Button starts hidden, becomes visible on success
246
  ply_download = gr.DownloadButton(
247
  label="Download .PLY File",
248
  variant="secondary",
 
1
  """
2
  SHARP Gradio Demo
3
+ - Standard Native Layout
4
+ - Fixed: Added @spaces.GPU for ZeroGPU compatibility (Fixes 'dummy' output)
5
+ - Fixed: Download Button visibility logic
6
  """
7
 
8
  from __future__ import annotations
 
17
  try:
18
  import spaces
19
  except ImportError:
20
+ # Fallback for local testing if spaces is not installed
21
  class spaces:
22
  @staticmethod
23
  def GPU(func):
 
84
  examples.append([str(img)])
85
  return examples
86
 
87
+ # --- 2. Apply @spaces.GPU Decorator ---
 
 
 
88
  @spaces.GPU(duration=120)
89
  def run_sharp(
90
  image_path: str | None,
 
102
  if not image_path:
103
  raise gr.Error("Please upload an image first.")
104
 
105
+ # Validate inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  out_long_side_val = None if int(output_long_side) <= 0 else int(output_long_side)
107
+
108
+ # Convert trajectory string to Enum safely
109
+ traj_key = trajectory_type.upper()
110
+ if hasattr(TrajectoryType, traj_key):
111
+ traj_enum = TrajectoryType[traj_key]
112
+ else:
113
+ traj_enum = trajectory_type
114
 
115
  try:
116
  progress(0.1, desc="Initializing SHARP model on GPU...")
117
 
118
+ # Call the backend model
119
  video_path, ply_path = predict_and_maybe_render_gpu(
120
  image_path,
121
  trajectory_type=traj_enum,
 
125
  render_video=bool(render_video),
126
  )
127
 
128
+ # Prepare outputs
129
  status_msg = f"### ✅ Success\nGenerated: `{ply_path.name}`"
130
 
131
  video_result = str(video_path) if video_path else None
132
  if video_path:
133
  status_msg += f"\nVideo: `{video_path.name}`"
 
 
134
 
135
  # Explicitly update the Download Button
136
  download_btn_update = gr.DownloadButton(
 
162
 
163
  with gr.Blocks(theme=theme, head=SEO_HEAD, title="SHARP 3D Generator") as demo:
164
 
 
165
  with gr.Row():
166
  with gr.Column(scale=1):
167
  gr.Markdown("# SHARP: Single-Image 3D Generator\nConvert any static image into a 3D Gaussian Splat scene instantly.")
 
178
  interactive=True
179
  )
180
 
181
+ # Configs
182
  with gr.Group():
183
  with gr.Row():
184
  trajectory = gr.Dropdown(
 
189
  )
190
  output_res = gr.Dropdown(
191
  label="Output Resolution",
192
+ choices=[("Original", 0), ("512px", 512), ("1024px", 1024)],
 
 
 
 
 
 
 
 
193
  value=0,
194
  scale=1
195
  )
 
223
 
224
  with gr.Group():
225
  status_md = gr.Markdown("Ready to generate.")
226
+ # Button starts hidden
227
  ply_download = gr.DownloadButton(
228
  label="Download .PLY File",
229
  variant="secondary",
model_utils.py CHANGED
@@ -4,7 +4,6 @@ Design goals:
4
  - Reuse SHARP's own predict/render pipeline (no subprocess calls).
5
  - Be robust on Hugging Face Spaces + ZeroGPU.
6
  - Cache model weights and predictor construction across requests.
7
- - Support extended camera trajectories beyond the defaults.
8
 
9
  Public API (used by the Gradio app):
10
  - TrajectoryType
@@ -13,7 +12,6 @@ Public API (used by the Gradio app):
13
 
14
  from __future__ import annotations
15
 
16
- import math
17
  import os
18
  import threading
19
  import time
@@ -44,24 +42,7 @@ from sharp.utils import camera, io
44
  from sharp.utils.gaussians import Gaussians3D, SceneMetaData, save_ply
45
  from sharp.utils.gsplat import GSplatRenderer
46
 
47
- # Extended list of supported trajectories (22 types)
48
- TrajectoryType = Literal[
49
- # Standard SHARP defaults
50
- "swipe", "shake", "rotate", "rotate_forward",
51
- # Extended Rotations
52
- "rotate_reverse", "rotate_up", "rotate_down",
53
- # Zooms & Dollies
54
- "zoom_in", "zoom_out", "dolly_in", "dolly_out",
55
- # Pans (planar movement)
56
- "pan_left", "pan_right", "pan_up", "pan_down",
57
- # Complex Paths
58
- "spiral_in", "spiral_out", "figure_eight", "loop", "heart",
59
- "bounce", "ken_burns"
60
- ]
61
-
62
- STANDARD_TRAJECTORIES: Final[set[str]] = {
63
- "swipe", "shake", "rotate", "rotate_forward"
64
- }
65
 
66
  # -----------------------------------------------------------------------------
67
  # Helpers
@@ -101,189 +82,6 @@ def _select_device(preference: str = "auto") -> torch.device:
101
  return torch.device("cpu")
102
 
103
 
104
- # -----------------------------------------------------------------------------
105
- # Custom Trajectory Generation
106
- # -----------------------------------------------------------------------------
107
-
108
- def _generate_custom_trajectory(
109
- gaussians: torch.Tensor,
110
- resolution: tuple[int, int],
111
- focal_length: float,
112
- traj_type: str,
113
- num_frames: int
114
- ) -> list[torch.Tensor]:
115
- """
116
- Generates a list of camera eye positions (tensors) for custom paths.
117
- Uses the standard 'rotate' path to establish a baseline radius/elevation.
118
- """
119
- # 1. Get baseline Radius (R) and Elevation (Y) from the standard generator
120
- # We generate just 1 step of the standard 'rotate' to see where SHARP puts the camera.
121
- base_params = camera.TrajectoryParams(type="rotate", num_steps=1)
122
- base_traj = camera.create_eye_trajectory(
123
- gaussians, base_params, resolution_px=resolution, f_px=focal_length
124
- )
125
- start_pos = list(base_traj)[0].cpu() # [3] tensor (x, y, z)
126
-
127
- # Calculate spherical coordinates from start_pos
128
- # Assuming LookAt(0,0,0), radius is norm, elevation is y.
129
- radius = float(torch.norm(start_pos))
130
- base_y = float(start_pos[1])
131
-
132
- # Starting azimuth (theta). Usually start_pos is roughly [0, 0, radius] or [radius, 0, 0]
133
- # We'll compute it to be safe.
134
- base_theta = math.atan2(start_pos[2], start_pos[0])
135
-
136
- positions = []
137
-
138
- # Time steps 0..1
139
- t_vals = [i / (num_frames - 1) for i in range(num_frames)]
140
-
141
- for t in t_vals:
142
- x, y, z = 0.0, 0.0, 0.0
143
-
144
- # --- Logic for 20+ movements ---
145
-
146
- if traj_type == "rotate_reverse":
147
- # Orbit opposite direction
148
- theta = base_theta - (2 * math.pi * t)
149
- x = radius * math.cos(theta)
150
- z = radius * math.sin(theta)
151
- y = base_y
152
-
153
- elif traj_type == "rotate_up":
154
- # Orbit over the top (vertical orbit)
155
- phi = (math.pi / 4) * math.sin(2 * math.pi * t)
156
- # Modulate Y significantly
157
- theta = base_theta + (0.5 * math.pi * t) # Slow rotate
158
- curr_r = radius
159
- x = curr_r * math.cos(theta)
160
- z = curr_r * math.sin(theta)
161
- y = base_y + (radius * 0.8 * math.sin(math.pi * t)) # Arc up
162
-
163
- elif traj_type == "rotate_down":
164
- theta = base_theta + (0.5 * math.pi * t)
165
- y = base_y - (radius * 0.5 * math.sin(math.pi * t))
166
- x = radius * math.cos(theta)
167
- z = radius * math.sin(theta)
168
-
169
- elif traj_type in ["zoom_in", "dolly_in"]:
170
- # Move from Radius to Radius*0.4
171
- cur_r = radius * (1.0 - 0.6 * t)
172
- x = cur_r * math.cos(base_theta)
173
- z = cur_r * math.sin(base_theta)
174
- y = base_y
175
-
176
- elif traj_type in ["zoom_out", "dolly_out"]:
177
- # Move from Radius*0.5 to Radius*1.2
178
- cur_r = (radius * 0.5) + (radius * 0.7 * t)
179
- x = cur_r * math.cos(base_theta)
180
- z = cur_r * math.sin(base_theta)
181
- y = base_y
182
-
183
- elif traj_type == "pan_left":
184
- # Linear slide perpendicular to view vector
185
- # Approx: move X relative to view
186
- offset = (t - 0.5) * 2.0 * (radius * 0.5)
187
- x = start_pos[0] + offset
188
- y = start_pos[1]
189
- z = start_pos[2] # Simple approximation
190
-
191
- elif traj_type == "pan_right":
192
- offset = (0.5 - t) * 2.0 * (radius * 0.5)
193
- x = start_pos[0] + offset
194
- y = start_pos[1]
195
- z = start_pos[2]
196
-
197
- elif traj_type == "pan_up":
198
- offset = (t - 0.5) * (radius * 0.8)
199
- x = start_pos[0]
200
- y = base_y - offset # In 3D, Y usually up, but check coord sys. usually Y is down in some CV.
201
- # Assuming Y is Up for scene.
202
- y = base_y + offset
203
- z = start_pos[2]
204
-
205
- elif traj_type == "pan_down":
206
- offset = (t - 0.5) * (radius * 0.8)
207
- x = start_pos[0]
208
- y = base_y - offset
209
- z = start_pos[2]
210
-
211
- elif traj_type == "spiral_in":
212
- # Rotate while getting closer
213
- theta = base_theta + (2 * math.pi * t)
214
- cur_r = radius * (1.0 - 0.6 * t)
215
- x = cur_r * math.cos(theta)
216
- z = cur_r * math.sin(theta)
217
- y = base_y + (0.2 * radius * math.sin(4 * math.pi * t)) # Slight wobble
218
-
219
- elif traj_type == "spiral_out":
220
- theta = base_theta + (2 * math.pi * t)
221
- cur_r = (radius * 0.4) + (radius * 0.8 * t)
222
- x = cur_r * math.cos(theta)
223
- z = cur_r * math.sin(theta)
224
- y = base_y
225
-
226
- elif traj_type == "figure_eight":
227
- # Lemniscate on sphere surface
228
- scale = 2 * math.pi * t
229
- # Lissajous-ish
230
- theta = base_theta + (0.5 * math.sin(scale))
231
- phi_offset = 0.3 * math.sin(2 * scale)
232
- y = base_y + (radius * phi_offset)
233
- x = radius * math.cos(theta)
234
- z = radius * math.sin(theta)
235
-
236
- elif traj_type == "loop":
237
- # Vertical circle
238
- angle = 2 * math.pi * t
239
- y_off = 0.5 * radius * math.sin(angle)
240
- x_off = 0.2 * radius * math.cos(angle)
241
- x = start_pos[0] + x_off
242
- y = base_y + y_off
243
- z = start_pos[2]
244
-
245
- elif traj_type == "heart":
246
- # Heart shape in XY plane projection
247
- angle = 2 * math.pi * t
248
- # Heart formula
249
- h_x = 16 * math.sin(angle)**3
250
- h_y = 13 * math.cos(angle) - 5*math.cos(2*angle) - 2*math.cos(3*angle) - math.cos(4*angle)
251
- # Scale down
252
- scale = radius * 0.02
253
- x = start_pos[0] + (h_x * scale)
254
- y = base_y + (h_y * scale)
255
- z = start_pos[2]
256
-
257
- elif traj_type == "bounce":
258
- # Decay bounce
259
- freq = 3 * math.pi
260
- amp = abs(math.cos(freq * t)) * (1-t)
261
- y = base_y + (radius * 0.5 * amp)
262
- x = start_pos[0]
263
- z = start_pos[2]
264
-
265
- elif traj_type == "ken_burns":
266
- # Pan diagonal + slow zoom
267
- zoom_fac = 1.0 - (0.3 * t) # Zoom in 30%
268
- pan_x = (t - 0.5) * (radius * 0.3)
269
- pan_y = (t - 0.5) * (radius * 0.2)
270
-
271
- cur_r = radius * zoom_fac
272
- x = (cur_r * math.cos(base_theta)) + pan_x
273
- y = base_y + pan_y
274
- z = (cur_r * math.sin(base_theta))
275
-
276
- else:
277
- # Fallback for anything else (or minor variations)
278
- return list(base_traj) # Should be caught by caller, but safe fallback
279
-
280
- # Construct tensor
281
- pos_tensor = torch.tensor([x, y, z], dtype=torch.float32, device=gaussians.device)
282
- positions.append(pos_tensor)
283
-
284
- return positions
285
-
286
-
287
  # -----------------------------------------------------------------------------
288
  # Prediction outputs
289
  # -----------------------------------------------------------------------------
@@ -608,13 +406,10 @@ class ModelWrapper:
608
  if fps < 1:
609
  raise ValueError("fps must be >= 1")
610
 
611
- # FAST PATH: Standard SHARP trajectories + Default Resolution
612
- # We only use the optimized CLI shortcut if it's a standard type AND default res.
613
- is_standard_traj = trajectory_type in STANDARD_TRAJECTORIES
614
-
615
- if output_long_side is None and int(fps) == 30 and is_standard_traj:
616
  params = camera.TrajectoryParams(
617
- type=trajectory_type, # type: ignore
618
  num_steps=int(num_frames),
619
  num_repeats=1,
620
  )
@@ -633,7 +428,7 @@ class ModelWrapper:
633
  pass
634
  return output_path
635
 
636
- # CUSTOM PATH: Manual loop (Handles Custom Res, FPS, or Custom Trajectories)
637
  src_w, src_h = metadata.resolution_px
638
  src_f = float(metadata.focal_length_px)
639
 
@@ -646,37 +441,15 @@ class ModelWrapper:
646
  out_h = _make_even(max(2, int(round(src_h * scale))))
647
  out_f = src_f * scale
648
 
 
 
 
 
 
 
649
  device = torch.device("cuda")
650
  gaussians_cuda = gaussians.to(device)
651
 
652
- # 1. Generate Camera Trajectory
653
- if is_standard_traj:
654
- # Use SHARP's built-in generator
655
- traj_params = camera.TrajectoryParams(
656
- type=trajectory_type, # type: ignore
657
- num_steps=int(num_frames),
658
- num_repeats=1,
659
- )
660
- trajectory = camera.create_eye_trajectory(
661
- gaussians_cuda,
662
- traj_params,
663
- resolution_px=(out_w, out_h),
664
- f_px=out_f,
665
- )
666
- lookat_mode = traj_params.lookat_mode
667
- else:
668
- # Use our custom generator
669
- trajectory = _generate_custom_trajectory(
670
- gaussians_cuda,
671
- resolution=(out_w, out_h),
672
- focal_length=out_f,
673
- traj_type=trajectory_type,
674
- num_frames=num_frames
675
- )
676
- # Custom trajectories always look at origin (0,0,0) for now
677
- lookat_mode = "scene" # Assuming SHARP 'scene' mode implies look-at-center
678
-
679
- # 2. Setup Camera Model
680
  intrinsics = torch.tensor(
681
  [
682
  [out_f, 0.0, (out_w - 1) / 2.0, 0.0],
@@ -692,7 +465,14 @@ class ModelWrapper:
692
  gaussians_cuda,
693
  intrinsics,
694
  resolution_px=(out_w, out_h),
695
- lookat_mode=lookat_mode,
 
 
 
 
 
 
 
696
  )
697
 
698
  renderer = GSplatRenderer(color_space=metadata.color_space)
@@ -829,4 +609,4 @@ def predict_and_maybe_render(
829
  if spaces is not None:
830
  predict_and_maybe_render_gpu = spaces.GPU(duration=180)(predict_and_maybe_render)
831
  else: # pragma: no cover
832
- predict_and_maybe_render_gpu = predict_and_maybe_render
 
4
  - Reuse SHARP's own predict/render pipeline (no subprocess calls).
5
  - Be robust on Hugging Face Spaces + ZeroGPU.
6
  - Cache model weights and predictor construction across requests.
 
7
 
8
  Public API (used by the Gradio app):
9
  - TrajectoryType
 
12
 
13
  from __future__ import annotations
14
 
 
15
  import os
16
  import threading
17
  import time
 
42
  from sharp.utils.gaussians import Gaussians3D, SceneMetaData, save_ply
43
  from sharp.utils.gsplat import GSplatRenderer
44
 
45
+ TrajectoryType = Literal["swipe", "shake", "rotate", "rotate_forward"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # -----------------------------------------------------------------------------
48
  # Helpers
 
82
  return torch.device("cpu")
83
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # -----------------------------------------------------------------------------
86
  # Prediction outputs
87
  # -----------------------------------------------------------------------------
 
406
  if fps < 1:
407
  raise ValueError("fps must be >= 1")
408
 
409
+ # Keep aligned with upstream CLI pipeline where possible.
410
+ if output_long_side is None and int(fps) == 30:
 
 
 
411
  params = camera.TrajectoryParams(
412
+ type=trajectory_type,
413
  num_steps=int(num_frames),
414
  num_repeats=1,
415
  )
 
428
  pass
429
  return output_path
430
 
431
+ # Adapted pipeline for custom output resolution / FPS.
432
  src_w, src_h = metadata.resolution_px
433
  src_f = float(metadata.focal_length_px)
434
 
 
441
  out_h = _make_even(max(2, int(round(src_h * scale))))
442
  out_f = src_f * scale
443
 
444
+ traj_params = camera.TrajectoryParams(
445
+ type=trajectory_type,
446
+ num_steps=int(num_frames),
447
+ num_repeats=1,
448
+ )
449
+
450
  device = torch.device("cuda")
451
  gaussians_cuda = gaussians.to(device)
452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  intrinsics = torch.tensor(
454
  [
455
  [out_f, 0.0, (out_w - 1) / 2.0, 0.0],
 
465
  gaussians_cuda,
466
  intrinsics,
467
  resolution_px=(out_w, out_h),
468
+ lookat_mode=traj_params.lookat_mode,
469
+ )
470
+
471
+ trajectory = camera.create_eye_trajectory(
472
+ gaussians_cuda,
473
+ traj_params,
474
+ resolution_px=(out_w, out_h),
475
+ f_px=out_f,
476
  )
477
 
478
  renderer = GSplatRenderer(color_space=metadata.color_space)
 
609
  if spaces is not None:
610
  predict_and_maybe_render_gpu = spaces.GPU(duration=180)(predict_and_maybe_render)
611
  else: # pragma: no cover
612
+ predict_and_maybe_render_gpu = predict_and_maybe_render