gagndeep commited on
Commit
81e21b3
·
1 Parent(s): 3a2ddb6
Files changed (1) hide show
  1. app.py +51 -21
app.py CHANGED
@@ -1,8 +1,8 @@
1
  """
2
  SHARP Gradio Demo
3
- - Standard Native Layout (No Custom CSS)
4
- - Two-Column Split
5
- - Clean Grouping for reduced congestion
6
  """
7
 
8
  from __future__ import annotations
@@ -13,7 +13,17 @@ from pathlib import Path
13
  from typing import Final
14
  import gradio as gr
15
 
16
- # Suppress internal warnings to keep logs clean
 
 
 
 
 
 
 
 
 
 
17
  warnings.filterwarnings("ignore", category=FutureWarning, module="torch.distributed")
18
 
19
  # Ensure model_utils is present in your directory
@@ -31,7 +41,7 @@ EXAMPLES_DIR: Final[Path] = ASSETS_DIR / "examples"
31
  IMAGE_EXTS: Final[tuple[str, ...]] = (".png", ".jpg", ".jpeg", ".webp")
32
 
33
  # -----------------------------------------------------------------------------
34
- # SEO (Meta Tags Only)
35
  # -----------------------------------------------------------------------------
36
 
37
  SEO_HEAD = """
@@ -74,6 +84,8 @@ def get_example_files() -> list[list[str]]:
74
  examples.append([str(img)])
75
  return examples
76
 
 
 
77
  def run_sharp(
78
  image_path: str | None,
79
  trajectory_type: str,
@@ -82,9 +94,10 @@ def run_sharp(
82
  fps: int,
83
  render_video: bool,
84
  progress=gr.Progress()
85
- ) -> tuple[str | None, str | None, str]:
86
  """
87
  Main Inference Function
 
88
  """
89
  if not image_path:
90
  raise gr.Error("Please upload an image first.")
@@ -92,12 +105,17 @@ def run_sharp(
92
  # Validate inputs
93
  out_long_side_val = None if int(output_long_side) <= 0 else int(output_long_side)
94
 
95
- # Convert trajectory string to Enum or pass as is
96
- traj_enum = TrajectoryType[trajectory_type.upper()] if hasattr(TrajectoryType, trajectory_type.upper()) else trajectory_type
 
 
 
 
97
 
98
  try:
99
- progress(0.1, desc="Initializing SHARP model...")
100
 
 
101
  video_path, ply_path = predict_and_maybe_render_gpu(
102
  image_path,
103
  trajectory_type=traj_enum,
@@ -107,37 +125,48 @@ def run_sharp(
107
  render_video=bool(render_video),
108
  )
109
 
110
- status_msg = f"### ✅ Success\nPLY: `{ply_path.name}`"
 
 
 
111
  if video_path:
112
  status_msg += f"\nVideo: `{video_path.name}`"
 
 
 
 
 
 
 
113
 
114
  return (
115
- str(video_path) if video_path else None,
116
- str(ply_path),
117
  status_msg
118
  )
119
 
120
  except Exception as e:
121
- raise gr.Error(f"Error: {str(e)}")
 
 
 
 
 
122
 
123
  # -----------------------------------------------------------------------------
124
  # UI Construction
125
  # -----------------------------------------------------------------------------
126
 
127
  def build_demo() -> gr.Blocks:
128
- # Use standard theme.
129
- # To reduce congestion further, you could try `theme=gr.themes.Soft()` or `Base()`
130
  theme = gr.themes.Default()
131
 
132
  with gr.Blocks(theme=theme, head=SEO_HEAD, title="SHARP 3D Generator") as demo:
133
 
134
- # --- Header ---
135
  with gr.Row():
136
  with gr.Column(scale=1):
137
  gr.Markdown("# SHARP: Single-Image 3D Generator\nConvert any static image into a 3D Gaussian Splat scene instantly.")
138
 
139
  # --- Main Layout (Strict Two Columns) ---
140
- # Removed 'variant="panel"' to remove the grey box/padding that restricts width
141
  with gr.Row(equal_height=False):
142
 
143
  # --- LEFT COLUMN: Input & Controls ---
@@ -149,7 +178,7 @@ def build_demo() -> gr.Blocks:
149
  interactive=True
150
  )
151
 
152
- # Settings grouped cleanly below image
153
  with gr.Group():
154
  with gr.Row():
155
  trajectory = gr.Dropdown(
@@ -185,19 +214,20 @@ def build_demo() -> gr.Blocks:
185
 
186
  # --- RIGHT COLUMN: Output ---
187
  with gr.Column(scale=1):
188
- # Removed fixed height so it fills the column naturally
189
  video_out = gr.Video(
190
  label="3D Preview",
191
  autoplay=True,
192
- elem_id="output-video"
 
193
  )
194
 
195
  with gr.Group():
196
  status_md = gr.Markdown("Ready to generate.")
 
197
  ply_download = gr.DownloadButton(
198
  label="Download .PLY File",
199
  variant="secondary",
200
- visible=True
201
  )
202
 
203
  # --- Logic Binding ---
 
1
  """
