File size: 4,614 Bytes
a54f9b5 6fe70f4 a54f9b5 75e3709 a54f9b5 75e3709 a54f9b5 6fe70f4 75e3709 6fe70f4 a54f9b5 6fe70f4 a54f9b5 6fe70f4 a54f9b5 6fe70f4 a54f9b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
import io
import zipfile
import base64
import tempfile
from pathlib import Path
import requests
import gradio as gr
# Front-end Gradio app that calls the backend FastAPI service hosted on GPU cloud.
# Configure the backend base URL through environment variable on Hugging Face Spaces.
# Example: API_BASE_URL = "https://your-api.example.com"
API_BASE_URL = os.getenv("API_BASE_URL")
MISSING_BACKEND_MSG = (
"Backend API is not configured. Set API_BASE_URL in Spaces Secrets "
"(e.g., http://134.199.133.78:80)"
)
def _files_payload(images):
"""Prepare multipart/form-data payload for requests.post(files=...)."""
files = []
for img in images:
if img is None:
continue
# gr.Image(type="filepath") returns a string path
if isinstance(img, str):
path = img
files.append(("files", (Path(path).name, open(path, "rb"), "image/*")))
continue
# gr.File returns objects with a .name attribute (path), or dict-like in some cases
path = getattr(img, "name", None)
if path is None and isinstance(img, dict) and "name" in img:
path = img["name"]
if path:
files.append(("files", (Path(path).name, open(path, "rb"), "image/*")))
return files
def predict_single(image):
"""Call /predict on backend for a single image and return one PLY file to download."""
if not image:
return None, "No image provided."
files = _files_payload([image])
if not files:
return None, "Invalid image input."
if not API_BASE_URL:
return None, MISSING_BACKEND_MSG
try:
resp = requests.post(f"{API_BASE_URL}/predict", files=files, timeout=120)
resp.raise_for_status()
data = resp.json()
except Exception as e:
return None, f"Backend error: {e}"
results = data.get("results", [])
if not results:
return None, "No result."
item = results[0]
if "error" in item:
return None, item["error"]
# Decode base64 PLY to a temporary file
ply_bytes = base64.b64decode(item["ply_data"])
with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as tmpf:
tmpf.write(ply_bytes)
ply_path = tmpf.name
meta = f"{item['ply_filename']} ({item['width']}x{item['height']}), f={item['focal_length']:.2f}"
return ply_path, meta
def predict_batch(images):
"""Call /predict on backend for multiple images and return a ZIP of PLY files."""
if not images:
return None, "No images provided."
files = _files_payload(images)
if not files:
return None, "Invalid inputs."
if not API_BASE_URL:
return None, MISSING_BACKEND_MSG
try:
resp = requests.post(f"{API_BASE_URL}/predict", files=files, timeout=300)
resp.raise_for_status()
data = resp.json()
except Exception as e:
return None, f"Backend error: {e}"
results = data.get("results", [])
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
metas = []
for item in results:
if "error" in item:
metas.append(f"{item.get('filename', '?')}: ERROR {item['error']}")
continue
ply_bytes = base64.b64decode(item["ply_data"])
zf.writestr(item["ply_filename"], ply_bytes)
metas.append(
f"{item['filename']} -> {item['ply_filename']} "
f"({item['width']}x{item['height']}, f={item['focal_length']:.2f})"
)
buf.seek(0)
return buf, "\n".join(metas)
with gr.Blocks(title="SHARP View Synthesis") as demo:
gr.Markdown(
"# SHARP View Synthesis\nUpload image(s) to generate 3D Gaussian PLY files via the backend API."
)
with gr.Tab("Single Image"):
in_img = gr.Image(type="filepath", label="Input Image")
out_file = gr.File(label="Generated PLY")
out_info = gr.Textbox(label="Info")
btn = gr.Button("Predict")
btn.click(predict_single, inputs=[in_img], outputs=[out_file, out_info])
with gr.Tab("Batch"):
in_imgs = gr.File(
file_count="multiple", file_types=["image"], label="Input Images"
)
out_zip = gr.File(label="PLY ZIP")
out_info2 = gr.Textbox(label="Info")
btn2 = gr.Button("Predict Batch")
btn2.click(predict_batch, inputs=[in_imgs], outputs=[out_zip, out_info2])
if __name__ == "__main__":
# On Hugging Face Spaces, API_BASE_URL must point to your GPU cloud FastAPI server
demo.launch(server_name="0.0.0.0", server_port=7860)
|