Spaces:
Running
on
Zero
Running
on
Zero
updates
Browse files- app.py +18 -37
- model_utils.py +20 -240
app.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
"""
|
| 2 |
SHARP Gradio Demo
|
| 3 |
-
- Standard Native Layout
|
| 4 |
-
-
|
| 5 |
-
-
|
| 6 |
"""
|
| 7 |
|
| 8 |
from __future__ import annotations
|
|
@@ -17,6 +17,7 @@ import gradio as gr
|
|
| 17 |
try:
|
| 18 |
import spaces
|
| 19 |
except ImportError:
|
|
|
|
| 20 |
class spaces:
|
| 21 |
@staticmethod
|
| 22 |
def GPU(func):
|
|
@@ -83,10 +84,7 @@ def get_example_files() -> list[list[str]]:
|
|
| 83 |
examples.append([str(img)])
|
| 84 |
return examples
|
| 85 |
|
| 86 |
-
#
|
| 87 |
-
# Main Inference Logic
|
| 88 |
-
# -----------------------------------------------------------------------------
|
| 89 |
-
|
| 90 |
@spaces.GPU(duration=120)
|
| 91 |
def run_sharp(
|
| 92 |
image_path: str | None,
|
|
@@ -104,26 +102,20 @@ def run_sharp(
|
|
| 104 |
if not image_path:
|
| 105 |
raise gr.Error("Please upload an image first.")
|
| 106 |
|
| 107 |
-
#
|
| 108 |
-
# The model likely expects the Enum object, not the string.
|
| 109 |
-
try:
|
| 110 |
-
# Try exact match (e.g. "swipe" -> TrajectoryType.swipe)
|
| 111 |
-
traj_enum = TrajectoryType[trajectory_type]
|
| 112 |
-
except KeyError:
|
| 113 |
-
try:
|
| 114 |
-
# Try upper case (e.g. "swipe" -> TrajectoryType.SWIPE)
|
| 115 |
-
traj_enum = TrajectoryType[trajectory_type.upper()]
|
| 116 |
-
except KeyError:
|
| 117 |
-
# Fallback: pass the string itself
|
| 118 |
-
traj_enum = trajectory_type
|
| 119 |
-
|
| 120 |
-
# 2. Logic: Handle Resolution
|
| 121 |
out_long_side_val = None if int(output_long_side) <= 0 else int(output_long_side)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
try:
|
| 124 |
progress(0.1, desc="Initializing SHARP model on GPU...")
|
| 125 |
|
| 126 |
-
#
|
| 127 |
video_path, ply_path = predict_and_maybe_render_gpu(
|
| 128 |
image_path,
|
| 129 |
trajectory_type=traj_enum,
|
|
@@ -133,14 +125,12 @@ def run_sharp(
|
|
| 133 |
render_video=bool(render_video),
|
| 134 |
)
|
| 135 |
|
| 136 |
-
#
|
| 137 |
status_msg = f"### ✅ Success\nGenerated: `{ply_path.name}`"
|
| 138 |
|
| 139 |
video_result = str(video_path) if video_path else None
|
| 140 |
if video_path:
|
| 141 |
status_msg += f"\nVideo: `{video_path.name}`"
|
| 142 |
-
else:
|
| 143 |
-
status_msg += "\n(Video rendering disabled or failed)"
|
| 144 |
|
| 145 |
# Explicitly update the Download Button
|
| 146 |
download_btn_update = gr.DownloadButton(
|
|
@@ -172,7 +162,6 @@ def build_demo() -> gr.Blocks:
|
|
| 172 |
|
| 173 |
with gr.Blocks(theme=theme, head=SEO_HEAD, title="SHARP 3D Generator") as demo:
|
| 174 |
|
| 175 |
-
# --- Header ---
|
| 176 |
with gr.Row():
|
| 177 |
with gr.Column(scale=1):
|
| 178 |
gr.Markdown("# SHARP: Single-Image 3D Generator\nConvert any static image into a 3D Gaussian Splat scene instantly.")
|
|
@@ -189,7 +178,7 @@ def build_demo() -> gr.Blocks:
|
|
| 189 |
interactive=True
|
| 190 |
)
|
| 191 |
|
| 192 |
-
# Configs
|
| 193 |
with gr.Group():
|
| 194 |
with gr.Row():
|
| 195 |
trajectory = gr.Dropdown(
|
|
@@ -200,15 +189,7 @@ def build_demo() -> gr.Blocks:
|
|
| 200 |
)
|
| 201 |
output_res = gr.Dropdown(
|
| 202 |
label="Output Resolution",
|
| 203 |
-
|
| 204 |
-
choices=[
|
| 205 |
-
("Match input", 0),
|
| 206 |
-
("512", 512),
|
| 207 |
-
("768", 768),
|
| 208 |
-
("1024", 1024),
|
| 209 |
-
("1280", 1280),
|
| 210 |
-
("1536", 1536),
|
| 211 |
-
],
|
| 212 |
value=0,
|
| 213 |
scale=1
|
| 214 |
)
|
|
@@ -242,7 +223,7 @@ def build_demo() -> gr.Blocks:
|
|
| 242 |
|
| 243 |
with gr.Group():
|
| 244 |
status_md = gr.Markdown("Ready to generate.")
|
| 245 |
-
# Button starts hidden
|
| 246 |
ply_download = gr.DownloadButton(
|
| 247 |
label="Download .PLY File",
|
| 248 |
variant="secondary",
|
|
|
|
| 1 |
"""
|
| 2 |
SHARP Gradio Demo
|
| 3 |
+
- Standard Native Layout
|
| 4 |
+
- Fixed: Added @spaces.GPU for ZeroGPU compatibility (Fixes 'dummy' output)
|
| 5 |
+
- Fixed: Download Button visibility logic
|
| 6 |
"""
|
| 7 |
|
| 8 |
from __future__ import annotations
|
|
|
|
| 17 |
try:
|
| 18 |
import spaces
|
| 19 |
except ImportError:
|
| 20 |
+
# Fallback for local testing if spaces is not installed
|
| 21 |
class spaces:
|
| 22 |
@staticmethod
|
| 23 |
def GPU(func):
|
|
|
|
| 84 |
examples.append([str(img)])
|
| 85 |
return examples
|
| 86 |
|
| 87 |
+
# --- 2. Apply @spaces.GPU Decorator ---
|
|
|
|
|
|
|
|
|
|
| 88 |
@spaces.GPU(duration=120)
|
| 89 |
def run_sharp(
|
| 90 |
image_path: str | None,
|
|
|
|
| 102 |
if not image_path:
|
| 103 |
raise gr.Error("Please upload an image first.")
|
| 104 |
|
| 105 |
+
# Validate inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
out_long_side_val = None if int(output_long_side) <= 0 else int(output_long_side)
|
| 107 |
+
|
| 108 |
+
# Convert trajectory string to Enum safely
|
| 109 |
+
traj_key = trajectory_type.upper()
|
| 110 |
+
if hasattr(TrajectoryType, traj_key):
|
| 111 |
+
traj_enum = TrajectoryType[traj_key]
|
| 112 |
+
else:
|
| 113 |
+
traj_enum = trajectory_type
|
| 114 |
|
| 115 |
try:
|
| 116 |
progress(0.1, desc="Initializing SHARP model on GPU...")
|
| 117 |
|
| 118 |
+
# Call the backend model
|
| 119 |
video_path, ply_path = predict_and_maybe_render_gpu(
|
| 120 |
image_path,
|
| 121 |
trajectory_type=traj_enum,
|
|
|
|
| 125 |
render_video=bool(render_video),
|
| 126 |
)
|
| 127 |
|
| 128 |
+
# Prepare outputs
|
| 129 |
status_msg = f"### ✅ Success\nGenerated: `{ply_path.name}`"
|
| 130 |
|
| 131 |
video_result = str(video_path) if video_path else None
|
| 132 |
if video_path:
|
| 133 |
status_msg += f"\nVideo: `{video_path.name}`"
|
|
|
|
|
|
|
| 134 |
|
| 135 |
# Explicitly update the Download Button
|
| 136 |
download_btn_update = gr.DownloadButton(
|
|
|
|
| 162 |
|
| 163 |
with gr.Blocks(theme=theme, head=SEO_HEAD, title="SHARP 3D Generator") as demo:
|
| 164 |
|
|
|
|
| 165 |
with gr.Row():
|
| 166 |
with gr.Column(scale=1):
|
| 167 |
gr.Markdown("# SHARP: Single-Image 3D Generator\nConvert any static image into a 3D Gaussian Splat scene instantly.")
|
|
|
|
| 178 |
interactive=True
|
| 179 |
)
|
| 180 |
|
| 181 |
+
# Configs
|
| 182 |
with gr.Group():
|
| 183 |
with gr.Row():
|
| 184 |
trajectory = gr.Dropdown(
|
|
|
|
| 189 |
)
|
| 190 |
output_res = gr.Dropdown(
|
| 191 |
label="Output Resolution",
|
| 192 |
+
choices=[("Original", 0), ("512px", 512), ("1024px", 1024)],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
value=0,
|
| 194 |
scale=1
|
| 195 |
)
|
|
|
|
| 223 |
|
| 224 |
with gr.Group():
|
| 225 |
status_md = gr.Markdown("Ready to generate.")
|
| 226 |
+
# Button starts hidden
|
| 227 |
ply_download = gr.DownloadButton(
|
| 228 |
label="Download .PLY File",
|
| 229 |
variant="secondary",
|
model_utils.py
CHANGED
|
@@ -4,7 +4,6 @@ Design goals:
|
|
| 4 |
- Reuse SHARP's own predict/render pipeline (no subprocess calls).
|
| 5 |
- Be robust on Hugging Face Spaces + ZeroGPU.
|
| 6 |
- Cache model weights and predictor construction across requests.
|
| 7 |
-
- Support extended camera trajectories beyond the defaults.
|
| 8 |
|
| 9 |
Public API (used by the Gradio app):
|
| 10 |
- TrajectoryType
|
|
@@ -13,7 +12,6 @@ Public API (used by the Gradio app):
|
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
|
| 16 |
-
import math
|
| 17 |
import os
|
| 18 |
import threading
|
| 19 |
import time
|
|
@@ -44,24 +42,7 @@ from sharp.utils import camera, io
|
|
| 44 |
from sharp.utils.gaussians import Gaussians3D, SceneMetaData, save_ply
|
| 45 |
from sharp.utils.gsplat import GSplatRenderer
|
| 46 |
|
| 47 |
-
|
| 48 |
-
TrajectoryType = Literal[
|
| 49 |
-
# Standard SHARP defaults
|
| 50 |
-
"swipe", "shake", "rotate", "rotate_forward",
|
| 51 |
-
# Extended Rotations
|
| 52 |
-
"rotate_reverse", "rotate_up", "rotate_down",
|
| 53 |
-
# Zooms & Dollies
|
| 54 |
-
"zoom_in", "zoom_out", "dolly_in", "dolly_out",
|
| 55 |
-
# Pans (planar movement)
|
| 56 |
-
"pan_left", "pan_right", "pan_up", "pan_down",
|
| 57 |
-
# Complex Paths
|
| 58 |
-
"spiral_in", "spiral_out", "figure_eight", "loop", "heart",
|
| 59 |
-
"bounce", "ken_burns"
|
| 60 |
-
]
|
| 61 |
-
|
| 62 |
-
STANDARD_TRAJECTORIES: Final[set[str]] = {
|
| 63 |
-
"swipe", "shake", "rotate", "rotate_forward"
|
| 64 |
-
}
|
| 65 |
|
| 66 |
# -----------------------------------------------------------------------------
|
| 67 |
# Helpers
|
|
@@ -101,189 +82,6 @@ def _select_device(preference: str = "auto") -> torch.device:
|
|
| 101 |
return torch.device("cpu")
|
| 102 |
|
| 103 |
|
| 104 |
-
# -----------------------------------------------------------------------------
|
| 105 |
-
# Custom Trajectory Generation
|
| 106 |
-
# -----------------------------------------------------------------------------
|
| 107 |
-
|
| 108 |
-
def _generate_custom_trajectory(
|
| 109 |
-
gaussians: torch.Tensor,
|
| 110 |
-
resolution: tuple[int, int],
|
| 111 |
-
focal_length: float,
|
| 112 |
-
traj_type: str,
|
| 113 |
-
num_frames: int
|
| 114 |
-
) -> list[torch.Tensor]:
|
| 115 |
-
"""
|
| 116 |
-
Generates a list of camera eye positions (tensors) for custom paths.
|
| 117 |
-
Uses the standard 'rotate' path to establish a baseline radius/elevation.
|
| 118 |
-
"""
|
| 119 |
-
# 1. Get baseline Radius (R) and Elevation (Y) from the standard generator
|
| 120 |
-
# We generate just 1 step of the standard 'rotate' to see where SHARP puts the camera.
|
| 121 |
-
base_params = camera.TrajectoryParams(type="rotate", num_steps=1)
|
| 122 |
-
base_traj = camera.create_eye_trajectory(
|
| 123 |
-
gaussians, base_params, resolution_px=resolution, f_px=focal_length
|
| 124 |
-
)
|
| 125 |
-
start_pos = list(base_traj)[0].cpu() # [3] tensor (x, y, z)
|
| 126 |
-
|
| 127 |
-
# Calculate spherical coordinates from start_pos
|
| 128 |
-
# Assuming LookAt(0,0,0), radius is norm, elevation is y.
|
| 129 |
-
radius = float(torch.norm(start_pos))
|
| 130 |
-
base_y = float(start_pos[1])
|
| 131 |
-
|
| 132 |
-
# Starting azimuth (theta). Usually start_pos is roughly [0, 0, radius] or [radius, 0, 0]
|
| 133 |
-
# We'll compute it to be safe.
|
| 134 |
-
base_theta = math.atan2(start_pos[2], start_pos[0])
|
| 135 |
-
|
| 136 |
-
positions = []
|
| 137 |
-
|
| 138 |
-
# Time steps 0..1
|
| 139 |
-
t_vals = [i / (num_frames - 1) for i in range(num_frames)]
|
| 140 |
-
|
| 141 |
-
for t in t_vals:
|
| 142 |
-
x, y, z = 0.0, 0.0, 0.0
|
| 143 |
-
|
| 144 |
-
# --- Logic for 20+ movements ---
|
| 145 |
-
|
| 146 |
-
if traj_type == "rotate_reverse":
|
| 147 |
-
# Orbit opposite direction
|
| 148 |
-
theta = base_theta - (2 * math.pi * t)
|
| 149 |
-
x = radius * math.cos(theta)
|
| 150 |
-
z = radius * math.sin(theta)
|
| 151 |
-
y = base_y
|
| 152 |
-
|
| 153 |
-
elif traj_type == "rotate_up":
|
| 154 |
-
# Orbit over the top (vertical orbit)
|
| 155 |
-
phi = (math.pi / 4) * math.sin(2 * math.pi * t)
|
| 156 |
-
# Modulate Y significantly
|
| 157 |
-
theta = base_theta + (0.5 * math.pi * t) # Slow rotate
|
| 158 |
-
curr_r = radius
|
| 159 |
-
x = curr_r * math.cos(theta)
|
| 160 |
-
z = curr_r * math.sin(theta)
|
| 161 |
-
y = base_y + (radius * 0.8 * math.sin(math.pi * t)) # Arc up
|
| 162 |
-
|
| 163 |
-
elif traj_type == "rotate_down":
|
| 164 |
-
theta = base_theta + (0.5 * math.pi * t)
|
| 165 |
-
y = base_y - (radius * 0.5 * math.sin(math.pi * t))
|
| 166 |
-
x = radius * math.cos(theta)
|
| 167 |
-
z = radius * math.sin(theta)
|
| 168 |
-
|
| 169 |
-
elif traj_type in ["zoom_in", "dolly_in"]:
|
| 170 |
-
# Move from Radius to Radius*0.4
|
| 171 |
-
cur_r = radius * (1.0 - 0.6 * t)
|
| 172 |
-
x = cur_r * math.cos(base_theta)
|
| 173 |
-
z = cur_r * math.sin(base_theta)
|
| 174 |
-
y = base_y
|
| 175 |
-
|
| 176 |
-
elif traj_type in ["zoom_out", "dolly_out"]:
|
| 177 |
-
# Move from Radius*0.5 to Radius*1.2
|
| 178 |
-
cur_r = (radius * 0.5) + (radius * 0.7 * t)
|
| 179 |
-
x = cur_r * math.cos(base_theta)
|
| 180 |
-
z = cur_r * math.sin(base_theta)
|
| 181 |
-
y = base_y
|
| 182 |
-
|
| 183 |
-
elif traj_type == "pan_left":
|
| 184 |
-
# Linear slide perpendicular to view vector
|
| 185 |
-
# Approx: move X relative to view
|
| 186 |
-
offset = (t - 0.5) * 2.0 * (radius * 0.5)
|
| 187 |
-
x = start_pos[0] + offset
|
| 188 |
-
y = start_pos[1]
|
| 189 |
-
z = start_pos[2] # Simple approximation
|
| 190 |
-
|
| 191 |
-
elif traj_type == "pan_right":
|
| 192 |
-
offset = (0.5 - t) * 2.0 * (radius * 0.5)
|
| 193 |
-
x = start_pos[0] + offset
|
| 194 |
-
y = start_pos[1]
|
| 195 |
-
z = start_pos[2]
|
| 196 |
-
|
| 197 |
-
elif traj_type == "pan_up":
|
| 198 |
-
offset = (t - 0.5) * (radius * 0.8)
|
| 199 |
-
x = start_pos[0]
|
| 200 |
-
y = base_y - offset # In 3D, Y usually up, but check coord sys. usually Y is down in some CV.
|
| 201 |
-
# Assuming Y is Up for scene.
|
| 202 |
-
y = base_y + offset
|
| 203 |
-
z = start_pos[2]
|
| 204 |
-
|
| 205 |
-
elif traj_type == "pan_down":
|
| 206 |
-
offset = (t - 0.5) * (radius * 0.8)
|
| 207 |
-
x = start_pos[0]
|
| 208 |
-
y = base_y - offset
|
| 209 |
-
z = start_pos[2]
|
| 210 |
-
|
| 211 |
-
elif traj_type == "spiral_in":
|
| 212 |
-
# Rotate while getting closer
|
| 213 |
-
theta = base_theta + (2 * math.pi * t)
|
| 214 |
-
cur_r = radius * (1.0 - 0.6 * t)
|
| 215 |
-
x = cur_r * math.cos(theta)
|
| 216 |
-
z = cur_r * math.sin(theta)
|
| 217 |
-
y = base_y + (0.2 * radius * math.sin(4 * math.pi * t)) # Slight wobble
|
| 218 |
-
|
| 219 |
-
elif traj_type == "spiral_out":
|
| 220 |
-
theta = base_theta + (2 * math.pi * t)
|
| 221 |
-
cur_r = (radius * 0.4) + (radius * 0.8 * t)
|
| 222 |
-
x = cur_r * math.cos(theta)
|
| 223 |
-
z = cur_r * math.sin(theta)
|
| 224 |
-
y = base_y
|
| 225 |
-
|
| 226 |
-
elif traj_type == "figure_eight":
|
| 227 |
-
# Lemniscate on sphere surface
|
| 228 |
-
scale = 2 * math.pi * t
|
| 229 |
-
# Lissajous-ish
|
| 230 |
-
theta = base_theta + (0.5 * math.sin(scale))
|
| 231 |
-
phi_offset = 0.3 * math.sin(2 * scale)
|
| 232 |
-
y = base_y + (radius * phi_offset)
|
| 233 |
-
x = radius * math.cos(theta)
|
| 234 |
-
z = radius * math.sin(theta)
|
| 235 |
-
|
| 236 |
-
elif traj_type == "loop":
|
| 237 |
-
# Vertical circle
|
| 238 |
-
angle = 2 * math.pi * t
|
| 239 |
-
y_off = 0.5 * radius * math.sin(angle)
|
| 240 |
-
x_off = 0.2 * radius * math.cos(angle)
|
| 241 |
-
x = start_pos[0] + x_off
|
| 242 |
-
y = base_y + y_off
|
| 243 |
-
z = start_pos[2]
|
| 244 |
-
|
| 245 |
-
elif traj_type == "heart":
|
| 246 |
-
# Heart shape in XY plane projection
|
| 247 |
-
angle = 2 * math.pi * t
|
| 248 |
-
# Heart formula
|
| 249 |
-
h_x = 16 * math.sin(angle)**3
|
| 250 |
-
h_y = 13 * math.cos(angle) - 5*math.cos(2*angle) - 2*math.cos(3*angle) - math.cos(4*angle)
|
| 251 |
-
# Scale down
|
| 252 |
-
scale = radius * 0.02
|
| 253 |
-
x = start_pos[0] + (h_x * scale)
|
| 254 |
-
y = base_y + (h_y * scale)
|
| 255 |
-
z = start_pos[2]
|
| 256 |
-
|
| 257 |
-
elif traj_type == "bounce":
|
| 258 |
-
# Decay bounce
|
| 259 |
-
freq = 3 * math.pi
|
| 260 |
-
amp = abs(math.cos(freq * t)) * (1-t)
|
| 261 |
-
y = base_y + (radius * 0.5 * amp)
|
| 262 |
-
x = start_pos[0]
|
| 263 |
-
z = start_pos[2]
|
| 264 |
-
|
| 265 |
-
elif traj_type == "ken_burns":
|
| 266 |
-
# Pan diagonal + slow zoom
|
| 267 |
-
zoom_fac = 1.0 - (0.3 * t) # Zoom in 30%
|
| 268 |
-
pan_x = (t - 0.5) * (radius * 0.3)
|
| 269 |
-
pan_y = (t - 0.5) * (radius * 0.2)
|
| 270 |
-
|
| 271 |
-
cur_r = radius * zoom_fac
|
| 272 |
-
x = (cur_r * math.cos(base_theta)) + pan_x
|
| 273 |
-
y = base_y + pan_y
|
| 274 |
-
z = (cur_r * math.sin(base_theta))
|
| 275 |
-
|
| 276 |
-
else:
|
| 277 |
-
# Fallback for anything else (or minor variations)
|
| 278 |
-
return list(base_traj) # Should be caught by caller, but safe fallback
|
| 279 |
-
|
| 280 |
-
# Construct tensor
|
| 281 |
-
pos_tensor = torch.tensor([x, y, z], dtype=torch.float32, device=gaussians.device)
|
| 282 |
-
positions.append(pos_tensor)
|
| 283 |
-
|
| 284 |
-
return positions
|
| 285 |
-
|
| 286 |
-
|
| 287 |
# -----------------------------------------------------------------------------
|
| 288 |
# Prediction outputs
|
| 289 |
# -----------------------------------------------------------------------------
|
|
@@ -608,13 +406,10 @@ class ModelWrapper:
|
|
| 608 |
if fps < 1:
|
| 609 |
raise ValueError("fps must be >= 1")
|
| 610 |
|
| 611 |
-
#
|
| 612 |
-
|
| 613 |
-
is_standard_traj = trajectory_type in STANDARD_TRAJECTORIES
|
| 614 |
-
|
| 615 |
-
if output_long_side is None and int(fps) == 30 and is_standard_traj:
|
| 616 |
params = camera.TrajectoryParams(
|
| 617 |
-
type=trajectory_type,
|
| 618 |
num_steps=int(num_frames),
|
| 619 |
num_repeats=1,
|
| 620 |
)
|
|
@@ -633,7 +428,7 @@ class ModelWrapper:
|
|
| 633 |
pass
|
| 634 |
return output_path
|
| 635 |
|
| 636 |
-
#
|
| 637 |
src_w, src_h = metadata.resolution_px
|
| 638 |
src_f = float(metadata.focal_length_px)
|
| 639 |
|
|
@@ -646,37 +441,15 @@ class ModelWrapper:
|
|
| 646 |
out_h = _make_even(max(2, int(round(src_h * scale))))
|
| 647 |
out_f = src_f * scale
|
| 648 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
device = torch.device("cuda")
|
| 650 |
gaussians_cuda = gaussians.to(device)
|
| 651 |
|
| 652 |
-
# 1. Generate Camera Trajectory
|
| 653 |
-
if is_standard_traj:
|
| 654 |
-
# Use SHARP's built-in generator
|
| 655 |
-
traj_params = camera.TrajectoryParams(
|
| 656 |
-
type=trajectory_type, # type: ignore
|
| 657 |
-
num_steps=int(num_frames),
|
| 658 |
-
num_repeats=1,
|
| 659 |
-
)
|
| 660 |
-
trajectory = camera.create_eye_trajectory(
|
| 661 |
-
gaussians_cuda,
|
| 662 |
-
traj_params,
|
| 663 |
-
resolution_px=(out_w, out_h),
|
| 664 |
-
f_px=out_f,
|
| 665 |
-
)
|
| 666 |
-
lookat_mode = traj_params.lookat_mode
|
| 667 |
-
else:
|
| 668 |
-
# Use our custom generator
|
| 669 |
-
trajectory = _generate_custom_trajectory(
|
| 670 |
-
gaussians_cuda,
|
| 671 |
-
resolution=(out_w, out_h),
|
| 672 |
-
focal_length=out_f,
|
| 673 |
-
traj_type=trajectory_type,
|
| 674 |
-
num_frames=num_frames
|
| 675 |
-
)
|
| 676 |
-
# Custom trajectories always look at origin (0,0,0) for now
|
| 677 |
-
lookat_mode = "scene" # Assuming SHARP 'scene' mode implies look-at-center
|
| 678 |
-
|
| 679 |
-
# 2. Setup Camera Model
|
| 680 |
intrinsics = torch.tensor(
|
| 681 |
[
|
| 682 |
[out_f, 0.0, (out_w - 1) / 2.0, 0.0],
|
|
@@ -692,7 +465,14 @@ class ModelWrapper:
|
|
| 692 |
gaussians_cuda,
|
| 693 |
intrinsics,
|
| 694 |
resolution_px=(out_w, out_h),
|
| 695 |
-
lookat_mode=lookat_mode,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
)
|
| 697 |
|
| 698 |
renderer = GSplatRenderer(color_space=metadata.color_space)
|
|
@@ -829,4 +609,4 @@ def predict_and_maybe_render(
|
|
| 829 |
if spaces is not None:
|
| 830 |
predict_and_maybe_render_gpu = spaces.GPU(duration=180)(predict_and_maybe_render)
|
| 831 |
else: # pragma: no cover
|
| 832 |
-
predict_and_maybe_render_gpu = predict_and_maybe_render
|
|
|
|
| 4 |
- Reuse SHARP's own predict/render pipeline (no subprocess calls).
|
| 5 |
- Be robust on Hugging Face Spaces + ZeroGPU.
|
| 6 |
- Cache model weights and predictor construction across requests.
|
|
|
|
| 7 |
|
| 8 |
Public API (used by the Gradio app):
|
| 9 |
- TrajectoryType
|
|
|
|
| 12 |
|
| 13 |
from __future__ import annotations
|
| 14 |
|
|
|
|
| 15 |
import os
|
| 16 |
import threading
|
| 17 |
import time
|
|
|
|
| 42 |
from sharp.utils.gaussians import Gaussians3D, SceneMetaData, save_ply
|
| 43 |
from sharp.utils.gsplat import GSplatRenderer
|
| 44 |
|
| 45 |
+
TrajectoryType = Literal["swipe", "shake", "rotate", "rotate_forward"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
# -----------------------------------------------------------------------------
|
| 48 |
# Helpers
|
|
|
|
| 82 |
return torch.device("cpu")
|
| 83 |
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
# -----------------------------------------------------------------------------
|
| 86 |
# Prediction outputs
|
| 87 |
# -----------------------------------------------------------------------------
|
|
|
|
| 406 |
if fps < 1:
|
| 407 |
raise ValueError("fps must be >= 1")
|
| 408 |
|
| 409 |
+
# Keep aligned with upstream CLI pipeline where possible.
|
| 410 |
+
if output_long_side is None and int(fps) == 30:
|
|
|
|
|
|
|
|
|
|
| 411 |
params = camera.TrajectoryParams(
|
| 412 |
+
type=trajectory_type,
|
| 413 |
num_steps=int(num_frames),
|
| 414 |
num_repeats=1,
|
| 415 |
)
|
|
|
|
| 428 |
pass
|
| 429 |
return output_path
|
| 430 |
|
| 431 |
+
# Adapted pipeline for custom output resolution / FPS.
|
| 432 |
src_w, src_h = metadata.resolution_px
|
| 433 |
src_f = float(metadata.focal_length_px)
|
| 434 |
|
|
|
|
| 441 |
out_h = _make_even(max(2, int(round(src_h * scale))))
|
| 442 |
out_f = src_f * scale
|
| 443 |
|
| 444 |
+
traj_params = camera.TrajectoryParams(
|
| 445 |
+
type=trajectory_type,
|
| 446 |
+
num_steps=int(num_frames),
|
| 447 |
+
num_repeats=1,
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
device = torch.device("cuda")
|
| 451 |
gaussians_cuda = gaussians.to(device)
|
| 452 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
intrinsics = torch.tensor(
|
| 454 |
[
|
| 455 |
[out_f, 0.0, (out_w - 1) / 2.0, 0.0],
|
|
|
|
| 465 |
gaussians_cuda,
|
| 466 |
intrinsics,
|
| 467 |
resolution_px=(out_w, out_h),
|
| 468 |
+
lookat_mode=traj_params.lookat_mode,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
trajectory = camera.create_eye_trajectory(
|
| 472 |
+
gaussians_cuda,
|
| 473 |
+
traj_params,
|
| 474 |
+
resolution_px=(out_w, out_h),
|
| 475 |
+
f_px=out_f,
|
| 476 |
)
|
| 477 |
|
| 478 |
renderer = GSplatRenderer(color_space=metadata.color_space)
|
|
|
|
| 609 |
if spaces is not None:
|
| 610 |
predict_and_maybe_render_gpu = spaces.GPU(duration=180)(predict_and_maybe_render)
|
| 611 |
else: # pragma: no cover
|
| 612 |
+
predict_and_maybe_render_gpu = predict_and_maybe_render
|