File size: 3,464 Bytes
45e56d0
 
 
 
 
 
17b4a27
45e56d0
df5c44f
904249c
45e56d0
 
 
 
 
 
6c969e9
45e56d0
17b4a27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45e56d0
 
 
 
17b4a27
45e56d0
 
 
 
 
17b4a27
45e56d0
 
6c969e9
17b4a27
 
45e56d0
17b4a27
6c969e9
45e56d0
17b4a27
45e56d0
17b4a27
 
45e56d0
 
6c969e9
17b4a27
45e56d0
6c969e9
17b4a27
 
45e56d0
6c969e9
17b4a27
 
45e56d0
 
 
 
17b4a27
6c969e9
 
45e56d0
 
6c969e9
 
45e56d0
6c969e9
45e56d0
17b4a27
 
 
 
 
 
45e56d0
 
b13d14a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import subprocess
import os
import uuid
import shutil
import tempfile
import sys

DESCRIPTION = """# SHARP · 3D from a Single Photo
Upload any photo and get a **3D Gaussian Splat (.ply)** in seconds.
Powered by Apple's [SHARP](https://github.com/apple/ml-sharp) monocular view synthesis model.

**How to use:**
1. Upload or drag & drop an image
2. Click **Generate 3D Splat**
3. Download the `.ply` file
4. View it at [SuperSplat](https://playcanvas.com/supersplat/editor)

> Running on CPU — first run downloads the ~2.6GB model and may take 5–10 min."""


def get_sharp_bin():
    """Locate the sharp CLI binary from the current Python environment."""
    scripts_dir = os.path.join(os.path.dirname(sys.executable), "sharp")
    if os.path.isfile(scripts_dir):
        return scripts_dir
    # Standard location: same bin dir as python
    candidate = os.path.join(os.path.dirname(sys.executable), "sharp")
    if os.path.isfile(candidate):
        return candidate
    # Try finding via which
    result = subprocess.run(["which", "sharp"], capture_output=True, text=True)
    if result.returncode == 0:
        return result.stdout.strip()
    raise FileNotFoundError(
        "Could not locate the `sharp` binary. "
        "Make sure `git+https://github.com/apple/ml-sharp.git` is in requirements.txt."
    )


def generate_splat(image_path):
    if image_path is None:
        raise gr.Error("Please upload an image first.")

    job_id = str(uuid.uuid4())
    input_dir = os.path.join(tempfile.gettempdir(), f"sharp_in_{job_id}")
    output_dir = os.path.join(tempfile.gettempdir(), f"sharp_out_{job_id}")
    os.makedirs(input_dir, exist_ok=True)
    os.makedirs(output_dir, exist_ok=True)

    try:
        ext = os.path.splitext(image_path)[1] or ".jpg"
        shutil.copy(image_path, os.path.join(input_dir, f"input{ext}"))

        sharp_bin = get_sharp_bin()
        result = subprocess.run(
            [sharp_bin, "predict", "-i", input_dir, "-o", output_dir],
            capture_output=True, text=True, timeout=600
        )

        if result.returncode != 0:
            raise gr.Error(f"SHARP failed:\n{result.stderr[-800:]}")

        ply_files = [f for f in os.listdir(output_dir) if f.endswith(".ply")]
        if not ply_files:
            raise gr.Error("No .ply file generated. Try a different image.")

        out_path = os.path.join(tempfile.gettempdir(), f"output_{job_id}.ply")
        shutil.copy(os.path.join(output_dir, ply_files[0]), out_path)
        return out_path, "✅ Done! Download your .ply above, then open it in SuperSplat."

    except subprocess.TimeoutExpired:
        raise gr.Error("Timed out after 10 minutes.")
    except FileNotFoundError as e:
        raise gr.Error(str(e))
    finally:
        shutil.rmtree(input_dir, ignore_errors=True)
        shutil.rmtree(output_dir, ignore_errors=True)


with gr.Blocks(title="SHARP 3D") as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="filepath", label="Input Image")
            run_btn = gr.Button("Generate 3D Splat", variant="primary")
        with gr.Column():
            file_output = gr.File(label="Download .ply file")
            status_output = gr.Markdown("")

    run_btn.click(
        fn=generate_splat,
        inputs=image_input,
        outputs=[file_output, status_output]
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0")