|
|
""" |
|
|
SHARP Gradio Demo |
|
|
- Standard Native Layout |
|
|
- Fixed: Added @spaces.GPU for ZeroGPU compatibility (Fixes 'dummy' output) |
|
|
- Fixed: Download Button visibility logic |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import warnings |
|
|
import json |
|
|
from pathlib import Path |
|
|
from typing import Final |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
try: |
|
|
import spaces |
|
|
except ImportError: |
|
|
|
|
|
class spaces: |
|
|
@staticmethod |
|
|
def GPU(func): |
|
|
return func |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning, module="torch.distributed") |
|
|
|
|
|
|
|
|
from model_utils import TrajectoryType, predict_and_maybe_render_gpu |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
APP_DIR: Final[Path] = Path(__file__).resolve().parent |
|
|
OUTPUTS_DIR: Final[Path] = APP_DIR / "outputs" |
|
|
ASSETS_DIR: Final[Path] = APP_DIR / "assets" |
|
|
EXAMPLES_DIR: Final[Path] = ASSETS_DIR / "examples" |
|
|
|
|
|
IMAGE_EXTS: Final[tuple[str, ...]] = (".png", ".jpg", ".jpeg", ".webp") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SEO_HEAD = """ |
|
|
<meta name="description" content="Turn 2D images into 3D Gaussian Splats instantly. SHARP (Apple) AI Demo."> |
|
|
<meta name="viewport" content="width=device-width, initial-scale=1"> |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ensure_dir(path: Path) -> Path: |
|
|
path.mkdir(parents=True, exist_ok=True) |
|
|
return path |
|
|
|
|
|
def get_example_files() -> list[list[str]]: |
|
|
"""Discover images in assets/examples for the UI.""" |
|
|
_ensure_dir(EXAMPLES_DIR) |
|
|
|
|
|
|
|
|
manifest_path = EXAMPLES_DIR / "manifest.json" |
|
|
if manifest_path.exists(): |
|
|
try: |
|
|
data = json.loads(manifest_path.read_text(encoding="utf-8")) |
|
|
examples = [] |
|
|
for entry in data: |
|
|
if "image" in entry: |
|
|
img_path = EXAMPLES_DIR / entry["image"] |
|
|
if img_path.exists(): |
|
|
examples.append([str(img_path)]) |
|
|
if examples: |
|
|
return examples |
|
|
except Exception as e: |
|
|
print(f"Manifest error: {e}") |
|
|
|
|
|
|
|
|
examples = [] |
|
|
for ext in IMAGE_EXTS: |
|
|
for img in sorted(EXAMPLES_DIR.glob(f"*{ext}")): |
|
|
examples.append([str(img)]) |
|
|
return examples |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def run_sharp( |
|
|
image_path: str | None, |
|
|
trajectory_type: str, |
|
|
output_long_side: int, |
|
|
num_frames: int, |
|
|
fps: int, |
|
|
render_video: bool, |
|
|
progress=gr.Progress() |
|
|
) -> tuple[str | None, dict, str]: |
|
|
""" |
|
|
Main Inference Function |
|
|
Decorated with @spaces.GPU to ensure it runs on the GPU node. |
|
|
""" |
|
|
if not image_path: |
|
|
raise gr.Error("Please upload an image first.") |
|
|
|
|
|
|
|
|
out_long_side_val = None if int(output_long_side) <= 0 else int(output_long_side) |
|
|
|
|
|
|
|
|
traj_key = trajectory_type.upper() |
|
|
if hasattr(TrajectoryType, traj_key): |
|
|
traj_enum = TrajectoryType[traj_key] |
|
|
else: |
|
|
traj_enum = trajectory_type |
|
|
|
|
|
try: |
|
|
progress(0.1, desc="Initializing SHARP model on GPU...") |
|
|
|
|
|
|
|
|
video_path, ply_path = predict_and_maybe_render_gpu( |
|
|
image_path, |
|
|
trajectory_type=traj_enum, |
|
|
num_frames=int(num_frames), |
|
|
fps=int(fps), |
|
|
output_long_side=out_long_side_val, |
|
|
render_video=bool(render_video), |
|
|
) |
|
|
|
|
|
|
|
|
status_msg = f"### ✅ Success\nGenerated: `{ply_path.name}`" |
|
|
|
|
|
video_result = str(video_path) if video_path else None |
|
|
if video_path: |
|
|
status_msg += f"\nVideo: `{video_path.name}`" |
|
|
|
|
|
|
|
|
download_btn_update = gr.DownloadButton( |
|
|
value=str(ply_path), |
|
|
visible=True, |
|
|
label=f"Download {ply_path.name}" |
|
|
) |
|
|
|
|
|
return ( |
|
|
video_result, |
|
|
download_btn_update, |
|
|
status_msg |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
return ( |
|
|
None, |
|
|
gr.DownloadButton(visible=False), |
|
|
f"### ❌ Error\n{str(e)}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_demo() -> gr.Blocks: |
|
|
theme = gr.themes.Default() |
|
|
|
|
|
with gr.Blocks(theme=theme, head=SEO_HEAD, title="SHARP 3D Generator") as demo: |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("# SHARP: Single-Image 3D Generator\nConvert any static image into a 3D Gaussian Splat scene instantly.") |
|
|
|
|
|
|
|
|
with gr.Row(equal_height=False): |
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
|
image_in = gr.Image( |
|
|
label="Input Image", |
|
|
type="filepath", |
|
|
sources=["upload", "clipboard"], |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
with gr.Row(): |
|
|
trajectory = gr.Dropdown( |
|
|
label="Camera Movement", |
|
|
choices=["swipe", "shake", "rotate", "rotate_forward"], |
|
|
value="rotate_forward", |
|
|
scale=2 |
|
|
) |
|
|
output_res = gr.Dropdown( |
|
|
label="Output Resolution", |
|
|
choices=[("Original", 0), ("512px", 512), ("1024px", 1024)], |
|
|
value=0, |
|
|
scale=1 |
|
|
) |
|
|
with gr.Row(): |
|
|
frames = gr.Slider(label="Frames", minimum=24, maximum=120, step=1, value=60) |
|
|
fps_in = gr.Slider(label="FPS", minimum=8, maximum=60, step=1, value=30) |
|
|
|
|
|
render_toggle = gr.Checkbox(label="Render Video Preview", value=True) |
|
|
|
|
|
run_btn = gr.Button("🚀 Generate 3D Scene", variant="primary", size="lg") |
|
|
|
|
|
|
|
|
example_files = get_example_files() |
|
|
if example_files: |
|
|
gr.Examples( |
|
|
examples=example_files, |
|
|
inputs=[image_in], |
|
|
label="Examples", |
|
|
run_on_click=False, |
|
|
cache_examples=False |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
|
video_out = gr.Video( |
|
|
label="3D Preview", |
|
|
autoplay=True, |
|
|
elem_id="output-video", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Group(): |
|
|
status_md = gr.Markdown("Ready to generate.") |
|
|
|
|
|
ply_download = gr.DownloadButton( |
|
|
label="Download .PLY File", |
|
|
variant="secondary", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
run_btn.click( |
|
|
fn=run_sharp, |
|
|
inputs=[image_in, trajectory, output_res, frames, fps_in, render_toggle], |
|
|
outputs=[video_out, ply_download, status_md], |
|
|
concurrency_limit=1 |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_ensure_dir(OUTPUTS_DIR) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = build_demo() |
|
|
demo.queue().launch( |
|
|
allowed_paths=[str(ASSETS_DIR)], |
|
|
ssr_mode=False |
|
|
) |