gagndeep commited on
Commit
7e390f4
·
1 Parent(s): 671c57a
Files changed (2) hide show
  1. app.py +37 -18
  2. model_utils.py +240 -20
app.py CHANGED
@@ -1,8 +1,8 @@
1
  """
2
  SHARP Gradio Demo
3
- - Standard Native Layout
4
- - Fixed: Added @spaces.GPU for ZeroGPU compatibility (Fixes 'dummy' output)
5
- - Fixed: Download Button visibility logic
6
  """
7
 
8
  from __future__ import annotations
@@ -17,7 +17,6 @@ import gradio as gr
17
  try:
18
  import spaces
19
  except ImportError:
20
- # Fallback for local testing if spaces is not installed
21
  class spaces:
22
  @staticmethod
23
  def GPU(func):
@@ -84,7 +83,10 @@ def get_example_files() -> list[list[str]]:
84
  examples.append([str(img)])
85
  return examples
86
 
87
- # --- 2. Apply @spaces.GPU Decorator ---
 
 
 
88
  @spaces.GPU(duration=120)
89
  def run_sharp(
90
  image_path: str | None,
@@ -102,20 +104,26 @@ def run_sharp(
102
  if not image_path:
103
  raise gr.Error("Please upload an image first.")
104
 
105
- # Validate inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  out_long_side_val = None if int(output_long_side) <= 0 else int(output_long_side)
107
-
108
- # Convert trajectory string to Enum safely
109
- traj_key = trajectory_type.upper()
110
- if hasattr(TrajectoryType, traj_key):
111
- traj_enum = TrajectoryType[traj_key]
112
- else:
113
- traj_enum = trajectory_type
114
 
115
  try:
116
  progress(0.1, desc="Initializing SHARP model on GPU...")
117
 
118
- # Call the backend model
119
  video_path, ply_path = predict_and_maybe_render_gpu(
120
  image_path,
121
  trajectory_type=traj_enum,
@@ -125,12 +133,14 @@ def run_sharp(
125
  render_video=bool(render_video),
126
  )
127
 
128
- # Prepare outputs
129
  status_msg = f"### ✅ Success\nGenerated: `{ply_path.name}`"
130
 
131
  video_result = str(video_path) if video_path else None
132
  if video_path:
133
  status_msg += f"\nVideo: `{video_path.name}`"
 
 
134
 
135
  # Explicitly update the Download Button
136
  download_btn_update = gr.DownloadButton(
@@ -162,6 +172,7 @@ def build_demo() -> gr.Blocks:
162
 
163
  with gr.Blocks(theme=theme, head=SEO_HEAD, title="SHARP 3D Generator") as demo:
164
 
 
165
  with gr.Row():
166
  with gr.Column(scale=1):
167
  gr.Markdown("# SHARP: Single-Image 3D Generator\nConvert any static image into a 3D Gaussian Splat scene instantly.")
@@ -178,7 +189,7 @@ def build_demo() -> gr.Blocks:
178
  interactive=True
179
  )
180
 
181
- # Configs
182
  with gr.Group():
183
  with gr.Row():
184
  trajectory = gr.Dropdown(
@@ -189,7 +200,15 @@ def build_demo() -> gr.Blocks:
189
  )
190
  output_res = gr.Dropdown(
191
  label="Output Resolution",
192
- choices=[("Original", 0), ("512px", 512), ("1024px", 1024)],
 
 
 
 
 
 
 
 
193
  value=0,
194
  scale=1
195
  )
@@ -223,7 +242,7 @@ def build_demo() -> gr.Blocks:
223
 
224
  with gr.Group():
225
  status_md = gr.Markdown("Ready to generate.")
226
- # Button starts hidden
227
  ply_download = gr.DownloadButton(
228
  label="Download .PLY File",
229
  variant="secondary",
 
1
  """
2
  SHARP Gradio Demo
3
+ - Standard Native Layout (Clean Two-Column)
4
+ - Logic: Matches original Apple implementation (Robust Enum & Resolution handling)
5
+ - System: ZeroGPU compatible
6
  """
7
 
8
  from __future__ import annotations
 
17
  try:
18
  import spaces
19
  except ImportError:
 
20
  class spaces:
21
  @staticmethod
22
  def GPU(func):
 
83
  examples.append([str(img)])
84
  return examples
85
 
86
+ # -----------------------------------------------------------------------------
87
+ # Main Inference Logic
88
+ # -----------------------------------------------------------------------------
89
+
90
  @spaces.GPU(duration=120)
91
  def run_sharp(
92
  image_path: str | None,
 
104
  if not image_path:
105
  raise gr.Error("Please upload an image first.")
106
 
107
+ # 1. Logic: Robust Enum Conversion
108
+ # The model likely expects the Enum object, not the string.
109
+ try:
110
+ # Try exact match (e.g. "swipe" -> TrajectoryType.swipe)
111
+ traj_enum = TrajectoryType[trajectory_type]
112
+ except KeyError:
113
+ try:
114
+ # Try upper case (e.g. "swipe" -> TrajectoryType.SWIPE)
115
+ traj_enum = TrajectoryType[trajectory_type.upper()]
116
+ except KeyError:
117
+ # Fallback: pass the string itself
118
+ traj_enum = trajectory_type
119
+
120
+ # 2. Logic: Handle Resolution
121
  out_long_side_val = None if int(output_long_side) <= 0 else int(output_long_side)
 
 
 
 
 
 
 
122
 
123
  try:
124
  progress(0.1, desc="Initializing SHARP model on GPU...")
125
 
126
+ # 3. Call Backend
127
  video_path, ply_path = predict_and_maybe_render_gpu(
128
  image_path,
129
  trajectory_type=traj_enum,
 
133
  render_video=bool(render_video),
134
  )
135
 
136
+ # 4. Prepare Outputs
137
  status_msg = f"### ✅ Success\nGenerated: `{ply_path.name}`"
138
 
139
  video_result = str(video_path) if video_path else None
140
  if video_path:
141
  status_msg += f"\nVideo: `{video_path.name}`"
142
+ else:
143
+ status_msg += "\n(Video rendering disabled or failed)"
144
 
145
  # Explicitly update the Download Button
146
  download_btn_update = gr.DownloadButton(
 
172
 
173
  with gr.Blocks(theme=theme, head=SEO_HEAD, title="SHARP 3D Generator") as demo:
174
 
175
+ # --- Header ---
176
  with gr.Row():
177
  with gr.Column(scale=1):
178
  gr.Markdown("# SHARP: Single-Image 3D Generator\nConvert any static image into a 3D Gaussian Splat scene instantly.")
 
189
  interactive=True
190
  )
191
 
192
+ # Configs (Updated with Full Options from Original File)
193
  with gr.Group():
194
  with gr.Row():
195
  trajectory = gr.Dropdown(
 
200
  )
201
  output_res = gr.Dropdown(
202
  label="Output Resolution",
203
+ # Full list from the original logic
204
+ choices=[
205
+ ("Match input", 0),
206
+ ("512", 512),
207
+ ("768", 768),
208
+ ("1024", 1024),
209
+ ("1280", 1280),
210
+ ("1536", 1536),
211
+ ],
212
  value=0,
213
  scale=1
214
  )
 
242
 
243
  with gr.Group():
244
  status_md = gr.Markdown("Ready to generate.")
245
+ # Button starts hidden, becomes visible on success
246
  ply_download = gr.DownloadButton(
247
  label="Download .PLY File",
248
  variant="secondary",
model_utils.py CHANGED
@@ -4,6 +4,7 @@ Design goals:
4
  - Reuse SHARP's own predict/render pipeline (no subprocess calls).
5
  - Be robust on Hugging Face Spaces + ZeroGPU.
6
  - Cache model weights and predictor construction across requests.
 
7
 
8
  Public API (used by the Gradio app):
9
  - TrajectoryType
@@ -12,6 +13,7 @@ Public API (used by the Gradio app):
12
 
13
  from __future__ import annotations
14
 
 
15
  import os
16
  import threading
17
  import time
@@ -42,7 +44,24 @@ from sharp.utils import camera, io
42
  from sharp.utils.gaussians import Gaussians3D, SceneMetaData, save_ply
43
  from sharp.utils.gsplat import GSplatRenderer
44
 
45
- TrajectoryType = Literal["swipe", "shake", "rotate", "rotate_forward"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # -----------------------------------------------------------------------------
48
  # Helpers
@@ -82,6 +101,189 @@ def _select_device(preference: str = "auto") -> torch.device:
82
  return torch.device("cpu")
83
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # -----------------------------------------------------------------------------
86
  # Prediction outputs
87
  # -----------------------------------------------------------------------------
@@ -406,10 +608,13 @@ class ModelWrapper:
406
  if fps < 1:
407
  raise ValueError("fps must be >= 1")
408
 
409
- # Keep aligned with upstream CLI pipeline where possible.
410
- if output_long_side is None and int(fps) == 30:
 
 
 
411
  params = camera.TrajectoryParams(
412
- type=trajectory_type,
413
  num_steps=int(num_frames),
414
  num_repeats=1,
415
  )
@@ -428,7 +633,7 @@ class ModelWrapper:
428
  pass
429
  return output_path
430
 
431
- # Adapted pipeline for custom output resolution / FPS.
432
  src_w, src_h = metadata.resolution_px
433
  src_f = float(metadata.focal_length_px)
434
 
@@ -441,15 +646,37 @@ class ModelWrapper:
441
  out_h = _make_even(max(2, int(round(src_h * scale))))
442
  out_f = src_f * scale
443
 
444
- traj_params = camera.TrajectoryParams(
445
- type=trajectory_type,
446
- num_steps=int(num_frames),
447
- num_repeats=1,
448
- )
449
-
450
  device = torch.device("cuda")
451
  gaussians_cuda = gaussians.to(device)
452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  intrinsics = torch.tensor(
454
  [
455
  [out_f, 0.0, (out_w - 1) / 2.0, 0.0],
@@ -465,14 +692,7 @@ class ModelWrapper:
465
  gaussians_cuda,
466
  intrinsics,
467
  resolution_px=(out_w, out_h),
468
- lookat_mode=traj_params.lookat_mode,
469
- )
470
-
471
- trajectory = camera.create_eye_trajectory(
472
- gaussians_cuda,
473
- traj_params,
474
- resolution_px=(out_w, out_h),
475
- f_px=out_f,
476
  )
477
 
478
  renderer = GSplatRenderer(color_space=metadata.color_space)
@@ -609,4 +829,4 @@ def predict_and_maybe_render(
609
  if spaces is not None:
610
  predict_and_maybe_render_gpu = spaces.GPU(duration=180)(predict_and_maybe_render)
611
  else: # pragma: no cover
612
- predict_and_maybe_render_gpu = predict_and_maybe_render
 
4
  - Reuse SHARP's own predict/render pipeline (no subprocess calls).
5
  - Be robust on Hugging Face Spaces + ZeroGPU.
6
  - Cache model weights and predictor construction across requests.
7
+ - Support extended camera trajectories beyond the defaults.
8
 
9
  Public API (used by the Gradio app):
10
  - TrajectoryType
 
13
 
14
  from __future__ import annotations
15
 
16
+ import math
17
  import os
18
  import threading
19
  import time
 
44
  from sharp.utils.gaussians import Gaussians3D, SceneMetaData, save_ply
45
  from sharp.utils.gsplat import GSplatRenderer
46
 
47
+ # Extended list of supported trajectories (22 types)
48
+ TrajectoryType = Literal[
49
+ # Standard SHARP defaults
50
+ "swipe", "shake", "rotate", "rotate_forward",
51
+ # Extended Rotations
52
+ "rotate_reverse", "rotate_up", "rotate_down",
53
+ # Zooms & Dollies
54
+ "zoom_in", "zoom_out", "dolly_in", "dolly_out",
55
+ # Pans (planar movement)
56
+ "pan_left", "pan_right", "pan_up", "pan_down",
57
+ # Complex Paths
58
+ "spiral_in", "spiral_out", "figure_eight", "loop", "heart",
59
+ "bounce", "ken_burns"
60
+ ]
61
+
62
+ STANDARD_TRAJECTORIES: Final[set[str]] = {
63
+ "swipe", "shake", "rotate", "rotate_forward"
64
+ }
65
 
66
  # -----------------------------------------------------------------------------
67
  # Helpers
 
101
  return torch.device("cpu")
102
 
103
 
104
+ # -----------------------------------------------------------------------------
105
+ # Custom Trajectory Generation
106
+ # -----------------------------------------------------------------------------
107
+
108
+ def _generate_custom_trajectory(
109
+ gaussians: torch.Tensor,
110
+ resolution: tuple[int, int],
111
+ focal_length: float,
112
+ traj_type: str,
113
+ num_frames: int
114
+ ) -> list[torch.Tensor]:
115
+ """
116
+ Generates a list of camera eye positions (tensors) for custom paths.
117
+ Uses the standard 'rotate' path to establish a baseline radius/elevation.
118
+ """
119
+ # 1. Get baseline Radius (R) and Elevation (Y) from the standard generator
120
+ # We generate just 1 step of the standard 'rotate' to see where SHARP puts the camera.
121
+ base_params = camera.TrajectoryParams(type="rotate", num_steps=1)
122
+ base_traj = camera.create_eye_trajectory(
123
+ gaussians, base_params, resolution_px=resolution, f_px=focal_length
124
+ )
125
+ start_pos = list(base_traj)[0].cpu() # [3] tensor (x, y, z)
126
+
127
+ # Calculate spherical coordinates from start_pos
128
+ # Assuming LookAt(0,0,0), radius is norm, elevation is y.
129
+ radius = float(torch.norm(start_pos))
130
+ base_y = float(start_pos[1])
131
+
132
+ # Starting azimuth (theta). Usually start_pos is roughly [0, 0, radius] or [radius, 0, 0]
133
+ # We'll compute it to be safe.
134
+ base_theta = math.atan2(start_pos[2], start_pos[0])
135
+
136
+ positions = []
137
+
138
+ # Time steps 0..1
139
+ t_vals = [i / (num_frames - 1) for i in range(num_frames)]
140
+
141
+ for t in t_vals:
142
+ x, y, z = 0.0, 0.0, 0.0
143
+
144
+ # --- Logic for 20+ movements ---
145
+
146
+ if traj_type == "rotate_reverse":
147
+ # Orbit opposite direction
148
+ theta = base_theta - (2 * math.pi * t)
149
+ x = radius * math.cos(theta)
150
+ z = radius * math.sin(theta)
151
+ y = base_y
152
+
153
+ elif traj_type == "rotate_up":
154
+ # Orbit over the top (vertical orbit)
155
+ phi = (math.pi / 4) * math.sin(2 * math.pi * t)
156
+ # Modulate Y significantly
157
+ theta = base_theta + (0.5 * math.pi * t) # Slow rotate
158
+ curr_r = radius
159
+ x = curr_r * math.cos(theta)
160
+ z = curr_r * math.sin(theta)
161
+ y = base_y + (radius * 0.8 * math.sin(math.pi * t)) # Arc up
162
+
163
+ elif traj_type == "rotate_down":
164
+ theta = base_theta + (0.5 * math.pi * t)
165
+ y = base_y - (radius * 0.5 * math.sin(math.pi * t))
166
+ x = radius * math.cos(theta)
167
+ z = radius * math.sin(theta)
168
+
169
+ elif traj_type in ["zoom_in", "dolly_in"]:
170
+ # Move from Radius to Radius*0.4
171
+ cur_r = radius * (1.0 - 0.6 * t)
172
+ x = cur_r * math.cos(base_theta)
173
+ z = cur_r * math.sin(base_theta)
174
+ y = base_y
175
+
176
+ elif traj_type in ["zoom_out", "dolly_out"]:
177
+ # Move from Radius*0.5 to Radius*1.2
178
+ cur_r = (radius * 0.5) + (radius * 0.7 * t)
179
+ x = cur_r * math.cos(base_theta)
180
+ z = cur_r * math.sin(base_theta)
181
+ y = base_y
182
+
183
+ elif traj_type == "pan_left":
184
+ # Linear slide perpendicular to view vector
185
+ # Approx: move X relative to view
186
+ offset = (t - 0.5) * 2.0 * (radius * 0.5)
187
+ x = start_pos[0] + offset
188
+ y = start_pos[1]
189
+ z = start_pos[2] # Simple approximation
190
+
191
+ elif traj_type == "pan_right":
192
+ offset = (0.5 - t) * 2.0 * (radius * 0.5)
193
+ x = start_pos[0] + offset
194
+ y = start_pos[1]
195
+ z = start_pos[2]
196
+
197
+ elif traj_type == "pan_up":
198
+ offset = (t - 0.5) * (radius * 0.8)
199
+ x = start_pos[0]
200
+ y = base_y - offset # In 3D, Y usually up, but check coord sys. usually Y is down in some CV.
201
+ # Assuming Y is Up for scene.
202
+ y = base_y + offset
203
+ z = start_pos[2]
204
+
205
+ elif traj_type == "pan_down":
206
+ offset = (t - 0.5) * (radius * 0.8)
207
+ x = start_pos[0]
208
+ y = base_y - offset
209
+ z = start_pos[2]
210
+
211
+ elif traj_type == "spiral_in":
212
+ # Rotate while getting closer
213
+ theta = base_theta + (2 * math.pi * t)
214
+ cur_r = radius * (1.0 - 0.6 * t)
215
+ x = cur_r * math.cos(theta)
216
+ z = cur_r * math.sin(theta)
217
+ y = base_y + (0.2 * radius * math.sin(4 * math.pi * t)) # Slight wobble
218
+
219
+ elif traj_type == "spiral_out":
220
+ theta = base_theta + (2 * math.pi * t)
221
+ cur_r = (radius * 0.4) + (radius * 0.8 * t)
222
+ x = cur_r * math.cos(theta)
223
+ z = cur_r * math.sin(theta)
224
+ y = base_y
225
+
226
+ elif traj_type == "figure_eight":
227
+ # Lemniscate on sphere surface
228
+ scale = 2 * math.pi * t
229
+ # Lissajous-ish
230
+ theta = base_theta + (0.5 * math.sin(scale))
231
+ phi_offset = 0.3 * math.sin(2 * scale)
232
+ y = base_y + (radius * phi_offset)
233
+ x = radius * math.cos(theta)
234
+ z = radius * math.sin(theta)
235
+
236
+ elif traj_type == "loop":
237
+ # Vertical circle
238
+ angle = 2 * math.pi * t
239
+ y_off = 0.5 * radius * math.sin(angle)
240
+ x_off = 0.2 * radius * math.cos(angle)
241
+ x = start_pos[0] + x_off
242
+ y = base_y + y_off
243
+ z = start_pos[2]
244
+
245
+ elif traj_type == "heart":
246
+ # Heart shape in XY plane projection
247
+ angle = 2 * math.pi * t
248
+ # Heart formula
249
+ h_x = 16 * math.sin(angle)**3
250
+ h_y = 13 * math.cos(angle) - 5*math.cos(2*angle) - 2*math.cos(3*angle) - math.cos(4*angle)
251
+ # Scale down
252
+ scale = radius * 0.02
253
+ x = start_pos[0] + (h_x * scale)
254
+ y = base_y + (h_y * scale)
255
+ z = start_pos[2]
256
+
257
+ elif traj_type == "bounce":
258
+ # Decay bounce
259
+ freq = 3 * math.pi
260
+ amp = abs(math.cos(freq * t)) * (1-t)
261
+ y = base_y + (radius * 0.5 * amp)
262
+ x = start_pos[0]
263
+ z = start_pos[2]
264
+
265
+ elif traj_type == "ken_burns":
266
+ # Pan diagonal + slow zoom
267
+ zoom_fac = 1.0 - (0.3 * t) # Zoom in 30%
268
+ pan_x = (t - 0.5) * (radius * 0.3)
269
+ pan_y = (t - 0.5) * (radius * 0.2)
270
+
271
+ cur_r = radius * zoom_fac
272
+ x = (cur_r * math.cos(base_theta)) + pan_x
273
+ y = base_y + pan_y
274
+ z = (cur_r * math.sin(base_theta))
275
+
276
+ else:
277
+ # Fallback for anything else (or minor variations)
278
+ return list(base_traj) # Should be caught by caller, but safe fallback
279
+
280
+ # Construct tensor
281
+ pos_tensor = torch.tensor([x, y, z], dtype=torch.float32, device=gaussians.device)
282
+ positions.append(pos_tensor)
283
+
284
+ return positions
285
+
286
+
287
  # -----------------------------------------------------------------------------
288
  # Prediction outputs
289
  # -----------------------------------------------------------------------------
 
608
  if fps < 1:
609
  raise ValueError("fps must be >= 1")
610
 
611
+ # FAST PATH: Standard SHARP trajectories + Default Resolution
612
+ # We only use the optimized CLI shortcut if it's a standard type AND default res.
613
+ is_standard_traj = trajectory_type in STANDARD_TRAJECTORIES
614
+
615
+ if output_long_side is None and int(fps) == 30 and is_standard_traj:
616
  params = camera.TrajectoryParams(
617
+ type=trajectory_type, # type: ignore
618
  num_steps=int(num_frames),
619
  num_repeats=1,
620
  )
 
633
  pass
634
  return output_path
635
 
636
+ # CUSTOM PATH: Manual loop (Handles Custom Res, FPS, or Custom Trajectories)
637
  src_w, src_h = metadata.resolution_px
638
  src_f = float(metadata.focal_length_px)
639
 
 
646
  out_h = _make_even(max(2, int(round(src_h * scale))))
647
  out_f = src_f * scale
648
 
 
 
 
 
 
 
649
  device = torch.device("cuda")
650
  gaussians_cuda = gaussians.to(device)
651
 
652
+ # 1. Generate Camera Trajectory
653
+ if is_standard_traj:
654
+ # Use SHARP's built-in generator
655
+ traj_params = camera.TrajectoryParams(
656
+ type=trajectory_type, # type: ignore
657
+ num_steps=int(num_frames),
658
+ num_repeats=1,
659
+ )
660
+ trajectory = camera.create_eye_trajectory(
661
+ gaussians_cuda,
662
+ traj_params,
663
+ resolution_px=(out_w, out_h),
664
+ f_px=out_f,
665
+ )
666
+ lookat_mode = traj_params.lookat_mode
667
+ else:
668
+ # Use our custom generator
669
+ trajectory = _generate_custom_trajectory(
670
+ gaussians_cuda,
671
+ resolution=(out_w, out_h),
672
+ focal_length=out_f,
673
+ traj_type=trajectory_type,
674
+ num_frames=num_frames
675
+ )
676
+ # Custom trajectories always look at origin (0,0,0) for now
677
+ lookat_mode = "scene" # Assuming SHARP 'scene' mode implies look-at-center
678
+
679
+ # 2. Setup Camera Model
680
  intrinsics = torch.tensor(
681
  [
682
  [out_f, 0.0, (out_w - 1) / 2.0, 0.0],
 
692
  gaussians_cuda,
693
  intrinsics,
694
  resolution_px=(out_w, out_h),
695
+ lookat_mode=lookat_mode,
 
 
 
 
 
 
 
696
  )
697
 
698
  renderer = GSplatRenderer(color_space=metadata.color_space)
 
829
  if spaces is not None:
830
  predict_and_maybe_render_gpu = spaces.GPU(duration=180)(predict_and_maybe_render)
831
  else: # pragma: no cover
832
+ predict_and_maybe_render_gpu = predict_and_maybe_render