File size: 4,614 Bytes
a54f9b5
 
6fe70f4
 
a54f9b5
 
 
 
 
 
 
 
 
75e3709
 
 
 
 
a54f9b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75e3709
 
 
a54f9b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fe70f4
75e3709
 
 
6fe70f4
a54f9b5
 
 
6fe70f4
a54f9b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fe70f4
a54f9b5
 
 
 
6fe70f4
 
a54f9b5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import io
import zipfile
import base64
import tempfile
from pathlib import Path

import requests
import gradio as gr

# Front-end Gradio app that calls the backend FastAPI service hosted on GPU cloud.
# Configure the backend base URL through environment variable on Hugging Face Spaces.
# Example: API_BASE_URL = "https://your-api.example.com"
API_BASE_URL = os.getenv("API_BASE_URL")
MISSING_BACKEND_MSG = (
    "Backend API is not configured. Set API_BASE_URL in Spaces Secrets "
    "(e.g., http://134.199.133.78:80)"
)


def _files_payload(images):
    """Prepare multipart/form-data payload for requests.post(files=...)."""
    files = []
    for img in images:
        if img is None:
            continue
        # gr.Image(type="filepath") returns a string path
        if isinstance(img, str):
            path = img
            files.append(("files", (Path(path).name, open(path, "rb"), "image/*")))
            continue
        # gr.File returns objects with a .name attribute (path), or dict-like in some cases
        path = getattr(img, "name", None)
        if path is None and isinstance(img, dict) and "name" in img:
            path = img["name"]
        if path:
            files.append(("files", (Path(path).name, open(path, "rb"), "image/*")))
    return files


def predict_single(image):
    """Call /predict on backend for a single image and return one PLY file to download."""
    if not image:
        return None, "No image provided."
    files = _files_payload([image])
    if not files:
        return None, "Invalid image input."

    if not API_BASE_URL:
        return None, MISSING_BACKEND_MSG

    try:
        resp = requests.post(f"{API_BASE_URL}/predict", files=files, timeout=120)
        resp.raise_for_status()
        data = resp.json()
    except Exception as e:
        return None, f"Backend error: {e}"

    results = data.get("results", [])
    if not results:
        return None, "No result."
    item = results[0]
    if "error" in item:
        return None, item["error"]

    # Decode base64 PLY to a temporary file
    ply_bytes = base64.b64decode(item["ply_data"])
    with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as tmpf:
        tmpf.write(ply_bytes)
        ply_path = tmpf.name

    meta = f"{item['ply_filename']} ({item['width']}x{item['height']}), f={item['focal_length']:.2f}"
    return ply_path, meta


def predict_batch(images):
    """Call /predict on backend for multiple images and return a ZIP of PLY files."""
    if not images:
        return None, "No images provided."
    files = _files_payload(images)
    if not files:
        return None, "Invalid inputs."

    if not API_BASE_URL:
        return None, MISSING_BACKEND_MSG

    try:
        resp = requests.post(f"{API_BASE_URL}/predict", files=files, timeout=300)
        resp.raise_for_status()
        data = resp.json()
    except Exception as e:
        return None, f"Backend error: {e}"

    results = data.get("results", [])
    buf = io.BytesIO()
    with zipfile.ZipFile(buf, "w") as zf:
        metas = []
        for item in results:
            if "error" in item:
                metas.append(f"{item.get('filename', '?')}: ERROR {item['error']}")
                continue
            ply_bytes = base64.b64decode(item["ply_data"])
            zf.writestr(item["ply_filename"], ply_bytes)
            metas.append(
                f"{item['filename']} -> {item['ply_filename']} "
                f"({item['width']}x{item['height']}, f={item['focal_length']:.2f})"
            )
    buf.seek(0)
    return buf, "\n".join(metas)


with gr.Blocks(title="SHARP View Synthesis") as demo:
    gr.Markdown(
        "# SHARP View Synthesis\nUpload image(s) to generate 3D Gaussian PLY files via the backend API."
    )

    with gr.Tab("Single Image"):
        in_img = gr.Image(type="filepath", label="Input Image")
        out_file = gr.File(label="Generated PLY")
        out_info = gr.Textbox(label="Info")
        btn = gr.Button("Predict")
        btn.click(predict_single, inputs=[in_img], outputs=[out_file, out_info])

    with gr.Tab("Batch"):
        in_imgs = gr.File(
            file_count="multiple", file_types=["image"], label="Input Images"
        )
        out_zip = gr.File(label="PLY ZIP")
        out_info2 = gr.Textbox(label="Info")
        btn2 = gr.Button("Predict Batch")
        btn2.click(predict_batch, inputs=[in_imgs], outputs=[out_zip, out_info2])

if __name__ == "__main__":
    # On Hugging Face Spaces, API_BASE_URL must point to your GPU cloud FastAPI server
    demo.launch(server_name="0.0.0.0", server_port=7860)