2
  SHARP Gradio Demo
3
+ - Standard Native Layout
4
+ - Fixed: Added @spaces.GPU for ZeroGPU compatibility (Fixes 'dummy' output)
5
+ - Fixed: Download Button visibility logic
6
  """
7
 
8
  from __future__ import annotations
 
13
  from typing import Final
14
  import gradio as gr
15
 
16
+ # --- 1. Import Spaces for ZeroGPU Support ---
17
+ try:
18
+ import spaces
19
+ except ImportError:
20
+ # Fallback for local testing if spaces is not installed
21
+ class spaces:
22
+ @staticmethod
23
+ def GPU(func):
24
+ return func
25
+
26
+ # Suppress internal warnings
27
  warnings.filterwarnings("ignore", category=FutureWarning, module="torch.distributed")
28
 
29
  # Ensure model_utils is present in your directory
 
41
  IMAGE_EXTS: Final[tuple[str, ...]] = (".png", ".jpg", ".jpeg", ".webp")
42
 
43
  # -----------------------------------------------------------------------------
44
+ # SEO
45
  # -----------------------------------------------------------------------------
46
 
47
  SEO_HEAD = """
 
84
  examples.append([str(img)])
85
  return examples
86
 
87
+ # --- 2. Apply @spaces.GPU Decorator ---
88
+ @spaces.GPU(duration=120)
89
  def run_sharp(
90
  image_path: str | None,
91
  trajectory_type: str,
 
94
  fps: int,
95
  render_video: bool,
96
  progress=gr.Progress()
97
+ ) -> tuple[str | None, dict, str]:
98
  """
99
  Main Inference Function
100
+ Decorated with @spaces.GPU to ensure it runs on the GPU node.
101
  """
102
  if not image_path:
103
  raise gr.Error("Please upload an image first.")
 
105
  # Validate inputs
106
  out_long_side_val = None if int(output_long_side) <= 0 else int(output_long_side)
107
 
108
+ # Convert trajectory string to Enum safely
109
+ traj_key = trajectory_type.upper()
110
+ if hasattr(TrajectoryType, traj_key):
111
+ traj_enum = TrajectoryType[traj_key]
112
+ else:
113
+ traj_enum = trajectory_type
114
 
115
  try:
116
+ progress(0.1, desc="Initializing SHARP model on GPU...")
117
 
118
+ # Call the backend model
119
  video_path, ply_path = predict_and_maybe_render_gpu(
120
  image_path,
121
  trajectory_type=traj_enum,
 
125
  render_video=bool(render_video),
126
  )
127
 
128
+ # Prepare outputs
129
+ status_msg = f"### ✅ Success\nGenerated: `{ply_path.name}`"
130
+
131
+ video_result = str(video_path) if video_path else None
132
  if video_path:
133
  status_msg += f"\nVideo: `{video_path.name}`"
134
+
135
+ # Explicitly update the Download Button
136
+ download_btn_update = gr.DownloadButton(
137
+ value=str(ply_path),
138
+ visible=True,
139
+ label=f"Download {ply_path.name}"
140
+ )
141
 
142
  return (
143
+ video_result,
144
+ download_btn_update,
145
  status_msg
146
  )
147
 
148
  except Exception as e:
149
+ # If it fails, we return None for video, hide button, and show error
150
+ return (
151
+ None,
152
+ gr.DownloadButton(visible=False),
153
+ f"### ❌ Error\n{str(e)}"
154
+ )
155
 
156
  # -----------------------------------------------------------------------------
157
  # UI Construction
158
  # -----------------------------------------------------------------------------
159
 
160
  def build_demo() -> gr.Blocks:
 
 
161
  theme = gr.themes.Default()
162
 
163
  with gr.Blocks(theme=theme, head=SEO_HEAD, title="SHARP 3D Generator") as demo:
164
 
 
165
  with gr.Row():
166
  with gr.Column(scale=1):
167
  gr.Markdown("# SHARP: Single-Image 3D Generator\nConvert any static image into a 3D Gaussian Splat scene instantly.")
168
 
169
  # --- Main Layout (Strict Two Columns) ---
 
170
  with gr.Row(equal_height=False):
171
 
172
  # --- LEFT COLUMN: Input & Controls ---
 
178
  interactive=True
179
  )
180
 
181
+ # Configs
182
  with gr.Group():
183
  with gr.Row():
184
  trajectory = gr.Dropdown(
 
214
 
215
  # --- RIGHT COLUMN: Output ---
216
  with gr.Column(scale=1):
 
217
  video_out = gr.Video(
218
  label="3D Preview",
219
  autoplay=True,
220
+ elem_id="output-video",
221
+ interactive=False
222
  )
223
 
224
  with gr.Group():
225
  status_md = gr.Markdown("Ready to generate.")
226
+ # Button starts hidden
227
  ply_download = gr.DownloadButton(
228
  label="Download .PLY File",
229
  variant="secondary",
230
+ visible=False
231
  )
232
 
233
  # --- Logic Binding ---